1use crate::backpressure::MkBackpressure;
4use std::collections::VecDeque;
5use std::future::Future;
6use std::ops::{Deref, DerefMut};
7use std::pin::Pin;
8use std::sync::atomic::{AtomicUsize, Ordering};
9use std::sync::{Arc, Mutex};
10use std::task::{Context, Poll};
11
12pub struct MkAsyncPool<T> {
14 inner: Arc<PoolInner<T>>,
15}
16
17struct PoolInner<T> {
18 items: Mutex<VecDeque<T>>,
19 capacity: usize,
20 available: AtomicUsize,
21 backpressure: MkBackpressure,
22}
23
24impl<T> MkAsyncPool<T> {
25 pub fn new(capacity: usize, backpressure: MkBackpressure) -> Self {
27 Self {
28 inner: Arc::new(PoolInner {
29 items: Mutex::new(VecDeque::with_capacity(capacity)),
30 capacity,
31 available: AtomicUsize::new(0),
32 backpressure,
33 }),
34 }
35 }
36
37 pub fn add(&self, item: T) -> Result<(), T> {
39 let mut items = self.inner.items.lock().unwrap();
40 if items.len() >= self.inner.capacity {
41 return Err(item);
42 }
43 items.push_back(item);
44 self.inner.available.fetch_add(1, Ordering::Release);
45 Ok(())
46 }
47
48 pub async fn acquire(&self) -> Option<MkPoolGuard<T>> {
50 loop {
51 if let Some(guard) = self.try_acquire() {
53 return Some(guard);
54 }
55
56 match self.inner.backpressure {
58 MkBackpressure::Fail => return None,
59 MkBackpressure::Wait => {
60 YieldOnce::new().await;
61 continue;
62 }
63 MkBackpressure::Timeout(duration) => {
64 let _ = duration;
66 return None;
67 }
68 MkBackpressure::Evict => {
69 return None;
70 }
71 }
72 }
73 }
74
75 pub fn try_acquire(&self) -> Option<MkPoolGuard<T>> {
77 let mut items = self.inner.items.lock().unwrap();
78 if let Some(item) = items.pop_front() {
79 self.inner.available.fetch_sub(1, Ordering::Acquire);
80 Some(MkPoolGuard {
81 item: Some(item),
82 pool: Arc::clone(&self.inner),
83 })
84 } else {
85 None
86 }
87 }
88
89 pub fn capacity(&self) -> usize {
91 self.inner.capacity
92 }
93
94 pub fn available(&self) -> usize {
96 self.inner.available.load(Ordering::Relaxed)
97 }
98}
99
100impl<T> Clone for MkAsyncPool<T> {
101 fn clone(&self) -> Self {
102 Self {
103 inner: Arc::clone(&self.inner),
104 }
105 }
106}
107
108pub struct MkPoolGuard<T> {
110 item: Option<T>,
111 pool: Arc<PoolInner<T>>,
112}
113
114impl<T> Deref for MkPoolGuard<T> {
115 type Target = T;
116
117 fn deref(&self) -> &Self::Target {
118 self.item.as_ref().unwrap()
119 }
120}
121
122impl<T> DerefMut for MkPoolGuard<T> {
123 fn deref_mut(&mut self) -> &mut Self::Target {
124 self.item.as_mut().unwrap()
125 }
126}
127
128impl<T> Drop for MkPoolGuard<T> {
129 fn drop(&mut self) {
130 if let Some(item) = self.item.take() {
131 let mut items = self.pool.items.lock().unwrap();
132 if items.len() < self.pool.capacity {
133 items.push_back(item);
134 self.pool.available.fetch_add(1, Ordering::Release);
135 }
136 }
137 }
138}
139
140struct YieldOnce(bool);
142
143impl YieldOnce {
144 fn new() -> Self {
145 Self(false)
146 }
147}
148
149impl Future for YieldOnce {
150 type Output = ();
151
152 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
153 if self.0 {
154 Poll::Ready(())
155 } else {
156 self.0 = true;
157 cx.waker().wake_by_ref();
158 Poll::Pending
159 }
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166
167 #[test]
168 fn test_pool_sync() {
169 let pool: MkAsyncPool<u32> = MkAsyncPool::new(3, MkBackpressure::Fail);
170
171 pool.add(1).unwrap();
172 pool.add(2).unwrap();
173 pool.add(3).unwrap();
174 assert!(pool.add(4).is_err());
175
176 assert_eq!(pool.available(), 3);
177
178 let guard = pool.try_acquire().unwrap();
179 assert_eq!(*guard, 1);
180 assert_eq!(pool.available(), 2);
181
182 drop(guard);
183 assert_eq!(pool.available(), 3);
184 }
185}