use std::collections::VecDeque;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use nodedb_types::error::{NodeDbError, NodeDbResult};
use nodedb_types::protocol::AuthMethod;
use tokio::sync::{Semaphore, SemaphorePermit};
use super::connection::NativeConnection;
#[derive(Debug, Clone)]
pub struct PoolConfig {
pub addr: String,
pub max_size: usize,
pub connect_timeout: Duration,
pub idle_timeout: Duration,
pub auth: AuthMethod,
pub tls: super::connection::TlsConfig,
}
impl Default for PoolConfig {
fn default() -> Self {
Self {
addr: "127.0.0.1:6433".into(),
max_size: 10,
connect_timeout: Duration::from_secs(5),
idle_timeout: Duration::from_secs(300),
auth: AuthMethod::Trust {
username: "admin".into(),
},
tls: Default::default(),
}
}
}
struct PoolInner {
idle: Mutex<VecDeque<NativeConnection>>,
max_size: usize,
}
pub struct Pool {
config: PoolConfig,
inner: Arc<PoolInner>,
semaphore: Semaphore,
}
impl Pool {
pub fn new(config: PoolConfig) -> Self {
let max_size = config.max_size;
let semaphore = Semaphore::new(max_size);
Self {
config,
inner: Arc::new(PoolInner {
idle: Mutex::new(VecDeque::new()),
max_size,
}),
semaphore,
}
}
pub async fn acquire(&self) -> NodeDbResult<PooledConnection<'_>> {
let permit = tokio::time::timeout(self.config.connect_timeout, self.semaphore.acquire())
.await
.map_err(|_| NodeDbError::sync_connection_failed("pool acquire timeout"))?
.map_err(|_| NodeDbError::sync_connection_failed("pool closed"))?;
let idle_conn = {
let mut idle = self.inner.idle.lock().unwrap_or_else(|e| e.into_inner());
idle.pop_front()
};
if let Some(mut conn) = idle_conn {
if conn.ping().await.is_ok() {
return Ok(PooledConnection {
conn: Some(conn),
inner: Arc::clone(&self.inner),
_permit: permit,
});
}
}
let addr = self.config.addr.clone();
let tls_cfg = self.config.tls.clone();
let timeout = self.config.connect_timeout;
let mut conn = tokio::time::timeout(timeout, async move {
if tls_cfg.enabled {
NativeConnection::connect_tls(&addr, &tls_cfg).await
} else {
NativeConnection::connect(&addr).await
}
})
.await
.map_err(|_| NodeDbError::sync_connection_failed("connect timeout"))??;
conn.authenticate(self.config.auth.clone()).await?;
Ok(PooledConnection {
conn: Some(conn),
inner: Arc::clone(&self.inner),
_permit: permit,
})
}
}
pub struct PooledConnection<'a> {
conn: Option<NativeConnection>,
inner: Arc<PoolInner>,
_permit: SemaphorePermit<'a>,
}
impl std::ops::Deref for PooledConnection<'_> {
type Target = NativeConnection;
fn deref(&self) -> &NativeConnection {
self.conn.as_ref().expect("connection taken")
}
}
impl std::ops::DerefMut for PooledConnection<'_> {
fn deref_mut(&mut self) -> &mut NativeConnection {
self.conn.as_mut().expect("connection taken")
}
}
impl Drop for PooledConnection<'_> {
fn drop(&mut self) {
if let Some(conn) = self.conn.take() {
let mut idle = self.inner.idle.lock().unwrap_or_else(|e| e.into_inner());
if idle.len() < self.inner.max_size {
idle.push_back(conn);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pool_config_defaults() {
let cfg = PoolConfig::default();
assert_eq!(cfg.addr, "127.0.0.1:6433");
assert_eq!(cfg.max_size, 10);
assert_eq!(cfg.connect_timeout, Duration::from_secs(5));
}
#[test]
fn pool_creates_semaphore() {
let pool = Pool::new(PoolConfig {
max_size: 5,
..Default::default()
});
assert_eq!(pool.semaphore.available_permits(), 5);
}
}