use std::future::Future;
use std::ops::Deref;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
use std::sync::Arc;
use std::time::Duration;
use crossbeam::queue::ArrayQueue;
use tokio::sync::Notify;
use tokio::time::timeout;
pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
pub type CreateFn<T> = Box<dyn Fn() -> BoxFuture<'static, Result<T, String>> + Send + Sync>;
pub type ValidateFn<T> = Box<dyn Fn(&T) -> bool + Send + Sync>;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PoolError {
Timeout,
Closed,
CreateFailed(String),
}
impl std::fmt::Display for PoolError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PoolError::Timeout => write!(f, "pool: timeout waiting for connection"),
PoolError::Closed => write!(f, "pool: closed"),
PoolError::CreateFailed(m) => write!(f, "pool: create failed: {m}"),
}
}
}
impl std::error::Error for PoolError {}
#[derive(Debug, Clone)]
pub struct PoolConfig {
pub max_size: u32,
pub create_timeout: Duration,
pub wait_timeout: Duration,
}
impl Default for PoolConfig {
fn default() -> Self {
Self {
max_size: 20,
create_timeout: Duration::from_secs(5),
wait_timeout: Duration::from_secs(10),
}
}
}
#[derive(Debug, Clone)]
pub struct PoolStatus {
pub size: u32,
pub idle: u32,
pub max_size: u32,
pub closed: bool,
}
pub struct LockFreePool<T: Send + 'static> {
inner: Arc<PoolInner<T>>,
}
unsafe impl<T: Send + 'static> Send for LockFreePool<T> {}
unsafe impl<T: Send + 'static> Sync for LockFreePool<T> {}
impl<T: Send + 'static> Clone for LockFreePool<T> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
}
}
}
pub struct PooledConnection<T: Send + 'static> {
inner: Option<T>,
pool: LockFreePool<T>,
}
unsafe impl<T: Send + 'static> Send for PooledConnection<T> {}
unsafe impl<T: Send + 'static> Sync for PooledConnection<T> {}
impl<T: Send + 'static> std::fmt::Debug for PooledConnection<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PooledConnection")
.field("connected", &self.inner.is_some())
.finish()
}
}
impl<T: Send + 'static> Deref for PooledConnection<T> {
type Target = T;
#[inline(always)]
fn deref(&self) -> &T {
unsafe { self.inner.as_ref().unwrap_unchecked() }
}
}
impl<T: Send + 'static> AsRef<T> for PooledConnection<T> {
#[inline(always)]
fn as_ref(&self) -> &T {
self.deref()
}
}
impl<T: Send + 'static> PooledConnection<T> {
pub fn take(mut self) -> T {
let conn = self.inner.take().unwrap();
self.pool.inner.size.0.fetch_sub(1, Ordering::Release);
conn
}
pub fn pool_status(&self) -> PoolStatus {
self.pool.status()
}
}
impl<T: Send + 'static> Drop for PooledConnection<T> {
#[inline]
fn drop(&mut self) {
if let Some(item) = self.inner.take() {
self.pool.inner.return_conn(item);
}
}
}
#[repr(C, align(64))]
struct AlignedSize(AtomicU32);
#[repr(C, align(64))]
struct AlignedClosed {
closed: AtomicBool,
max_size: u32,
}
#[repr(C)]
struct PoolInner<T: Send + 'static> {
size: AlignedSize,
closed: AlignedClosed,
create: CreateFn<T>,
validate: ValidateFn<T>,
idle: ArrayQueue<T>,
notify: Notify,
create_timeout: Duration,
wait_timeout: Duration,
}
unsafe impl<T: Send + 'static> Send for PoolInner<T> {}
unsafe impl<T: Send + 'static> Sync for PoolInner<T> {}
impl<T: Send + 'static> LockFreePool<T> {
pub fn new(
create: CreateFn<T>,
validate: ValidateFn<T>,
config: &PoolConfig,
) -> Self {
let idle = ArrayQueue::new(config.max_size as usize);
Self {
inner: Arc::new(PoolInner {
size: AlignedSize(AtomicU32::new(0)),
closed: AlignedClosed {
closed: AtomicBool::new(false),
max_size: config.max_size,
},
create,
validate,
idle,
notify: Notify::new(),
create_timeout: config.create_timeout,
wait_timeout: config.wait_timeout,
}),
}
}
#[inline]
pub async fn acquire(&self) -> Result<PooledConnection<T>, PoolError> {
if self.inner.closed.closed.load(Ordering::Acquire) {
return Err(PoolError::Closed);
}
if let Some(item) = self.inner.idle.pop() {
if (self.inner.validate)(&item) {
return Ok(PooledConnection {
inner: Some(item),
pool: self.clone(),
});
}
drop(item);
self.inner.size.0.fetch_sub(1, Ordering::Release);
}
loop {
if self.inner.closed.closed.load(Ordering::Acquire) {
return Err(PoolError::Closed);
}
let current = self.inner.size.0.load(Ordering::Acquire);
if current < self.inner.closed.max_size {
if self.inner.size.0.compare_exchange_weak(
current,
current + 1,
Ordering::AcqRel,
Ordering::Relaxed,
).is_ok() {
return match self.create_one().await {
Ok(item) => Ok(PooledConnection {
inner: Some(item),
pool: self.clone(),
}),
Err(e) => {
self.inner.size.0.fetch_sub(1, Ordering::Release);
self.inner.notify.notify_one();
Err(e)
}
};
}
continue;
}
if self.inner.wait_timeout == Duration::ZERO {
return Err(PoolError::Timeout);
}
let notified = self.inner.notify.notified();
tokio::select! {
_ = notified => {
if let Some(item) = self.inner.idle.pop() {
if (self.inner.validate)(&item) {
return Ok(PooledConnection {
inner: Some(item),
pool: self.clone(),
});
}
drop(item);
self.inner.size.0.fetch_sub(1, Ordering::Release);
}
continue;
}
_ = tokio::time::sleep(self.inner.wait_timeout) => {
if let Some(item) = self.inner.idle.pop() {
if (self.inner.validate)(&item) {
return Ok(PooledConnection {
inner: Some(item),
pool: self.clone(),
});
}
drop(item);
self.inner.size.0.fetch_sub(1, Ordering::Release);
}
return Err(PoolError::Timeout);
}
}
}
}
#[inline]
async fn create_one(&self) -> Result<T, PoolError> {
if self.inner.closed.closed.load(Ordering::Acquire) {
self.inner.size.0.fetch_sub(1, Ordering::Release);
return Err(PoolError::Closed);
}
match timeout(self.inner.create_timeout, (self.inner.create)()).await {
Ok(Ok(item)) => Ok(item),
Ok(Err(msg)) => Err(PoolError::CreateFailed(msg)),
Err(_) => Err(PoolError::CreateFailed("timeout".into())),
}
}
pub fn close(&self) {
self.inner.closed.closed.store(true, Ordering::Release);
self.inner.notify.notify_waiters();
while self.inner.idle.pop().is_some() {
self.inner.size.0.fetch_sub(1, Ordering::Relaxed);
}
}
pub fn is_closed(&self) -> bool {
self.inner.closed.closed.load(Ordering::Acquire)
}
#[inline]
pub fn status(&self) -> PoolStatus {
self.inner.status()
}
pub fn max_size(&self) -> u32 {
self.inner.closed.max_size
}
}
impl<T: Send + 'static> PoolInner<T> {
#[inline]
fn return_conn(&self, item: T) {
let closed = self.closed.closed.load(Ordering::Acquire);
if !closed {
match self.idle.push(item) {
Ok(()) => {
self.notify.notify_one();
return;
}
Err(dropped) => {
drop(dropped);
}
}
}
self.size.0.fetch_sub(1, Ordering::Release);
self.notify.notify_one();
}
#[inline]
fn status(&self) -> PoolStatus {
let size = self.size.0.load(Ordering::Acquire);
let idle = self.idle.len();
PoolStatus {
size,
idle: idle as u32,
max_size: self.closed.max_size,
closed: self.closed.closed.load(Ordering::Acquire),
}
}
}
impl<T: Send + 'static> Drop for PoolInner<T> {
fn drop(&mut self) {
while self.idle.pop().is_some() {}
}
}
#[cfg(test)]
pub(crate) mod test_helpers {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering as AtomicOrdering};
pub struct TestConnection {
pub id: u32,
pub valid: bool,
}
impl Drop for TestConnection {
fn drop(&mut self) {
}
}
pub fn create_test_pool(
max_size: u32,
fail_create: bool,
fail_validate: bool,
) -> LockFreePool<TestConnection> {
let create_count = Arc::new(AtomicU32::new(0));
let create = {
let cc = create_count.clone();
Box::new(move || {
let count = cc.fetch_add(1, AtomicOrdering::Relaxed);
Box::pin(async move {
if fail_create {
Err("create failed".into())
} else {
Ok(TestConnection {
id: count,
valid: !fail_validate,
})
}
}) as BoxFuture<'static, Result<TestConnection, String>>
}) as CreateFn<TestConnection>
};
let validate = Box::new(move |conn: &TestConnection| conn.valid) as ValidateFn<TestConnection>;
let config = PoolConfig {
max_size,
create_timeout: Duration::from_secs(5),
wait_timeout: Duration::from_secs(10),
};
LockFreePool::new(create, validate, &config)
}
}
#[cfg(test)]
mod tests {
use super::test_helpers::*;
use super::*;
use std::sync::atomic::{AtomicU32, Ordering as AtomicOrdering};
use std::sync::Arc;
use std::time::Duration;
use tokio::time::sleep;
#[tokio::test]
async fn test_acquire_release_one() {
let pool = create_test_pool(5, false, false);
assert!(!pool.is_closed());
let conn = pool.acquire().await.unwrap();
assert_eq!(conn.id, 0);
assert!(conn.valid);
let status = pool.status();
assert_eq!(status.size, 1);
assert_eq!(status.idle, 0);
drop(conn);
sleep(Duration::from_millis(10)).await;
let status = pool.status();
assert_eq!(status.idle, 1);
}
#[tokio::test]
async fn test_acquire_release_reuse() {
let pool = create_test_pool(5, false, false);
let conn1 = pool.acquire().await.unwrap();
let id1 = conn1.id;
drop(conn1);
sleep(Duration::from_millis(10)).await;
let conn2 = pool.acquire().await.unwrap();
assert_eq!(conn2.id, id1, "should reuse the same connection");
}
#[tokio::test]
async fn test_multiple_connections() {
let pool = create_test_pool(5, false, false);
let mut conns = Vec::new();
for _ in 0..5 {
let conn = pool.acquire().await.unwrap();
conns.push(conn);
}
assert_eq!(pool.status().size, 5);
assert_eq!(pool.status().idle, 0);
drop(conns);
}
#[tokio::test]
async fn test_acquire_multiple_release_reuse() {
let pool = create_test_pool(5, false, false);
let mut conns = Vec::new();
for _ in 0..5 {
conns.push(pool.acquire().await.unwrap());
}
let ids: Vec<u32> = conns.iter().map(|c| c.id).collect();
drop(conns);
sleep(Duration::from_millis(10)).await;
let mut reused = 0;
for _ in 0..5 {
let conn = pool.acquire().await.unwrap();
if ids.contains(&conn.id) {
reused += 1;
}
drop(conn);
}
assert!(reused >= 4, "most connections should be reused");
}
#[tokio::test]
async fn test_pool_exhaustion_short_timeout() {
let config = PoolConfig {
max_size: 1,
create_timeout: Duration::from_secs(1),
wait_timeout: Duration::from_millis(100),
};
let pool = LockFreePool::new(
Box::new(|| {
Box::pin(async { Ok(TestConnection { id: 0, valid: true }) })
as BoxFuture<'static, Result<TestConnection, String>>
}) as CreateFn<TestConnection>,
Box::new(|_conn: &TestConnection| true) as ValidateFn<TestConnection>,
&config,
);
let conn1 = pool.acquire().await.unwrap();
let result = pool.acquire().await;
assert!(result.is_err());
assert_eq!(result.unwrap_err(), PoolError::Timeout);
drop(conn1);
}
#[tokio::test]
async fn test_acquire_no_timeout_when_conn_returned() {
let config = PoolConfig {
max_size: 1,
create_timeout: Duration::from_secs(1),
wait_timeout: Duration::from_secs(5),
};
let pool = Arc::new(LockFreePool::new(
Box::new(|| {
Box::pin(async { Ok(TestConnection { id: 0, valid: true }) })
as BoxFuture<'static, Result<TestConnection, String>>
}) as CreateFn<TestConnection>,
Box::new(|_conn: &TestConnection| true) as ValidateFn<TestConnection>,
&config,
));
let conn1 = pool.acquire().await.unwrap();
let pool_clone = pool.clone();
let handle = tokio::spawn(async move {
pool_clone.acquire().await
});
sleep(Duration::from_millis(50)).await;
drop(conn1);
let result = handle.await.unwrap();
assert!(result.is_ok(), "returned conn should unblock waiter");
}
#[tokio::test]
async fn test_validation_rejects_invalid_connections() {
let reject_count = Arc::new(AtomicU32::new(0));
let create_count = Arc::new(AtomicU32::new(0));
let create = {
let cc = create_count.clone();
Box::new(move || {
let id = cc.fetch_add(1, AtomicOrdering::Relaxed);
Box::pin(async move {
Ok(TestConnection { id, valid: true })
}) as BoxFuture<'static, Result<TestConnection, String>>
}) as CreateFn<TestConnection>
};
let validate = {
let rc = reject_count.clone();
Box::new(move |_conn: &TestConnection| {
rc.fetch_add(1, AtomicOrdering::Relaxed);
false
}) as ValidateFn<TestConnection>
};
let config = PoolConfig {
max_size: 5,
create_timeout: Duration::from_secs(5),
wait_timeout: Duration::from_secs(1),
};
let pool = LockFreePool::new(create, validate, &config);
let conn1 = pool.acquire().await.unwrap();
assert_eq!(conn1.id, 0);
drop(conn1);
let conn2 = pool.acquire().await.unwrap();
assert_eq!(conn2.id, 1, "rejected idle conn should be replaced");
let rejected = reject_count.load(AtomicOrdering::Relaxed);
assert_eq!(rejected, 1, "validator should be called exactly once");
drop(conn2);
}
#[tokio::test]
async fn test_close() {
let pool = create_test_pool(5, false, false);
let conn = pool.acquire().await.unwrap();
assert!(!pool.is_closed());
pool.close();
assert!(pool.is_closed());
let result = pool.acquire().await;
assert!(result.is_err());
assert_eq!(result.unwrap_err(), PoolError::Closed);
drop(conn); }
#[tokio::test]
async fn test_close_with_waiter() {
let config = PoolConfig {
max_size: 1,
create_timeout: Duration::from_secs(1),
wait_timeout: Duration::from_secs(10),
};
let pool = Arc::new(LockFreePool::new(
Box::new(|| {
Box::pin(async { Ok(TestConnection { id: 0, valid: true }) })
as BoxFuture<'static, Result<TestConnection, String>>
}) as CreateFn<TestConnection>,
Box::new(|_conn: &TestConnection| true) as ValidateFn<TestConnection>,
&config,
));
let conn1 = pool.acquire().await.unwrap();
let pool_clone = pool.clone();
let handle = tokio::spawn(async move {
pool_clone.acquire().await
});
sleep(Duration::from_millis(50)).await;
pool.close();
let result = handle.await.unwrap();
assert!(result.is_err());
assert_eq!(result.unwrap_err(), PoolError::Closed);
drop(conn1);
}
#[tokio::test]
async fn test_concurrent_acquire_release() {
let pool = Arc::new(create_test_pool(8, false, false));
let mut handles = Vec::new();
for _ in 0..16 {
let pool = pool.clone();
handles.push(tokio::spawn(async move {
for _ in 0..10 {
let conn = pool.acquire().await.unwrap();
sleep(Duration::from_millis(5)).await;
drop(conn); }
}));
}
for h in handles {
h.await.unwrap();
}
let status = pool.status();
assert!(status.size <= 8, "pool should not exceed max_size");
}
#[tokio::test]
async fn test_concurrent_stress_high_contention() {
let pool = Arc::new(create_test_pool(4, false, false));
let mut handles = Vec::new();
for _ in 0..32 {
let pool = pool.clone();
handles.push(tokio::spawn(async move {
for _ in 0..25 {
match pool.acquire().await {
Ok(conn) => {
tokio::task::yield_now().await;
drop(conn);
}
Err(PoolError::Timeout) => {
tokio::task::yield_now().await;
}
Err(e) => panic!("Unexpected error: {e}"),
}
}
}));
}
for h in handles {
h.await.unwrap();
}
let status = pool.status();
assert!(status.size <= 4, "pool exceeded max_size: {}", status.size);
assert!(!status.closed);
}
#[tokio::test]
async fn test_zero_wait_timeout() {
let config = PoolConfig {
max_size: 1,
create_timeout: Duration::from_secs(1),
wait_timeout: Duration::ZERO,
};
let pool = LockFreePool::new(
Box::new(|| {
Box::pin(async { Ok(TestConnection { id: 0, valid: true }) })
as BoxFuture<'static, Result<TestConnection, String>>
}) as CreateFn<TestConnection>,
Box::new(|_conn: &TestConnection| true) as ValidateFn<TestConnection>,
&config,
);
let _conn = pool.acquire().await.unwrap();
let result = pool.acquire().await;
assert_eq!(result.unwrap_err(), PoolError::Timeout);
}
#[tokio::test]
async fn test_create_failure() {
let pool = create_test_pool(5, true, false);
let result = pool.acquire().await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), PoolError::CreateFailed(_)));
}
#[tokio::test]
async fn test_take_connection() {
let pool = create_test_pool(5, false, false);
let conn = pool.acquire().await.unwrap();
let id = conn.id;
let taken = PooledConnection::take(conn);
assert_eq!(taken.id, id);
let status = pool.status();
assert_eq!(status.size, 0); }
#[tokio::test]
async fn test_pool_clone() {
let pool = create_test_pool(5, false, false);
let pool2 = pool.clone();
let conn = pool2.acquire().await.unwrap();
assert!(conn.valid);
drop(conn);
}
#[tokio::test]
async fn test_close_with_active_connections() {
let pool = create_test_pool(5, false, false);
let conn1 = pool.acquire().await.unwrap();
let conn2 = pool.acquire().await.unwrap();
pool.close();
assert!(pool.is_closed());
let result = pool.acquire().await;
assert_eq!(result.unwrap_err(), PoolError::Closed);
drop(conn1);
drop(conn2);
}
}