use std::collections::HashMap;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::LazyLock;
use std::time::{Duration, Instant};
use tokio::sync::Mutex;
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
const MAX_IDLE_PER_KEY: usize = 4;
const MAX_IDLE_TOTAL: usize = 8;
const IDLE_TIMEOUT: Duration = Duration::from_secs(55);
const MAX_LIFETIME: Duration = Duration::from_secs(300);
pub type WsStream = WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>;
#[derive(Clone, Hash, Eq, PartialEq)]
pub struct PoolKey {
url: String,
key_hash: u64,
}
impl PoolKey {
pub fn new(url: &str, api_key: &str) -> Self {
let mut hasher = DefaultHasher::new();
api_key.hash(&mut hasher);
Self {
url: url.to_string(),
key_hash: hasher.finish(),
}
}
}
struct IdleConnection {
stream: WsStream,
returned_at: Instant,
created_at: Instant,
}
pub struct WsPool {
connections: Mutex<HashMap<PoolKey, Vec<IdleConnection>>>,
}
impl WsPool {
pub fn shared() -> &'static Self {
static POOL: LazyLock<WsPool> = LazyLock::new(WsPool::new);
&POOL
}
fn new() -> Self {
Self {
connections: Mutex::new(HashMap::new()),
}
}
pub async fn checkout(&self, key: &PoolKey) -> Option<(WsStream, Instant)> {
let mut map = self.connections.lock().await;
let bucket = map.get_mut(key)?;
let now = Instant::now();
while let Some(entry) = bucket.pop() {
if now.duration_since(entry.returned_at) > IDLE_TIMEOUT {
tracing::debug!("ws_pool: dropping idle-timeout connection");
continue;
}
if now.duration_since(entry.created_at) > MAX_LIFETIME {
tracing::debug!("ws_pool: dropping max-lifetime connection");
continue;
}
if bucket.is_empty() {
map.remove(key);
}
tracing::debug!("ws_pool: reusing pooled connection");
return Some((entry.stream, entry.created_at));
}
map.remove(key);
None
}
pub async fn return_conn(&self, key: PoolKey, stream: WsStream, created_at: Instant) {
let now = Instant::now();
if now.duration_since(created_at) > MAX_LIFETIME {
tracing::debug!("ws_pool: not returning max-lifetime connection");
return;
}
let mut map = self.connections.lock().await;
let bucket = map.entry(key).or_default();
if bucket.len() >= MAX_IDLE_PER_KEY {
tracing::debug!("ws_pool: per-key cap reached, dropping oldest");
bucket.remove(0);
}
bucket.push(IdleConnection {
stream,
returned_at: now,
created_at,
});
let total: usize = map.values().map(Vec::len).sum();
if total > MAX_IDLE_TOTAL {
tracing::debug!("ws_pool: global cap reached, evicting oldest");
Self::evict_oldest(&mut map);
}
}
fn evict_oldest(map: &mut HashMap<PoolKey, Vec<IdleConnection>>) {
let mut oldest_key: Option<PoolKey> = None;
let mut oldest_idx: usize = 0;
let mut oldest_time: Option<Instant> = None;
for (key, bucket) in map.iter() {
for (idx, entry) in bucket.iter().enumerate() {
if oldest_time.is_none() || Some(entry.returned_at) < oldest_time {
oldest_key = Some(key.clone());
oldest_idx = idx;
oldest_time = Some(entry.returned_at);
}
}
}
if let Some(key) = oldest_key {
if let Some(bucket) = map.get_mut(&key) {
bucket.remove(oldest_idx);
if bucket.is_empty() {
map.remove(&key);
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures_util::{SinkExt, StreamExt};
use tokio_tungstenite::tungstenite::Message;
type ServerWsStream = WebSocketStream<tokio::net::TcpStream>;
async fn make_ws_pair() -> (WsStream, ServerWsStream) {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
.await
.expect("bind");
let addr = listener.local_addr().expect("addr");
let server_handle = tokio::spawn(async move {
let (tcp, _) = listener.accept().await.expect("accept");
tokio_tungstenite::accept_async(tcp)
.await
.expect("ws accept")
});
let url = format!("ws://127.0.0.1:{}", addr.port());
let (client, _) = tokio_tungstenite::connect_async(&url)
.await
.expect("connect");
let server = server_handle.await.expect("join server");
(client, server)
}
fn test_key(url: &str, secret: &str) -> PoolKey {
PoolKey::new(url, secret)
}
#[tokio::test]
async fn checkout_empty_returns_none() {
let pool = WsPool::new();
let key = test_key("wss://api.openai.com/v1/responses", "sk-test");
assert!(pool.checkout(&key).await.is_none());
}
#[tokio::test]
async fn return_then_checkout() {
let pool = WsPool::new();
let key = test_key("wss://api.openai.com/v1/responses", "sk-test");
let (client, mut server) = make_ws_pair().await;
let created = Instant::now();
pool.return_conn(key.clone(), client, created).await;
let (mut stream, checkout_created) =
pool.checkout(&key).await.expect("should get connection");
assert_eq!(checkout_created, created);
stream
.send(Message::Text("hello".into()))
.await
.expect("send");
let msg = server.next().await.expect("recv").expect("frame");
assert_eq!(msg, Message::Text("hello".into()));
assert!(pool.checkout(&key).await.is_none());
}
#[tokio::test]
async fn idle_timeout_eviction() {
let pool = WsPool::new();
let key = test_key("wss://api.openai.com/v1/responses", "sk-test");
let (client, _server) = make_ws_pair().await;
let created = Instant::now();
{
let mut map = pool.connections.lock().await;
map.entry(key.clone()).or_default().push(IdleConnection {
stream: client,
returned_at: created - (IDLE_TIMEOUT + Duration::from_secs(1)),
created_at: created,
});
}
assert!(pool.checkout(&key).await.is_none());
}
#[tokio::test]
async fn max_lifetime_eviction() {
let pool = WsPool::new();
let key = test_key("wss://api.openai.com/v1/responses", "sk-test");
let (client, _server) = make_ws_pair().await;
let old_created = Instant::now() - (MAX_LIFETIME + Duration::from_secs(1));
{
let mut map = pool.connections.lock().await;
map.entry(key.clone()).or_default().push(IdleConnection {
stream: client,
returned_at: Instant::now(),
created_at: old_created,
});
}
assert!(pool.checkout(&key).await.is_none());
}
#[tokio::test]
async fn different_keys_isolated() {
let pool = WsPool::new();
let key_a = test_key("wss://api.openai.com/v1/responses", "sk-aaa");
let key_b = test_key("wss://api.openai.com/v1/responses", "sk-bbb");
let (client_a, _server_a) = make_ws_pair().await;
pool.return_conn(key_a.clone(), client_a, Instant::now())
.await;
assert!(pool.checkout(&key_b).await.is_none());
assert!(pool.checkout(&key_a).await.is_some());
}
#[tokio::test]
async fn max_lifetime_rejected_on_return() {
let pool = WsPool::new();
let key = test_key("wss://api.openai.com/v1/responses", "sk-test");
let (client, _server) = make_ws_pair().await;
let old_created = Instant::now() - (MAX_LIFETIME + Duration::from_secs(1));
pool.return_conn(key.clone(), client, old_created).await;
let map = pool.connections.lock().await;
assert!(map.get(&key).is_none());
}
}