1use std::future::Future;
36use std::ops::Deref;
37use std::pin::Pin;
38use std::sync::Arc;
39use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
40use std::time::Duration;
41
42use crossbeam::queue::ArrayQueue;
43use tokio::sync::Notify;
44use tokio::time::timeout;
45
46pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
47pub type CreateFn<T> = Box<dyn Fn() -> BoxFuture<'static, Result<T, String>> + Send + Sync>;
48pub type ValidateFn<T> = Box<dyn Fn(&T) -> bool + Send + Sync>;
49
50#[derive(Debug, Clone, PartialEq, Eq)]
53pub enum PoolError {
54 Timeout,
55 Closed,
56 CreateFailed(String),
57}
58
59impl std::fmt::Display for PoolError {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 match self {
62 PoolError::Timeout => write!(f, "pool: timeout waiting for connection"),
63 PoolError::Closed => write!(f, "pool: closed"),
64 PoolError::CreateFailed(m) => write!(f, "pool: create failed: {m}"),
65 }
66 }
67}
68
69impl std::error::Error for PoolError {}
70
71#[derive(Debug, Clone)]
74pub struct PoolConfig {
75 pub max_size: u32,
76 pub create_timeout: Duration,
77 pub wait_timeout: Duration,
78}
79
80impl Default for PoolConfig {
81 fn default() -> Self {
82 Self {
83 max_size: 20,
84 create_timeout: Duration::from_secs(5),
85 wait_timeout: Duration::from_secs(10),
86 }
87 }
88}
89
90#[derive(Debug, Clone)]
93pub struct PoolStatus {
94 pub size: u32,
96 pub idle: u32,
98 pub max_size: u32,
100 pub closed: bool,
102}
103
104pub struct LockFreePool<T: Send + 'static> {
107 inner: Arc<PoolInner<T>>,
108}
109
110unsafe impl<T: Send + 'static> Send for LockFreePool<T> {}
113unsafe impl<T: Send + 'static> Sync for LockFreePool<T> {}
114
115impl<T: Send + 'static> Clone for LockFreePool<T> {
116 fn clone(&self) -> Self {
117 Self {
118 inner: self.inner.clone(),
119 }
120 }
121}
122
123pub struct PooledConnection<T: Send + 'static> {
134 inner: Option<T>,
135 pool: LockFreePool<T>,
136}
137
138unsafe impl<T: Send + 'static> Send for PooledConnection<T> {}
141unsafe impl<T: Send + 'static> Sync for PooledConnection<T> {}
142
143impl<T: Send + 'static> std::fmt::Debug for PooledConnection<T> {
144 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145 f.debug_struct("PooledConnection")
146 .field("connected", &self.inner.is_some())
147 .finish()
148 }
149}
150
151impl<T: Send + 'static> Deref for PooledConnection<T> {
152 type Target = T;
153 #[inline(always)]
154 fn deref(&self) -> &T {
155 unsafe { self.inner.as_ref().unwrap_unchecked() }
157 }
158}
159
160impl<T: Send + 'static> AsRef<T> for PooledConnection<T> {
161 #[inline(always)]
162 fn as_ref(&self) -> &T {
163 self.deref()
164 }
165}
166
167impl<T: Send + 'static> PooledConnection<T> {
168 pub fn take(mut self) -> T {
171 let conn = self.inner.take().unwrap();
172 self.pool.inner.size.0.fetch_sub(1, Ordering::Release);
173 conn
174 }
175
176 pub fn pool_status(&self) -> PoolStatus {
178 self.pool.status()
179 }
180}
181
182impl<T: Send + 'static> Drop for PooledConnection<T> {
186 #[inline]
187 fn drop(&mut self) {
188 if let Some(item) = self.inner.take() {
189 self.pool.inner.return_conn(item);
190 }
191 }
192}
193
194#[repr(C, align(64))]
205struct AlignedSize(AtomicU32);
206
207#[repr(C, align(64))]
210struct AlignedClosed {
211 closed: AtomicBool,
212 max_size: u32,
213}
214
215#[repr(C)]
218struct PoolInner<T: Send + 'static> {
219 size: AlignedSize,
224
225 closed: AlignedClosed,
230
231 create: CreateFn<T>,
236
237 validate: ValidateFn<T>,
239
240 idle: ArrayQueue<T>,
244
245 notify: Notify,
249
250 create_timeout: Duration,
252
253 wait_timeout: Duration,
255}
256
257unsafe impl<T: Send + 'static> Send for PoolInner<T> {}
259unsafe impl<T: Send + 'static> Sync for PoolInner<T> {}
260
261impl<T: Send + 'static> LockFreePool<T> {
262 pub fn new(create: CreateFn<T>, validate: ValidateFn<T>, config: &PoolConfig) -> Self {
267 let idle = ArrayQueue::new(config.max_size as usize);
269 Self {
270 inner: Arc::new(PoolInner {
271 size: AlignedSize(AtomicU32::new(0)),
272 closed: AlignedClosed {
273 closed: AtomicBool::new(false),
274 max_size: config.max_size,
275 },
276 create,
277 validate,
278 idle,
279 notify: Notify::new(),
280 create_timeout: config.create_timeout,
281 wait_timeout: config.wait_timeout,
282 }),
283 }
284 }
285
286 #[inline]
299 pub async fn acquire(&self) -> Result<PooledConnection<T>, PoolError> {
300 if self.inner.closed.closed.load(Ordering::Acquire) {
305 return Err(PoolError::Closed);
306 }
307
308 if let Some(item) = self.inner.idle.pop() {
311 if (self.inner.validate)(&item) {
314 return Ok(PooledConnection {
315 inner: Some(item),
316 pool: self.clone(),
317 });
318 }
319 drop(item);
322 self.inner.size.0.fetch_sub(1, Ordering::Release);
323 }
325
326 loop {
328 if self.inner.closed.closed.load(Ordering::Acquire) {
329 return Err(PoolError::Closed);
330 }
331
332 let current = self.inner.size.0.load(Ordering::Acquire);
334 if current < self.inner.closed.max_size {
335 if self
339 .inner
340 .size
341 .0
342 .compare_exchange_weak(
343 current,
344 current + 1,
345 Ordering::AcqRel,
346 Ordering::Relaxed,
347 )
348 .is_ok()
349 {
350 return match self.create_one().await {
352 Ok(item) => Ok(PooledConnection {
353 inner: Some(item),
354 pool: self.clone(),
355 }),
356 Err(e) => {
357 self.inner.size.0.fetch_sub(1, Ordering::Release);
359 self.inner.notify.notify_one();
360 Err(e)
361 }
362 };
363 }
364 continue;
366 }
367
368 if self.inner.wait_timeout == Duration::ZERO {
371 return Err(PoolError::Timeout);
372 }
373
374 let notified = self.inner.notify.notified();
377 tokio::select! {
378 _ = notified => {
379 if let Some(item) = self.inner.idle.pop() {
382 if (self.inner.validate)(&item) {
383 return Ok(PooledConnection {
384 inner: Some(item),
385 pool: self.clone(),
386 });
387 }
388 drop(item);
389 self.inner.size.0.fetch_sub(1, Ordering::Release);
390 }
391 continue;
396 }
397 _ = tokio::time::sleep(self.inner.wait_timeout) => {
398 if let Some(item) = self.inner.idle.pop() {
400 if (self.inner.validate)(&item) {
401 return Ok(PooledConnection {
402 inner: Some(item),
403 pool: self.clone(),
404 });
405 }
406 drop(item);
407 self.inner.size.0.fetch_sub(1, Ordering::Release);
408 }
409 return Err(PoolError::Timeout);
410 }
411 }
412 }
413 }
414
415 #[inline]
417 async fn create_one(&self) -> Result<T, PoolError> {
418 if self.inner.closed.closed.load(Ordering::Acquire) {
419 self.inner.size.0.fetch_sub(1, Ordering::Release);
420 return Err(PoolError::Closed);
421 }
422 match timeout(self.inner.create_timeout, (self.inner.create)()).await {
423 Ok(Ok(item)) => Ok(item),
424 Ok(Err(msg)) => Err(PoolError::CreateFailed(msg)),
425 Err(_) => Err(PoolError::CreateFailed("timeout".into())),
426 }
427 }
428
429 pub fn close(&self) {
430 self.inner.closed.closed.store(true, Ordering::Release);
431 self.inner.notify.notify_waiters();
432 while self.inner.idle.pop().is_some() {
433 self.inner.size.0.fetch_sub(1, Ordering::Relaxed);
434 }
435 }
436
437 pub fn is_closed(&self) -> bool {
438 self.inner.closed.closed.load(Ordering::Acquire)
439 }
440
441 #[inline]
442 pub fn status(&self) -> PoolStatus {
443 self.inner.status()
444 }
445
446 pub fn max_size(&self) -> u32 {
447 self.inner.closed.max_size
448 }
449}
450
451impl<T: Send + 'static> PoolInner<T> {
452 #[inline]
462 fn return_conn(&self, item: T) {
463 let closed = self.closed.closed.load(Ordering::Acquire);
464 if !closed {
465 match self.idle.push(item) {
466 Ok(()) => {
467 self.notify.notify_one();
468 return;
469 }
470 Err(dropped) => {
471 drop(dropped);
473 }
474 }
475 }
476 self.size.0.fetch_sub(1, Ordering::Release);
477 self.notify.notify_one();
478 }
479
480 #[inline]
481 fn status(&self) -> PoolStatus {
482 let size = self.size.0.load(Ordering::Acquire);
483 let idle = self.idle.len();
484 PoolStatus {
485 size,
486 idle: idle as u32,
487 max_size: self.closed.max_size,
488 closed: self.closed.closed.load(Ordering::Acquire),
489 }
490 }
491}
492
493impl<T: Send + 'static> Drop for PoolInner<T> {
496 fn drop(&mut self) {
497 while self.idle.pop().is_some() {}
499 }
500}
501
502#[cfg(test)]
505pub(crate) mod test_helpers {
506 use super::*;
507 use std::sync::atomic::{AtomicU32, Ordering as AtomicOrdering};
508
509 pub struct TestConnection {
511 pub id: u32,
512 pub valid: bool,
513 }
514
515 impl Drop for TestConnection {
516 fn drop(&mut self) {
517 }
519 }
520
521 pub fn create_test_pool(
522 max_size: u32,
523 fail_create: bool,
524 fail_validate: bool,
525 ) -> LockFreePool<TestConnection> {
526 let create_count = Arc::new(AtomicU32::new(0));
527
528 let create = {
529 let cc = create_count;
530 Box::new(move || {
531 let count = cc.fetch_add(1, AtomicOrdering::Relaxed);
532 Box::pin(async move {
533 if fail_create {
534 Err("create failed".into())
535 } else {
536 Ok(TestConnection {
537 id: count,
538 valid: !fail_validate,
539 })
540 }
541 }) as BoxFuture<'static, Result<TestConnection, String>>
542 }) as CreateFn<TestConnection>
543 };
544
545 let validate =
546 Box::new(move |conn: &TestConnection| conn.valid) as ValidateFn<TestConnection>;
547
548 let config = PoolConfig {
549 max_size,
550 create_timeout: Duration::from_secs(5),
551 wait_timeout: Duration::from_secs(10),
552 };
553
554 LockFreePool::new(create, validate, &config)
555 }
556}
557
558#[cfg(test)]
563mod tests {
564 use super::test_helpers::*;
565 use super::*;
566 use std::sync::Arc;
567 use std::sync::atomic::{AtomicU32, Ordering as AtomicOrdering};
568 use std::time::Duration;
569 use tokio::time::sleep;
570
571 #[tokio::test]
574 async fn test_acquire_release_one() {
575 let pool = create_test_pool(5, false, false);
576 assert!(!pool.is_closed());
577
578 let conn = pool.acquire().await.unwrap();
579 assert_eq!(conn.id, 0);
580 assert!(conn.valid);
581
582 let status = pool.status();
583 assert_eq!(status.size, 1);
584 assert_eq!(status.idle, 0);
585
586 drop(conn);
587 sleep(Duration::from_millis(10)).await;
588
589 let status = pool.status();
590 assert_eq!(status.idle, 1);
591 }
592
593 #[tokio::test]
594 async fn test_acquire_release_reuse() {
595 let pool = create_test_pool(5, false, false);
596
597 let conn1 = pool.acquire().await.unwrap();
598 let id1 = conn1.id;
599 drop(conn1);
600
601 sleep(Duration::from_millis(10)).await;
602
603 let conn2 = pool.acquire().await.unwrap();
604 assert_eq!(conn2.id, id1, "should reuse the same connection");
605 }
606
607 #[tokio::test]
608 async fn test_multiple_connections() {
609 let pool = create_test_pool(5, false, false);
610 let mut conns = Vec::new();
611 for _ in 0..5 {
612 let conn = pool.acquire().await.unwrap();
613 conns.push(conn);
614 }
615 assert_eq!(pool.status().size, 5);
616 assert_eq!(pool.status().idle, 0);
617 drop(conns);
618 }
619
620 #[tokio::test]
621 async fn test_acquire_multiple_release_reuse() {
622 let pool = create_test_pool(5, false, false);
623 let mut conns = Vec::new();
624
625 for _ in 0..5 {
626 conns.push(pool.acquire().await.unwrap());
627 }
628 let ids: Vec<u32> = conns.iter().map(|c| c.id).collect();
629 drop(conns);
630
631 sleep(Duration::from_millis(10)).await;
632
633 let mut reused = 0;
634 for _ in 0..5 {
635 let conn = pool.acquire().await.unwrap();
636 if ids.contains(&conn.id) {
637 reused += 1;
638 }
639 drop(conn);
640 }
641 assert!(reused >= 4, "most connections should be reused");
642 }
643
644 #[tokio::test]
647 async fn test_pool_exhaustion_short_timeout() {
648 let config = PoolConfig {
649 max_size: 1,
650 create_timeout: Duration::from_secs(1),
651 wait_timeout: Duration::from_millis(100),
652 };
653 let pool = LockFreePool::new(
654 Box::new(|| {
655 Box::pin(async { Ok(TestConnection { id: 0, valid: true }) })
656 as BoxFuture<'static, Result<TestConnection, String>>
657 }) as CreateFn<TestConnection>,
658 Box::new(|_conn: &TestConnection| true) as ValidateFn<TestConnection>,
659 &config,
660 );
661
662 let conn1 = pool.acquire().await.unwrap();
663 let result = pool.acquire().await;
664 assert!(result.is_err());
665 assert_eq!(result.unwrap_err(), PoolError::Timeout);
666 drop(conn1);
667 }
668
669 #[tokio::test]
670 async fn test_acquire_no_timeout_when_conn_returned() {
671 let config = PoolConfig {
673 max_size: 1,
674 create_timeout: Duration::from_secs(1),
675 wait_timeout: Duration::from_secs(5),
676 };
677 let pool = Arc::new(LockFreePool::new(
678 Box::new(|| {
679 Box::pin(async { Ok(TestConnection { id: 0, valid: true }) })
680 as BoxFuture<'static, Result<TestConnection, String>>
681 }) as CreateFn<TestConnection>,
682 Box::new(|_conn: &TestConnection| true) as ValidateFn<TestConnection>,
683 &config,
684 ));
685
686 let conn1 = pool.acquire().await.unwrap();
687 let pool_clone = pool.clone();
688
689 let handle = tokio::spawn(async move { pool_clone.acquire().await });
690
691 sleep(Duration::from_millis(50)).await;
692 drop(conn1);
693
694 let result = handle.await.unwrap();
695 assert!(result.is_ok(), "returned conn should unblock waiter");
696 }
697
698 #[tokio::test]
701 async fn test_validation_rejects_invalid_connections() {
702 let reject_count = Arc::new(AtomicU32::new(0));
705 let create_count = Arc::new(AtomicU32::new(0));
706
707 let create = {
708 let cc = create_count.clone();
709 Box::new(move || {
710 let id = cc.fetch_add(1, AtomicOrdering::Relaxed);
711 Box::pin(async move { Ok(TestConnection { id, valid: true }) })
712 as BoxFuture<'static, Result<TestConnection, String>>
713 }) as CreateFn<TestConnection>
714 };
715
716 let validate = {
717 let rc = reject_count.clone();
718 Box::new(move |_conn: &TestConnection| {
719 rc.fetch_add(1, AtomicOrdering::Relaxed);
720 false
721 }) as ValidateFn<TestConnection>
722 };
723
724 let config = PoolConfig {
725 max_size: 5,
726 create_timeout: Duration::from_secs(5),
727 wait_timeout: Duration::from_secs(1),
728 };
729 let pool = LockFreePool::new(create, validate, &config);
730
731 let conn1 = pool.acquire().await.unwrap();
733 assert_eq!(conn1.id, 0);
734 drop(conn1); let conn2 = pool.acquire().await.unwrap();
739 assert_eq!(conn2.id, 1, "rejected idle conn should be replaced");
740
741 let rejected = reject_count.load(AtomicOrdering::Relaxed);
742 assert_eq!(rejected, 1, "validator should be called exactly once");
743
744 drop(conn2);
745 }
746
747 #[tokio::test]
750 async fn test_close() {
751 let pool = create_test_pool(5, false, false);
752 let conn = pool.acquire().await.unwrap();
753 assert!(!pool.is_closed());
754 pool.close();
755 assert!(pool.is_closed());
756 let result = pool.acquire().await;
758 assert!(result.is_err());
759 assert_eq!(result.unwrap_err(), PoolError::Closed);
760 drop(conn); }
762
763 #[tokio::test]
764 async fn test_close_with_waiter() {
765 let config = PoolConfig {
766 max_size: 1,
767 create_timeout: Duration::from_secs(1),
768 wait_timeout: Duration::from_secs(10),
769 };
770 let pool = Arc::new(LockFreePool::new(
771 Box::new(|| {
772 Box::pin(async { Ok(TestConnection { id: 0, valid: true }) })
773 as BoxFuture<'static, Result<TestConnection, String>>
774 }) as CreateFn<TestConnection>,
775 Box::new(|_conn: &TestConnection| true) as ValidateFn<TestConnection>,
776 &config,
777 ));
778
779 let conn1 = pool.acquire().await.unwrap();
780 let pool_clone = pool.clone();
781
782 let handle = tokio::spawn(async move { pool_clone.acquire().await });
784
785 sleep(Duration::from_millis(50)).await;
787
788 pool.close();
790 let result = handle.await.unwrap();
791 assert!(result.is_err());
792 assert_eq!(result.unwrap_err(), PoolError::Closed);
793 drop(conn1);
794 }
795
796 #[tokio::test]
799 async fn test_concurrent_acquire_release() {
800 let pool = Arc::new(create_test_pool(8, false, false));
801 let mut handles = Vec::new();
802
803 for _ in 0..16 {
804 let pool = pool.clone();
805 handles.push(tokio::spawn(async move {
806 for _ in 0..10 {
807 let conn = pool.acquire().await.unwrap();
808 sleep(Duration::from_millis(5)).await;
810 drop(conn); }
812 }));
813 }
814
815 for h in handles {
816 h.await.unwrap();
817 }
818
819 let status = pool.status();
820 assert!(status.size <= 8, "pool should not exceed max_size");
821 }
822
823 #[tokio::test]
824 async fn test_concurrent_stress_high_contention() {
825 let pool = Arc::new(create_test_pool(4, false, false));
826 let mut handles = Vec::new();
827
828 for _ in 0..32 {
829 let pool = pool.clone();
830 handles.push(tokio::spawn(async move {
831 for _ in 0..25 {
832 match pool.acquire().await {
833 Ok(conn) => {
834 tokio::task::yield_now().await;
836 drop(conn);
837 }
838 Err(PoolError::Timeout) => {
839 tokio::task::yield_now().await;
841 }
842 Err(e) => panic!("Unexpected error: {e}"),
843 }
844 }
845 }));
846 }
847
848 for h in handles {
849 h.await.unwrap();
850 }
851
852 let status = pool.status();
853 assert!(status.size <= 4, "pool exceeded max_size: {}", status.size);
854 assert!(!status.closed);
855 }
856
857 #[tokio::test]
860 async fn test_zero_wait_timeout() {
861 let config = PoolConfig {
862 max_size: 1,
863 create_timeout: Duration::from_secs(1),
864 wait_timeout: Duration::ZERO,
865 };
866 let pool = LockFreePool::new(
867 Box::new(|| {
868 Box::pin(async { Ok(TestConnection { id: 0, valid: true }) })
869 as BoxFuture<'static, Result<TestConnection, String>>
870 }) as CreateFn<TestConnection>,
871 Box::new(|_conn: &TestConnection| true) as ValidateFn<TestConnection>,
872 &config,
873 );
874
875 let _conn = pool.acquire().await.unwrap();
876 let result = pool.acquire().await;
878 assert_eq!(result.unwrap_err(), PoolError::Timeout);
879 }
880
881 #[tokio::test]
884 async fn test_create_failure() {
885 let pool = create_test_pool(5, true, false);
886 let result = pool.acquire().await;
887 assert!(result.is_err());
888 assert!(matches!(result.unwrap_err(), PoolError::CreateFailed(_)));
889 }
890
891 #[tokio::test]
894 async fn test_take_connection() {
895 let pool = create_test_pool(5, false, false);
896 let conn = pool.acquire().await.unwrap();
897 let id = conn.id;
898 let taken = PooledConnection::take(conn);
899 assert_eq!(taken.id, id);
900 let status = pool.status();
903 assert_eq!(status.size, 0); }
905
906 #[tokio::test]
909 async fn test_pool_clone() {
910 let pool = create_test_pool(5, false, false);
911 let pool2 = pool.clone();
912 let conn = pool2.acquire().await.unwrap();
913 assert!(conn.valid);
914 drop(conn);
915 }
916
917 #[tokio::test]
920 async fn test_close_with_active_connections() {
921 let pool = create_test_pool(5, false, false);
922 let conn1 = pool.acquire().await.unwrap();
923 let conn2 = pool.acquire().await.unwrap();
924 pool.close();
925 assert!(pool.is_closed());
926 let result = pool.acquire().await;
927 assert_eq!(result.unwrap_err(), PoolError::Closed);
928 drop(conn1);
930 drop(conn2);
931 }
932}