use fast_pool::{Manager, Pool};
use std::fmt::Display;
use std::ops::{Deref, DerefMut};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
#[derive(Debug, Clone)]
pub struct TestManager {}
#[derive(Debug, Clone)]
pub struct TestConnection {
pub inner: String,
}
impl TestConnection {
pub fn new() -> TestConnection {
TestConnection {
inner: "".to_string(),
}
}
}
impl Display for TestConnection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.inner)
}
}
impl Deref for TestConnection {
type Target = String;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl DerefMut for TestConnection {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
impl Manager for TestManager {
type Connection = TestConnection;
type Error = String;
async fn connect(&self) -> Result<Self::Connection, Self::Error> {
Ok(TestConnection::new())
}
async fn check(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> {
if conn.inner != "" {
return Err(Self::Error::from(&conn.to_string()));
}
Ok(())
}
}
#[derive(Clone)]
pub struct CheckCounterManager {
manager: TestManager,
check_count: Arc<AtomicUsize>,
}
impl CheckCounterManager {
fn new() -> Self {
Self {
manager: TestManager {},
check_count: Arc::new(AtomicUsize::new(0)),
}
}
fn get_check_count(&self) -> usize {
self.check_count.load(Ordering::SeqCst)
}
}
impl Manager for CheckCounterManager {
type Connection = TestConnection;
type Error = String;
async fn connect(&self) -> Result<Self::Connection, Self::Error> {
self.manager.connect().await
}
async fn check(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> {
self.check_count.fetch_add(1, Ordering::SeqCst);
self.manager.check(conn).await
}
}
struct DurationManager<M: Manager> {
pub manager: M,
pub duration: Duration,
pub last_check: std::sync::Mutex<Option<std::time::SystemTime>>,
}
impl<M: Manager> DurationManager<M> {
fn new(manager: M, duration: Duration) -> Self {
Self {
manager,
duration,
last_check: std::sync::Mutex::new(None),
}
}
}
impl<M: Manager> Manager for DurationManager<M> {
type Connection = M::Connection;
type Error = M::Error;
async fn connect(&self) -> Result<Self::Connection, Self::Error> {
self.manager.connect().await
}
async fn check(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> {
let now = std::time::SystemTime::now();
let should_check = {
let mut last_check = self.last_check.lock().unwrap();
let should_check = match *last_check {
None => true,
Some(time) => {
match now.duration_since(time) {
Ok(elapsed) => elapsed >= self.duration,
Err(_) => true, }
}
};
if should_check {
*last_check = Some(now);
}
should_check
};
if should_check {
self.manager.check(conn).await
} else {
Ok(())
}
}
}
#[tokio::test]
async fn test_check_duration_manager_basic() {
let counter_manager = CheckCounterManager::new();
let duration_manager =
DurationManager::new(counter_manager.clone(), Duration::from_millis(500));
let pool = Pool::new(duration_manager);
let conn = pool.get().await.unwrap();
drop(conn);
assert_eq!(
counter_manager.get_check_count(),
1,
"First check should happen"
);
let conn = pool.get().await.unwrap();
drop(conn);
assert_eq!(
counter_manager.get_check_count(),
1,
"Second immediate check should be skipped"
);
tokio::time::sleep(Duration::from_millis(550)).await;
let conn = pool.get().await.unwrap();
drop(conn);
assert_eq!(
counter_manager.get_check_count(),
2,
"Check should happen after duration"
);
}
#[tokio::test]
async fn test_check_duration_manager_concurrent() {
let counter_manager = CheckCounterManager::new();
let check_count_tracker = counter_manager.check_count.clone();
let duration_manager = DurationManager::new(counter_manager, Duration::from_millis(200));
let pool = Pool::new(duration_manager);
for _ in 0..10 {
let conn = pool.get().await.unwrap();
drop(conn);
}
assert_eq!(
check_count_tracker.load(Ordering::SeqCst),
1,
"Multiple rapid connections should only trigger one check"
);
tokio::time::sleep(Duration::from_millis(250)).await;
let conn = pool.get().await.unwrap();
drop(conn);
assert_eq!(
check_count_tracker.load(Ordering::SeqCst),
2,
"Check should happen after duration expires"
);
}
#[tokio::test]
async fn test_check_duration_manager_invalid_connection() {
let counter_manager = CheckCounterManager::new();
let check_count_tracker = counter_manager.check_count.clone();
let duration_manager = DurationManager::new(counter_manager, Duration::from_millis(100));
let pool = Pool::new(duration_manager);
let mut conn = pool.get().await.unwrap();
(*conn).inner = "invalid".to_string();
drop(conn);
assert_eq!(check_count_tracker.load(Ordering::SeqCst), 1);
let conn = pool.get().await.unwrap();
assert_eq!(
conn.inner.as_ref().unwrap().inner,
"invalid",
"Connection should still be invalid since check was skipped"
);
assert_eq!(
check_count_tracker.load(Ordering::SeqCst),
1,
"Check should be skipped due to duration limit"
);
tokio::time::sleep(Duration::from_millis(150)).await;
drop(conn);
let conn = pool.get().await.unwrap();
assert_eq!(
check_count_tracker.load(Ordering::SeqCst),
2,
"Check should happen after duration expires"
);
assert_eq!(
conn.inner.as_ref().unwrap().inner,
"",
"Connection should be valid after check is performed"
);
}