1use std::future::Future;
36use std::ops::Deref;
37use std::pin::Pin;
38use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
39use std::sync::Arc;
40use std::time::Duration;
41
42use crossbeam::queue::ArrayQueue;
43use tokio::sync::Notify;
44use tokio::time::timeout;
45
46pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
47pub type CreateFn<T> = Box<dyn Fn() -> BoxFuture<'static, Result<T, String>> + Send + Sync>;
48pub type ValidateFn<T> = Box<dyn Fn(&T) -> bool + Send + Sync>;
49
50#[derive(Debug, Clone, PartialEq, Eq)]
53pub enum PoolError {
54 Timeout,
55 Closed,
56 CreateFailed(String),
57}
58
59impl std::fmt::Display for PoolError {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 match self {
62 PoolError::Timeout => write!(f, "pool: timeout waiting for connection"),
63 PoolError::Closed => write!(f, "pool: closed"),
64 PoolError::CreateFailed(m) => write!(f, "pool: create failed: {m}"),
65 }
66 }
67}
68
69impl std::error::Error for PoolError {}
70
71#[derive(Debug, Clone)]
74pub struct PoolConfig {
75 pub max_size: u32,
76 pub create_timeout: Duration,
77 pub wait_timeout: Duration,
78}
79
80impl Default for PoolConfig {
81 fn default() -> Self {
82 Self {
83 max_size: 20,
84 create_timeout: Duration::from_secs(5),
85 wait_timeout: Duration::from_secs(10),
86 }
87 }
88}
89
90#[derive(Debug, Clone)]
93pub struct PoolStatus {
94 pub size: u32,
96 pub idle: u32,
98 pub max_size: u32,
100 pub closed: bool,
102}
103
104pub struct LockFreePool<T: Send + 'static> {
107 inner: Arc<PoolInner<T>>,
108}
109
110unsafe impl<T: Send + 'static> Send for LockFreePool<T> {}
113unsafe impl<T: Send + 'static> Sync for LockFreePool<T> {}
114
115impl<T: Send + 'static> Clone for LockFreePool<T> {
116 fn clone(&self) -> Self {
117 Self {
118 inner: self.inner.clone(),
119 }
120 }
121}
122
123pub struct PooledConnection<T: Send + 'static> {
134 inner: Option<T>,
135 pool: LockFreePool<T>,
136}
137
138unsafe impl<T: Send + 'static> Send for PooledConnection<T> {}
141unsafe impl<T: Send + 'static> Sync for PooledConnection<T> {}
142
143impl<T: Send + 'static> std::fmt::Debug for PooledConnection<T> {
144 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145 f.debug_struct("PooledConnection")
146 .field("connected", &self.inner.is_some())
147 .finish()
148 }
149}
150
151impl<T: Send + 'static> Deref for PooledConnection<T> {
152 type Target = T;
153 #[inline(always)]
154 fn deref(&self) -> &T {
155 unsafe { self.inner.as_ref().unwrap_unchecked() }
157 }
158}
159
160impl<T: Send + 'static> AsRef<T> for PooledConnection<T> {
161 #[inline(always)]
162 fn as_ref(&self) -> &T {
163 self.deref()
164 }
165}
166
167impl<T: Send + 'static> PooledConnection<T> {
168 pub fn take(mut self) -> T {
171 let conn = self.inner.take().unwrap();
172 self.pool.inner.size.fetch_sub(1, Ordering::Release);
173 conn
174 }
175
176 pub fn pool_status(&self) -> PoolStatus {
178 self.pool.status()
179 }
180}
181
182impl<T: Send + 'static> Drop for PooledConnection<T> {
186 #[inline]
187 fn drop(&mut self) {
188 if let Some(item) = self.inner.take() {
189 self.pool.inner.return_conn(item);
190 }
191 }
192}
193
194struct PoolInner<T: Send + 'static> {
197 create: CreateFn<T>,
199
200 validate: ValidateFn<T>,
202
203 idle: ArrayQueue<T>,
206
207 size: AtomicU32,
210
211 max_size: u32,
213
214 closed: AtomicBool,
216
217 notify: Notify,
221
222 create_timeout: Duration,
224
225 wait_timeout: Duration,
227}
228
229unsafe impl<T: Send + 'static> Send for PoolInner<T> {}
231unsafe impl<T: Send + 'static> Sync for PoolInner<T> {}
232
233impl<T: Send + 'static> LockFreePool<T> {
234 pub fn new(
239 create: CreateFn<T>,
240 validate: ValidateFn<T>,
241 config: PoolConfig,
242 ) -> Self {
243 let idle = ArrayQueue::new(config.max_size as usize);
245 Self {
246 inner: Arc::new(PoolInner {
247 create,
248 validate,
249 idle,
250 size: AtomicU32::new(0),
251 max_size: config.max_size,
252 closed: AtomicBool::new(false),
253 notify: Notify::new(),
254 create_timeout: config.create_timeout,
255 wait_timeout: config.wait_timeout,
256 }),
257 }
258 }
259
260 #[inline]
273 pub async fn acquire(&self) -> Result<PooledConnection<T>, PoolError> {
274 if self.inner.closed.load(Ordering::Acquire) {
279 return Err(PoolError::Closed);
280 }
281
282 if let Some(item) = self.inner.idle.pop() {
285 if (self.inner.validate)(&item) {
288 return Ok(PooledConnection {
289 inner: Some(item),
290 pool: self.clone(),
291 });
292 }
293 drop(item);
296 self.inner.size.fetch_sub(1, Ordering::Release);
297 }
299
300 loop {
302 if self.inner.closed.load(Ordering::Acquire) {
303 return Err(PoolError::Closed);
304 }
305
306 let current = self.inner.size.load(Ordering::Acquire);
308 if current < self.inner.max_size {
309 if self.inner.size.compare_exchange_weak(
313 current,
314 current + 1,
315 Ordering::AcqRel,
316 Ordering::Relaxed,
317 ).is_ok() {
318 return match self.create_one().await {
320 Ok(item) => Ok(PooledConnection {
321 inner: Some(item),
322 pool: self.clone(),
323 }),
324 Err(e) => {
325 self.inner.size.fetch_sub(1, Ordering::Release);
327 self.inner.notify.notify_one();
328 Err(e)
329 }
330 };
331 }
332 continue;
334 }
335
336 if self.inner.wait_timeout == Duration::ZERO {
339 return Err(PoolError::Timeout);
340 }
341
342 let notified = self.inner.notify.notified();
345 tokio::select! {
346 _ = notified => {
347 if let Some(item) = self.inner.idle.pop() {
350 if (self.inner.validate)(&item) {
351 return Ok(PooledConnection {
352 inner: Some(item),
353 pool: self.clone(),
354 });
355 }
356 drop(item);
357 self.inner.size.fetch_sub(1, Ordering::Release);
358 }
359 continue;
364 }
365 _ = tokio::time::sleep(self.inner.wait_timeout) => {
366 if let Some(item) = self.inner.idle.pop() {
368 if (self.inner.validate)(&item) {
369 return Ok(PooledConnection {
370 inner: Some(item),
371 pool: self.clone(),
372 });
373 }
374 drop(item);
375 self.inner.size.fetch_sub(1, Ordering::Release);
376 }
377 return Err(PoolError::Timeout);
378 }
379 }
380 }
381 }
382
383 #[inline]
385 async fn create_one(&self) -> Result<T, PoolError> {
386 if self.inner.closed.load(Ordering::Acquire) {
387 self.inner.size.fetch_sub(1, Ordering::Release);
388 return Err(PoolError::Closed);
389 }
390 match timeout(self.inner.create_timeout, (self.inner.create)()).await {
391 Ok(Ok(item)) => Ok(item),
392 Ok(Err(msg)) => Err(PoolError::CreateFailed(msg)),
393 Err(_) => Err(PoolError::CreateFailed("timeout".into())),
394 }
395 }
396
397 pub fn close(&self) {
398 self.inner.closed.store(true, Ordering::Release);
399 self.inner.notify.notify_waiters();
400 while self.inner.idle.pop().is_some() {
401 self.inner.size.fetch_sub(1, Ordering::Relaxed);
402 }
403 }
404
405 pub fn is_closed(&self) -> bool {
406 self.inner.closed.load(Ordering::Acquire)
407 }
408
409 #[inline]
410 pub fn status(&self) -> PoolStatus {
411 self.inner.status()
412 }
413
414 pub fn max_size(&self) -> u32 {
415 self.inner.max_size
416 }
417}
418
419impl<T: Send + 'static> PoolInner<T> {
420 #[inline]
430 fn return_conn(&self, item: T) {
431 let closed = self.closed.load(Ordering::Acquire);
432 if !closed {
433 match self.idle.push(item) {
434 Ok(()) => {
435 self.notify.notify_one();
436 return;
437 }
438 Err(dropped) => {
439 drop(dropped);
441 }
442 }
443 }
444 self.size.fetch_sub(1, Ordering::Release);
445 self.notify.notify_one();
446 }
447
448 #[inline]
449 fn status(&self) -> PoolStatus {
450 let size = self.size.load(Ordering::Acquire);
451 let idle = self.idle.len();
452 PoolStatus {
453 size,
454 idle: idle as u32,
455 max_size: self.max_size,
456 closed: self.closed.load(Ordering::Acquire),
457 }
458 }
459}
460
461impl<T: Send + 'static> Drop for PoolInner<T> {
464 fn drop(&mut self) {
465 while self.idle.pop().is_some() {}
467 }
468}
469
470#[cfg(test)]
473pub(crate) mod test_helpers {
474 use super::*;
475 use std::sync::atomic::{AtomicU32, Ordering as AtomicOrdering};
476
477 pub struct TestConnection {
479 pub id: u32,
480 pub valid: bool,
481 }
482
483 impl Drop for TestConnection {
484 fn drop(&mut self) {
485 }
487 }
488
489 pub fn create_test_pool(
490 max_size: u32,
491 fail_create: bool,
492 fail_validate: bool,
493 ) -> LockFreePool<TestConnection> {
494 let create_count = Arc::new(AtomicU32::new(0));
495
496 let create = {
497 let cc = create_count.clone();
498 Box::new(move || {
499 let count = cc.fetch_add(1, AtomicOrdering::Relaxed);
500 Box::pin(async move {
501 if fail_create {
502 Err("create failed".into())
503 } else {
504 Ok(TestConnection {
505 id: count,
506 valid: !fail_validate,
507 })
508 }
509 }) as BoxFuture<'static, Result<TestConnection, String>>
510 }) as CreateFn<TestConnection>
511 };
512
513 let validate = Box::new(move |conn: &TestConnection| conn.valid) as ValidateFn<TestConnection>;
514
515 let config = PoolConfig {
516 max_size,
517 create_timeout: Duration::from_secs(5),
518 wait_timeout: Duration::from_secs(10),
519 };
520
521 LockFreePool::new(create, validate, config)
522 }
523}
524
525#[cfg(test)]
530mod tests {
531 use super::test_helpers::*;
532 use super::*;
533 use std::sync::atomic::{AtomicU32, Ordering as AtomicOrdering};
534 use std::sync::Arc;
535 use std::time::Duration;
536 use tokio::time::sleep;
537
538 #[tokio::test]
541 async fn test_acquire_release_one() {
542 let pool = create_test_pool(5, false, false);
543 assert!(!pool.is_closed());
544
545 let conn = pool.acquire().await.unwrap();
546 assert_eq!(conn.id, 0);
547 assert!(conn.valid);
548
549 let status = pool.status();
550 assert_eq!(status.size, 1);
551 assert_eq!(status.idle, 0);
552
553 drop(conn);
554 sleep(Duration::from_millis(10)).await;
555
556 let status = pool.status();
557 assert_eq!(status.idle, 1);
558 }
559
560 #[tokio::test]
561 async fn test_acquire_release_reuse() {
562 let pool = create_test_pool(5, false, false);
563
564 let conn1 = pool.acquire().await.unwrap();
565 let id1 = conn1.id;
566 drop(conn1);
567
568 sleep(Duration::from_millis(10)).await;
569
570 let conn2 = pool.acquire().await.unwrap();
571 assert_eq!(conn2.id, id1, "should reuse the same connection");
572 }
573
574 #[tokio::test]
575 async fn test_multiple_connections() {
576 let pool = create_test_pool(5, false, false);
577 let mut conns = Vec::new();
578 for _ in 0..5 {
579 let conn = pool.acquire().await.unwrap();
580 conns.push(conn);
581 }
582 assert_eq!(pool.status().size, 5);
583 assert_eq!(pool.status().idle, 0);
584 drop(conns);
585 }
586
587 #[tokio::test]
588 async fn test_acquire_multiple_release_reuse() {
589 let pool = create_test_pool(5, false, false);
590 let mut conns = Vec::new();
591
592 for _ in 0..5 {
593 conns.push(pool.acquire().await.unwrap());
594 }
595 let ids: Vec<u32> = conns.iter().map(|c| c.id).collect();
596 drop(conns);
597
598 sleep(Duration::from_millis(10)).await;
599
600 let mut reused = 0;
601 for _ in 0..5 {
602 let conn = pool.acquire().await.unwrap();
603 if ids.contains(&conn.id) {
604 reused += 1;
605 }
606 drop(conn);
607 }
608 assert!(reused >= 4, "most connections should be reused");
609 }
610
611 #[tokio::test]
614 async fn test_pool_exhaustion_short_timeout() {
615 let config = PoolConfig {
616 max_size: 1,
617 create_timeout: Duration::from_secs(1),
618 wait_timeout: Duration::from_millis(100),
619 };
620 let pool = LockFreePool::new(
621 Box::new(|| {
622 Box::pin(async { Ok(TestConnection { id: 0, valid: true }) })
623 as BoxFuture<'static, Result<TestConnection, String>>
624 }) as CreateFn<TestConnection>,
625 Box::new(|_conn: &TestConnection| true) as ValidateFn<TestConnection>,
626 config,
627 );
628
629 let conn1 = pool.acquire().await.unwrap();
630 let result = pool.acquire().await;
631 assert!(result.is_err());
632 assert_eq!(result.unwrap_err(), PoolError::Timeout);
633 drop(conn1);
634 }
635
636 #[tokio::test]
637 async fn test_acquire_no_timeout_when_conn_returned() {
638 let config = PoolConfig {
640 max_size: 1,
641 create_timeout: Duration::from_secs(1),
642 wait_timeout: Duration::from_secs(5),
643 };
644 let pool = Arc::new(LockFreePool::new(
645 Box::new(|| {
646 Box::pin(async { Ok(TestConnection { id: 0, valid: true }) })
647 as BoxFuture<'static, Result<TestConnection, String>>
648 }) as CreateFn<TestConnection>,
649 Box::new(|_conn: &TestConnection| true) as ValidateFn<TestConnection>,
650 config,
651 ));
652
653 let conn1 = pool.acquire().await.unwrap();
654 let pool_clone = pool.clone();
655
656 let handle = tokio::spawn(async move {
657 pool_clone.acquire().await
658 });
659
660 sleep(Duration::from_millis(50)).await;
661 drop(conn1);
662
663 let result = handle.await.unwrap();
664 assert!(result.is_ok(), "returned conn should unblock waiter");
665 }
666
667 #[tokio::test]
670 async fn test_validation_rejects_invalid_connections() {
671 let reject_count = Arc::new(AtomicU32::new(0));
674 let create_count = Arc::new(AtomicU32::new(0));
675
676 let create = {
677 let cc = create_count.clone();
678 Box::new(move || {
679 let id = cc.fetch_add(1, AtomicOrdering::Relaxed);
680 Box::pin(async move {
681 Ok(TestConnection { id, valid: true })
682 }) as BoxFuture<'static, Result<TestConnection, String>>
683 }) as CreateFn<TestConnection>
684 };
685
686 let validate = {
687 let rc = reject_count.clone();
688 Box::new(move |_conn: &TestConnection| {
689 rc.fetch_add(1, AtomicOrdering::Relaxed);
690 false
691 }) as ValidateFn<TestConnection>
692 };
693
694 let pool = LockFreePool::new(
695 create,
696 validate,
697 PoolConfig {
698 max_size: 5,
699 create_timeout: Duration::from_secs(5),
700 wait_timeout: Duration::from_secs(1),
701 },
702 );
703
704 let conn1 = pool.acquire().await.unwrap();
706 assert_eq!(conn1.id, 0);
707 drop(conn1); let conn2 = pool.acquire().await.unwrap();
712 assert_eq!(conn2.id, 1, "rejected idle conn should be replaced");
713
714 let rejected = reject_count.load(AtomicOrdering::Relaxed);
715 assert_eq!(rejected, 1, "validator should be called exactly once");
716
717 drop(conn2);
718 }
719
720 #[tokio::test]
723 async fn test_close() {
724 let pool = create_test_pool(5, false, false);
725 let conn = pool.acquire().await.unwrap();
726 assert!(!pool.is_closed());
727 pool.close();
728 assert!(pool.is_closed());
729 let result = pool.acquire().await;
731 assert!(result.is_err());
732 assert_eq!(result.unwrap_err(), PoolError::Closed);
733 drop(conn); }
735
736 #[tokio::test]
737 async fn test_close_with_waiter() {
738 let config = PoolConfig {
739 max_size: 1,
740 create_timeout: Duration::from_secs(1),
741 wait_timeout: Duration::from_secs(10),
742 };
743 let pool = Arc::new(LockFreePool::new(
744 Box::new(|| {
745 Box::pin(async { Ok(TestConnection { id: 0, valid: true }) })
746 as BoxFuture<'static, Result<TestConnection, String>>
747 }) as CreateFn<TestConnection>,
748 Box::new(|_conn: &TestConnection| true) as ValidateFn<TestConnection>,
749 config,
750 ));
751
752 let conn1 = pool.acquire().await.unwrap();
753 let pool_clone = pool.clone();
754
755 let handle = tokio::spawn(async move {
757 pool_clone.acquire().await
758 });
759
760 sleep(Duration::from_millis(50)).await;
762
763 pool.close();
765 let result = handle.await.unwrap();
766 assert!(result.is_err());
767 assert_eq!(result.unwrap_err(), PoolError::Closed);
768 drop(conn1);
769 }
770
771 #[tokio::test]
774 async fn test_concurrent_acquire_release() {
775 let pool = Arc::new(create_test_pool(8, false, false));
776 let mut handles = Vec::new();
777
778 for _ in 0..16 {
779 let pool = pool.clone();
780 handles.push(tokio::spawn(async move {
781 for _ in 0..10 {
782 let conn = pool.acquire().await.unwrap();
783 sleep(Duration::from_millis(5)).await;
785 drop(conn); }
787 }));
788 }
789
790 for h in handles {
791 h.await.unwrap();
792 }
793
794 let status = pool.status();
795 assert!(status.size <= 8, "pool should not exceed max_size");
796 }
797
798 #[tokio::test]
799 async fn test_concurrent_stress_high_contention() {
800 let pool = Arc::new(create_test_pool(4, false, false));
801 let mut handles = Vec::new();
802
803 for _ in 0..32 {
804 let pool = pool.clone();
805 handles.push(tokio::spawn(async move {
806 for _ in 0..25 {
807 match pool.acquire().await {
808 Ok(conn) => {
809 tokio::task::yield_now().await;
811 drop(conn);
812 }
813 Err(PoolError::Timeout) => {
814 tokio::task::yield_now().await;
816 }
817 Err(e) => panic!("Unexpected error: {e}"),
818 }
819 }
820 }));
821 }
822
823 for h in handles {
824 h.await.unwrap();
825 }
826
827 let status = pool.status();
828 assert!(status.size <= 4, "pool exceeded max_size: {}", status.size);
829 assert!(!status.closed);
830 }
831
832 #[tokio::test]
835 async fn test_zero_wait_timeout() {
836 let config = PoolConfig {
837 max_size: 1,
838 create_timeout: Duration::from_secs(1),
839 wait_timeout: Duration::ZERO,
840 };
841 let pool = LockFreePool::new(
842 Box::new(|| {
843 Box::pin(async { Ok(TestConnection { id: 0, valid: true }) })
844 as BoxFuture<'static, Result<TestConnection, String>>
845 }) as CreateFn<TestConnection>,
846 Box::new(|_conn: &TestConnection| true) as ValidateFn<TestConnection>,
847 config,
848 );
849
850 let _conn = pool.acquire().await.unwrap();
851 let result = pool.acquire().await;
853 assert_eq!(result.unwrap_err(), PoolError::Timeout);
854 }
855
856 #[tokio::test]
859 async fn test_create_failure() {
860 let pool = create_test_pool(5, true, false);
861 let result = pool.acquire().await;
862 assert!(result.is_err());
863 assert!(matches!(result.unwrap_err(), PoolError::CreateFailed(_)));
864 }
865
866 #[tokio::test]
869 async fn test_take_connection() {
870 let pool = create_test_pool(5, false, false);
871 let conn = pool.acquire().await.unwrap();
872 let id = conn.id;
873 let taken = PooledConnection::take(conn);
874 assert_eq!(taken.id, id);
875 let status = pool.status();
878 assert_eq!(status.size, 0); }
880
881 #[tokio::test]
884 async fn test_pool_clone() {
885 let pool = create_test_pool(5, false, false);
886 let pool2 = pool.clone();
887 let conn = pool2.acquire().await.unwrap();
888 assert!(conn.valid);
889 drop(conn);
890 }
891
892 #[tokio::test]
895 async fn test_close_with_active_connections() {
896 let pool = create_test_pool(5, false, false);
897 let conn1 = pool.acquire().await.unwrap();
898 let conn2 = pool.acquire().await.unwrap();
899 pool.close();
900 assert!(pool.is_closed());
901 let result = pool.acquire().await;
902 assert_eq!(result.unwrap_err(), PoolError::Closed);
903 drop(conn1);
905 drop(conn2);
906 }
907}