use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Mutex, RwLock};
use std::time::{Duration, Instant};
use super::session::{SshConfig, SshSession};
#[derive(Debug, Clone)]
pub struct PoolConfig {
pub max_per_host: usize,
pub max_total: usize,
pub idle_timeout: Duration,
pub reuse_connections: bool,
pub validate_on_checkout: bool,
pub max_connection_age: Option<Duration>,
}
impl Default for PoolConfig {
fn default() -> Self {
Self {
max_per_host: 5,
max_total: 20,
idle_timeout: Duration::from_secs(300),
reuse_connections: true,
validate_on_checkout: true,
max_connection_age: Some(Duration::from_secs(3600)), }
}
}
impl PoolConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub const fn max_per_host(mut self, max: usize) -> Self {
self.max_per_host = max;
self
}
#[must_use]
pub const fn max_total(mut self, max: usize) -> Self {
self.max_total = max;
self
}
#[must_use]
pub const fn idle_timeout(mut self, timeout: Duration) -> Self {
self.idle_timeout = timeout;
self
}
#[must_use]
pub const fn reuse_connections(mut self, reuse: bool) -> Self {
self.reuse_connections = reuse;
self
}
#[must_use]
pub const fn validate_on_checkout(mut self, validate: bool) -> Self {
self.validate_on_checkout = validate;
self
}
#[must_use]
pub const fn max_connection_age(mut self, age: Option<Duration>) -> Self {
self.max_connection_age = age;
self
}
}
#[derive(Debug)]
struct SharedSession {
session: RwLock<SshSession>,
in_use: AtomicBool,
created: Instant,
last_used: AtomicU64,
config: SshConfig,
}
impl SharedSession {
fn new(session: SshSession, config: SshConfig) -> Self {
Self {
session: RwLock::new(session),
in_use: AtomicBool::new(false),
created: Instant::now(),
last_used: AtomicU64::new(0),
config,
}
}
fn mark_used(&self) {
let now = Instant::now().elapsed().as_millis() as u64;
self.last_used.store(now, Ordering::Relaxed);
}
fn acquire(&self) -> bool {
self.in_use
.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
.is_ok()
}
fn release(&self) {
self.in_use.store(false, Ordering::Release);
self.mark_used();
}
fn is_in_use(&self) -> bool {
self.in_use.load(Ordering::Relaxed)
}
fn age(&self) -> Duration {
self.created.elapsed()
}
fn is_connected(&self) -> bool {
self.session
.read()
.map(|s| s.is_connected())
.unwrap_or(false)
}
}
type PoolEntry = Arc<SharedSession>;
#[derive(Debug)]
pub struct ConnectionPool {
config: PoolConfig,
connections: Arc<Mutex<HashMap<String, Vec<PoolEntry>>>>,
}
impl ConnectionPool {
#[must_use]
pub fn new(config: PoolConfig) -> Self {
Self {
config,
connections: Arc::new(Mutex::new(HashMap::new())),
}
}
#[must_use]
pub fn with_defaults() -> Self {
Self::new(PoolConfig::default())
}
fn make_key(ssh_config: &SshConfig) -> String {
format!(
"{}@{}:{}",
ssh_config.credentials.username, ssh_config.host, ssh_config.port
)
}
pub fn get(&self, ssh_config: &SshConfig) -> crate::error::Result<PooledConnection> {
let key = Self::make_key(ssh_config);
if self.config.reuse_connections
&& let Some(conn) = self.try_acquire_existing(&key)
{
if self.config.validate_on_checkout && !conn.is_connected() {
conn.release();
} else {
return Ok(PooledConnection::new(conn));
}
}
{
let connections = self
.connections
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
if let Some(entries) = connections.get(&key)
&& entries.len() >= self.config.max_per_host
{
return Err(crate::error::ExpectError::config(format!(
"Maximum connections per host ({}) exceeded for {}",
self.config.max_per_host, key
)));
}
}
let mut session = SshSession::new(ssh_config.clone());
session.connect()?;
let entry = Arc::new(SharedSession::new(session, ssh_config.clone()));
entry.acquire();
{
let mut connections = self
.connections
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
connections.entry(key).or_default().push(Arc::clone(&entry));
}
Ok(PooledConnection::new(entry))
}
#[cfg(feature = "ssh")]
pub async fn get_async(
&self,
ssh_config: &SshConfig,
) -> crate::error::Result<PooledConnection> {
let key = Self::make_key(ssh_config);
if self.config.reuse_connections
&& let Some(conn) = self.try_acquire_existing(&key)
{
if self.config.validate_on_checkout && !conn.is_connected() {
conn.release();
} else {
return Ok(PooledConnection::new(conn));
}
}
{
let connections = self
.connections
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let total: usize = connections.values().map(std::vec::Vec::len).sum();
if total >= self.config.max_total {
return Err(crate::error::ExpectError::config(format!(
"Maximum total connections ({}) exceeded",
self.config.max_total
)));
}
if let Some(entries) = connections.get(&key)
&& entries.len() >= self.config.max_per_host
{
return Err(crate::error::ExpectError::config(format!(
"Maximum connections per host ({}) exceeded for {}",
self.config.max_per_host, key
)));
}
}
let mut session = SshSession::new(ssh_config.clone());
session.connect_async().await?;
let entry = Arc::new(SharedSession::new(session, ssh_config.clone()));
entry.acquire();
{
let mut connections = self
.connections
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
connections.entry(key).or_default().push(Arc::clone(&entry));
}
Ok(PooledConnection::new(entry))
}
#[allow(clippy::significant_drop_tightening)]
fn try_acquire_existing(&self, key: &str) -> Option<PoolEntry> {
let connections = self
.connections
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
if let Some(entries) = connections.get(key) {
for entry in entries {
if let Some(max_age) = self.config.max_connection_age
&& entry.age() > max_age
{
continue;
}
if entry.acquire() {
return Some(Arc::clone(entry));
}
}
}
None
}
pub fn cleanup(&self) {
let mut connections = self
.connections
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let idle_timeout = self.config.idle_timeout;
let max_age = self.config.max_connection_age;
for entries in connections.values_mut() {
entries.retain(|entry| {
if entry.is_in_use() {
return true;
}
if !entry.is_connected() {
return false;
}
if let Some(max) = max_age
&& entry.age() > max
{
return false;
}
if entry.age() > idle_timeout {
return false;
}
true
});
}
connections.retain(|_, entries| !entries.is_empty());
}
#[must_use]
pub fn stats(&self) -> PoolStats {
let connections = self
.connections
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let mut total = 0;
let mut active = 0;
let mut idle = 0;
let mut connected = 0;
for entries in connections.values() {
for entry in entries {
total += 1;
if entry.is_in_use() {
active += 1;
} else {
idle += 1;
}
if entry.is_connected() {
connected += 1;
}
}
}
PoolStats {
total,
active,
idle,
connected,
hosts: connections.len(),
}
}
pub fn close_all(&self) {
let mut connections = self
.connections
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
for entries in connections.values() {
for entry in entries {
if let Ok(mut session) = entry.session.write() {
session.disconnect();
}
}
}
connections.clear();
}
#[must_use]
pub const fn config(&self) -> &PoolConfig {
&self.config
}
}
impl Default for ConnectionPool {
fn default() -> Self {
Self::with_defaults()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PoolStats {
pub total: usize,
pub active: usize,
pub idle: usize,
pub connected: usize,
pub hosts: usize,
}
impl PoolStats {
#[must_use]
pub const fn is_empty(&self) -> bool {
self.total == 0
}
#[must_use]
pub fn utilization(&self) -> f64 {
if self.total == 0 {
0.0
} else {
(self.active as f64 / self.total as f64) * 100.0
}
}
}
pub struct PooledConnection {
entry: PoolEntry,
}
impl std::fmt::Debug for PooledConnection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PooledConnection")
.field("connected", &self.is_connected())
.field("age", &self.age())
.finish()
}
}
impl PooledConnection {
fn new(entry: PoolEntry) -> Self {
entry.mark_used();
Self { entry }
}
#[must_use]
pub fn is_connected(&self) -> bool {
self.entry.is_connected()
}
#[must_use]
pub fn age(&self) -> Duration {
self.entry.age()
}
#[must_use]
pub fn config(&self) -> &SshConfig {
&self.entry.config
}
pub fn with_session<F, R>(&self, f: F) -> Option<R>
where
F: FnOnce(&SshSession) -> R,
{
self.entry.session.read().ok().map(|s| f(&s))
}
pub fn with_session_mut<F, R>(&self, f: F) -> Option<R>
where
F: FnOnce(&mut SshSession) -> R,
{
self.entry.session.write().ok().map(|mut s| f(&mut s))
}
#[must_use]
pub fn session(&self) -> Option<std::sync::RwLockReadGuard<'_, SshSession>> {
self.entry.session.read().ok()
}
#[must_use]
pub fn session_mut(&self) -> Option<std::sync::RwLockWriteGuard<'_, SshSession>> {
self.entry.session.write().ok()
}
pub fn release(self) {
drop(self);
}
}
impl Drop for PooledConnection {
fn drop(&mut self) {
self.entry.release();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pool_stats_empty() {
let pool = ConnectionPool::with_defaults();
let stats = pool.stats();
assert_eq!(stats.total, 0);
assert_eq!(stats.active, 0);
assert_eq!(stats.idle, 0);
assert_eq!(stats.connected, 0);
assert!(stats.is_empty());
assert!(stats.utilization().abs() < f64::EPSILON);
}
#[test]
fn pool_config_builder() {
let config = PoolConfig::new()
.max_per_host(10)
.max_total(50)
.idle_timeout(Duration::from_secs(600))
.reuse_connections(true)
.validate_on_checkout(false)
.max_connection_age(Some(Duration::from_secs(7200)));
let pool = ConnectionPool::new(config);
assert_eq!(pool.config().max_per_host, 10);
assert_eq!(pool.config().max_total, 50);
assert_eq!(pool.config().idle_timeout, Duration::from_secs(600));
assert!(pool.config().reuse_connections);
assert!(!pool.config().validate_on_checkout);
assert_eq!(
pool.config().max_connection_age,
Some(Duration::from_secs(7200))
);
}
#[test]
fn pool_config_defaults() {
let config = PoolConfig::default();
assert_eq!(config.max_per_host, 5);
assert_eq!(config.max_total, 20);
assert_eq!(config.idle_timeout, Duration::from_secs(300));
assert!(config.reuse_connections);
assert!(config.validate_on_checkout);
assert_eq!(config.max_connection_age, Some(Duration::from_secs(3600)));
}
#[test]
fn pool_cleanup_empty() {
let pool = ConnectionPool::with_defaults();
pool.cleanup();
assert!(pool.stats().is_empty());
}
#[test]
fn pool_close_all_empty() {
let pool = ConnectionPool::with_defaults();
pool.close_all();
assert!(pool.stats().is_empty());
}
#[test]
fn pool_stats_utilization() {
let stats = PoolStats {
total: 10,
active: 5,
idle: 5,
connected: 8,
hosts: 2,
};
assert!((stats.utilization() - 50.0).abs() < f64::EPSILON);
}
}