use std::collections::HashMap;
use std::net::{SocketAddr, TcpStream, ToSocketAddrs};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use crate::connection::Connection;
use crate::error::IiopError;
#[derive(Debug, Clone)]
pub struct ConnectorConfig {
pub connect_timeout: Option<Duration>,
pub read_timeout: Option<Duration>,
pub write_timeout: Option<Duration>,
pub nodelay: bool,
pub max_connections_per_endpoint: usize,
}
impl Default for ConnectorConfig {
fn default() -> Self {
Self {
connect_timeout: Some(Duration::from_secs(10)),
read_timeout: Some(Duration::from_secs(30)),
write_timeout: Some(Duration::from_secs(30)),
nodelay: true,
max_connections_per_endpoint: 16,
}
}
}
pub struct PooledConnection {
inner: Option<Connection>,
pool: Arc<Mutex<PoolInner>>,
endpoint: SocketAddr,
return_to_pool: bool,
}
impl PooledConnection {
#[must_use]
pub fn connection(&mut self) -> Option<&mut Connection> {
self.inner.as_mut()
}
pub fn invalidate(&mut self) {
self.return_to_pool = false;
}
}
impl Drop for PooledConnection {
fn drop(&mut self) {
if !self.return_to_pool {
return;
}
if let Some(c) = self.inner.take() {
if let Ok(mut pool) = self.pool.lock() {
pool.idle.entry(self.endpoint).or_default().push(c);
}
}
}
}
#[derive(Default)]
struct PoolInner {
idle: HashMap<SocketAddr, Vec<Connection>>,
in_use_count: HashMap<SocketAddr, usize>,
}
pub struct Connector {
config: ConnectorConfig,
pool: Arc<Mutex<PoolInner>>,
}
impl Connector {
#[must_use]
pub fn new(config: ConnectorConfig) -> Self {
Self {
config,
pool: Arc::new(Mutex::new(PoolInner::default())),
}
}
pub fn connect(&self, host: &str, port: u16) -> Result<PooledConnection, IiopError> {
let endpoint = (host, port).to_socket_addrs()?.next().ok_or_else(|| {
IiopError::Other(alloc::format!("no address resolved for {host}:{port}"))
})?;
{
let mut pool = self
.pool
.lock()
.map_err(|_| IiopError::Other("connector pool mutex poisoned".into()))?;
if let Some(slots) = pool.idle.get_mut(&endpoint) {
if let Some(conn) = slots.pop() {
*pool.in_use_count.entry(endpoint).or_insert(0) += 1;
return Ok(PooledConnection {
inner: Some(conn),
pool: Arc::clone(&self.pool),
endpoint,
return_to_pool: true,
});
}
}
let in_use = pool.in_use_count.get(&endpoint).copied().unwrap_or(0);
if in_use >= self.config.max_connections_per_endpoint {
return Err(IiopError::PoolExhausted);
}
*pool.in_use_count.entry(endpoint).or_insert(0) += 1;
}
let stream = if let Some(t) = self.config.connect_timeout {
TcpStream::connect_timeout(&endpoint, t)?
} else {
TcpStream::connect(endpoint)?
};
let conn = Connection::from_stream(stream)?;
conn.set_read_timeout(self.config.read_timeout)?;
conn.set_write_timeout(self.config.write_timeout)?;
conn.set_nodelay(self.config.nodelay)?;
Ok(PooledConnection {
inner: Some(conn),
pool: Arc::clone(&self.pool),
endpoint,
return_to_pool: true,
})
}
#[must_use]
pub fn idle_count(&self, host: &str, port: u16) -> usize {
let Ok(addrs) = (host, port).to_socket_addrs() else {
return 0;
};
let endpoint = addrs.into_iter().next();
let Some(endpoint) = endpoint else {
return 0;
};
self.pool
.lock()
.map(|p| p.idle.get(&endpoint).map_or(0, Vec::len))
.unwrap_or(0)
}
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
use std::net::TcpListener;
use std::thread;
fn echo_server(listener: TcpListener) {
loop {
let Ok((mut stream, _)) = listener.accept() else {
return;
};
thread::spawn(move || {
use std::io::{Read, Write};
let mut buf = [0u8; 4096];
loop {
let Ok(n) = stream.read(&mut buf) else {
return;
};
if n == 0 {
return;
}
if stream.write_all(&buf[..n]).is_err() {
return;
}
}
});
}
}
#[test]
fn connect_reuses_pooled_connection() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
thread::spawn(move || echo_server(listener));
let connector = Connector::new(ConnectorConfig::default());
let host = addr.ip().to_string();
let port = addr.port();
{
let _c1 = connector.connect(&host, port).unwrap();
} assert_eq!(connector.idle_count(&host, port), 1);
let _c2 = connector.connect(&host, port).unwrap();
assert_eq!(connector.idle_count(&host, port), 0);
}
#[test]
fn invalidated_connection_is_not_returned_to_pool() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
thread::spawn(move || echo_server(listener));
let connector = Connector::new(ConnectorConfig::default());
let host = addr.ip().to_string();
let port = addr.port();
{
let mut c = connector.connect(&host, port).unwrap();
c.invalidate();
}
assert_eq!(connector.idle_count(&host, port), 0);
}
#[test]
fn max_connections_per_endpoint_is_enforced() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
thread::spawn(move || echo_server(listener));
let connector = Connector::new(ConnectorConfig {
max_connections_per_endpoint: 1,
..ConnectorConfig::default()
});
let host = addr.ip().to_string();
let port = addr.port();
let _c1 = connector
.connect(&host, port)
.map_err(|e| panic!("first connect: {e}"))
.ok();
match connector.connect(&host, port) {
Ok(_) => panic!("expected PoolExhausted"),
Err(IiopError::PoolExhausted) => {}
Err(other) => panic!("expected PoolExhausted, got {other}"),
}
}
}