firefox_webdriver/transport/
pool.rs

1//! Connection pool for multiplexed WebSocket connections.
2//!
3//! Manages multiple WebSocket connections keyed by SessionId.
4//! All Firefox windows connect to the same port, messages routed by session.
5//!
6//! # Architecture
7//!
8//! ```text
9//! ┌─────────────────────────────────────────┐
10//! │           ConnectionPool                │
11//! │           (single port)                 │
12//! │  ┌─────────────────────────────────┐   │
13//! │  │ SessionId=1 → Connection 1      │   │
14//! │  │ SessionId=2 → Connection 2      │   │
15//! │  │ SessionId=3 → Connection 3      │   │
16//! │  └─────────────────────────────────┘   │
17//! └─────────────────────────────────────────┘
18//! ```
19
20// ============================================================================
21// Imports
22// ============================================================================
23
24use std::net::{IpAddr, Ipv4Addr, SocketAddr};
25use std::sync::Arc;
26use std::sync::atomic::{AtomicBool, Ordering};
27use std::time::Duration;
28
29use parking_lot::{Mutex, RwLock};
30use rustc_hash::FxHashMap;
31use tokio::net::TcpListener;
32use tokio::sync::oneshot;
33use tokio::time::timeout;
34use tracing::{debug, error, info, warn};
35
36use crate::error::{Error, Result};
37use crate::identifiers::SessionId;
38use crate::protocol::{Request, Response};
39use crate::transport::Connection;
40use crate::transport::connection::ReadyData;
41
42// ============================================================================
43// Constants
44// ============================================================================
45
46/// Default bind address for WebSocket server (localhost).
47const DEFAULT_BIND_IP: IpAddr = IpAddr::V4(Ipv4Addr::LOCALHOST);
48
49/// Timeout for waiting for a session to connect.
50const SESSION_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
51
52// ============================================================================
53// ConnectionPool
54// ============================================================================
55
56/// Manages multiple WebSocket connections keyed by SessionId.
57///
58/// Thread-safe, supports concurrent access from multiple Windows.
59/// All Firefox windows connect to the same port, messages routed by session.
60///
61/// # Example
62///
63/// ```ignore
64/// let pool = ConnectionPool::new().await?;
65/// println!("WebSocket URL: {}", pool.ws_url());
66///
67/// // Wait for a specific session to connect
68/// let ready_data = pool.wait_for_session(session_id).await?;
69///
70/// // Send a request to that session
71/// let response = pool.send(session_id, request).await?;
72/// ```
73pub struct ConnectionPool {
74    /// WebSocket server port.
75    port: u16,
76
77    /// Active connections by session ID.
78    connections: RwLock<FxHashMap<SessionId, Connection>>,
79
80    /// Waiters for pending sessions (spawn_window waiting for Firefox to connect).
81    waiters: Mutex<FxHashMap<SessionId, oneshot::Sender<ReadyData>>>,
82
83    /// Shutdown flag.
84    shutdown: AtomicBool,
85}
86
87// ============================================================================
88// ConnectionPool - Constructor
89// ============================================================================
90
91impl ConnectionPool {
92    /// Creates a new connection pool and starts the accept loop.
93    ///
94    /// Binds to `localhost:0` (random available port).
95    ///
96    /// # Errors
97    ///
98    /// Returns [`Error::Io`] if binding fails.
99    pub async fn new() -> Result<Arc<Self>> {
100        Self::with_ip_port(DEFAULT_BIND_IP, 0).await
101    }
102
103    /// Creates a new connection pool bound to a specific port.
104    ///
105    /// # Arguments
106    ///
107    /// * `port` - Port to bind to (0 for random)
108    ///
109    /// # Errors
110    ///
111    /// Returns [`Error::Io`] if binding fails.
112    pub async fn with_port(port: u16) -> Result<Arc<Self>> {
113        Self::with_ip_port(DEFAULT_BIND_IP, port).await
114    }
115
116    /// Creates a new connection pool bound to a specific IP and port.
117    ///
118    /// # Arguments
119    ///
120    /// * `ip` - IP address to bind to
121    /// * `port` - Port to bind to (0 for random)
122    ///
123    /// # Errors
124    ///
125    /// Returns [`Error::Io`] if binding fails.
126    pub async fn with_ip_port(ip: IpAddr, port: u16) -> Result<Arc<Self>> {
127        let addr = SocketAddr::new(ip, port);
128        let listener = TcpListener::bind(addr).await?;
129        let actual_port = listener.local_addr()?.port();
130
131        debug!(port = actual_port, "ConnectionPool WebSocket server bound");
132
133        let pool = Arc::new(Self {
134            port: actual_port,
135            connections: RwLock::new(FxHashMap::default()),
136            waiters: Mutex::new(FxHashMap::default()),
137            shutdown: AtomicBool::new(false),
138        });
139
140        // Spawn accept loop
141        let pool_clone = Arc::clone(&pool);
142        tokio::spawn(async move {
143            pool_clone.accept_loop(listener).await;
144        });
145
146        info!(port = actual_port, "ConnectionPool started");
147
148        Ok(pool)
149    }
150}
151
152// ============================================================================
153// ConnectionPool - Public API
154// ============================================================================
155
156impl ConnectionPool {
157    /// Returns the WebSocket URL for this pool.
158    ///
159    /// Format: `ws://127.0.0.1:{port}`
160    #[inline]
161    #[must_use]
162    pub fn ws_url(&self) -> String {
163        format!("ws://127.0.0.1:{}", self.port)
164    }
165
166    /// Returns the port the pool is bound to.
167    #[inline]
168    #[must_use]
169    pub fn port(&self) -> u16 {
170        self.port
171    }
172
173    /// Returns the number of active connections.
174    #[inline]
175    #[must_use]
176    pub fn connection_count(&self) -> usize {
177        self.connections.read().len()
178    }
179
180    /// Waits for a specific session to connect.
181    ///
182    /// Called by `spawn_window` after launching Firefox.
183    /// Returns when Firefox with this sessionId connects and sends READY.
184    ///
185    /// # Arguments
186    ///
187    /// * `session_id` - The session ID to wait for
188    ///
189    /// # Errors
190    ///
191    /// - [`Error::ConnectionTimeout`] if session doesn't connect within 30s
192    pub async fn wait_for_session(&self, session_id: SessionId) -> Result<ReadyData> {
193        let (tx, rx) = oneshot::channel();
194
195        // Register waiter
196        {
197            let mut waiters = self.waiters.lock();
198            waiters.insert(session_id, tx);
199        }
200
201        // Wait with timeout
202        match timeout(SESSION_CONNECT_TIMEOUT, rx).await {
203            Ok(Ok(ready_data)) => {
204                debug!(session_id = %session_id, "Session connected");
205                Ok(ready_data)
206            }
207            Ok(Err(_)) => {
208                // Channel closed without sending - shouldn't happen
209                self.waiters.lock().remove(&session_id);
210                Err(Error::connection("Session waiter channel closed"))
211            }
212            Err(_) => {
213                // Timeout
214                self.waiters.lock().remove(&session_id);
215                Err(Error::connection_timeout(
216                    SESSION_CONNECT_TIMEOUT.as_millis() as u64,
217                ))
218            }
219        }
220    }
221
222    /// Sends a request to a specific session.
223    ///
224    /// # Arguments
225    ///
226    /// * `session_id` - Target session
227    /// * `request` - Request to send
228    ///
229    /// # Errors
230    ///
231    /// - [`Error::SessionNotFound`] if session doesn't exist
232    /// - [`Error::ConnectionClosed`] if connection is closed
233    /// - [`Error::RequestTimeout`] if response not received within timeout
234    pub async fn send(&self, session_id: SessionId, request: Request) -> Result<Response> {
235        let connection = {
236            let connections = self.connections.read();
237            connections
238                .get(&session_id)
239                .ok_or_else(|| Error::session_not_found(session_id))?
240                .clone()
241        };
242
243        connection.send(request).await
244    }
245
246    /// Sends a request with custom timeout.
247    ///
248    /// # Arguments
249    ///
250    /// * `session_id` - Target session
251    /// * `request` - Request to send
252    /// * `timeout` - Maximum time to wait for response
253    ///
254    /// # Errors
255    ///
256    /// - [`Error::SessionNotFound`] if session doesn't exist
257    /// - [`Error::ConnectionClosed`] if connection is closed
258    /// - [`Error::RequestTimeout`] if response not received within timeout
259    pub async fn send_with_timeout(
260        &self,
261        session_id: SessionId,
262        request: Request,
263        request_timeout: Duration,
264    ) -> Result<Response> {
265        let connection = {
266            let connections = self.connections.read();
267            connections
268                .get(&session_id)
269                .ok_or_else(|| Error::session_not_found(session_id))?
270                .clone()
271        };
272
273        connection.send_with_timeout(request, request_timeout).await
274    }
275}
276
277// ============================================================================
278// ConnectionPool - Event Handlers
279// ============================================================================
280
281impl ConnectionPool {
282    /// Sets the event handler for a session.
283    ///
284    /// # Arguments
285    ///
286    /// * `session_id` - Target session
287    /// * `handler` - Event handler callback
288    pub fn set_event_handler(
289        &self,
290        session_id: SessionId,
291        handler: crate::transport::EventHandler,
292    ) {
293        let connections = self.connections.read();
294        if let Some(connection) = connections.get(&session_id) {
295            connection.set_event_handler(handler);
296        }
297    }
298
299    /// Clears the event handler for a session.
300    ///
301    /// # Arguments
302    ///
303    /// * `session_id` - Target session
304    pub fn clear_event_handler(&self, session_id: SessionId) {
305        let connections = self.connections.read();
306        if let Some(connection) = connections.get(&session_id) {
307            connection.clear_event_handler();
308        }
309    }
310}
311
312// ============================================================================
313// ConnectionPool - Lifecycle
314// ============================================================================
315
316impl ConnectionPool {
317    /// Removes a session from the pool.
318    ///
319    /// Called when a Window closes.
320    ///
321    /// # Arguments
322    ///
323    /// * `session_id` - Session to remove
324    pub fn remove(&self, session_id: SessionId) {
325        let removed = {
326            let mut connections = self.connections.write();
327            connections.remove(&session_id)
328        };
329
330        if let Some(connection) = removed {
331            connection.shutdown();
332            debug!(session_id = %session_id, "Session removed from pool");
333        }
334    }
335
336    /// Shuts down the pool and all connections.
337    pub async fn shutdown(&self) {
338        info!("ConnectionPool shutting down");
339
340        // Signal accept loop to stop
341        self.shutdown.store(true, Ordering::SeqCst);
342
343        // Close all connections
344        let connections: Vec<_> = {
345            let mut map = self.connections.write();
346            map.drain().collect()
347        };
348
349        for (session_id, connection) in connections {
350            connection.shutdown();
351            debug!(session_id = %session_id, "Connection closed during shutdown");
352        }
353
354        // Cancel all waiters
355        let waiters: Vec<_> = {
356            let mut map = self.waiters.lock();
357            map.drain().collect()
358        };
359
360        drop(waiters); // Dropping senders will cause receivers to error
361
362        info!("ConnectionPool shutdown complete");
363    }
364}
365
366// ============================================================================
367// ConnectionPool - Accept Loop
368// ============================================================================
369
370impl ConnectionPool {
371    /// Background task that accepts new connections.
372    async fn accept_loop(self: Arc<Self>, listener: TcpListener) {
373        debug!("Accept loop started");
374
375        loop {
376            // Check shutdown flag
377            if self.shutdown.load(Ordering::SeqCst) {
378                debug!("Accept loop shutting down");
379                break;
380            }
381
382            // Accept with timeout to allow checking shutdown flag
383            match timeout(Duration::from_millis(100), listener.accept()).await {
384                Ok(Ok((stream, addr))) => {
385                    let pool = Arc::clone(&self);
386                    tokio::spawn(async move {
387                        if let Err(e) = pool.handle_connection(stream, addr).await {
388                            warn!(error = %e, ?addr, "Connection handling failed");
389                        }
390                    });
391                }
392                Ok(Err(e)) => {
393                    error!(error = %e, "Accept failed");
394                }
395                Err(_) => {
396                    // Timeout - just continue to check shutdown flag
397                    continue;
398                }
399            }
400        }
401
402        debug!("Accept loop terminated");
403    }
404
405    /// Handles a single incoming connection.
406    async fn handle_connection(
407        &self,
408        stream: tokio::net::TcpStream,
409        addr: SocketAddr,
410    ) -> Result<()> {
411        debug!(?addr, "New TCP connection");
412
413        // Upgrade to WebSocket
414        let ws_stream = tokio_tungstenite::accept_async(stream)
415            .await
416            .map_err(|e| Error::connection(format!("WebSocket upgrade failed: {e}")))?;
417
418        info!(?addr, "WebSocket connection established");
419
420        // Create Connection and wait for READY
421        let connection = Connection::new(ws_stream);
422        let ready_data = connection.wait_ready().await?;
423
424        let session_id = SessionId::from_u32(ready_data.session_id)
425            .ok_or_else(|| Error::protocol("Invalid session_id in READY (must be > 0)"))?;
426
427        info!(session_id = %session_id, ?addr, "Session READY received");
428
429        // Store connection
430        {
431            let mut connections = self.connections.write();
432            connections.insert(session_id, connection);
433        }
434
435        // Notify waiter if any
436        {
437            let mut waiters = self.waiters.lock();
438            if let Some(tx) = waiters.remove(&session_id) {
439                let _ = tx.send(ready_data);
440            }
441        }
442
443        Ok(())
444    }
445}
446
447// ============================================================================
448// Tests
449// ============================================================================
450
451#[cfg(test)]
452mod tests {
453    use super::*;
454
455    #[tokio::test]
456    async fn test_pool_creation() {
457        let pool = ConnectionPool::new().await.expect("pool creation");
458        assert!(pool.port() > 0);
459        assert!(pool.ws_url().starts_with("ws://127.0.0.1:"));
460        assert_eq!(pool.connection_count(), 0);
461        pool.shutdown().await;
462    }
463
464    #[tokio::test]
465    async fn test_pool_ws_url_format() {
466        let pool = ConnectionPool::new().await.expect("pool creation");
467        let url = pool.ws_url();
468        let expected = format!("ws://127.0.0.1:{}", pool.port());
469        assert_eq!(url, expected);
470        pool.shutdown().await;
471    }
472
473    #[tokio::test]
474    async fn test_send_to_unknown_session() {
475        let pool = ConnectionPool::new().await.expect("pool creation");
476        let session_id = SessionId::next();
477        let request = crate::protocol::Request::new(
478            crate::identifiers::TabId::new(1).unwrap(),
479            crate::identifiers::FrameId::main(),
480            crate::protocol::Command::Session(crate::protocol::SessionCommand::Status),
481        );
482
483        let result = pool.send(session_id, request).await;
484        assert!(result.is_err());
485
486        pool.shutdown().await;
487    }
488
489    #[tokio::test]
490    async fn test_wait_for_session_timeout() {
491        let pool = ConnectionPool::new().await.expect("pool creation");
492        let session_id = SessionId::next();
493
494        // Use a short timeout for testing
495        let (tx, rx) = oneshot::channel::<ReadyData>();
496        pool.waiters.lock().insert(session_id, tx);
497
498        // Don't send anything, let it timeout
499        drop(rx);
500
501        // The waiter should be cleaned up
502        // (In real usage, wait_for_session would timeout)
503        pool.shutdown().await;
504    }
505
506    #[tokio::test]
507    async fn test_remove_nonexistent_session() {
508        let pool = ConnectionPool::new().await.expect("pool creation");
509        let session_id = SessionId::next();
510
511        // Should not panic
512        pool.remove(session_id);
513
514        pool.shutdown().await;
515    }
516}