1use std::future::Future;
36use std::ops::Deref;
37use std::pin::Pin;
38use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
39use std::sync::Arc;
40use std::time::Duration;
41
42use crossbeam::queue::ArrayQueue;
43use tokio::sync::Notify;
44use tokio::time::timeout;
45
46pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
47pub type CreateFn<T> = Box<dyn Fn() -> BoxFuture<'static, Result<T, String>> + Send + Sync>;
48pub type ValidateFn<T> = Box<dyn Fn(&T) -> bool + Send + Sync>;
49
50#[derive(Debug, Clone, PartialEq, Eq)]
53pub enum PoolError {
54 Timeout,
55 Closed,
56 CreateFailed(String),
57}
58
59impl std::fmt::Display for PoolError {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 match self {
62 PoolError::Timeout => write!(f, "pool: timeout waiting for connection"),
63 PoolError::Closed => write!(f, "pool: closed"),
64 PoolError::CreateFailed(m) => write!(f, "pool: create failed: {m}"),
65 }
66 }
67}
68
69impl std::error::Error for PoolError {}
70
71#[derive(Debug, Clone)]
74pub struct PoolConfig {
75 pub max_size: u32,
76 pub create_timeout: Duration,
77 pub wait_timeout: Duration,
78}
79
80impl Default for PoolConfig {
81 fn default() -> Self {
82 Self {
83 max_size: 20,
84 create_timeout: Duration::from_secs(5),
85 wait_timeout: Duration::from_secs(10),
86 }
87 }
88}
89
90#[derive(Debug, Clone)]
93pub struct PoolStatus {
94 pub size: u32,
96 pub idle: u32,
98 pub max_size: u32,
100 pub closed: bool,
102}
103
104pub struct LockFreePool<T: Send + 'static> {
107 inner: Arc<PoolInner<T>>,
108}
109
110unsafe impl<T: Send + 'static> Send for LockFreePool<T> {}
113unsafe impl<T: Send + 'static> Sync for LockFreePool<T> {}
114
115impl<T: Send + 'static> Clone for LockFreePool<T> {
116 fn clone(&self) -> Self {
117 Self {
118 inner: self.inner.clone(),
119 }
120 }
121}
122
123pub struct PooledConnection<T: Send + 'static> {
134 inner: Option<T>,
135 pool: LockFreePool<T>,
136}
137
138unsafe impl<T: Send + 'static> Send for PooledConnection<T> {}
141unsafe impl<T: Send + 'static> Sync for PooledConnection<T> {}
142
143impl<T: Send + 'static> std::fmt::Debug for PooledConnection<T> {
144 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145 f.debug_struct("PooledConnection")
146 .field("connected", &self.inner.is_some())
147 .finish()
148 }
149}
150
151impl<T: Send + 'static> Deref for PooledConnection<T> {
152 type Target = T;
153 #[inline(always)]
154 fn deref(&self) -> &T {
155 unsafe { self.inner.as_ref().unwrap_unchecked() }
157 }
158}
159
160impl<T: Send + 'static> AsRef<T> for PooledConnection<T> {
161 #[inline(always)]
162 fn as_ref(&self) -> &T {
163 self.deref()
164 }
165}
166
167impl<T: Send + 'static> PooledConnection<T> {
168 pub fn take(mut self) -> T {
171 let conn = self.inner.take().unwrap();
172 self.pool.inner.size.0.fetch_sub(1, Ordering::Release);
173 conn
174 }
175
176 pub fn pool_status(&self) -> PoolStatus {
178 self.pool.status()
179 }
180}
181
182impl<T: Send + 'static> Drop for PooledConnection<T> {
186 #[inline]
187 fn drop(&mut self) {
188 if let Some(item) = self.inner.take() {
189 self.pool.inner.return_conn(item);
190 }
191 }
192}
193
194#[repr(C, align(64))]
205struct AlignedSize(AtomicU32);
206
207#[repr(C, align(64))]
210struct AlignedClosed {
211 closed: AtomicBool,
212 max_size: u32,
213}
214
215#[repr(C)]
218struct PoolInner<T: Send + 'static> {
219 size: AlignedSize,
224
225 closed: AlignedClosed,
230
231 create: CreateFn<T>,
237
238 validate: ValidateFn<T>,
240
241 idle: ArrayQueue<T>,
245
246 notify: Notify,
250
251 create_timeout: Duration,
253
254 wait_timeout: Duration,
256}
257
258unsafe impl<T: Send + 'static> Send for PoolInner<T> {}
260unsafe impl<T: Send + 'static> Sync for PoolInner<T> {}
261
262impl<T: Send + 'static> LockFreePool<T> {
263 pub fn new(
268 create: CreateFn<T>,
269 validate: ValidateFn<T>,
270 config: &PoolConfig,
271 ) -> Self {
272 let idle = ArrayQueue::new(config.max_size as usize);
274 Self {
275 inner: Arc::new(PoolInner {
276 size: AlignedSize(AtomicU32::new(0)),
277 closed: AlignedClosed {
278 closed: AtomicBool::new(false),
279 max_size: config.max_size,
280 },
281 create,
282 validate,
283 idle,
284 notify: Notify::new(),
285 create_timeout: config.create_timeout,
286 wait_timeout: config.wait_timeout,
287 }),
288 }
289 }
290
291 #[inline]
304 pub async fn acquire(&self) -> Result<PooledConnection<T>, PoolError> {
305 if self.inner.closed.closed.load(Ordering::Acquire) {
310 return Err(PoolError::Closed);
311 }
312
313 if let Some(item) = self.inner.idle.pop() {
316 if (self.inner.validate)(&item) {
319 return Ok(PooledConnection {
320 inner: Some(item),
321 pool: self.clone(),
322 });
323 }
324 drop(item);
327 self.inner.size.0.fetch_sub(1, Ordering::Release);
328 }
330
331 loop {
333 if self.inner.closed.closed.load(Ordering::Acquire) {
334 return Err(PoolError::Closed);
335 }
336
337 let current = self.inner.size.0.load(Ordering::Acquire);
339 if current < self.inner.closed.max_size {
340 if self.inner.size.0.compare_exchange_weak(
344 current,
345 current + 1,
346 Ordering::AcqRel,
347 Ordering::Relaxed,
348 ).is_ok() {
349 return match self.create_one().await {
351 Ok(item) => Ok(PooledConnection {
352 inner: Some(item),
353 pool: self.clone(),
354 }),
355 Err(e) => {
356 self.inner.size.0.fetch_sub(1, Ordering::Release);
358 self.inner.notify.notify_one();
359 Err(e)
360 }
361 };
362 }
363 continue;
365 }
366
367 if self.inner.wait_timeout == Duration::ZERO {
370 return Err(PoolError::Timeout);
371 }
372
373 let notified = self.inner.notify.notified();
376 tokio::select! {
377 _ = notified => {
378 if let Some(item) = self.inner.idle.pop() {
381 if (self.inner.validate)(&item) {
382 return Ok(PooledConnection {
383 inner: Some(item),
384 pool: self.clone(),
385 });
386 }
387 drop(item);
388 self.inner.size.0.fetch_sub(1, Ordering::Release);
389 }
390 continue;
395 }
396 _ = tokio::time::sleep(self.inner.wait_timeout) => {
397 if let Some(item) = self.inner.idle.pop() {
399 if (self.inner.validate)(&item) {
400 return Ok(PooledConnection {
401 inner: Some(item),
402 pool: self.clone(),
403 });
404 }
405 drop(item);
406 self.inner.size.0.fetch_sub(1, Ordering::Release);
407 }
408 return Err(PoolError::Timeout);
409 }
410 }
411 }
412 }
413
414 #[inline]
416 async fn create_one(&self) -> Result<T, PoolError> {
417 if self.inner.closed.closed.load(Ordering::Acquire) {
418 self.inner.size.0.fetch_sub(1, Ordering::Release);
419 return Err(PoolError::Closed);
420 }
421 match timeout(self.inner.create_timeout, (self.inner.create)()).await {
422 Ok(Ok(item)) => Ok(item),
423 Ok(Err(msg)) => Err(PoolError::CreateFailed(msg)),
424 Err(_) => Err(PoolError::CreateFailed("timeout".into())),
425 }
426 }
427
428 pub fn close(&self) {
429 self.inner.closed.closed.store(true, Ordering::Release);
430 self.inner.notify.notify_waiters();
431 while self.inner.idle.pop().is_some() {
432 self.inner.size.0.fetch_sub(1, Ordering::Relaxed);
433 }
434 }
435
436 pub fn is_closed(&self) -> bool {
437 self.inner.closed.closed.load(Ordering::Acquire)
438 }
439
440 #[inline]
441 pub fn status(&self) -> PoolStatus {
442 self.inner.status()
443 }
444
445 pub fn max_size(&self) -> u32 {
446 self.inner.closed.max_size
447 }
448}
449
450impl<T: Send + 'static> PoolInner<T> {
451 #[inline]
461 fn return_conn(&self, item: T) {
462 let closed = self.closed.closed.load(Ordering::Acquire);
463 if !closed {
464 match self.idle.push(item) {
465 Ok(()) => {
466 self.notify.notify_one();
467 return;
468 }
469 Err(dropped) => {
470 drop(dropped);
472 }
473 }
474 }
475 self.size.0.fetch_sub(1, Ordering::Release);
476 self.notify.notify_one();
477 }
478
479 #[inline]
480 fn status(&self) -> PoolStatus {
481 let size = self.size.0.load(Ordering::Acquire);
482 let idle = self.idle.len();
483 PoolStatus {
484 size,
485 idle: idle as u32,
486 max_size: self.closed.max_size,
487 closed: self.closed.closed.load(Ordering::Acquire),
488 }
489 }
490}
491
492impl<T: Send + 'static> Drop for PoolInner<T> {
495 fn drop(&mut self) {
496 while self.idle.pop().is_some() {}
498 }
499}
500
501#[cfg(test)]
504pub(crate) mod test_helpers {
505 use super::*;
506 use std::sync::atomic::{AtomicU32, Ordering as AtomicOrdering};
507
508 pub struct TestConnection {
510 pub id: u32,
511 pub valid: bool,
512 }
513
514 impl Drop for TestConnection {
515 fn drop(&mut self) {
516 }
518 }
519
520 pub fn create_test_pool(
521 max_size: u32,
522 fail_create: bool,
523 fail_validate: bool,
524 ) -> LockFreePool<TestConnection> {
525 let create_count = Arc::new(AtomicU32::new(0));
526
527 let create = {
528 let cc = create_count.clone();
529 Box::new(move || {
530 let count = cc.fetch_add(1, AtomicOrdering::Relaxed);
531 Box::pin(async move {
532 if fail_create {
533 Err("create failed".into())
534 } else {
535 Ok(TestConnection {
536 id: count,
537 valid: !fail_validate,
538 })
539 }
540 }) as BoxFuture<'static, Result<TestConnection, String>>
541 }) as CreateFn<TestConnection>
542 };
543
544 let validate = Box::new(move |conn: &TestConnection| conn.valid) as ValidateFn<TestConnection>;
545
546 let config = PoolConfig {
547 max_size,
548 create_timeout: Duration::from_secs(5),
549 wait_timeout: Duration::from_secs(10),
550 };
551
552 LockFreePool::new(create, validate, &config)
553 }
554}
555
556#[cfg(test)]
561mod tests {
562 use super::test_helpers::*;
563 use super::*;
564 use std::sync::atomic::{AtomicU32, Ordering as AtomicOrdering};
565 use std::sync::Arc;
566 use std::time::Duration;
567 use tokio::time::sleep;
568
569 #[tokio::test]
572 async fn test_acquire_release_one() {
573 let pool = create_test_pool(5, false, false);
574 assert!(!pool.is_closed());
575
576 let conn = pool.acquire().await.unwrap();
577 assert_eq!(conn.id, 0);
578 assert!(conn.valid);
579
580 let status = pool.status();
581 assert_eq!(status.size, 1);
582 assert_eq!(status.idle, 0);
583
584 drop(conn);
585 sleep(Duration::from_millis(10)).await;
586
587 let status = pool.status();
588 assert_eq!(status.idle, 1);
589 }
590
591 #[tokio::test]
592 async fn test_acquire_release_reuse() {
593 let pool = create_test_pool(5, false, false);
594
595 let conn1 = pool.acquire().await.unwrap();
596 let id1 = conn1.id;
597 drop(conn1);
598
599 sleep(Duration::from_millis(10)).await;
600
601 let conn2 = pool.acquire().await.unwrap();
602 assert_eq!(conn2.id, id1, "should reuse the same connection");
603 }
604
605 #[tokio::test]
606 async fn test_multiple_connections() {
607 let pool = create_test_pool(5, false, false);
608 let mut conns = Vec::new();
609 for _ in 0..5 {
610 let conn = pool.acquire().await.unwrap();
611 conns.push(conn);
612 }
613 assert_eq!(pool.status().size, 5);
614 assert_eq!(pool.status().idle, 0);
615 drop(conns);
616 }
617
618 #[tokio::test]
619 async fn test_acquire_multiple_release_reuse() {
620 let pool = create_test_pool(5, false, false);
621 let mut conns = Vec::new();
622
623 for _ in 0..5 {
624 conns.push(pool.acquire().await.unwrap());
625 }
626 let ids: Vec<u32> = conns.iter().map(|c| c.id).collect();
627 drop(conns);
628
629 sleep(Duration::from_millis(10)).await;
630
631 let mut reused = 0;
632 for _ in 0..5 {
633 let conn = pool.acquire().await.unwrap();
634 if ids.contains(&conn.id) {
635 reused += 1;
636 }
637 drop(conn);
638 }
639 assert!(reused >= 4, "most connections should be reused");
640 }
641
642 #[tokio::test]
645 async fn test_pool_exhaustion_short_timeout() {
646 let config = PoolConfig {
647 max_size: 1,
648 create_timeout: Duration::from_secs(1),
649 wait_timeout: Duration::from_millis(100),
650 };
651 let pool = LockFreePool::new(
652 Box::new(|| {
653 Box::pin(async { Ok(TestConnection { id: 0, valid: true }) })
654 as BoxFuture<'static, Result<TestConnection, String>>
655 }) as CreateFn<TestConnection>,
656 Box::new(|_conn: &TestConnection| true) as ValidateFn<TestConnection>,
657 &config,
658 );
659
660 let conn1 = pool.acquire().await.unwrap();
661 let result = pool.acquire().await;
662 assert!(result.is_err());
663 assert_eq!(result.unwrap_err(), PoolError::Timeout);
664 drop(conn1);
665 }
666
667 #[tokio::test]
668 async fn test_acquire_no_timeout_when_conn_returned() {
669 let config = PoolConfig {
671 max_size: 1,
672 create_timeout: Duration::from_secs(1),
673 wait_timeout: Duration::from_secs(5),
674 };
675 let pool = Arc::new(LockFreePool::new(
676 Box::new(|| {
677 Box::pin(async { Ok(TestConnection { id: 0, valid: true }) })
678 as BoxFuture<'static, Result<TestConnection, String>>
679 }) as CreateFn<TestConnection>,
680 Box::new(|_conn: &TestConnection| true) as ValidateFn<TestConnection>,
681 &config,
682 ));
683
684 let conn1 = pool.acquire().await.unwrap();
685 let pool_clone = pool.clone();
686
687 let handle = tokio::spawn(async move {
688 pool_clone.acquire().await
689 });
690
691 sleep(Duration::from_millis(50)).await;
692 drop(conn1);
693
694 let result = handle.await.unwrap();
695 assert!(result.is_ok(), "returned conn should unblock waiter");
696 }
697
698 #[tokio::test]
701 async fn test_validation_rejects_invalid_connections() {
702 let reject_count = Arc::new(AtomicU32::new(0));
705 let create_count = Arc::new(AtomicU32::new(0));
706
707 let create = {
708 let cc = create_count.clone();
709 Box::new(move || {
710 let id = cc.fetch_add(1, AtomicOrdering::Relaxed);
711 Box::pin(async move {
712 Ok(TestConnection { id, valid: true })
713 }) as BoxFuture<'static, Result<TestConnection, String>>
714 }) as CreateFn<TestConnection>
715 };
716
717 let validate = {
718 let rc = reject_count.clone();
719 Box::new(move |_conn: &TestConnection| {
720 rc.fetch_add(1, AtomicOrdering::Relaxed);
721 false
722 }) as ValidateFn<TestConnection>
723 };
724
725 let config = PoolConfig {
726 max_size: 5,
727 create_timeout: Duration::from_secs(5),
728 wait_timeout: Duration::from_secs(1),
729 };
730 let pool = LockFreePool::new(create, validate, &config);
731
732 let conn1 = pool.acquire().await.unwrap();
734 assert_eq!(conn1.id, 0);
735 drop(conn1); let conn2 = pool.acquire().await.unwrap();
740 assert_eq!(conn2.id, 1, "rejected idle conn should be replaced");
741
742 let rejected = reject_count.load(AtomicOrdering::Relaxed);
743 assert_eq!(rejected, 1, "validator should be called exactly once");
744
745 drop(conn2);
746 }
747
748 #[tokio::test]
751 async fn test_close() {
752 let pool = create_test_pool(5, false, false);
753 let conn = pool.acquire().await.unwrap();
754 assert!(!pool.is_closed());
755 pool.close();
756 assert!(pool.is_closed());
757 let result = pool.acquire().await;
759 assert!(result.is_err());
760 assert_eq!(result.unwrap_err(), PoolError::Closed);
761 drop(conn); }
763
764 #[tokio::test]
765 async fn test_close_with_waiter() {
766 let config = PoolConfig {
767 max_size: 1,
768 create_timeout: Duration::from_secs(1),
769 wait_timeout: Duration::from_secs(10),
770 };
771 let pool = Arc::new(LockFreePool::new(
772 Box::new(|| {
773 Box::pin(async { Ok(TestConnection { id: 0, valid: true }) })
774 as BoxFuture<'static, Result<TestConnection, String>>
775 }) as CreateFn<TestConnection>,
776 Box::new(|_conn: &TestConnection| true) as ValidateFn<TestConnection>,
777 &config,
778 ));
779
780 let conn1 = pool.acquire().await.unwrap();
781 let pool_clone = pool.clone();
782
783 let handle = tokio::spawn(async move {
785 pool_clone.acquire().await
786 });
787
788 sleep(Duration::from_millis(50)).await;
790
791 pool.close();
793 let result = handle.await.unwrap();
794 assert!(result.is_err());
795 assert_eq!(result.unwrap_err(), PoolError::Closed);
796 drop(conn1);
797 }
798
799 #[tokio::test]
802 async fn test_concurrent_acquire_release() {
803 let pool = Arc::new(create_test_pool(8, false, false));
804 let mut handles = Vec::new();
805
806 for _ in 0..16 {
807 let pool = pool.clone();
808 handles.push(tokio::spawn(async move {
809 for _ in 0..10 {
810 let conn = pool.acquire().await.unwrap();
811 sleep(Duration::from_millis(5)).await;
813 drop(conn); }
815 }));
816 }
817
818 for h in handles {
819 h.await.unwrap();
820 }
821
822 let status = pool.status();
823 assert!(status.size <= 8, "pool should not exceed max_size");
824 }
825
826 #[tokio::test]
827 async fn test_concurrent_stress_high_contention() {
828 let pool = Arc::new(create_test_pool(4, false, false));
829 let mut handles = Vec::new();
830
831 for _ in 0..32 {
832 let pool = pool.clone();
833 handles.push(tokio::spawn(async move {
834 for _ in 0..25 {
835 match pool.acquire().await {
836 Ok(conn) => {
837 tokio::task::yield_now().await;
839 drop(conn);
840 }
841 Err(PoolError::Timeout) => {
842 tokio::task::yield_now().await;
844 }
845 Err(e) => panic!("Unexpected error: {e}"),
846 }
847 }
848 }));
849 }
850
851 for h in handles {
852 h.await.unwrap();
853 }
854
855 let status = pool.status();
856 assert!(status.size <= 4, "pool exceeded max_size: {}", status.size);
857 assert!(!status.closed);
858 }
859
860 #[tokio::test]
863 async fn test_zero_wait_timeout() {
864 let config = PoolConfig {
865 max_size: 1,
866 create_timeout: Duration::from_secs(1),
867 wait_timeout: Duration::ZERO,
868 };
869 let pool = LockFreePool::new(
870 Box::new(|| {
871 Box::pin(async { Ok(TestConnection { id: 0, valid: true }) })
872 as BoxFuture<'static, Result<TestConnection, String>>
873 }) as CreateFn<TestConnection>,
874 Box::new(|_conn: &TestConnection| true) as ValidateFn<TestConnection>,
875 &config,
876 );
877
878 let _conn = pool.acquire().await.unwrap();
879 let result = pool.acquire().await;
881 assert_eq!(result.unwrap_err(), PoolError::Timeout);
882 }
883
884 #[tokio::test]
887 async fn test_create_failure() {
888 let pool = create_test_pool(5, true, false);
889 let result = pool.acquire().await;
890 assert!(result.is_err());
891 assert!(matches!(result.unwrap_err(), PoolError::CreateFailed(_)));
892 }
893
894 #[tokio::test]
897 async fn test_take_connection() {
898 let pool = create_test_pool(5, false, false);
899 let conn = pool.acquire().await.unwrap();
900 let id = conn.id;
901 let taken = PooledConnection::take(conn);
902 assert_eq!(taken.id, id);
903 let status = pool.status();
906 assert_eq!(status.size, 0); }
908
909 #[tokio::test]
912 async fn test_pool_clone() {
913 let pool = create_test_pool(5, false, false);
914 let pool2 = pool.clone();
915 let conn = pool2.acquire().await.unwrap();
916 assert!(conn.valid);
917 drop(conn);
918 }
919
920 #[tokio::test]
923 async fn test_close_with_active_connections() {
924 let pool = create_test_pool(5, false, false);
925 let conn1 = pool.acquire().await.unwrap();
926 let conn2 = pool.acquire().await.unwrap();
927 pool.close();
928 assert!(pool.is_closed());
929 let result = pool.acquire().await;
930 assert_eq!(result.unwrap_err(), PoolError::Closed);
931 drop(conn1);
933 drop(conn2);
934 }
935}