use std::collections::HashMap;
use std::io::{BufReader, Read, Write};
use std::net::TcpStream;
use std::sync::{Mutex, OnceLock};
use crate::tls::TlsStream;
const POOL_PER_KEY_CAP: usize = 4;
const POOL_GLOBAL_CAP: usize = 32;
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub(crate) struct Key {
pub scheme: String,
pub host: String,
pub port: u16,
}
pub(crate) struct Pool<S: Read + Write> {
entries: HashMap<Key, Vec<BufReader<S>>>,
}
impl<S: Read + Write> Pool<S> {
fn new() -> Self {
Self {
entries: HashMap::new(),
}
}
pub(crate) fn checkout(&mut self, key: &Key) -> Option<BufReader<S>> {
let bucket = self.entries.get_mut(key)?;
let r = bucket.pop();
if bucket.is_empty() {
self.entries.remove(key);
}
r
}
pub(crate) fn release(&mut self, key: Key, conn: BufReader<S>) {
let total: usize = self.entries.values().map(Vec::len).sum();
if total >= POOL_GLOBAL_CAP {
return;
}
let bucket = self.entries.entry(key).or_default();
if bucket.len() >= POOL_PER_KEY_CAP {
return;
}
bucket.push(conn);
}
#[cfg(test)]
fn total_len(&self) -> usize {
self.entries.values().map(Vec::len).sum()
}
}
static POOL_PLAIN: OnceLock<Mutex<Pool<TcpStream>>> = OnceLock::new();
static POOL_TLS: OnceLock<Mutex<Pool<TlsStream<TcpStream>>>> = OnceLock::new();
pub(crate) fn plain() -> &'static Mutex<Pool<TcpStream>> {
POOL_PLAIN.get_or_init(|| Mutex::new(Pool::new()))
}
pub(crate) fn tls() -> &'static Mutex<Pool<TlsStream<TcpStream>>> {
POOL_TLS.get_or_init(|| Mutex::new(Pool::new()))
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::{Read, Result as IoResult, Write};
struct Stub;
impl Read for Stub {
fn read(&mut self, _buf: &mut [u8]) -> IoResult<usize> {
Ok(0)
}
}
impl Write for Stub {
fn write(&mut self, buf: &[u8]) -> IoResult<usize> {
Ok(buf.len())
}
fn flush(&mut self) -> IoResult<()> {
Ok(())
}
}
fn k(host: &str, port: u16) -> Key {
Key {
scheme: "http".into(),
host: host.into(),
port,
}
}
#[test]
fn lifo_checkout_after_two_releases() {
let mut p: Pool<Stub> = Pool::new();
p.release(k("h", 80), BufReader::new(Stub));
p.release(k("h", 80), BufReader::new(Stub));
assert!(p.checkout(&k("h", 80)).is_some());
assert!(p.checkout(&k("h", 80)).is_some());
assert!(p.checkout(&k("h", 80)).is_none());
assert_eq!(p.total_len(), 0);
}
#[test]
fn per_key_cap_enforced() {
let mut p: Pool<Stub> = Pool::new();
for _ in 0..(POOL_PER_KEY_CAP + 2) {
p.release(k("h", 80), BufReader::new(Stub));
}
assert_eq!(p.total_len(), POOL_PER_KEY_CAP);
}
#[test]
fn global_cap_enforced_across_keys() {
let mut p: Pool<Stub> = Pool::new();
for i in 0..(POOL_GLOBAL_CAP + 5) {
p.release(k("h", i as u16), BufReader::new(Stub));
}
assert_eq!(p.total_len(), POOL_GLOBAL_CAP);
}
}