firefox_webdriver/transport/pool.rs
1//! Connection pool for multiplexed WebSocket connections.
2//!
3//! Manages multiple WebSocket connections keyed by SessionId.
4//! All Firefox windows connect to the same port, messages routed by session.
5//!
6//! # Architecture
7//!
8//! ```text
9//! ┌─────────────────────────────────────────┐
10//! │ ConnectionPool │
11//! │ (single port) │
12//! │ ┌─────────────────────────────────────┐│
13//! │ │ SessionId=1 → Arc<Connection> 1 ││
14//! │ │ SessionId=2 → Arc<Connection> 2 ││
15//! │ │ SessionId=3 → Arc<Connection> 3 ││
16//! │ └─────────────────────────────────────┘│
17//! └─────────────────────────────────────────┘
18//! ```
19
20// ============================================================================
21// Imports
22// ============================================================================
23
24use std::net::{IpAddr, Ipv4Addr, SocketAddr};
25use std::sync::Arc;
26use std::time::Duration;
27
28use parking_lot::{Mutex, RwLock};
29use rustc_hash::FxHashMap;
30use tokio::net::TcpListener;
31use tokio::sync::{Notify, oneshot};
32use tokio::time::timeout;
33use tracing::{debug, error, info, warn};
34
35use crate::error::{Error, Result};
36use crate::identifiers::SessionId;
37use crate::protocol::{Request, Response};
38use crate::transport::Connection;
39use crate::transport::connection::ReadyData;
40
41// ============================================================================
42// Constants
43// ============================================================================
44
45/// Default bind address for WebSocket server (localhost).
46const DEFAULT_BIND_IP: IpAddr = IpAddr::V4(Ipv4Addr::LOCALHOST);
47
48/// Timeout for waiting for a session to connect.
49const SESSION_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
50
51// ============================================================================
52// ConnectionPool
53// ============================================================================
54
55/// Manages multiple WebSocket connections keyed by SessionId.
56///
57/// Thread-safe, supports concurrent access from multiple Windows.
58/// All Firefox windows connect to the same port, messages routed by session.
59///
60/// # Example
61///
62/// ```ignore
63/// let pool = ConnectionPool::new().await?;
64/// println!("WebSocket URL: {}", pool.ws_url());
65///
66/// // Wait for a specific session to connect
67/// let ready_data = pool.wait_for_session(session_id).await?;
68///
69/// // Send a request to that session
70/// let response = pool.send(session_id, request).await?;
71/// ```
72pub struct ConnectionPool {
73 /// WebSocket server port.
74 port: u16,
75
76 /// Precomputed WebSocket URL string (e.g., `ws://127.0.0.1:12345`).
77 ws_url: String,
78
79 /// Active connections by session ID (wrapped in Arc for cheap cloning).
80 connections: RwLock<FxHashMap<SessionId, Arc<Connection>>>,
81
82 /// Waiters for pending sessions (spawn_window waiting for Firefox to connect).
83 waiters: Mutex<FxHashMap<SessionId, oneshot::Sender<ReadyData>>>,
84
85 /// Shutdown notification for the accept loop.
86 shutdown_notify: Arc<Notify>,
87}
88
89// ============================================================================
90// ConnectionPool - Constructor
91// ============================================================================
92
93impl ConnectionPool {
94 /// Creates a new connection pool and starts the accept loop.
95 ///
96 /// Binds to `localhost:0` (random available port).
97 ///
98 /// # Errors
99 ///
100 /// Returns [`Error::Io`] if binding fails.
101 pub async fn new() -> Result<Arc<Self>> {
102 Self::with_ip_port(DEFAULT_BIND_IP, 0).await
103 }
104
105 /// Creates a new connection pool bound to a specific port.
106 ///
107 /// # Arguments
108 ///
109 /// * `port` - Port to bind to (0 for random)
110 ///
111 /// # Errors
112 ///
113 /// Returns [`Error::Io`] if binding fails.
114 pub async fn with_port(port: u16) -> Result<Arc<Self>> {
115 Self::with_ip_port(DEFAULT_BIND_IP, port).await
116 }
117
118 /// Creates a new connection pool bound to a specific IP and port.
119 ///
120 /// # Arguments
121 ///
122 /// * `ip` - IP address to bind to
123 /// * `port` - Port to bind to (0 for random)
124 ///
125 /// # Errors
126 ///
127 /// Returns [`Error::Io`] if binding fails.
128 pub async fn with_ip_port(ip: IpAddr, port: u16) -> Result<Arc<Self>> {
129 let addr = SocketAddr::new(ip, port);
130 let listener = TcpListener::bind(addr).await?;
131 let bound_addr = listener.local_addr()?;
132 let actual_port = bound_addr.port();
133
134 debug!(port = actual_port, "ConnectionPool WebSocket server bound");
135
136 let shutdown_notify = Arc::new(Notify::new());
137 let ws_url = format!("ws://{}:{}", bound_addr.ip(), bound_addr.port());
138
139 let pool = Arc::new(Self {
140 port: actual_port,
141 ws_url,
142 connections: RwLock::new(FxHashMap::default()),
143 waiters: Mutex::new(FxHashMap::default()),
144 shutdown_notify: Arc::clone(&shutdown_notify),
145 });
146
147 // Spawn accept loop
148 let pool_clone = Arc::clone(&pool);
149 tokio::spawn(async move {
150 pool_clone.accept_loop(listener).await;
151 });
152
153 info!(port = actual_port, "ConnectionPool started");
154
155 Ok(pool)
156 }
157}
158
159// ============================================================================
160// ConnectionPool - Public API
161// ============================================================================
162
163impl ConnectionPool {
164 /// Returns the WebSocket URL for this pool.
165 ///
166 /// Uses the actual bound IP address instead of hardcoding 127.0.0.1.
167 ///
168 /// Format: `ws://{bound_ip}:{port}`
169 #[inline]
170 #[must_use]
171 pub fn ws_url(&self) -> &str {
172 &self.ws_url
173 }
174
175 /// Returns the port the pool is bound to.
176 #[inline]
177 #[must_use]
178 pub fn port(&self) -> u16 {
179 self.port
180 }
181
182 /// Returns the number of active connections.
183 #[inline]
184 #[must_use]
185 pub fn connection_count(&self) -> usize {
186 self.connections.read().len()
187 }
188
189 /// Waits for a specific session to connect.
190 ///
191 /// Called by `spawn_window` after launching Firefox.
192 /// Returns when Firefox with this sessionId connects and sends READY.
193 ///
194 /// # Arguments
195 ///
196 /// * `session_id` - The session ID to wait for
197 ///
198 /// # Errors
199 ///
200 /// - [`Error::ConnectionTimeout`] if session doesn't connect within 30s
201 pub async fn wait_for_session(&self, session_id: SessionId) -> Result<ReadyData> {
202 let (tx, rx) = oneshot::channel();
203
204 // Register waiter
205 {
206 let mut waiters = self.waiters.lock();
207 waiters.insert(session_id, tx);
208 }
209
210 // Wait with timeout
211 match timeout(SESSION_CONNECT_TIMEOUT, rx).await {
212 Ok(Ok(ready_data)) => {
213 debug!(session_id = %session_id, "Session connected");
214 Ok(ready_data)
215 }
216 Ok(Err(_)) => {
217 // Channel closed without sending - shouldn't happen
218 self.waiters.lock().remove(&session_id);
219 Err(Error::connection("Session waiter channel closed"))
220 }
221 Err(_) => {
222 // Timeout
223 self.waiters.lock().remove(&session_id);
224 Err(Error::connection_timeout(
225 SESSION_CONNECT_TIMEOUT.as_millis() as u64,
226 ))
227 }
228 }
229 }
230
231 /// Sends a request to a specific session.
232 ///
233 /// # Arguments
234 ///
235 /// * `session_id` - Target session
236 /// * `request` - Request to send
237 ///
238 /// # Errors
239 ///
240 /// - [`Error::SessionNotFound`] if session doesn't exist
241 /// - [`Error::ConnectionClosed`] if connection is closed
242 /// - [`Error::RequestTimeout`] if response not received within timeout
243 pub async fn send(&self, session_id: SessionId, request: Request) -> Result<Response> {
244 let connection = {
245 let connections = self.connections.read();
246 connections
247 .get(&session_id)
248 .ok_or_else(|| Error::session_not_found(session_id))?
249 .clone()
250 };
251
252 connection.send(request).await
253 }
254
255 /// Sends a request with custom timeout.
256 ///
257 /// # Arguments
258 ///
259 /// * `session_id` - Target session
260 /// * `request` - Request to send
261 /// * `timeout` - Maximum time to wait for response
262 ///
263 /// # Errors
264 ///
265 /// - [`Error::SessionNotFound`] if session doesn't exist
266 /// - [`Error::ConnectionClosed`] if connection is closed
267 /// - [`Error::RequestTimeout`] if response not received within timeout
268 pub async fn send_with_timeout(
269 &self,
270 session_id: SessionId,
271 request: Request,
272 request_timeout: Duration,
273 ) -> Result<Response> {
274 let connection = {
275 let connections = self.connections.read();
276 connections
277 .get(&session_id)
278 .ok_or_else(|| Error::session_not_found(session_id))?
279 .clone()
280 };
281
282 connection.send_with_timeout(request, request_timeout).await
283 }
284}
285
286// ============================================================================
287// ConnectionPool - Event Handlers
288// ============================================================================
289
290impl ConnectionPool {
291 /// Adds an event handler for a session with a key label.
292 ///
293 /// Multiple handlers can be registered for the same session.
294 /// If a handler with the same key already exists, it is replaced.
295 ///
296 /// # Arguments
297 ///
298 /// * `session_id` - Target session
299 /// * `key` - Unique key for this handler (used for removal)
300 /// * `handler` - Event handler callback
301 pub fn add_event_handler(
302 &self,
303 session_id: SessionId,
304 key: String,
305 handler: crate::transport::EventHandler,
306 ) {
307 let connections = self.connections.read();
308 if let Some(connection) = connections.get(&session_id) {
309 connection.add_event_handler(key, handler);
310 }
311 }
312
313 /// Removes an event handler for a session by key.
314 ///
315 /// # Arguments
316 ///
317 /// * `session_id` - Target session
318 /// * `key` - Key of the handler to remove
319 pub fn remove_event_handler(&self, session_id: SessionId, key: &str) {
320 let connections = self.connections.read();
321 if let Some(connection) = connections.get(&session_id) {
322 connection.remove_event_handler(key);
323 }
324 }
325
326 /// Clears all event handlers for a session (for shutdown).
327 ///
328 /// # Arguments
329 ///
330 /// * `session_id` - Target session
331 pub fn clear_all_event_handlers(&self, session_id: SessionId) {
332 let connections = self.connections.read();
333 if let Some(connection) = connections.get(&session_id) {
334 connection.clear_all_event_handlers();
335 }
336 }
337}
338
339// ============================================================================
340// ConnectionPool - Lifecycle
341// ============================================================================
342
343impl ConnectionPool {
344 /// Removes a session from the pool.
345 ///
346 /// Called when a Window closes.
347 ///
348 /// # Arguments
349 ///
350 /// * `session_id` - Session to remove
351 pub fn remove(&self, session_id: SessionId) {
352 let removed = {
353 let mut connections = self.connections.write();
354 connections.remove(&session_id)
355 };
356
357 if let Some(connection) = removed {
358 connection.shutdown();
359 debug!(session_id = %session_id, "Session removed from pool");
360 }
361 }
362
363 /// Shuts down the pool and all connections.
364 pub async fn shutdown(&self) {
365 info!("ConnectionPool shutting down");
366
367 // Signal accept loop to stop
368 self.shutdown_notify.notify_one();
369
370 // Close all connections
371 let connections: Vec<_> = {
372 let mut map = self.connections.write();
373 map.drain().collect()
374 };
375
376 for (session_id, connection) in connections {
377 connection.shutdown();
378 debug!(session_id = %session_id, "Connection closed during shutdown");
379 }
380
381 // Cancel all waiters
382 let waiters: Vec<_> = {
383 let mut map = self.waiters.lock();
384 map.drain().collect()
385 };
386
387 drop(waiters); // Dropping senders will cause receivers to error
388
389 info!("ConnectionPool shutdown complete");
390 }
391}
392
393// ============================================================================
394// ConnectionPool - Accept Loop
395// ============================================================================
396
397impl ConnectionPool {
398 /// Background task that accepts new connections.
399 ///
400 /// Uses `tokio::select!` with a shutdown notification instead of
401 /// busy-polling with 100ms timeout.
402 async fn accept_loop(self: Arc<Self>, listener: TcpListener) {
403 debug!("Accept loop started");
404
405 loop {
406 tokio::select! {
407 result = listener.accept() => {
408 match result {
409 Ok((stream, addr)) => {
410 let pool = Arc::clone(&self);
411 tokio::spawn(async move {
412 if let Err(e) = pool.handle_connection(stream, addr).await {
413 warn!(error = %e, ?addr, "Connection handling failed");
414 }
415 });
416 }
417 Err(e) => {
418 error!(error = %e, "Accept failed");
419 }
420 }
421 }
422 _ = self.shutdown_notify.notified() => {
423 debug!("Accept loop shutting down via notify");
424 break;
425 }
426 }
427 }
428
429 debug!("Accept loop terminated");
430 }
431
432 /// Handles a single incoming connection.
433 async fn handle_connection(
434 &self,
435 stream: tokio::net::TcpStream,
436 addr: SocketAddr,
437 ) -> Result<()> {
438 debug!(?addr, "New TCP connection");
439
440 // Upgrade to WebSocket
441 let ws_stream = tokio_tungstenite::accept_async(stream)
442 .await
443 .map_err(|e| Error::connection(format!("WebSocket upgrade failed: {e}")))?;
444
445 info!(?addr, "WebSocket connection established");
446
447 // Create Connection and wait for READY
448 let connection = Connection::new(ws_stream);
449 let ready_data = connection.wait_ready().await?;
450
451 let session_id = SessionId::from_u32(ready_data.session_id)
452 .ok_or_else(|| Error::protocol("Invalid session_id in READY (must be > 0)"))?;
453
454 info!(session_id = %session_id, ?addr, "Session READY received");
455
456 // Store connection wrapped in Arc
457 {
458 let mut connections = self.connections.write();
459 connections.insert(session_id, Arc::new(connection));
460 }
461
462 // Notify waiter if any
463 {
464 let mut waiters = self.waiters.lock();
465 if let Some(tx) = waiters.remove(&session_id) {
466 let _ = tx.send(ready_data);
467 }
468 }
469
470 Ok(())
471 }
472}
473
474// ============================================================================
475// Tests
476// ============================================================================
477
478#[cfg(test)]
479mod tests {
480 use super::*;
481
482 #[tokio::test]
483 async fn test_pool_creation() {
484 let pool = ConnectionPool::new().await.expect("pool creation");
485 assert!(pool.port() > 0);
486 assert!(pool.ws_url().starts_with("ws://127.0.0.1:"));
487 assert_eq!(pool.connection_count(), 0);
488 pool.shutdown().await;
489 }
490
491 #[tokio::test]
492 async fn test_pool_ws_url_format() {
493 let pool = ConnectionPool::new().await.expect("pool creation");
494 let url = pool.ws_url();
495 let expected = format!("ws://127.0.0.1:{}", pool.port());
496 assert_eq!(url, expected);
497 pool.shutdown().await;
498 }
499
500 #[tokio::test]
501 async fn test_send_to_unknown_session() {
502 let pool = ConnectionPool::new().await.expect("pool creation");
503 let session_id = SessionId::next();
504 let request = crate::protocol::Request::new(
505 crate::identifiers::TabId::new(1).unwrap(),
506 crate::identifiers::FrameId::main(),
507 crate::protocol::Command::Session(crate::protocol::SessionCommand::Status),
508 );
509
510 let result = pool.send(session_id, request).await;
511 assert!(result.is_err());
512
513 pool.shutdown().await;
514 }
515
516 #[tokio::test]
517 async fn test_wait_for_session_timeout() {
518 let pool = ConnectionPool::new().await.expect("pool creation");
519 let session_id = SessionId::next();
520
521 // Use a short timeout for testing
522 let (tx, rx) = oneshot::channel::<ReadyData>();
523 pool.waiters.lock().insert(session_id, tx);
524
525 // Don't send anything, let it timeout
526 drop(rx);
527
528 // The waiter should be cleaned up
529 // (In real usage, wait_for_session would timeout)
530 pool.shutdown().await;
531 }
532
533 #[tokio::test]
534 async fn test_remove_nonexistent_session() {
535 let pool = ConnectionPool::new().await.expect("pool creation");
536 let session_id = SessionId::next();
537
538 // Should not panic
539 pool.remove(session_id);
540
541 pool.shutdown().await;
542 }
543}