firefox_webdriver/transport/pool.rs
1//! Connection pool for multiplexed WebSocket connections.
2//!
3//! Manages multiple WebSocket connections keyed by SessionId.
4//! All Firefox windows connect to the same port, messages routed by session.
5//!
6//! # Architecture
7//!
8//! ```text
9//! ┌─────────────────────────────────────────┐
10//! │ ConnectionPool │
11//! │ (single port) │
12//! │ ┌─────────────────────────────────┐ │
13//! │ │ SessionId=1 → Connection 1 │ │
14//! │ │ SessionId=2 → Connection 2 │ │
15//! │ │ SessionId=3 → Connection 3 │ │
16//! │ └─────────────────────────────────┘ │
17//! └─────────────────────────────────────────┘
18//! ```
19
20// ============================================================================
21// Imports
22// ============================================================================
23
24use std::net::{IpAddr, Ipv4Addr, SocketAddr};
25use std::sync::Arc;
26use std::sync::atomic::{AtomicBool, Ordering};
27use std::time::Duration;
28
29use parking_lot::{Mutex, RwLock};
30use rustc_hash::FxHashMap;
31use tokio::net::TcpListener;
32use tokio::sync::oneshot;
33use tokio::time::timeout;
34use tracing::{debug, error, info, warn};
35
36use crate::error::{Error, Result};
37use crate::identifiers::SessionId;
38use crate::protocol::{Request, Response};
39use crate::transport::Connection;
40use crate::transport::connection::ReadyData;
41
42// ============================================================================
43// Constants
44// ============================================================================
45
46/// Default bind address for WebSocket server (localhost).
47const DEFAULT_BIND_IP: IpAddr = IpAddr::V4(Ipv4Addr::LOCALHOST);
48
49/// Timeout for waiting for a session to connect.
50const SESSION_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
51
52// ============================================================================
53// ConnectionPool
54// ============================================================================
55
56/// Manages multiple WebSocket connections keyed by SessionId.
57///
58/// Thread-safe, supports concurrent access from multiple Windows.
59/// All Firefox windows connect to the same port, messages routed by session.
60///
61/// # Example
62///
63/// ```ignore
64/// let pool = ConnectionPool::new().await?;
65/// println!("WebSocket URL: {}", pool.ws_url());
66///
67/// // Wait for a specific session to connect
68/// let ready_data = pool.wait_for_session(session_id).await?;
69///
70/// // Send a request to that session
71/// let response = pool.send(session_id, request).await?;
72/// ```
73pub struct ConnectionPool {
74 /// WebSocket server port.
75 port: u16,
76
77 /// Active connections by session ID.
78 connections: RwLock<FxHashMap<SessionId, Connection>>,
79
80 /// Waiters for pending sessions (spawn_window waiting for Firefox to connect).
81 waiters: Mutex<FxHashMap<SessionId, oneshot::Sender<ReadyData>>>,
82
83 /// Shutdown flag.
84 shutdown: AtomicBool,
85}
86
87// ============================================================================
88// ConnectionPool - Constructor
89// ============================================================================
90
91impl ConnectionPool {
92 /// Creates a new connection pool and starts the accept loop.
93 ///
94 /// Binds to `localhost:0` (random available port).
95 ///
96 /// # Errors
97 ///
98 /// Returns [`Error::Io`] if binding fails.
99 pub async fn new() -> Result<Arc<Self>> {
100 Self::with_ip_port(DEFAULT_BIND_IP, 0).await
101 }
102
103 /// Creates a new connection pool bound to a specific port.
104 ///
105 /// # Arguments
106 ///
107 /// * `port` - Port to bind to (0 for random)
108 ///
109 /// # Errors
110 ///
111 /// Returns [`Error::Io`] if binding fails.
112 pub async fn with_port(port: u16) -> Result<Arc<Self>> {
113 Self::with_ip_port(DEFAULT_BIND_IP, port).await
114 }
115
116 /// Creates a new connection pool bound to a specific IP and port.
117 ///
118 /// # Arguments
119 ///
120 /// * `ip` - IP address to bind to
121 /// * `port` - Port to bind to (0 for random)
122 ///
123 /// # Errors
124 ///
125 /// Returns [`Error::Io`] if binding fails.
126 pub async fn with_ip_port(ip: IpAddr, port: u16) -> Result<Arc<Self>> {
127 let addr = SocketAddr::new(ip, port);
128 let listener = TcpListener::bind(addr).await?;
129 let actual_port = listener.local_addr()?.port();
130
131 debug!(port = actual_port, "ConnectionPool WebSocket server bound");
132
133 let pool = Arc::new(Self {
134 port: actual_port,
135 connections: RwLock::new(FxHashMap::default()),
136 waiters: Mutex::new(FxHashMap::default()),
137 shutdown: AtomicBool::new(false),
138 });
139
140 // Spawn accept loop
141 let pool_clone = Arc::clone(&pool);
142 tokio::spawn(async move {
143 pool_clone.accept_loop(listener).await;
144 });
145
146 info!(port = actual_port, "ConnectionPool started");
147
148 Ok(pool)
149 }
150}
151
152// ============================================================================
153// ConnectionPool - Public API
154// ============================================================================
155
156impl ConnectionPool {
157 /// Returns the WebSocket URL for this pool.
158 ///
159 /// Format: `ws://127.0.0.1:{port}`
160 #[inline]
161 #[must_use]
162 pub fn ws_url(&self) -> String {
163 format!("ws://127.0.0.1:{}", self.port)
164 }
165
166 /// Returns the port the pool is bound to.
167 #[inline]
168 #[must_use]
169 pub fn port(&self) -> u16 {
170 self.port
171 }
172
173 /// Returns the number of active connections.
174 #[inline]
175 #[must_use]
176 pub fn connection_count(&self) -> usize {
177 self.connections.read().len()
178 }
179
180 /// Waits for a specific session to connect.
181 ///
182 /// Called by `spawn_window` after launching Firefox.
183 /// Returns when Firefox with this sessionId connects and sends READY.
184 ///
185 /// # Arguments
186 ///
187 /// * `session_id` - The session ID to wait for
188 ///
189 /// # Errors
190 ///
191 /// - [`Error::ConnectionTimeout`] if session doesn't connect within 30s
192 pub async fn wait_for_session(&self, session_id: SessionId) -> Result<ReadyData> {
193 let (tx, rx) = oneshot::channel();
194
195 // Register waiter
196 {
197 let mut waiters = self.waiters.lock();
198 waiters.insert(session_id, tx);
199 }
200
201 // Wait with timeout
202 match timeout(SESSION_CONNECT_TIMEOUT, rx).await {
203 Ok(Ok(ready_data)) => {
204 debug!(session_id = %session_id, "Session connected");
205 Ok(ready_data)
206 }
207 Ok(Err(_)) => {
208 // Channel closed without sending - shouldn't happen
209 self.waiters.lock().remove(&session_id);
210 Err(Error::connection("Session waiter channel closed"))
211 }
212 Err(_) => {
213 // Timeout
214 self.waiters.lock().remove(&session_id);
215 Err(Error::connection_timeout(
216 SESSION_CONNECT_TIMEOUT.as_millis() as u64,
217 ))
218 }
219 }
220 }
221
222 /// Sends a request to a specific session.
223 ///
224 /// # Arguments
225 ///
226 /// * `session_id` - Target session
227 /// * `request` - Request to send
228 ///
229 /// # Errors
230 ///
231 /// - [`Error::SessionNotFound`] if session doesn't exist
232 /// - [`Error::ConnectionClosed`] if connection is closed
233 /// - [`Error::RequestTimeout`] if response not received within timeout
234 pub async fn send(&self, session_id: SessionId, request: Request) -> Result<Response> {
235 let connection = {
236 let connections = self.connections.read();
237 connections
238 .get(&session_id)
239 .ok_or_else(|| Error::session_not_found(session_id))?
240 .clone()
241 };
242
243 connection.send(request).await
244 }
245
246 /// Sends a request with custom timeout.
247 ///
248 /// # Arguments
249 ///
250 /// * `session_id` - Target session
251 /// * `request` - Request to send
252 /// * `timeout` - Maximum time to wait for response
253 ///
254 /// # Errors
255 ///
256 /// - [`Error::SessionNotFound`] if session doesn't exist
257 /// - [`Error::ConnectionClosed`] if connection is closed
258 /// - [`Error::RequestTimeout`] if response not received within timeout
259 pub async fn send_with_timeout(
260 &self,
261 session_id: SessionId,
262 request: Request,
263 request_timeout: Duration,
264 ) -> Result<Response> {
265 let connection = {
266 let connections = self.connections.read();
267 connections
268 .get(&session_id)
269 .ok_or_else(|| Error::session_not_found(session_id))?
270 .clone()
271 };
272
273 connection.send_with_timeout(request, request_timeout).await
274 }
275}
276
277// ============================================================================
278// ConnectionPool - Event Handlers
279// ============================================================================
280
281impl ConnectionPool {
282 /// Sets the event handler for a session.
283 ///
284 /// # Arguments
285 ///
286 /// * `session_id` - Target session
287 /// * `handler` - Event handler callback
288 pub fn set_event_handler(
289 &self,
290 session_id: SessionId,
291 handler: crate::transport::EventHandler,
292 ) {
293 let connections = self.connections.read();
294 if let Some(connection) = connections.get(&session_id) {
295 connection.set_event_handler(handler);
296 }
297 }
298
299 /// Clears the event handler for a session.
300 ///
301 /// # Arguments
302 ///
303 /// * `session_id` - Target session
304 pub fn clear_event_handler(&self, session_id: SessionId) {
305 let connections = self.connections.read();
306 if let Some(connection) = connections.get(&session_id) {
307 connection.clear_event_handler();
308 }
309 }
310}
311
312// ============================================================================
313// ConnectionPool - Lifecycle
314// ============================================================================
315
316impl ConnectionPool {
317 /// Removes a session from the pool.
318 ///
319 /// Called when a Window closes.
320 ///
321 /// # Arguments
322 ///
323 /// * `session_id` - Session to remove
324 pub fn remove(&self, session_id: SessionId) {
325 let removed = {
326 let mut connections = self.connections.write();
327 connections.remove(&session_id)
328 };
329
330 if let Some(connection) = removed {
331 connection.shutdown();
332 debug!(session_id = %session_id, "Session removed from pool");
333 }
334 }
335
336 /// Shuts down the pool and all connections.
337 pub async fn shutdown(&self) {
338 info!("ConnectionPool shutting down");
339
340 // Signal accept loop to stop
341 self.shutdown.store(true, Ordering::SeqCst);
342
343 // Close all connections
344 let connections: Vec<_> = {
345 let mut map = self.connections.write();
346 map.drain().collect()
347 };
348
349 for (session_id, connection) in connections {
350 connection.shutdown();
351 debug!(session_id = %session_id, "Connection closed during shutdown");
352 }
353
354 // Cancel all waiters
355 let waiters: Vec<_> = {
356 let mut map = self.waiters.lock();
357 map.drain().collect()
358 };
359
360 drop(waiters); // Dropping senders will cause receivers to error
361
362 info!("ConnectionPool shutdown complete");
363 }
364}
365
366// ============================================================================
367// ConnectionPool - Accept Loop
368// ============================================================================
369
370impl ConnectionPool {
371 /// Background task that accepts new connections.
372 async fn accept_loop(self: Arc<Self>, listener: TcpListener) {
373 debug!("Accept loop started");
374
375 loop {
376 // Check shutdown flag
377 if self.shutdown.load(Ordering::SeqCst) {
378 debug!("Accept loop shutting down");
379 break;
380 }
381
382 // Accept with timeout to allow checking shutdown flag
383 match timeout(Duration::from_millis(100), listener.accept()).await {
384 Ok(Ok((stream, addr))) => {
385 let pool = Arc::clone(&self);
386 tokio::spawn(async move {
387 if let Err(e) = pool.handle_connection(stream, addr).await {
388 warn!(error = %e, ?addr, "Connection handling failed");
389 }
390 });
391 }
392 Ok(Err(e)) => {
393 error!(error = %e, "Accept failed");
394 }
395 Err(_) => {
396 // Timeout - just continue to check shutdown flag
397 continue;
398 }
399 }
400 }
401
402 debug!("Accept loop terminated");
403 }
404
405 /// Handles a single incoming connection.
406 async fn handle_connection(
407 &self,
408 stream: tokio::net::TcpStream,
409 addr: SocketAddr,
410 ) -> Result<()> {
411 debug!(?addr, "New TCP connection");
412
413 // Upgrade to WebSocket
414 let ws_stream = tokio_tungstenite::accept_async(stream)
415 .await
416 .map_err(|e| Error::connection(format!("WebSocket upgrade failed: {e}")))?;
417
418 info!(?addr, "WebSocket connection established");
419
420 // Create Connection and wait for READY
421 let connection = Connection::new(ws_stream);
422 let ready_data = connection.wait_ready().await?;
423
424 let session_id = SessionId::from_u32(ready_data.session_id)
425 .ok_or_else(|| Error::protocol("Invalid session_id in READY (must be > 0)"))?;
426
427 info!(session_id = %session_id, ?addr, "Session READY received");
428
429 // Store connection
430 {
431 let mut connections = self.connections.write();
432 connections.insert(session_id, connection);
433 }
434
435 // Notify waiter if any
436 {
437 let mut waiters = self.waiters.lock();
438 if let Some(tx) = waiters.remove(&session_id) {
439 let _ = tx.send(ready_data);
440 }
441 }
442
443 Ok(())
444 }
445}
446
447// ============================================================================
448// Tests
449// ============================================================================
450
451#[cfg(test)]
452mod tests {
453 use super::*;
454
455 #[tokio::test]
456 async fn test_pool_creation() {
457 let pool = ConnectionPool::new().await.expect("pool creation");
458 assert!(pool.port() > 0);
459 assert!(pool.ws_url().starts_with("ws://127.0.0.1:"));
460 assert_eq!(pool.connection_count(), 0);
461 pool.shutdown().await;
462 }
463
464 #[tokio::test]
465 async fn test_pool_ws_url_format() {
466 let pool = ConnectionPool::new().await.expect("pool creation");
467 let url = pool.ws_url();
468 let expected = format!("ws://127.0.0.1:{}", pool.port());
469 assert_eq!(url, expected);
470 pool.shutdown().await;
471 }
472
473 #[tokio::test]
474 async fn test_send_to_unknown_session() {
475 let pool = ConnectionPool::new().await.expect("pool creation");
476 let session_id = SessionId::next();
477 let request = crate::protocol::Request::new(
478 crate::identifiers::TabId::new(1).unwrap(),
479 crate::identifiers::FrameId::main(),
480 crate::protocol::Command::Session(crate::protocol::SessionCommand::Status),
481 );
482
483 let result = pool.send(session_id, request).await;
484 assert!(result.is_err());
485
486 pool.shutdown().await;
487 }
488
489 #[tokio::test]
490 async fn test_wait_for_session_timeout() {
491 let pool = ConnectionPool::new().await.expect("pool creation");
492 let session_id = SessionId::next();
493
494 // Use a short timeout for testing
495 let (tx, rx) = oneshot::channel::<ReadyData>();
496 pool.waiters.lock().insert(session_id, tx);
497
498 // Don't send anything, let it timeout
499 drop(rx);
500
501 // The waiter should be cleaned up
502 // (In real usage, wait_for_session would timeout)
503 pool.shutdown().await;
504 }
505
506 #[tokio::test]
507 async fn test_remove_nonexistent_session() {
508 let pool = ConnectionPool::new().await.expect("pool creation");
509 let session_id = SessionId::next();
510
511 // Should not panic
512 pool.remove(session_id);
513
514 pool.shutdown().await;
515 }
516}