Skip to main content

firefox_webdriver/transport/
pool.rs

1//! Connection pool for multiplexed WebSocket connections.
2//!
3//! Manages multiple WebSocket connections keyed by SessionId.
4//! All Firefox windows connect to the same port, messages routed by session.
5//!
6//! # Architecture
7//!
8//! ```text
9//! ┌─────────────────────────────────────────┐
10//! │           ConnectionPool                │
11//! │           (single port)                 │
12//! │  ┌─────────────────────────────────────┐│
13//! │  │ SessionId=1 → Arc<Connection> 1     ││
14//! │  │ SessionId=2 → Arc<Connection> 2     ││
15//! │  │ SessionId=3 → Arc<Connection> 3     ││
16//! │  └─────────────────────────────────────┘│
17//! └─────────────────────────────────────────┘
18//! ```
19
20// ============================================================================
21// Imports
22// ============================================================================
23
24use std::net::{IpAddr, Ipv4Addr, SocketAddr};
25use std::sync::Arc;
26use std::time::Duration;
27
28use parking_lot::{Mutex, RwLock};
29use rustc_hash::FxHashMap;
30use tokio::net::TcpListener;
31use tokio::sync::{Notify, oneshot};
32use tokio::time::timeout;
33use tracing::{debug, error, info, warn};
34
35use crate::error::{Error, Result};
36use crate::identifiers::SessionId;
37use crate::protocol::{Request, Response};
38use crate::transport::Connection;
39use crate::transport::connection::ReadyData;
40
41// ============================================================================
42// Constants
43// ============================================================================
44
45/// Default bind address for WebSocket server (localhost).
46const DEFAULT_BIND_IP: IpAddr = IpAddr::V4(Ipv4Addr::LOCALHOST);
47
48/// Timeout for waiting for a session to connect.
49const SESSION_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
50
51// ============================================================================
52// ConnectionPool
53// ============================================================================
54
55/// Manages multiple WebSocket connections keyed by SessionId.
56///
57/// Thread-safe, supports concurrent access from multiple Windows.
58/// All Firefox windows connect to the same port, messages routed by session.
59///
60/// # Example
61///
62/// ```ignore
63/// let pool = ConnectionPool::new().await?;
64/// println!("WebSocket URL: {}", pool.ws_url());
65///
66/// // Wait for a specific session to connect
67/// let ready_data = pool.wait_for_session(session_id).await?;
68///
69/// // Send a request to that session
70/// let response = pool.send(session_id, request).await?;
71/// ```
72pub struct ConnectionPool {
73    /// WebSocket server port.
74    port: u16,
75
76    /// Precomputed WebSocket URL string (e.g., `ws://127.0.0.1:12345`).
77    ws_url: String,
78
79    /// Active connections by session ID (wrapped in Arc for cheap cloning).
80    connections: RwLock<FxHashMap<SessionId, Arc<Connection>>>,
81
82    /// Waiters for pending sessions (spawn_window waiting for Firefox to connect).
83    waiters: Mutex<FxHashMap<SessionId, oneshot::Sender<ReadyData>>>,
84
85    /// Shutdown notification for the accept loop.
86    shutdown_notify: Arc<Notify>,
87}
88
89// ============================================================================
90// ConnectionPool - Constructor
91// ============================================================================
92
93impl ConnectionPool {
94    /// Creates a new connection pool and starts the accept loop.
95    ///
96    /// Binds to `localhost:0` (random available port).
97    ///
98    /// # Errors
99    ///
100    /// Returns [`Error::Io`] if binding fails.
101    pub async fn new() -> Result<Arc<Self>> {
102        Self::with_ip_port(DEFAULT_BIND_IP, 0).await
103    }
104
105    /// Creates a new connection pool bound to a specific port.
106    ///
107    /// # Arguments
108    ///
109    /// * `port` - Port to bind to (0 for random)
110    ///
111    /// # Errors
112    ///
113    /// Returns [`Error::Io`] if binding fails.
114    pub async fn with_port(port: u16) -> Result<Arc<Self>> {
115        Self::with_ip_port(DEFAULT_BIND_IP, port).await
116    }
117
118    /// Creates a new connection pool bound to a specific IP and port.
119    ///
120    /// # Arguments
121    ///
122    /// * `ip` - IP address to bind to
123    /// * `port` - Port to bind to (0 for random)
124    ///
125    /// # Errors
126    ///
127    /// Returns [`Error::Io`] if binding fails.
128    pub async fn with_ip_port(ip: IpAddr, port: u16) -> Result<Arc<Self>> {
129        let addr = SocketAddr::new(ip, port);
130        let listener = TcpListener::bind(addr).await?;
131        let bound_addr = listener.local_addr()?;
132        let actual_port = bound_addr.port();
133
134        debug!(port = actual_port, "ConnectionPool WebSocket server bound");
135
136        let shutdown_notify = Arc::new(Notify::new());
137        let ws_url = format!("ws://{}:{}", bound_addr.ip(), bound_addr.port());
138
139        let pool = Arc::new(Self {
140            port: actual_port,
141            ws_url,
142            connections: RwLock::new(FxHashMap::default()),
143            waiters: Mutex::new(FxHashMap::default()),
144            shutdown_notify: Arc::clone(&shutdown_notify),
145        });
146
147        // Spawn accept loop
148        let pool_clone = Arc::clone(&pool);
149        tokio::spawn(async move {
150            pool_clone.accept_loop(listener).await;
151        });
152
153        info!(port = actual_port, "ConnectionPool started");
154
155        Ok(pool)
156    }
157}
158
159// ============================================================================
160// ConnectionPool - Public API
161// ============================================================================
162
163impl ConnectionPool {
164    /// Returns the WebSocket URL for this pool.
165    ///
166    /// Uses the actual bound IP address instead of hardcoding 127.0.0.1.
167    ///
168    /// Format: `ws://{bound_ip}:{port}`
169    #[inline]
170    #[must_use]
171    pub fn ws_url(&self) -> &str {
172        &self.ws_url
173    }
174
175    /// Returns the port the pool is bound to.
176    #[inline]
177    #[must_use]
178    pub fn port(&self) -> u16 {
179        self.port
180    }
181
182    /// Returns the number of active connections.
183    #[inline]
184    #[must_use]
185    pub fn connection_count(&self) -> usize {
186        self.connections.read().len()
187    }
188
189    /// Waits for a specific session to connect.
190    ///
191    /// Called by `spawn_window` after launching Firefox.
192    /// Returns when Firefox with this sessionId connects and sends READY.
193    ///
194    /// # Arguments
195    ///
196    /// * `session_id` - The session ID to wait for
197    ///
198    /// # Errors
199    ///
200    /// - [`Error::ConnectionTimeout`] if session doesn't connect within 30s
201    pub async fn wait_for_session(&self, session_id: SessionId) -> Result<ReadyData> {
202        let (tx, rx) = oneshot::channel();
203
204        // Register waiter
205        {
206            let mut waiters = self.waiters.lock();
207            waiters.insert(session_id, tx);
208        }
209
210        // Wait with timeout
211        match timeout(SESSION_CONNECT_TIMEOUT, rx).await {
212            Ok(Ok(ready_data)) => {
213                debug!(session_id = %session_id, "Session connected");
214                Ok(ready_data)
215            }
216            Ok(Err(_)) => {
217                // Channel closed without sending - shouldn't happen
218                self.waiters.lock().remove(&session_id);
219                Err(Error::connection("Session waiter channel closed"))
220            }
221            Err(_) => {
222                // Timeout
223                self.waiters.lock().remove(&session_id);
224                Err(Error::connection_timeout(
225                    SESSION_CONNECT_TIMEOUT.as_millis() as u64,
226                ))
227            }
228        }
229    }
230
231    /// Sends a request to a specific session.
232    ///
233    /// # Arguments
234    ///
235    /// * `session_id` - Target session
236    /// * `request` - Request to send
237    ///
238    /// # Errors
239    ///
240    /// - [`Error::SessionNotFound`] if session doesn't exist
241    /// - [`Error::ConnectionClosed`] if connection is closed
242    /// - [`Error::RequestTimeout`] if response not received within timeout
243    pub async fn send(&self, session_id: SessionId, request: Request) -> Result<Response> {
244        let connection = {
245            let connections = self.connections.read();
246            connections
247                .get(&session_id)
248                .ok_or_else(|| Error::session_not_found(session_id))?
249                .clone()
250        };
251
252        connection.send(request).await
253    }
254
255    /// Sends a request with custom timeout.
256    ///
257    /// # Arguments
258    ///
259    /// * `session_id` - Target session
260    /// * `request` - Request to send
261    /// * `timeout` - Maximum time to wait for response
262    ///
263    /// # Errors
264    ///
265    /// - [`Error::SessionNotFound`] if session doesn't exist
266    /// - [`Error::ConnectionClosed`] if connection is closed
267    /// - [`Error::RequestTimeout`] if response not received within timeout
268    pub async fn send_with_timeout(
269        &self,
270        session_id: SessionId,
271        request: Request,
272        request_timeout: Duration,
273    ) -> Result<Response> {
274        let connection = {
275            let connections = self.connections.read();
276            connections
277                .get(&session_id)
278                .ok_or_else(|| Error::session_not_found(session_id))?
279                .clone()
280        };
281
282        connection.send_with_timeout(request, request_timeout).await
283    }
284}
285
286// ============================================================================
287// ConnectionPool - Event Handlers
288// ============================================================================
289
290impl ConnectionPool {
291    /// Adds an event handler for a session with a key label.
292    ///
293    /// Multiple handlers can be registered for the same session.
294    /// If a handler with the same key already exists, it is replaced.
295    ///
296    /// # Arguments
297    ///
298    /// * `session_id` - Target session
299    /// * `key` - Unique key for this handler (used for removal)
300    /// * `handler` - Event handler callback
301    pub fn add_event_handler(
302        &self,
303        session_id: SessionId,
304        key: String,
305        handler: crate::transport::EventHandler,
306    ) {
307        let connections = self.connections.read();
308        if let Some(connection) = connections.get(&session_id) {
309            connection.add_event_handler(key, handler);
310        }
311    }
312
313    /// Removes an event handler for a session by key.
314    ///
315    /// # Arguments
316    ///
317    /// * `session_id` - Target session
318    /// * `key` - Key of the handler to remove
319    pub fn remove_event_handler(&self, session_id: SessionId, key: &str) {
320        let connections = self.connections.read();
321        if let Some(connection) = connections.get(&session_id) {
322            connection.remove_event_handler(key);
323        }
324    }
325
326    /// Clears all event handlers for a session (for shutdown).
327    ///
328    /// # Arguments
329    ///
330    /// * `session_id` - Target session
331    pub fn clear_all_event_handlers(&self, session_id: SessionId) {
332        let connections = self.connections.read();
333        if let Some(connection) = connections.get(&session_id) {
334            connection.clear_all_event_handlers();
335        }
336    }
337}
338
339// ============================================================================
340// ConnectionPool - Lifecycle
341// ============================================================================
342
343impl ConnectionPool {
344    /// Removes a session from the pool.
345    ///
346    /// Called when a Window closes.
347    ///
348    /// # Arguments
349    ///
350    /// * `session_id` - Session to remove
351    pub fn remove(&self, session_id: SessionId) {
352        let removed = {
353            let mut connections = self.connections.write();
354            connections.remove(&session_id)
355        };
356
357        if let Some(connection) = removed {
358            connection.shutdown();
359            debug!(session_id = %session_id, "Session removed from pool");
360        }
361    }
362
363    /// Shuts down the pool and all connections.
364    pub async fn shutdown(&self) {
365        info!("ConnectionPool shutting down");
366
367        // Signal accept loop to stop
368        self.shutdown_notify.notify_one();
369
370        // Close all connections
371        let connections: Vec<_> = {
372            let mut map = self.connections.write();
373            map.drain().collect()
374        };
375
376        for (session_id, connection) in connections {
377            connection.shutdown();
378            debug!(session_id = %session_id, "Connection closed during shutdown");
379        }
380
381        // Cancel all waiters
382        let waiters: Vec<_> = {
383            let mut map = self.waiters.lock();
384            map.drain().collect()
385        };
386
387        drop(waiters); // Dropping senders will cause receivers to error
388
389        info!("ConnectionPool shutdown complete");
390    }
391}
392
393// ============================================================================
394// ConnectionPool - Accept Loop
395// ============================================================================
396
397impl ConnectionPool {
398    /// Background task that accepts new connections.
399    ///
400    /// Uses `tokio::select!` with a shutdown notification instead of
401    /// busy-polling with 100ms timeout.
402    async fn accept_loop(self: Arc<Self>, listener: TcpListener) {
403        debug!("Accept loop started");
404
405        loop {
406            tokio::select! {
407                result = listener.accept() => {
408                    match result {
409                        Ok((stream, addr)) => {
410                            let pool = Arc::clone(&self);
411                            tokio::spawn(async move {
412                                if let Err(e) = pool.handle_connection(stream, addr).await {
413                                    warn!(error = %e, ?addr, "Connection handling failed");
414                                }
415                            });
416                        }
417                        Err(e) => {
418                            error!(error = %e, "Accept failed");
419                        }
420                    }
421                }
422                _ = self.shutdown_notify.notified() => {
423                    debug!("Accept loop shutting down via notify");
424                    break;
425                }
426            }
427        }
428
429        debug!("Accept loop terminated");
430    }
431
432    /// Handles a single incoming connection.
433    async fn handle_connection(
434        &self,
435        stream: tokio::net::TcpStream,
436        addr: SocketAddr,
437    ) -> Result<()> {
438        debug!(?addr, "New TCP connection");
439
440        // Upgrade to WebSocket
441        let ws_stream = tokio_tungstenite::accept_async(stream)
442            .await
443            .map_err(|e| Error::connection(format!("WebSocket upgrade failed: {e}")))?;
444
445        info!(?addr, "WebSocket connection established");
446
447        // Create Connection and wait for READY
448        let connection = Connection::new(ws_stream);
449        let ready_data = connection.wait_ready().await?;
450
451        let session_id = SessionId::from_u32(ready_data.session_id)
452            .ok_or_else(|| Error::protocol("Invalid session_id in READY (must be > 0)"))?;
453
454        info!(session_id = %session_id, ?addr, "Session READY received");
455
456        // Store connection wrapped in Arc
457        {
458            let mut connections = self.connections.write();
459            connections.insert(session_id, Arc::new(connection));
460        }
461
462        // Notify waiter if any
463        {
464            let mut waiters = self.waiters.lock();
465            if let Some(tx) = waiters.remove(&session_id) {
466                let _ = tx.send(ready_data);
467            }
468        }
469
470        Ok(())
471    }
472}
473
474// ============================================================================
475// Tests
476// ============================================================================
477
478#[cfg(test)]
479mod tests {
480    use super::*;
481
482    #[tokio::test]
483    async fn test_pool_creation() {
484        let pool = ConnectionPool::new().await.expect("pool creation");
485        assert!(pool.port() > 0);
486        assert!(pool.ws_url().starts_with("ws://127.0.0.1:"));
487        assert_eq!(pool.connection_count(), 0);
488        pool.shutdown().await;
489    }
490
491    #[tokio::test]
492    async fn test_pool_ws_url_format() {
493        let pool = ConnectionPool::new().await.expect("pool creation");
494        let url = pool.ws_url();
495        let expected = format!("ws://127.0.0.1:{}", pool.port());
496        assert_eq!(url, expected);
497        pool.shutdown().await;
498    }
499
500    #[tokio::test]
501    async fn test_send_to_unknown_session() {
502        let pool = ConnectionPool::new().await.expect("pool creation");
503        let session_id = SessionId::next();
504        let request = crate::protocol::Request::new(
505            crate::identifiers::TabId::new(1).unwrap(),
506            crate::identifiers::FrameId::main(),
507            crate::protocol::Command::Session(crate::protocol::SessionCommand::Status),
508        );
509
510        let result = pool.send(session_id, request).await;
511        assert!(result.is_err());
512
513        pool.shutdown().await;
514    }
515
516    #[tokio::test]
517    async fn test_wait_for_session_timeout() {
518        let pool = ConnectionPool::new().await.expect("pool creation");
519        let session_id = SessionId::next();
520
521        // Use a short timeout for testing
522        let (tx, rx) = oneshot::channel::<ReadyData>();
523        pool.waiters.lock().insert(session_id, tx);
524
525        // Don't send anything, let it timeout
526        drop(rx);
527
528        // The waiter should be cleaned up
529        // (In real usage, wait_for_session would timeout)
530        pool.shutdown().await;
531    }
532
533    #[tokio::test]
534    async fn test_remove_nonexistent_session() {
535        let pool = ConnectionPool::new().await.expect("pool creation");
536        let session_id = SessionId::next();
537
538        // Should not panic
539        pool.remove(session_id);
540
541        pool.shutdown().await;
542    }
543}