firefox_webdriver/transport/connection.rs
1//! WebSocket connection and event loop.
2//!
3//! This module handles the WebSocket connection to Firefox extension,
4//! including request/response correlation and event routing.
5//!
6//! See ARCHITECTURE.md Section 3.5-3.6 for event loop specification.
7//!
8//! # Event Loop
9//!
10//! The connection spawns a tokio task that handles:
11//!
12//! - Incoming messages from extension (responses, events)
13//! - Outgoing commands from Rust API
14//! - Request/response correlation by UUID
15//! - Event handler callbacks
16
17// ============================================================================
18// Imports
19// ============================================================================
20
21use std::sync::Arc;
22use std::time::Duration;
23
24use futures_util::{SinkExt, StreamExt};
25use parking_lot::Mutex;
26use rustc_hash::FxHashMap;
27use serde_json::{from_str, to_string};
28use tokio::net::TcpStream;
29use tokio::sync::{mpsc, oneshot};
30use tokio::time::timeout;
31use tokio_tungstenite::WebSocketStream;
32use tokio_tungstenite::tungstenite::Message;
33use tracing::{debug, error, trace, warn};
34
35use crate::error::{Error, Result};
36use crate::identifiers::RequestId;
37use crate::protocol::{Event, EventReply, Request, Response};
38
39// ============================================================================
40// Constants
41// ============================================================================
42
43/// Default timeout for command execution (30s per spec).
44const DEFAULT_COMMAND_TIMEOUT: Duration = Duration::from_secs(30);
45
46/// Maximum pending requests before rejecting new ones.
47const MAX_PENDING_REQUESTS: usize = 100;
48
49/// Timeout for READY handshake.
50const READY_TIMEOUT: Duration = Duration::from_secs(30);
51
52// ============================================================================
53// Types
54// ============================================================================
55
56/// Map of request IDs to response channels.
57type CorrelationMap = FxHashMap<RequestId, oneshot::Sender<Result<Response>>>;
58
59/// Event handler callback type.
60///
61/// Called for each event received from the extension.
62/// Return `Some(EventReply)` to send a reply (for network interception).
63pub type EventHandler = Box<dyn Fn(Event) -> Option<EventReply> + Send + Sync>;
64
65// ============================================================================
66// ReadyData
67// ============================================================================
68
69/// Data received in the READY handshake message.
70///
71/// The extension sends this immediately after connecting to provide
72/// initial tab and session information.
73#[derive(Debug, Clone)]
74pub struct ReadyData {
75 /// Initial tab ID from Firefox.
76 pub tab_id: u32,
77 /// Session ID.
78 pub session_id: u32,
79}
80
81// ============================================================================
82// ConnectionCommand
83// ============================================================================
84
85/// Internal commands for the event loop.
86enum ConnectionCommand {
87 /// Send a request and wait for response.
88 Send {
89 request: Request,
90 response_tx: oneshot::Sender<Result<Response>>,
91 },
92 /// Remove a timed-out correlation entry.
93 RemoveCorrelation(RequestId),
94 /// Shutdown the connection.
95 Shutdown,
96}
97
98// ============================================================================
99// Connection
100// ============================================================================
101
102/// WebSocket connection to Firefox extension.
103///
104/// Handles request/response correlation and event routing.
105/// The connection spawns an internal event loop task.
106///
107/// # Thread Safety
108///
109/// `Connection` is `Send + Sync` and can be shared across tasks.
110/// All operations are non-blocking.
111pub struct Connection {
112 /// Channel for sending commands to the event loop.
113 command_tx: mpsc::UnboundedSender<ConnectionCommand>,
114 /// Correlation map (shared with event loop).
115 correlation: Arc<Mutex<CorrelationMap>>,
116 /// Event handler (shared with event loop).
117 event_handler: Arc<Mutex<Option<EventHandler>>>,
118}
119
120impl Clone for Connection {
121 fn clone(&self) -> Self {
122 Self {
123 command_tx: self.command_tx.clone(),
124 correlation: Arc::clone(&self.correlation),
125 event_handler: Arc::clone(&self.event_handler),
126 }
127 }
128}
129
130impl Connection {
131 /// Creates a new connection from a WebSocket stream.
132 ///
133 /// Spawns the event loop task internally.
134 pub(crate) fn new(ws_stream: WebSocketStream<TcpStream>) -> Self {
135 let (command_tx, command_rx) = mpsc::unbounded_channel();
136 let correlation = Arc::new(Mutex::new(CorrelationMap::default()));
137 let event_handler: Arc<Mutex<Option<EventHandler>>> = Arc::new(Mutex::new(None));
138
139 // Spawn event loop task
140 let correlation_clone = Arc::clone(&correlation);
141 let event_handler_clone = Arc::clone(&event_handler);
142
143 tokio::spawn(Self::run_event_loop(
144 ws_stream,
145 command_rx,
146 correlation_clone,
147 event_handler_clone,
148 ));
149
150 Self {
151 command_tx,
152 correlation,
153 event_handler,
154 }
155 }
156
157 /// Waits for the READY handshake message.
158 ///
159 /// Must be called after connection is established.
160 /// The extension sends READY with nil UUID immediately after connecting.
161 ///
162 /// # Errors
163 ///
164 /// - [`Error::ConnectionTimeout`] if READY not received within 30s
165 /// - [`Error::ConnectionClosed`] if connection closes before READY
166 pub async fn wait_ready(&self) -> Result<ReadyData> {
167 let (tx, rx) = oneshot::channel();
168
169 // Register correlation for READY (nil UUID)
170 {
171 let mut correlation = self.correlation.lock();
172 correlation.insert(RequestId::ready(), tx);
173 }
174
175 // Wait for READY with timeout
176 let response = timeout(READY_TIMEOUT, rx)
177 .await
178 .map_err(|_| Error::connection_timeout(READY_TIMEOUT.as_millis() as u64))??;
179
180 let response = response?;
181
182 // Extract data from READY response using helper methods
183 let tab_id = response.get_u64("tabId").max(1) as u32;
184 let session_id = response.get_u64("sessionId").max(1) as u32;
185
186 debug!(tab_id, session_id, "READY handshake completed");
187
188 Ok(ReadyData { tab_id, session_id })
189 }
190
191 /// Sets the event handler callback.
192 ///
193 /// The handler is called for each event received from the extension.
194 /// Return `Some(EventReply)` to send a reply back.
195 pub fn set_event_handler(&self, handler: EventHandler) {
196 let mut guard = self.event_handler.lock();
197 *guard = Some(handler);
198 }
199
200 /// Clears the event handler.
201 pub fn clear_event_handler(&self) {
202 let mut guard = self.event_handler.lock();
203 *guard = None;
204 }
205
206 /// Sends a request and waits for response with default timeout (30s).
207 ///
208 /// # Errors
209 ///
210 /// - [`Error::ConnectionClosed`] if connection is closed
211 /// - [`Error::RequestTimeout`] if response not received within timeout
212 /// - [`Error::Protocol`] if too many pending requests
213 pub async fn send(&self, request: Request) -> Result<Response> {
214 self.send_with_timeout(request, DEFAULT_COMMAND_TIMEOUT)
215 .await
216 }
217
218 /// Sends a request and waits for response with custom timeout.
219 ///
220 /// # Arguments
221 ///
222 /// * `request` - The request to send
223 /// * `request_timeout` - Maximum time to wait for response
224 ///
225 /// # Errors
226 ///
227 /// - [`Error::ConnectionClosed`] if connection is closed
228 /// - [`Error::RequestTimeout`] if response not received within timeout
229 /// - [`Error::Protocol`] if too many pending requests
230 pub async fn send_with_timeout(
231 &self,
232 request: Request,
233 request_timeout: Duration,
234 ) -> Result<Response> {
235 let request_id = request.id;
236
237 // Check pending request limit
238 {
239 let correlation = self.correlation.lock();
240 if correlation.len() >= MAX_PENDING_REQUESTS {
241 warn!(
242 pending = correlation.len(),
243 max = MAX_PENDING_REQUESTS,
244 "Too many pending requests"
245 );
246 return Err(Error::protocol(format!(
247 "Too many pending requests: {}/{}",
248 correlation.len(),
249 MAX_PENDING_REQUESTS
250 )));
251 }
252 }
253
254 // Create response channel
255 let (response_tx, response_rx) = oneshot::channel();
256
257 // Send command to event loop
258 self.command_tx
259 .send(ConnectionCommand::Send {
260 request,
261 response_tx,
262 })
263 .map_err(|_| Error::ConnectionClosed)?;
264
265 // Wait for response with timeout
266 match timeout(request_timeout, response_rx).await {
267 Ok(Ok(result)) => result,
268 Ok(Err(_)) => Err(Error::ConnectionClosed),
269 Err(_) => {
270 // Timeout - clean up correlation entry
271 let _ = self
272 .command_tx
273 .send(ConnectionCommand::RemoveCorrelation(request_id));
274
275 Err(Error::request_timeout(
276 request_id,
277 request_timeout.as_millis() as u64,
278 ))
279 }
280 }
281 }
282
283 /// Returns the number of pending requests.
284 #[inline]
285 #[must_use]
286 pub fn pending_count(&self) -> usize {
287 self.correlation.lock().len()
288 }
289
290 /// Shuts down the connection gracefully.
291 ///
292 /// This is called automatically on drop.
293 pub fn shutdown(&self) {
294 let _ = self.command_tx.send(ConnectionCommand::Shutdown);
295 }
296
297 /// Event loop that handles WebSocket I/O.
298 async fn run_event_loop(
299 ws_stream: WebSocketStream<TcpStream>,
300 mut command_rx: mpsc::UnboundedReceiver<ConnectionCommand>,
301 correlation: Arc<Mutex<CorrelationMap>>,
302 event_handler: Arc<Mutex<Option<EventHandler>>>,
303 ) {
304 let (mut ws_write, mut ws_read) = ws_stream.split();
305
306 loop {
307 tokio::select! {
308 // Incoming messages from extension
309 message = ws_read.next() => {
310 match message {
311 Some(Ok(Message::Text(text))) => {
312 let reply = Self::handle_incoming_message(
313 &text,
314 &correlation,
315 &event_handler,
316 );
317
318 // Send event reply if needed
319 if let Some(reply) = reply
320 && let Ok(json) = to_string(&reply)
321 && let Err(e) = ws_write.send(Message::Text(json.into())).await
322 {
323 warn!(error = %e, "Failed to send event reply");
324 }
325 }
326
327 Some(Ok(Message::Close(_))) => {
328 debug!("WebSocket closed by remote");
329 break;
330 }
331
332 Some(Err(e)) => {
333 error!(error = %e, "WebSocket error");
334 break;
335 }
336
337 None => {
338 debug!("WebSocket stream ended");
339 break;
340 }
341
342 // Ignore Binary, Ping, Pong
343 _ => {}
344 }
345 }
346
347 // Commands from Rust API
348 command = command_rx.recv() => {
349 match command {
350 Some(ConnectionCommand::Send { request, response_tx }) => {
351 Self::handle_send_command(
352 request,
353 response_tx,
354 &mut ws_write,
355 &correlation,
356 ).await;
357 }
358
359 Some(ConnectionCommand::RemoveCorrelation(request_id)) => {
360 correlation.lock().remove(&request_id);
361 debug!(?request_id, "Removed timed-out correlation");
362 }
363
364 Some(ConnectionCommand::Shutdown) => {
365 debug!("Shutdown command received");
366 let _ = ws_write.close().await;
367 break;
368 }
369
370 None => {
371 debug!("Command channel closed");
372 break;
373 }
374 }
375 }
376 }
377 }
378
379 // Fail all pending requests on shutdown
380 Self::fail_pending_requests(&correlation);
381
382 debug!("Event loop terminated");
383 }
384
385 /// Handles an incoming text message from the extension.
386 fn handle_incoming_message(
387 text: &str,
388 correlation: &Arc<Mutex<CorrelationMap>>,
389 event_handler: &Arc<Mutex<Option<EventHandler>>>,
390 ) -> Option<EventReply> {
391 // Try to parse as Response first
392 if let Ok(response) = from_str::<Response>(text) {
393 let tx = correlation.lock().remove(&response.id);
394
395 if let Some(tx) = tx {
396 let _ = tx.send(Ok(response));
397 } else {
398 warn!(id = %response.id, "Response for unknown request");
399 }
400
401 return None;
402 }
403
404 // Try to parse as Event
405 if let Ok(event) = from_str::<Event>(text) {
406 let handler = event_handler.lock();
407 if let Some(ref handler) = *handler {
408 return handler(event);
409 }
410 return None;
411 }
412
413 warn!(text = %text, "Failed to parse incoming message");
414 None
415 }
416
417 /// Handles a send command from the Rust API.
418 async fn handle_send_command(
419 request: Request,
420 response_tx: oneshot::Sender<Result<Response>>,
421 ws_write: &mut futures_util::stream::SplitSink<WebSocketStream<TcpStream>, Message>,
422 correlation: &Arc<Mutex<CorrelationMap>>,
423 ) {
424 let request_id = request.id;
425
426 // Serialize request
427 let json = match to_string(&request) {
428 Ok(j) => j,
429 Err(e) => {
430 let _ = response_tx.send(Err(Error::Json(e)));
431 return;
432 }
433 };
434
435 // Store correlation before sending
436 correlation.lock().insert(request_id, response_tx);
437
438 // Send over WebSocket
439 if let Err(e) = ws_write.send(Message::Text(json.into())).await {
440 // Remove correlation and notify caller
441 if let Some(tx) = correlation.lock().remove(&request_id) {
442 let _ = tx.send(Err(Error::connection(e.to_string())));
443 }
444 }
445
446 trace!(?request_id, "Request sent");
447 }
448
449 /// Fails all pending requests with ConnectionClosed error.
450 fn fail_pending_requests(correlation: &Arc<Mutex<CorrelationMap>>) {
451 let pending: Vec<_> = correlation.lock().drain().collect();
452 let count = pending.len();
453
454 for (_, tx) in pending {
455 let _ = tx.send(Err(Error::ConnectionClosed));
456 }
457
458 if count > 0 {
459 debug!(count, "Failed pending requests on shutdown");
460 }
461 }
462}
463
464impl Drop for Connection {
465 fn drop(&mut self) {
466 // Only shutdown if this is the last reference
467 // Since command_tx is cloned, we can check if we're the only sender
468 // Actually, we can't easily check this, so we should NOT auto-shutdown on drop
469 // The pool.remove() will explicitly call shutdown()
470 //
471 // DO NOT call shutdown here - it breaks cloned connections!
472 }
473}
474
475// ============================================================================
476// Tests
477// ============================================================================
478
479#[cfg(test)]
480mod tests {
481 use super::*;
482
483 #[test]
484 fn test_constants() {
485 assert_eq!(DEFAULT_COMMAND_TIMEOUT.as_secs(), 30);
486 assert_eq!(MAX_PENDING_REQUESTS, 100);
487 assert_eq!(READY_TIMEOUT.as_secs(), 30);
488 }
489
490 #[test]
491 fn test_ready_data() {
492 let data = ReadyData {
493 tab_id: 1,
494 session_id: 2,
495 };
496 assert_eq!(data.tab_id, 1);
497 assert_eq!(data.session_id, 2);
498 }
499}