use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::Arc;
use std::time::Duration;
use parking_lot::{Mutex, RwLock};
use rustc_hash::FxHashMap;
use tokio::net::TcpListener;
use tokio::sync::{Notify, oneshot};
use tokio::time::timeout;
use tracing::{debug, error, info, warn};
use crate::error::{Error, Result};
use crate::identifiers::SessionId;
use crate::protocol::{Request, Response};
use crate::transport::Connection;
use crate::transport::connection::ReadyData;
const DEFAULT_BIND_IP: IpAddr = IpAddr::V4(Ipv4Addr::LOCALHOST);
const SESSION_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
pub struct ConnectionPool {
port: u16,
ws_url: String,
connections: RwLock<FxHashMap<SessionId, Arc<Connection>>>,
waiters: Mutex<FxHashMap<SessionId, oneshot::Sender<ReadyData>>>,
shutdown_notify: Arc<Notify>,
}
impl ConnectionPool {
pub async fn new() -> Result<Arc<Self>> {
Self::with_ip_port(DEFAULT_BIND_IP, 0).await
}
pub async fn with_port(port: u16) -> Result<Arc<Self>> {
Self::with_ip_port(DEFAULT_BIND_IP, port).await
}
pub async fn with_ip_port(ip: IpAddr, port: u16) -> Result<Arc<Self>> {
let addr = SocketAddr::new(ip, port);
let listener = TcpListener::bind(addr).await?;
let bound_addr = listener.local_addr()?;
let actual_port = bound_addr.port();
debug!(port = actual_port, "ConnectionPool WebSocket server bound");
let shutdown_notify = Arc::new(Notify::new());
let ws_url = format!("ws://{}:{}", bound_addr.ip(), bound_addr.port());
let pool = Arc::new(Self {
port: actual_port,
ws_url,
connections: RwLock::new(FxHashMap::default()),
waiters: Mutex::new(FxHashMap::default()),
shutdown_notify: Arc::clone(&shutdown_notify),
});
let pool_clone = Arc::clone(&pool);
tokio::spawn(async move {
pool_clone.accept_loop(listener).await;
});
info!(port = actual_port, "ConnectionPool started");
Ok(pool)
}
}
impl ConnectionPool {
#[inline]
#[must_use]
pub fn ws_url(&self) -> &str {
&self.ws_url
}
#[inline]
#[must_use]
pub fn port(&self) -> u16 {
self.port
}
#[inline]
#[must_use]
pub fn connection_count(&self) -> usize {
self.connections.read().len()
}
pub async fn wait_for_session(&self, session_id: SessionId) -> Result<ReadyData> {
let (tx, rx) = oneshot::channel();
{
let mut waiters = self.waiters.lock();
waiters.insert(session_id, tx);
}
match timeout(SESSION_CONNECT_TIMEOUT, rx).await {
Ok(Ok(ready_data)) => {
debug!(session_id = %session_id, "Session connected");
Ok(ready_data)
}
Ok(Err(_)) => {
self.waiters.lock().remove(&session_id);
Err(Error::connection("Session waiter channel closed"))
}
Err(_) => {
self.waiters.lock().remove(&session_id);
Err(Error::connection_timeout(
SESSION_CONNECT_TIMEOUT.as_millis() as u64,
))
}
}
}
pub async fn send(&self, session_id: SessionId, request: Request) -> Result<Response> {
let connection = {
let connections = self.connections.read();
connections
.get(&session_id)
.ok_or_else(|| Error::session_not_found(session_id))?
.clone()
};
connection.send(request).await
}
pub async fn send_with_timeout(
&self,
session_id: SessionId,
request: Request,
request_timeout: Duration,
) -> Result<Response> {
let connection = {
let connections = self.connections.read();
connections
.get(&session_id)
.ok_or_else(|| Error::session_not_found(session_id))?
.clone()
};
connection.send_with_timeout(request, request_timeout).await
}
}
impl ConnectionPool {
pub fn add_event_handler(
&self,
session_id: SessionId,
key: String,
handler: crate::transport::EventHandler,
) {
let connections = self.connections.read();
if let Some(connection) = connections.get(&session_id) {
connection.add_event_handler(key, handler);
}
}
pub fn remove_event_handler(&self, session_id: SessionId, key: &str) {
let connections = self.connections.read();
if let Some(connection) = connections.get(&session_id) {
connection.remove_event_handler(key);
}
}
pub fn clear_all_event_handlers(&self, session_id: SessionId) {
let connections = self.connections.read();
if let Some(connection) = connections.get(&session_id) {
connection.clear_all_event_handlers();
}
}
}
impl ConnectionPool {
pub fn remove(&self, session_id: SessionId) {
let removed = {
let mut connections = self.connections.write();
connections.remove(&session_id)
};
if let Some(connection) = removed {
connection.shutdown();
debug!(session_id = %session_id, "Session removed from pool");
}
}
pub async fn shutdown(&self) {
info!("ConnectionPool shutting down");
self.shutdown_notify.notify_one();
let connections: Vec<_> = {
let mut map = self.connections.write();
map.drain().collect()
};
for (session_id, connection) in connections {
connection.shutdown();
debug!(session_id = %session_id, "Connection closed during shutdown");
}
let waiters: Vec<_> = {
let mut map = self.waiters.lock();
map.drain().collect()
};
drop(waiters);
info!("ConnectionPool shutdown complete");
}
}
impl ConnectionPool {
async fn accept_loop(self: Arc<Self>, listener: TcpListener) {
debug!("Accept loop started");
loop {
tokio::select! {
result = listener.accept() => {
match result {
Ok((stream, addr)) => {
let pool = Arc::clone(&self);
tokio::spawn(async move {
if let Err(e) = pool.handle_connection(stream, addr).await {
warn!(error = %e, ?addr, "Connection handling failed");
}
});
}
Err(e) => {
error!(error = %e, "Accept failed");
}
}
}
_ = self.shutdown_notify.notified() => {
debug!("Accept loop shutting down via notify");
break;
}
}
}
debug!("Accept loop terminated");
}
async fn handle_connection(
&self,
stream: tokio::net::TcpStream,
addr: SocketAddr,
) -> Result<()> {
debug!(?addr, "New TCP connection");
let ws_stream = tokio_tungstenite::accept_async(stream)
.await
.map_err(|e| Error::connection(format!("WebSocket upgrade failed: {e}")))?;
info!(?addr, "WebSocket connection established");
let connection = Connection::new(ws_stream);
let ready_data = connection.wait_ready().await?;
let session_id = SessionId::from_u32(ready_data.session_id)
.ok_or_else(|| Error::protocol("Invalid session_id in READY (must be > 0)"))?;
info!(session_id = %session_id, ?addr, "Session READY received");
{
let mut connections = self.connections.write();
connections.insert(session_id, Arc::new(connection));
}
{
let mut waiters = self.waiters.lock();
if let Some(tx) = waiters.remove(&session_id) {
let _ = tx.send(ready_data);
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_pool_creation() {
let pool = ConnectionPool::new().await.expect("pool creation");
assert!(pool.port() > 0);
assert!(pool.ws_url().starts_with("ws://127.0.0.1:"));
assert_eq!(pool.connection_count(), 0);
pool.shutdown().await;
}
#[tokio::test]
async fn test_pool_ws_url_format() {
let pool = ConnectionPool::new().await.expect("pool creation");
let url = pool.ws_url();
let expected = format!("ws://127.0.0.1:{}", pool.port());
assert_eq!(url, expected);
pool.shutdown().await;
}
#[tokio::test]
async fn test_send_to_unknown_session() {
let pool = ConnectionPool::new().await.expect("pool creation");
let session_id = SessionId::next();
let request = crate::protocol::Request::new(
crate::identifiers::TabId::new(1).unwrap(),
crate::identifiers::FrameId::main(),
crate::protocol::Command::Session(crate::protocol::SessionCommand::Status),
);
let result = pool.send(session_id, request).await;
assert!(result.is_err());
pool.shutdown().await;
}
#[tokio::test]
async fn test_wait_for_session_timeout() {
let pool = ConnectionPool::new().await.expect("pool creation");
let session_id = SessionId::next();
let (tx, rx) = oneshot::channel::<ReadyData>();
pool.waiters.lock().insert(session_id, tx);
drop(rx);
pool.shutdown().await;
}
#[tokio::test]
async fn test_remove_nonexistent_session() {
let pool = ConnectionPool::new().await.expect("pool creation");
let session_id = SessionId::next();
pool.remove(session_id);
pool.shutdown().await;
}
}