firefox_webdriver/transport/connection.rs
1//! WebSocket connection and event loop.
2//!
3//! This module handles the WebSocket connection to Firefox extension,
4//! including request/response correlation and event routing.
5//!
6//! See ARCHITECTURE.md Section 3.5-3.6 for event loop specification.
7//!
8//! # Event Loop
9//!
10//! The connection spawns a tokio task that handles:
11//!
12//! - Incoming messages from extension (responses, events)
13//! - Outgoing commands from Rust API
14//! - Request/response correlation by UUID
15//! - Multi-handler event callbacks
16
17// ============================================================================
18// Imports
19// ============================================================================
20
21use std::sync::Arc;
22use std::sync::atomic::{AtomicUsize, Ordering};
23use std::time::Duration;
24
25use futures_util::{SinkExt, StreamExt};
26use parking_lot::Mutex;
27use rustc_hash::FxHashMap;
28use serde_json::{from_value, to_string};
29use tokio::net::TcpStream;
30use tokio::sync::{mpsc, oneshot};
31use tokio::time::timeout;
32use tokio_tungstenite::WebSocketStream;
33use tokio_tungstenite::tungstenite::Message;
34use tracing::{debug, error, trace, warn};
35
36use crate::error::{Error, Result};
37use crate::identifiers::RequestId;
38use crate::protocol::{Event, EventReply, Request, Response};
39
40// ============================================================================
41// Constants
42// ============================================================================
43
44/// Default timeout for command execution (30s per spec).
45const DEFAULT_COMMAND_TIMEOUT: Duration = Duration::from_secs(30);
46
47/// Maximum pending requests before rejecting new ones.
48const MAX_PENDING_REQUESTS: usize = 100;
49
50/// Timeout for READY handshake.
51const READY_TIMEOUT: Duration = Duration::from_secs(30);
52
53/// Capacity for the bounded command channel.
54const COMMAND_CHANNEL_CAPACITY: usize = 256;
55
56// ============================================================================
57// Types
58// ============================================================================
59
60/// Map of request IDs to response channels.
61type CorrelationMap = FxHashMap<RequestId, oneshot::Sender<Result<Response>>>;
62
63/// Event handler callback type.
64///
65/// Called for each event received from the extension.
66/// Return `Some(EventReply)` to send a reply (for network interception).
67pub type EventHandler = Box<dyn Fn(Event) -> Option<EventReply> + Send + Sync>;
68
69/// A labeled event handler entry: `(key, handler)`.
70type HandlerEntry = (String, Arc<dyn Fn(Event) -> Option<EventReply> + Send + Sync>);
71
72/// Multi-handler storage: a vec of labeled handlers.
73type HandlerVec = Vec<HandlerEntry>;
74
75// ============================================================================
76// ReadyData
77// ============================================================================
78
79/// Data received in the READY handshake message.
80///
81/// The extension sends this immediately after connecting to provide
82/// initial tab and session information.
83#[derive(Debug, Clone)]
84pub struct ReadyData {
85 /// Initial tab ID from Firefox.
86 pub tab_id: u32,
87 /// Session ID.
88 pub session_id: u32,
89}
90
91// ============================================================================
92// ConnectionCommand
93// ============================================================================
94
95/// Internal commands for the event loop.
96enum ConnectionCommand {
97 /// Send a request and wait for response.
98 Send {
99 request: Request,
100 response_tx: oneshot::Sender<Result<Response>>,
101 },
102 /// Remove a timed-out correlation entry.
103 RemoveCorrelation(RequestId),
104 /// Shutdown the connection.
105 Shutdown,
106}
107
108// ============================================================================
109// Connection
110// ============================================================================
111
112/// WebSocket connection to Firefox extension.
113///
114/// Handles request/response correlation and event routing.
115/// The connection spawns an internal event loop task.
116///
117/// # Thread Safety
118///
119/// `Connection` is `Send + Sync` and can be shared across tasks.
120/// All operations are non-blocking.
121pub struct Connection {
122 /// Channel for sending commands to the event loop.
123 command_tx: mpsc::Sender<ConnectionCommand>,
124 /// Correlation map (shared with event loop).
125 correlation: Arc<Mutex<CorrelationMap>>,
126 /// Multi-handler event handlers (shared with event loop).
127 event_handlers: Arc<Mutex<HandlerVec>>,
128 /// Atomic counter for pending requests (avoids locking correlation map
129 /// just to check the count).
130 pending_count: Arc<AtomicUsize>,
131}
132
133impl Connection {
134 /// Creates a new connection from a WebSocket stream.
135 ///
136 /// Spawns the event loop task internally.
137 pub(crate) fn new(ws_stream: WebSocketStream<TcpStream>) -> Self {
138 let (command_tx, command_rx) = mpsc::channel(COMMAND_CHANNEL_CAPACITY);
139 let correlation = Arc::new(Mutex::new(CorrelationMap::default()));
140 let event_handlers: Arc<Mutex<HandlerVec>> = Arc::new(Mutex::new(Vec::new()));
141 let pending_count = Arc::new(AtomicUsize::new(0));
142
143 // Spawn event loop task
144 let correlation_clone = Arc::clone(&correlation);
145 let event_handlers_clone = Arc::clone(&event_handlers);
146 let pending_count_clone = Arc::clone(&pending_count);
147
148 tokio::spawn(Self::run_event_loop(
149 ws_stream,
150 command_rx,
151 correlation_clone,
152 event_handlers_clone,
153 pending_count_clone,
154 ));
155
156 Self {
157 command_tx,
158 correlation,
159 event_handlers,
160 pending_count,
161 }
162 }
163
164 /// Waits for the READY handshake message.
165 ///
166 /// Must be called after connection is established.
167 /// The extension sends READY with nil UUID immediately after connecting.
168 ///
169 /// # Errors
170 ///
171 /// - [`Error::ConnectionTimeout`] if READY not received within 30s
172 /// - [`Error::ConnectionClosed`] if connection closes before READY
173 pub async fn wait_ready(&self) -> Result<ReadyData> {
174 let (tx, rx) = oneshot::channel();
175
176 // Register correlation for READY (nil UUID)
177 {
178 let mut correlation = self.correlation.lock();
179 correlation.insert(RequestId::ready(), tx);
180 }
181 self.pending_count.fetch_add(1, Ordering::Relaxed);
182
183 // Wait for READY with timeout
184 let response = timeout(READY_TIMEOUT, rx)
185 .await
186 .map_err(|_| Error::connection_timeout(READY_TIMEOUT.as_millis() as u64))??;
187
188 let response = response?;
189
190 // Extract data from READY response using helper methods
191 let tab_id = response.get_u64("tabId").max(1) as u32;
192 let session_id = response.get_u64("sessionId").max(1) as u32;
193
194 debug!(tab_id, session_id, "READY handshake completed");
195
196 Ok(ReadyData { tab_id, session_id })
197 }
198
199 /// Adds an event handler with a key label.
200 ///
201 /// Multiple handlers can be registered simultaneously.
202 /// When an event arrives, handlers are iterated in order until
203 /// one returns `Some(EventReply)`.
204 ///
205 /// If a handler with the same key already exists, it is replaced.
206 pub fn add_event_handler(&self, key: String, handler: EventHandler) {
207 let handler: Arc<dyn Fn(Event) -> Option<EventReply> + Send + Sync> = Arc::from(handler);
208 let mut guard = self.event_handlers.lock();
209 // Replace existing handler with same key
210 if let Some(entry) = guard.iter_mut().find(|(k, _)| k == &key) {
211 entry.1 = handler;
212 } else {
213 guard.push((key, handler));
214 }
215 }
216
217 /// Removes an event handler by key.
218 pub fn remove_event_handler(&self, key: &str) {
219 let mut guard = self.event_handlers.lock();
220 guard.retain(|(k, _)| k != key);
221 }
222
223 /// Clears all event handlers (for shutdown).
224 pub fn clear_all_event_handlers(&self) {
225 let mut guard = self.event_handlers.lock();
226 guard.clear();
227 }
228
229 /// Sends a request and waits for response with default timeout (30s).
230 ///
231 /// # Errors
232 ///
233 /// - [`Error::ConnectionClosed`] if connection is closed
234 /// - [`Error::RequestTimeout`] if response not received within timeout
235 /// - [`Error::Protocol`] if too many pending requests
236 pub async fn send(&self, request: Request) -> Result<Response> {
237 self.send_with_timeout(request, DEFAULT_COMMAND_TIMEOUT)
238 .await
239 }
240
241 /// Sends a request and waits for response with custom timeout.
242 ///
243 /// # Arguments
244 ///
245 /// * `request` - The request to send
246 /// * `request_timeout` - Maximum time to wait for response
247 ///
248 /// # Errors
249 ///
250 /// - [`Error::ConnectionClosed`] if connection is closed
251 /// - [`Error::RequestTimeout`] if response not received within timeout
252 /// - [`Error::Protocol`] if too many pending requests
253 pub async fn send_with_timeout(
254 &self,
255 request: Request,
256 request_timeout: Duration,
257 ) -> Result<Response> {
258 let request_id = request.id;
259
260 // Check pending request limit using atomic counter (no lock needed)
261 let pending = self.pending_count.load(Ordering::Relaxed);
262 if pending >= MAX_PENDING_REQUESTS {
263 warn!(
264 pending = pending,
265 max = MAX_PENDING_REQUESTS,
266 "Too many pending requests"
267 );
268 return Err(Error::protocol(format!(
269 "Too many pending requests: {}/{}",
270 pending, MAX_PENDING_REQUESTS
271 )));
272 }
273
274 // Create response channel
275 let (response_tx, response_rx) = oneshot::channel();
276
277 // Use try_send to avoid blocking in synchronous-like contexts.
278 self.command_tx
279 .try_send(ConnectionCommand::Send {
280 request,
281 response_tx,
282 })
283 .map_err(|e| match e {
284 mpsc::error::TrySendError::Full(_) => {
285 Error::protocol("Command channel full (backpressure)")
286 }
287 mpsc::error::TrySendError::Closed(_) => Error::ConnectionClosed,
288 })?;
289
290 // Wait for response with timeout
291 match timeout(request_timeout, response_rx).await {
292 Ok(Ok(result)) => result,
293 Ok(Err(_)) => Err(Error::ConnectionClosed),
294 Err(_) => {
295 // Timeout - clean up correlation entry
296 let _ = self
297 .command_tx
298 .try_send(ConnectionCommand::RemoveCorrelation(request_id));
299
300 Err(Error::request_timeout(
301 request_id,
302 request_timeout.as_millis() as u64,
303 ))
304 }
305 }
306 }
307
308 /// Returns the number of pending requests.
309 #[inline]
310 #[must_use]
311 pub fn pending_count(&self) -> usize {
312 self.pending_count.load(Ordering::Relaxed)
313 }
314
315 /// Shuts down the connection gracefully.
316 ///
317 /// This is called automatically on drop.
318 pub fn shutdown(&self) {
319 let _ = self.command_tx.try_send(ConnectionCommand::Shutdown);
320 }
321
322 /// Event loop that handles WebSocket I/O.
323 async fn run_event_loop(
324 ws_stream: WebSocketStream<TcpStream>,
325 mut command_rx: mpsc::Receiver<ConnectionCommand>,
326 correlation: Arc<Mutex<CorrelationMap>>,
327 event_handlers: Arc<Mutex<HandlerVec>>,
328 pending_count: Arc<AtomicUsize>,
329 ) {
330 let (mut ws_write, mut ws_read) = ws_stream.split();
331
332 loop {
333 tokio::select! {
334 // Incoming messages from extension
335 message = ws_read.next() => {
336 match message {
337 Some(Ok(Message::Text(text))) => {
338 let reply = Self::handle_incoming_message(
339 &text,
340 &correlation,
341 &event_handlers,
342 &pending_count,
343 );
344
345 // Send event reply if needed
346 if let Some(reply) = reply
347 && let Ok(json) = to_string(&reply)
348 && let Err(e) = ws_write.send(Message::Text(json.into())).await
349 {
350 warn!(error = %e, "Failed to send event reply");
351 }
352 }
353
354 Some(Ok(Message::Close(_))) => {
355 debug!("WebSocket closed by remote");
356 break;
357 }
358
359 Some(Err(e)) => {
360 error!(error = %e, "WebSocket error");
361 break;
362 }
363
364 None => {
365 debug!("WebSocket stream ended");
366 break;
367 }
368
369 // Ignore Binary, Ping, Pong
370 _ => {}
371 }
372 }
373
374 // Commands from Rust API
375 command = command_rx.recv() => {
376 match command {
377 Some(ConnectionCommand::Send { request, response_tx }) => {
378 Self::handle_send_command(
379 request,
380 response_tx,
381 &mut ws_write,
382 &correlation,
383 &pending_count,
384 ).await;
385 }
386
387 Some(ConnectionCommand::RemoveCorrelation(request_id)) => {
388 if correlation.lock().remove(&request_id).is_some() {
389 pending_count.fetch_sub(1, Ordering::Relaxed);
390 }
391 debug!(?request_id, "Removed timed-out correlation");
392 }
393
394 Some(ConnectionCommand::Shutdown) => {
395 debug!("Shutdown command received");
396 let _ = ws_write.close().await;
397 break;
398 }
399
400 None => {
401 debug!("Command channel closed");
402 break;
403 }
404 }
405 }
406 }
407 }
408
409 // Fail all pending requests on shutdown
410 Self::fail_pending_requests(&correlation, &pending_count);
411
412 debug!("Event loop terminated");
413 }
414
415 /// Handles an incoming text message from the extension.
416 ///
417 /// Parses JSON once, then discriminates between Response and Event
418 /// based on the presence of "type" or "method" fields.
419 fn handle_incoming_message(
420 text: &str,
421 correlation: &Arc<Mutex<CorrelationMap>>,
422 event_handlers: &Arc<Mutex<HandlerVec>>,
423 pending_count: &Arc<AtomicUsize>,
424 ) -> Option<EventReply> {
425 // Parse once to serde_json::Value
426 let value: serde_json::Value = match serde_json::from_str(text) {
427 Ok(v) => v,
428 Err(e) => {
429 warn!(error = %e, text = %text, "Failed to parse incoming message as JSON");
430 return None;
431 }
432 };
433
434 // Check discriminator: Response has "type" = "success" or "error"
435 if value
436 .get("type")
437 .and_then(|v| v.as_str())
438 .is_some_and(|t| t == "success" || t == "error")
439 {
440 // It's a Response - convert from Value (no re-parse)
441 let response: Response = match from_value(value) {
442 Ok(r) => r,
443 Err(e) => {
444 warn!(error = %e, "Failed to deserialize Response from Value");
445 return None;
446 }
447 };
448
449 let tx = correlation.lock().remove(&response.id);
450
451 if let Some(tx) = tx {
452 pending_count.fetch_sub(1, Ordering::Relaxed);
453 let _ = tx.send(Ok(response));
454 } else {
455 warn!(id = %response.id, "Response for unknown request");
456 }
457
458 return None;
459 }
460
461 // Check for Event: has "method" field
462 if value.get("method").is_some() {
463 // It's an Event - convert from Value (no re-parse)
464 let event: Event = match from_value(value) {
465 Ok(e) => e,
466 Err(e) => {
467 warn!(error = %e, "Failed to deserialize Event from Value");
468 return None;
469 }
470 };
471
472 // Clone handlers vec to avoid holding lock during callback execution
473 let handlers: Vec<HandlerEntry> = {
474 let guard = event_handlers.lock();
475 guard.clone()
476 };
477
478 // Iterate all handlers until one returns Some(EventReply)
479 for (_key, handler) in &handlers {
480 if let Some(reply) = handler(event.clone()) {
481 return Some(reply);
482 }
483 }
484
485 return None;
486 }
487
488 warn!(text = %text, "Failed to parse incoming message: no type or method field");
489 None
490 }
491
492 /// Handles a send command from the Rust API.
493 async fn handle_send_command(
494 request: Request,
495 response_tx: oneshot::Sender<Result<Response>>,
496 ws_write: &mut futures_util::stream::SplitSink<WebSocketStream<TcpStream>, Message>,
497 correlation: &Arc<Mutex<CorrelationMap>>,
498 pending_count: &Arc<AtomicUsize>,
499 ) {
500 let request_id = request.id;
501
502 // Serialize request
503 let json = match to_string(&request) {
504 Ok(j) => j,
505 Err(e) => {
506 let _ = response_tx.send(Err(Error::Json(e)));
507 return;
508 }
509 };
510
511 // Store correlation before sending and increment counter
512 correlation.lock().insert(request_id, response_tx);
513 pending_count.fetch_add(1, Ordering::Relaxed);
514
515 // Send over WebSocket
516 if let Err(e) = ws_write.send(Message::Text(json.into())).await {
517 // Remove correlation and notify caller
518 if let Some(tx) = correlation.lock().remove(&request_id) {
519 pending_count.fetch_sub(1, Ordering::Relaxed);
520 let _ = tx.send(Err(Error::connection(e.to_string())));
521 }
522 }
523
524 trace!(?request_id, "Request sent");
525 }
526
527 /// Fails all pending requests with ConnectionClosed error.
528 fn fail_pending_requests(
529 correlation: &Arc<Mutex<CorrelationMap>>,
530 pending_count: &Arc<AtomicUsize>,
531 ) {
532 let pending: Vec<_> = correlation.lock().drain().collect();
533 let count = pending.len();
534
535 for (_, tx) in pending {
536 let _ = tx.send(Err(Error::ConnectionClosed));
537 }
538
539 // Reset counter
540 pending_count.store(0, Ordering::Relaxed);
541
542 if count > 0 {
543 debug!(count, "Failed pending requests on shutdown");
544 }
545 }
546}
547
548impl Drop for Connection {
549 fn drop(&mut self) {
550 // Only shutdown if this is the last reference
551 // Since command_tx is cloned, we can check if we're the only sender
552 // Actually, we can't easily check this, so we should NOT auto-shutdown on drop
553 // The pool.remove() will explicitly call shutdown()
554 //
555 // DO NOT call shutdown here - it breaks cloned connections!
556 }
557}
558
559// ============================================================================
560// Tests
561// ============================================================================
562
563#[cfg(test)]
564mod tests {
565 use super::*;
566
567 #[test]
568 fn test_constants() {
569 assert_eq!(DEFAULT_COMMAND_TIMEOUT.as_secs(), 30);
570 assert_eq!(MAX_PENDING_REQUESTS, 100);
571 assert_eq!(READY_TIMEOUT.as_secs(), 30);
572 }
573
574 #[test]
575 fn test_ready_data() {
576 let data = ReadyData {
577 tab_id: 1,
578 session_id: 2,
579 };
580 assert_eq!(data.tab_id, 1);
581 assert_eq!(data.session_id, 2);
582 }
583}