firefox_webdriver/transport/
connection.rs

1//! WebSocket connection and event loop.
2//!
3//! This module handles the WebSocket connection to Firefox extension,
4//! including request/response correlation and event routing.
5//!
6//! See ARCHITECTURE.md Section 3.5-3.6 for event loop specification.
7//!
8//! # Event Loop
9//!
10//! The connection spawns a tokio task that handles:
11//!
12//! - Incoming messages from extension (responses, events)
13//! - Outgoing commands from Rust API
14//! - Request/response correlation by UUID
15//! - Event handler callbacks
16
17// ============================================================================
18// Imports
19// ============================================================================
20
21use std::sync::Arc;
22use std::time::Duration;
23
24use futures_util::{SinkExt, StreamExt};
25use parking_lot::Mutex;
26use rustc_hash::FxHashMap;
27use serde_json::{from_str, to_string};
28use tokio::net::TcpStream;
29use tokio::sync::{mpsc, oneshot};
30use tokio::time::timeout;
31use tokio_tungstenite::WebSocketStream;
32use tokio_tungstenite::tungstenite::Message;
33use tracing::{debug, error, trace, warn};
34
35use crate::error::{Error, Result};
36use crate::identifiers::RequestId;
37use crate::protocol::{Event, EventReply, Request, Response};
38
39// ============================================================================
40// Constants
41// ============================================================================
42
43/// Default timeout for command execution (30s per spec).
44const DEFAULT_COMMAND_TIMEOUT: Duration = Duration::from_secs(30);
45
46/// Maximum pending requests before rejecting new ones.
47const MAX_PENDING_REQUESTS: usize = 100;
48
49/// Timeout for READY handshake.
50const READY_TIMEOUT: Duration = Duration::from_secs(30);
51
52// ============================================================================
53// Types
54// ============================================================================
55
56/// Map of request IDs to response channels.
57type CorrelationMap = FxHashMap<RequestId, oneshot::Sender<Result<Response>>>;
58
59/// Event handler callback type.
60///
61/// Called for each event received from the extension.
62/// Return `Some(EventReply)` to send a reply (for network interception).
63pub type EventHandler = Box<dyn Fn(Event) -> Option<EventReply> + Send + Sync>;
64
65// ============================================================================
66// ReadyData
67// ============================================================================
68
69/// Data received in the READY handshake message.
70///
71/// The extension sends this immediately after connecting to provide
72/// initial tab and session information.
73#[derive(Debug, Clone)]
74pub struct ReadyData {
75    /// Initial tab ID from Firefox.
76    pub tab_id: u32,
77    /// Session ID.
78    pub session_id: u32,
79}
80
81// ============================================================================
82// ConnectionCommand
83// ============================================================================
84
85/// Internal commands for the event loop.
86enum ConnectionCommand {
87    /// Send a request and wait for response.
88    Send {
89        request: Request,
90        response_tx: oneshot::Sender<Result<Response>>,
91    },
92    /// Remove a timed-out correlation entry.
93    RemoveCorrelation(RequestId),
94    /// Shutdown the connection.
95    Shutdown,
96}
97
98// ============================================================================
99// Connection
100// ============================================================================
101
102/// WebSocket connection to Firefox extension.
103///
104/// Handles request/response correlation and event routing.
105/// The connection spawns an internal event loop task.
106///
107/// # Thread Safety
108///
109/// `Connection` is `Send + Sync` and can be shared across tasks.
110/// All operations are non-blocking.
111pub struct Connection {
112    /// Channel for sending commands to the event loop.
113    command_tx: mpsc::UnboundedSender<ConnectionCommand>,
114    /// Correlation map (shared with event loop).
115    correlation: Arc<Mutex<CorrelationMap>>,
116    /// Event handler (shared with event loop).
117    event_handler: Arc<Mutex<Option<EventHandler>>>,
118}
119
120impl Clone for Connection {
121    fn clone(&self) -> Self {
122        Self {
123            command_tx: self.command_tx.clone(),
124            correlation: Arc::clone(&self.correlation),
125            event_handler: Arc::clone(&self.event_handler),
126        }
127    }
128}
129
130impl Connection {
131    /// Creates a new connection from a WebSocket stream.
132    ///
133    /// Spawns the event loop task internally.
134    pub(crate) fn new(ws_stream: WebSocketStream<TcpStream>) -> Self {
135        let (command_tx, command_rx) = mpsc::unbounded_channel();
136        let correlation = Arc::new(Mutex::new(CorrelationMap::default()));
137        let event_handler: Arc<Mutex<Option<EventHandler>>> = Arc::new(Mutex::new(None));
138
139        // Spawn event loop task
140        let correlation_clone = Arc::clone(&correlation);
141        let event_handler_clone = Arc::clone(&event_handler);
142
143        tokio::spawn(Self::run_event_loop(
144            ws_stream,
145            command_rx,
146            correlation_clone,
147            event_handler_clone,
148        ));
149
150        Self {
151            command_tx,
152            correlation,
153            event_handler,
154        }
155    }
156
157    /// Waits for the READY handshake message.
158    ///
159    /// Must be called after connection is established.
160    /// The extension sends READY with nil UUID immediately after connecting.
161    ///
162    /// # Errors
163    ///
164    /// - [`Error::ConnectionTimeout`] if READY not received within 30s
165    /// - [`Error::ConnectionClosed`] if connection closes before READY
166    pub async fn wait_ready(&self) -> Result<ReadyData> {
167        let (tx, rx) = oneshot::channel();
168
169        // Register correlation for READY (nil UUID)
170        {
171            let mut correlation = self.correlation.lock();
172            correlation.insert(RequestId::ready(), tx);
173        }
174
175        // Wait for READY with timeout
176        let response = timeout(READY_TIMEOUT, rx)
177            .await
178            .map_err(|_| Error::connection_timeout(READY_TIMEOUT.as_millis() as u64))??;
179
180        let response = response?;
181
182        // Extract data from READY response using helper methods
183        let tab_id = response.get_u64("tabId").max(1) as u32;
184        let session_id = response.get_u64("sessionId").max(1) as u32;
185
186        debug!(tab_id, session_id, "READY handshake completed");
187
188        Ok(ReadyData { tab_id, session_id })
189    }
190
191    /// Sets the event handler callback.
192    ///
193    /// The handler is called for each event received from the extension.
194    /// Return `Some(EventReply)` to send a reply back.
195    pub fn set_event_handler(&self, handler: EventHandler) {
196        let mut guard = self.event_handler.lock();
197        *guard = Some(handler);
198    }
199
200    /// Clears the event handler.
201    pub fn clear_event_handler(&self) {
202        let mut guard = self.event_handler.lock();
203        *guard = None;
204    }
205
206    /// Sends a request and waits for response with default timeout (30s).
207    ///
208    /// # Errors
209    ///
210    /// - [`Error::ConnectionClosed`] if connection is closed
211    /// - [`Error::RequestTimeout`] if response not received within timeout
212    /// - [`Error::Protocol`] if too many pending requests
213    pub async fn send(&self, request: Request) -> Result<Response> {
214        self.send_with_timeout(request, DEFAULT_COMMAND_TIMEOUT)
215            .await
216    }
217
218    /// Sends a request and waits for response with custom timeout.
219    ///
220    /// # Arguments
221    ///
222    /// * `request` - The request to send
223    /// * `request_timeout` - Maximum time to wait for response
224    ///
225    /// # Errors
226    ///
227    /// - [`Error::ConnectionClosed`] if connection is closed
228    /// - [`Error::RequestTimeout`] if response not received within timeout
229    /// - [`Error::Protocol`] if too many pending requests
230    pub async fn send_with_timeout(
231        &self,
232        request: Request,
233        request_timeout: Duration,
234    ) -> Result<Response> {
235        let request_id = request.id;
236
237        // Check pending request limit
238        {
239            let correlation = self.correlation.lock();
240            if correlation.len() >= MAX_PENDING_REQUESTS {
241                warn!(
242                    pending = correlation.len(),
243                    max = MAX_PENDING_REQUESTS,
244                    "Too many pending requests"
245                );
246                return Err(Error::protocol(format!(
247                    "Too many pending requests: {}/{}",
248                    correlation.len(),
249                    MAX_PENDING_REQUESTS
250                )));
251            }
252        }
253
254        // Create response channel
255        let (response_tx, response_rx) = oneshot::channel();
256
257        // Send command to event loop
258        self.command_tx
259            .send(ConnectionCommand::Send {
260                request,
261                response_tx,
262            })
263            .map_err(|_| Error::ConnectionClosed)?;
264
265        // Wait for response with timeout
266        match timeout(request_timeout, response_rx).await {
267            Ok(Ok(result)) => result,
268            Ok(Err(_)) => Err(Error::ConnectionClosed),
269            Err(_) => {
270                // Timeout - clean up correlation entry
271                let _ = self
272                    .command_tx
273                    .send(ConnectionCommand::RemoveCorrelation(request_id));
274
275                Err(Error::request_timeout(
276                    request_id,
277                    request_timeout.as_millis() as u64,
278                ))
279            }
280        }
281    }
282
283    /// Returns the number of pending requests.
284    #[inline]
285    #[must_use]
286    pub fn pending_count(&self) -> usize {
287        self.correlation.lock().len()
288    }
289
290    /// Shuts down the connection gracefully.
291    ///
292    /// This is called automatically on drop.
293    pub fn shutdown(&self) {
294        let _ = self.command_tx.send(ConnectionCommand::Shutdown);
295    }
296
297    /// Event loop that handles WebSocket I/O.
298    async fn run_event_loop(
299        ws_stream: WebSocketStream<TcpStream>,
300        mut command_rx: mpsc::UnboundedReceiver<ConnectionCommand>,
301        correlation: Arc<Mutex<CorrelationMap>>,
302        event_handler: Arc<Mutex<Option<EventHandler>>>,
303    ) {
304        let (mut ws_write, mut ws_read) = ws_stream.split();
305
306        loop {
307            tokio::select! {
308                // Incoming messages from extension
309                message = ws_read.next() => {
310                    match message {
311                        Some(Ok(Message::Text(text))) => {
312                            let reply = Self::handle_incoming_message(
313                                &text,
314                                &correlation,
315                                &event_handler,
316                            );
317
318                            // Send event reply if needed
319                            if let Some(reply) = reply
320                                && let Ok(json) = to_string(&reply)
321                                && let Err(e) = ws_write.send(Message::Text(json.into())).await
322                            {
323                                warn!(error = %e, "Failed to send event reply");
324                            }
325                        }
326
327                        Some(Ok(Message::Close(_))) => {
328                            debug!("WebSocket closed by remote");
329                            break;
330                        }
331
332                        Some(Err(e)) => {
333                            error!(error = %e, "WebSocket error");
334                            break;
335                        }
336
337                        None => {
338                            debug!("WebSocket stream ended");
339                            break;
340                        }
341
342                        // Ignore Binary, Ping, Pong
343                        _ => {}
344                    }
345                }
346
347                // Commands from Rust API
348                command = command_rx.recv() => {
349                    match command {
350                        Some(ConnectionCommand::Send { request, response_tx }) => {
351                            Self::handle_send_command(
352                                request,
353                                response_tx,
354                                &mut ws_write,
355                                &correlation,
356                            ).await;
357                        }
358
359                        Some(ConnectionCommand::RemoveCorrelation(request_id)) => {
360                            correlation.lock().remove(&request_id);
361                            debug!(?request_id, "Removed timed-out correlation");
362                        }
363
364                        Some(ConnectionCommand::Shutdown) => {
365                            debug!("Shutdown command received");
366                            let _ = ws_write.close().await;
367                            break;
368                        }
369
370                        None => {
371                            debug!("Command channel closed");
372                            break;
373                        }
374                    }
375                }
376            }
377        }
378
379        // Fail all pending requests on shutdown
380        Self::fail_pending_requests(&correlation);
381
382        debug!("Event loop terminated");
383    }
384
385    /// Handles an incoming text message from the extension.
386    fn handle_incoming_message(
387        text: &str,
388        correlation: &Arc<Mutex<CorrelationMap>>,
389        event_handler: &Arc<Mutex<Option<EventHandler>>>,
390    ) -> Option<EventReply> {
391        // Try to parse as Response first
392        if let Ok(response) = from_str::<Response>(text) {
393            let tx = correlation.lock().remove(&response.id);
394
395            if let Some(tx) = tx {
396                let _ = tx.send(Ok(response));
397            } else {
398                warn!(id = %response.id, "Response for unknown request");
399            }
400
401            return None;
402        }
403
404        // Try to parse as Event
405        if let Ok(event) = from_str::<Event>(text) {
406            let handler = event_handler.lock();
407            if let Some(ref handler) = *handler {
408                return handler(event);
409            }
410            return None;
411        }
412
413        warn!(text = %text, "Failed to parse incoming message");
414        None
415    }
416
417    /// Handles a send command from the Rust API.
418    async fn handle_send_command(
419        request: Request,
420        response_tx: oneshot::Sender<Result<Response>>,
421        ws_write: &mut futures_util::stream::SplitSink<WebSocketStream<TcpStream>, Message>,
422        correlation: &Arc<Mutex<CorrelationMap>>,
423    ) {
424        let request_id = request.id;
425
426        // Serialize request
427        let json = match to_string(&request) {
428            Ok(j) => j,
429            Err(e) => {
430                let _ = response_tx.send(Err(Error::Json(e)));
431                return;
432            }
433        };
434
435        // Store correlation before sending
436        correlation.lock().insert(request_id, response_tx);
437
438        // Send over WebSocket
439        if let Err(e) = ws_write.send(Message::Text(json.into())).await {
440            // Remove correlation and notify caller
441            if let Some(tx) = correlation.lock().remove(&request_id) {
442                let _ = tx.send(Err(Error::connection(e.to_string())));
443            }
444        }
445
446        trace!(?request_id, "Request sent");
447    }
448
449    /// Fails all pending requests with ConnectionClosed error.
450    fn fail_pending_requests(correlation: &Arc<Mutex<CorrelationMap>>) {
451        let pending: Vec<_> = correlation.lock().drain().collect();
452        let count = pending.len();
453
454        for (_, tx) in pending {
455            let _ = tx.send(Err(Error::ConnectionClosed));
456        }
457
458        if count > 0 {
459            debug!(count, "Failed pending requests on shutdown");
460        }
461    }
462}
463
464impl Drop for Connection {
465    fn drop(&mut self) {
466        // Only shutdown if this is the last reference
467        // Since command_tx is cloned, we can check if we're the only sender
468        // Actually, we can't easily check this, so we should NOT auto-shutdown on drop
469        // The pool.remove() will explicitly call shutdown()
470        //
471        // DO NOT call shutdown here - it breaks cloned connections!
472    }
473}
474
475// ============================================================================
476// Tests
477// ============================================================================
478
479#[cfg(test)]
480mod tests {
481    use super::*;
482
483    #[test]
484    fn test_constants() {
485        assert_eq!(DEFAULT_COMMAND_TIMEOUT.as_secs(), 30);
486        assert_eq!(MAX_PENDING_REQUESTS, 100);
487        assert_eq!(READY_TIMEOUT.as_secs(), 30);
488    }
489
490    #[test]
491    fn test_ready_data() {
492        let data = ReadyData {
493            tab_id: 1,
494            session_id: 2,
495        };
496        assert_eq!(data.tab_id, 1);
497        assert_eq!(data.session_id, 2);
498    }
499}