Skip to main content

firefox_webdriver/transport/
connection.rs

1//! WebSocket connection and event loop.
2//!
3//! This module handles the WebSocket connection to Firefox extension,
4//! including request/response correlation and event routing.
5//!
6//! See ARCHITECTURE.md Section 3.5-3.6 for event loop specification.
7//!
8//! # Event Loop
9//!
10//! The connection spawns a tokio task that handles:
11//!
12//! - Incoming messages from extension (responses, events)
13//! - Outgoing commands from Rust API
14//! - Request/response correlation by UUID
15//! - Multi-handler event callbacks
16
17// ============================================================================
18// Imports
19// ============================================================================
20
21use std::sync::Arc;
22use std::sync::atomic::{AtomicUsize, Ordering};
23use std::time::Duration;
24
25use futures_util::{SinkExt, StreamExt};
26use parking_lot::Mutex;
27use rustc_hash::FxHashMap;
28use serde_json::{from_value, to_string};
29use tokio::net::TcpStream;
30use tokio::sync::{mpsc, oneshot};
31use tokio::time::timeout;
32use tokio_tungstenite::WebSocketStream;
33use tokio_tungstenite::tungstenite::Message;
34use tracing::{debug, error, trace, warn};
35
36use crate::error::{Error, Result};
37use crate::identifiers::RequestId;
38use crate::protocol::{Event, EventReply, Request, Response};
39
40// ============================================================================
41// Constants
42// ============================================================================
43
44/// Default timeout for command execution (30s per spec).
45const DEFAULT_COMMAND_TIMEOUT: Duration = Duration::from_secs(30);
46
47/// Maximum pending requests before rejecting new ones.
48const MAX_PENDING_REQUESTS: usize = 100;
49
50/// Timeout for READY handshake.
51const READY_TIMEOUT: Duration = Duration::from_secs(30);
52
53/// Capacity for the bounded command channel.
54const COMMAND_CHANNEL_CAPACITY: usize = 256;
55
56// ============================================================================
57// Types
58// ============================================================================
59
60/// Map of request IDs to response channels.
61type CorrelationMap = FxHashMap<RequestId, oneshot::Sender<Result<Response>>>;
62
63/// Event handler callback type.
64///
65/// Called for each event received from the extension.
66/// Return `Some(EventReply)` to send a reply (for network interception).
67pub type EventHandler = Box<dyn Fn(Event) -> Option<EventReply> + Send + Sync>;
68
69/// A labeled event handler entry: `(key, handler)`.
70type HandlerEntry = (String, Arc<dyn Fn(Event) -> Option<EventReply> + Send + Sync>);
71
72/// Multi-handler storage: a vec of labeled handlers.
73type HandlerVec = Vec<HandlerEntry>;
74
75// ============================================================================
76// ReadyData
77// ============================================================================
78
79/// Data received in the READY handshake message.
80///
81/// The extension sends this immediately after connecting to provide
82/// initial tab and session information.
83#[derive(Debug, Clone)]
84pub struct ReadyData {
85    /// Initial tab ID from Firefox.
86    pub tab_id: u32,
87    /// Session ID.
88    pub session_id: u32,
89}
90
91// ============================================================================
92// ConnectionCommand
93// ============================================================================
94
95/// Internal commands for the event loop.
96enum ConnectionCommand {
97    /// Send a request and wait for response.
98    Send {
99        request: Request,
100        response_tx: oneshot::Sender<Result<Response>>,
101    },
102    /// Remove a timed-out correlation entry.
103    RemoveCorrelation(RequestId),
104    /// Shutdown the connection.
105    Shutdown,
106}
107
108// ============================================================================
109// Connection
110// ============================================================================
111
112/// WebSocket connection to Firefox extension.
113///
114/// Handles request/response correlation and event routing.
115/// The connection spawns an internal event loop task.
116///
117/// # Thread Safety
118///
119/// `Connection` is `Send + Sync` and can be shared across tasks.
120/// All operations are non-blocking.
121pub struct Connection {
122    /// Channel for sending commands to the event loop.
123    command_tx: mpsc::Sender<ConnectionCommand>,
124    /// Correlation map (shared with event loop).
125    correlation: Arc<Mutex<CorrelationMap>>,
126    /// Multi-handler event handlers (shared with event loop).
127    event_handlers: Arc<Mutex<HandlerVec>>,
128    /// Atomic counter for pending requests (avoids locking correlation map
129    /// just to check the count).
130    pending_count: Arc<AtomicUsize>,
131}
132
133impl Connection {
134    /// Creates a new connection from a WebSocket stream.
135    ///
136    /// Spawns the event loop task internally.
137    pub(crate) fn new(ws_stream: WebSocketStream<TcpStream>) -> Self {
138        let (command_tx, command_rx) = mpsc::channel(COMMAND_CHANNEL_CAPACITY);
139        let correlation = Arc::new(Mutex::new(CorrelationMap::default()));
140        let event_handlers: Arc<Mutex<HandlerVec>> = Arc::new(Mutex::new(Vec::new()));
141        let pending_count = Arc::new(AtomicUsize::new(0));
142
143        // Spawn event loop task
144        let correlation_clone = Arc::clone(&correlation);
145        let event_handlers_clone = Arc::clone(&event_handlers);
146        let pending_count_clone = Arc::clone(&pending_count);
147
148        tokio::spawn(Self::run_event_loop(
149            ws_stream,
150            command_rx,
151            correlation_clone,
152            event_handlers_clone,
153            pending_count_clone,
154        ));
155
156        Self {
157            command_tx,
158            correlation,
159            event_handlers,
160            pending_count,
161        }
162    }
163
164    /// Waits for the READY handshake message.
165    ///
166    /// Must be called after connection is established.
167    /// The extension sends READY with nil UUID immediately after connecting.
168    ///
169    /// # Errors
170    ///
171    /// - [`Error::ConnectionTimeout`] if READY not received within 30s
172    /// - [`Error::ConnectionClosed`] if connection closes before READY
173    pub async fn wait_ready(&self) -> Result<ReadyData> {
174        let (tx, rx) = oneshot::channel();
175
176        // Register correlation for READY (nil UUID)
177        {
178            let mut correlation = self.correlation.lock();
179            correlation.insert(RequestId::ready(), tx);
180        }
181        self.pending_count.fetch_add(1, Ordering::Relaxed);
182
183        // Wait for READY with timeout
184        let response = timeout(READY_TIMEOUT, rx)
185            .await
186            .map_err(|_| Error::connection_timeout(READY_TIMEOUT.as_millis() as u64))??;
187
188        let response = response?;
189
190        // Extract data from READY response using helper methods
191        let tab_id = response.get_u64("tabId").max(1) as u32;
192        let session_id = response.get_u64("sessionId").max(1) as u32;
193
194        debug!(tab_id, session_id, "READY handshake completed");
195
196        Ok(ReadyData { tab_id, session_id })
197    }
198
199    /// Adds an event handler with a key label.
200    ///
201    /// Multiple handlers can be registered simultaneously.
202    /// When an event arrives, handlers are iterated in order until
203    /// one returns `Some(EventReply)`.
204    ///
205    /// If a handler with the same key already exists, it is replaced.
206    pub fn add_event_handler(&self, key: String, handler: EventHandler) {
207        let handler: Arc<dyn Fn(Event) -> Option<EventReply> + Send + Sync> = Arc::from(handler);
208        let mut guard = self.event_handlers.lock();
209        // Replace existing handler with same key
210        if let Some(entry) = guard.iter_mut().find(|(k, _)| k == &key) {
211            entry.1 = handler;
212        } else {
213            guard.push((key, handler));
214        }
215    }
216
217    /// Removes an event handler by key.
218    pub fn remove_event_handler(&self, key: &str) {
219        let mut guard = self.event_handlers.lock();
220        guard.retain(|(k, _)| k != key);
221    }
222
223    /// Clears all event handlers (for shutdown).
224    pub fn clear_all_event_handlers(&self) {
225        let mut guard = self.event_handlers.lock();
226        guard.clear();
227    }
228
229    /// Sends a request and waits for response with default timeout (30s).
230    ///
231    /// # Errors
232    ///
233    /// - [`Error::ConnectionClosed`] if connection is closed
234    /// - [`Error::RequestTimeout`] if response not received within timeout
235    /// - [`Error::Protocol`] if too many pending requests
236    pub async fn send(&self, request: Request) -> Result<Response> {
237        self.send_with_timeout(request, DEFAULT_COMMAND_TIMEOUT)
238            .await
239    }
240
241    /// Sends a request and waits for response with custom timeout.
242    ///
243    /// # Arguments
244    ///
245    /// * `request` - The request to send
246    /// * `request_timeout` - Maximum time to wait for response
247    ///
248    /// # Errors
249    ///
250    /// - [`Error::ConnectionClosed`] if connection is closed
251    /// - [`Error::RequestTimeout`] if response not received within timeout
252    /// - [`Error::Protocol`] if too many pending requests
253    pub async fn send_with_timeout(
254        &self,
255        request: Request,
256        request_timeout: Duration,
257    ) -> Result<Response> {
258        let request_id = request.id;
259
260        // Check pending request limit using atomic counter (no lock needed)
261        let pending = self.pending_count.load(Ordering::Relaxed);
262        if pending >= MAX_PENDING_REQUESTS {
263            warn!(
264                pending = pending,
265                max = MAX_PENDING_REQUESTS,
266                "Too many pending requests"
267            );
268            return Err(Error::protocol(format!(
269                "Too many pending requests: {}/{}",
270                pending, MAX_PENDING_REQUESTS
271            )));
272        }
273
274        // Create response channel
275        let (response_tx, response_rx) = oneshot::channel();
276
277        // Use try_send to avoid blocking in synchronous-like contexts.
278        self.command_tx
279            .try_send(ConnectionCommand::Send {
280                request,
281                response_tx,
282            })
283            .map_err(|e| match e {
284                mpsc::error::TrySendError::Full(_) => {
285                    Error::protocol("Command channel full (backpressure)")
286                }
287                mpsc::error::TrySendError::Closed(_) => Error::ConnectionClosed,
288            })?;
289
290        // Wait for response with timeout
291        match timeout(request_timeout, response_rx).await {
292            Ok(Ok(result)) => result,
293            Ok(Err(_)) => Err(Error::ConnectionClosed),
294            Err(_) => {
295                // Timeout - clean up correlation entry
296                let _ = self
297                    .command_tx
298                    .try_send(ConnectionCommand::RemoveCorrelation(request_id));
299
300                Err(Error::request_timeout(
301                    request_id,
302                    request_timeout.as_millis() as u64,
303                ))
304            }
305        }
306    }
307
308    /// Returns the number of pending requests.
309    #[inline]
310    #[must_use]
311    pub fn pending_count(&self) -> usize {
312        self.pending_count.load(Ordering::Relaxed)
313    }
314
315    /// Shuts down the connection gracefully.
316    ///
317    /// This is called automatically on drop.
318    pub fn shutdown(&self) {
319        let _ = self.command_tx.try_send(ConnectionCommand::Shutdown);
320    }
321
322    /// Event loop that handles WebSocket I/O.
323    async fn run_event_loop(
324        ws_stream: WebSocketStream<TcpStream>,
325        mut command_rx: mpsc::Receiver<ConnectionCommand>,
326        correlation: Arc<Mutex<CorrelationMap>>,
327        event_handlers: Arc<Mutex<HandlerVec>>,
328        pending_count: Arc<AtomicUsize>,
329    ) {
330        let (mut ws_write, mut ws_read) = ws_stream.split();
331
332        loop {
333            tokio::select! {
334                // Incoming messages from extension
335                message = ws_read.next() => {
336                    match message {
337                        Some(Ok(Message::Text(text))) => {
338                            let reply = Self::handle_incoming_message(
339                                &text,
340                                &correlation,
341                                &event_handlers,
342                                &pending_count,
343                            );
344
345                            // Send event reply if needed
346                            if let Some(reply) = reply
347                                && let Ok(json) = to_string(&reply)
348                                && let Err(e) = ws_write.send(Message::Text(json.into())).await
349                            {
350                                warn!(error = %e, "Failed to send event reply");
351                            }
352                        }
353
354                        Some(Ok(Message::Close(_))) => {
355                            debug!("WebSocket closed by remote");
356                            break;
357                        }
358
359                        Some(Err(e)) => {
360                            error!(error = %e, "WebSocket error");
361                            break;
362                        }
363
364                        None => {
365                            debug!("WebSocket stream ended");
366                            break;
367                        }
368
369                        // Ignore Binary, Ping, Pong
370                        _ => {}
371                    }
372                }
373
374                // Commands from Rust API
375                command = command_rx.recv() => {
376                    match command {
377                        Some(ConnectionCommand::Send { request, response_tx }) => {
378                            Self::handle_send_command(
379                                request,
380                                response_tx,
381                                &mut ws_write,
382                                &correlation,
383                                &pending_count,
384                            ).await;
385                        }
386
387                        Some(ConnectionCommand::RemoveCorrelation(request_id)) => {
388                            if correlation.lock().remove(&request_id).is_some() {
389                                pending_count.fetch_sub(1, Ordering::Relaxed);
390                            }
391                            debug!(?request_id, "Removed timed-out correlation");
392                        }
393
394                        Some(ConnectionCommand::Shutdown) => {
395                            debug!("Shutdown command received");
396                            let _ = ws_write.close().await;
397                            break;
398                        }
399
400                        None => {
401                            debug!("Command channel closed");
402                            break;
403                        }
404                    }
405                }
406            }
407        }
408
409        // Fail all pending requests on shutdown
410        Self::fail_pending_requests(&correlation, &pending_count);
411
412        debug!("Event loop terminated");
413    }
414
415    /// Handles an incoming text message from the extension.
416    ///
417    /// Parses JSON once, then discriminates between Response and Event
418    /// based on the presence of "type" or "method" fields.
419    fn handle_incoming_message(
420        text: &str,
421        correlation: &Arc<Mutex<CorrelationMap>>,
422        event_handlers: &Arc<Mutex<HandlerVec>>,
423        pending_count: &Arc<AtomicUsize>,
424    ) -> Option<EventReply> {
425        // Parse once to serde_json::Value
426        let value: serde_json::Value = match serde_json::from_str(text) {
427            Ok(v) => v,
428            Err(e) => {
429                warn!(error = %e, text = %text, "Failed to parse incoming message as JSON");
430                return None;
431            }
432        };
433
434        // Check discriminator: Response has "type" = "success" or "error"
435        if value
436            .get("type")
437            .and_then(|v| v.as_str())
438            .is_some_and(|t| t == "success" || t == "error")
439        {
440            // It's a Response - convert from Value (no re-parse)
441            let response: Response = match from_value(value) {
442                Ok(r) => r,
443                Err(e) => {
444                    warn!(error = %e, "Failed to deserialize Response from Value");
445                    return None;
446                }
447            };
448
449            let tx = correlation.lock().remove(&response.id);
450
451            if let Some(tx) = tx {
452                pending_count.fetch_sub(1, Ordering::Relaxed);
453                let _ = tx.send(Ok(response));
454            } else {
455                warn!(id = %response.id, "Response for unknown request");
456            }
457
458            return None;
459        }
460
461        // Check for Event: has "method" field
462        if value.get("method").is_some() {
463            // It's an Event - convert from Value (no re-parse)
464            let event: Event = match from_value(value) {
465                Ok(e) => e,
466                Err(e) => {
467                    warn!(error = %e, "Failed to deserialize Event from Value");
468                    return None;
469                }
470            };
471
472            // Clone handlers vec to avoid holding lock during callback execution
473            let handlers: Vec<HandlerEntry> = {
474                let guard = event_handlers.lock();
475                guard.clone()
476            };
477
478            // Iterate all handlers until one returns Some(EventReply)
479            for (_key, handler) in &handlers {
480                if let Some(reply) = handler(event.clone()) {
481                    return Some(reply);
482                }
483            }
484
485            return None;
486        }
487
488        warn!(text = %text, "Failed to parse incoming message: no type or method field");
489        None
490    }
491
492    /// Handles a send command from the Rust API.
493    async fn handle_send_command(
494        request: Request,
495        response_tx: oneshot::Sender<Result<Response>>,
496        ws_write: &mut futures_util::stream::SplitSink<WebSocketStream<TcpStream>, Message>,
497        correlation: &Arc<Mutex<CorrelationMap>>,
498        pending_count: &Arc<AtomicUsize>,
499    ) {
500        let request_id = request.id;
501
502        // Serialize request
503        let json = match to_string(&request) {
504            Ok(j) => j,
505            Err(e) => {
506                let _ = response_tx.send(Err(Error::Json(e)));
507                return;
508            }
509        };
510
511        // Store correlation before sending and increment counter
512        correlation.lock().insert(request_id, response_tx);
513        pending_count.fetch_add(1, Ordering::Relaxed);
514
515        // Send over WebSocket
516        if let Err(e) = ws_write.send(Message::Text(json.into())).await {
517            // Remove correlation and notify caller
518            if let Some(tx) = correlation.lock().remove(&request_id) {
519                pending_count.fetch_sub(1, Ordering::Relaxed);
520                let _ = tx.send(Err(Error::connection(e.to_string())));
521            }
522        }
523
524        trace!(?request_id, "Request sent");
525    }
526
527    /// Fails all pending requests with ConnectionClosed error.
528    fn fail_pending_requests(
529        correlation: &Arc<Mutex<CorrelationMap>>,
530        pending_count: &Arc<AtomicUsize>,
531    ) {
532        let pending: Vec<_> = correlation.lock().drain().collect();
533        let count = pending.len();
534
535        for (_, tx) in pending {
536            let _ = tx.send(Err(Error::ConnectionClosed));
537        }
538
539        // Reset counter
540        pending_count.store(0, Ordering::Relaxed);
541
542        if count > 0 {
543            debug!(count, "Failed pending requests on shutdown");
544        }
545    }
546}
547
548impl Drop for Connection {
549    fn drop(&mut self) {
550        // Only shutdown if this is the last reference
551        // Since command_tx is cloned, we can check if we're the only sender
552        // Actually, we can't easily check this, so we should NOT auto-shutdown on drop
553        // The pool.remove() will explicitly call shutdown()
554        //
555        // DO NOT call shutdown here - it breaks cloned connections!
556    }
557}
558
559// ============================================================================
560// Tests
561// ============================================================================
562
563#[cfg(test)]
564mod tests {
565    use super::*;
566
567    #[test]
568    fn test_constants() {
569        assert_eq!(DEFAULT_COMMAND_TIMEOUT.as_secs(), 30);
570        assert_eq!(MAX_PENDING_REQUESTS, 100);
571        assert_eq!(READY_TIMEOUT.as_secs(), 30);
572    }
573
574    #[test]
575    fn test_ready_data() {
576        let data = ReadyData {
577            tab_id: 1,
578            session_id: 2,
579        };
580        assert_eq!(data.tab_id, 1);
581        assert_eq!(data.session_id, 2);
582    }
583}