use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{Semaphore, Mutex};
use tokio::time::timeout;
use crate::client::{Client, ClientBuilder};
use crate::error::{Error, Result};
#[derive(Debug, Clone)]
pub struct PoolConfig {
pub min_connections: usize,
pub max_connections: usize,
pub acquire_timeout: Duration,
pub connect_timeout: Duration,
pub max_idle_time: Option<Duration>,
}
impl Default for PoolConfig {
fn default() -> Self {
Self {
min_connections: 1,
max_connections: 10,
acquire_timeout: Duration::from_secs(30),
connect_timeout: Duration::from_secs(10),
max_idle_time: Some(Duration::from_secs(300)), }
}
}
impl PoolConfig {
pub fn new() -> Self {
Self::default()
}
pub fn min_connections(mut self, min: usize) -> Self {
self.min_connections = min;
self
}
pub fn max_connections(mut self, max: usize) -> Self {
self.max_connections = max;
self
}
pub fn acquire_timeout(mut self, timeout: Duration) -> Self {
self.acquire_timeout = timeout;
self
}
pub fn connect_timeout(mut self, timeout: Duration) -> Self {
self.connect_timeout = timeout;
self
}
pub fn max_idle_time(mut self, idle_time: Duration) -> Self {
self.max_idle_time = Some(idle_time);
self
}
}
struct PooledConnection {
client: Client,
_created_at: tokio::time::Instant,
last_used: tokio::time::Instant,
}
impl PooledConnection {
fn new(client: Client) -> Self {
let now = tokio::time::Instant::now();
Self {
client,
_created_at: now,
last_used: now,
}
}
fn touch(&mut self) {
self.last_used = tokio::time::Instant::now();
}
fn is_idle_too_long(&self, max_idle: Duration) -> bool {
self.last_used.elapsed() > max_idle
}
}
struct PoolInner {
connections: Vec<PooledConnection>,
config: PoolConfig,
builder: ClientBuilder,
}
impl PoolInner {
fn new(config: PoolConfig, builder: ClientBuilder) -> Self {
Self {
connections: Vec::with_capacity(config.max_connections),
config,
builder,
}
}
async fn create_connection(&self) -> Result<Client> {
timeout(
self.config.connect_timeout,
self.builder.clone().build(),
)
.await
.map_err(|_| Error::Connection("Connection timeout".to_string()))?
}
fn prune_idle_connections(&mut self) {
if let Some(max_idle) = self.config.max_idle_time {
self.connections.retain(|conn| !conn.is_idle_too_long(max_idle));
}
}
fn total_connections(&self) -> usize {
self.connections.len()
}
}
pub struct Pool {
inner: Arc<Mutex<PoolInner>>,
semaphore: Arc<Semaphore>,
config: PoolConfig,
}
impl Pool {
pub async fn new(config: PoolConfig, builder: ClientBuilder) -> Result<Self> {
if config.min_connections > config.max_connections {
return Err(Error::Configuration(
"min_connections cannot exceed max_connections".to_string(),
));
}
let semaphore = Arc::new(Semaphore::new(config.max_connections));
let inner = Arc::new(Mutex::new(PoolInner::new(config.clone(), builder)));
let pool = Self {
inner,
semaphore,
config,
};
pool.initialize_min_connections().await?;
Ok(pool)
}
async fn initialize_min_connections(&self) -> Result<()> {
let mut inner = self.inner.lock().await;
for _ in 0..self.config.min_connections {
let client = inner.create_connection().await?;
inner.connections.push(PooledConnection::new(client));
}
Ok(())
}
pub async fn acquire(&self) -> Result<PooledClient> {
let permit = timeout(
self.config.acquire_timeout,
self.semaphore.clone().acquire_owned(),
)
.await
.map_err(|_| Error::Connection("Pool acquire timeout".to_string()))?
.map_err(|_| Error::Connection("Semaphore closed".to_string()))?;
let mut inner = self.inner.lock().await;
if let Some(mut conn) = inner.connections.pop() {
conn.touch();
return Ok(PooledClient {
client: Some(conn.client),
pool: self.inner.clone(),
_permit: permit,
});
}
let client = inner.create_connection().await?;
Ok(PooledClient {
client: Some(client),
pool: self.inner.clone(),
_permit: permit,
})
}
pub async fn size(&self) -> usize {
self.inner.lock().await.total_connections()
}
pub async fn prune(&self) {
let mut inner = self.inner.lock().await;
inner.prune_idle_connections();
}
pub async fn close(&self) {
let mut inner = self.inner.lock().await;
inner.connections.clear();
}
}
impl Clone for Pool {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
semaphore: Arc::clone(&self.semaphore),
config: self.config.clone(),
}
}
}
pub struct PooledClient {
client: Option<Client>,
pool: Arc<Mutex<PoolInner>>,
_permit: tokio::sync::OwnedSemaphorePermit,
}
impl PooledClient {
pub fn client(&self) -> &Client {
self.client.as_ref().expect("Client should always be present")
}
pub fn client_mut(&mut self) -> &mut Client {
self.client.as_mut().expect("Client should always be present")
}
}
impl std::ops::Deref for PooledClient {
type Target = Client;
fn deref(&self) -> &Self::Target {
self.client()
}
}
impl std::ops::DerefMut for PooledClient {
fn deref_mut(&mut self) -> &mut Self::Target {
self.client_mut()
}
}
impl Drop for PooledClient {
fn drop(&mut self) {
if let Some(client) = self.client.take() {
let pool = Arc::clone(&self.pool);
tokio::spawn(async move {
let mut inner = pool.lock().await;
if inner.total_connections() < inner.config.max_connections {
inner.connections.push(PooledConnection::new(client));
}
});
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pool_config_defaults() {
let config = PoolConfig::default();
assert_eq!(config.min_connections, 1);
assert_eq!(config.max_connections, 10);
assert_eq!(config.acquire_timeout, Duration::from_secs(30));
}
#[test]
fn test_pool_config_builder() {
let config = PoolConfig::new()
.min_connections(5)
.max_connections(20)
.acquire_timeout(Duration::from_secs(10))
.connect_timeout(Duration::from_secs(5));
assert_eq!(config.min_connections, 5);
assert_eq!(config.max_connections, 20);
assert_eq!(config.acquire_timeout, Duration::from_secs(10));
assert_eq!(config.connect_timeout, Duration::from_secs(5));
}
#[test]
fn test_pooled_connection_idle_check() {
let client = Client::builder();
}
}