1use std::sync::Arc;
7use tokio::sync::Semaphore;
8
9#[derive(Clone, Debug)]
30pub struct ConnectionPool {
31 pub db_path: std::path::PathBuf,
33 semaphore: Arc<Semaphore>,
35 pub max_connections: usize,
37}
38
39impl ConnectionPool {
40 pub fn new(db_path: impl AsRef<std::path::Path>, max_connections: usize) -> Self {
55 Self {
56 db_path: db_path.as_ref().to_path_buf(),
57 semaphore: Arc::new(Semaphore::new(max_connections)),
58 max_connections,
59 }
60 }
61
62 pub async fn acquire(&self) -> anyhow::Result<ConnectionPermit> {
85 let permit = self.semaphore.clone().acquire_owned().await?;
86 Ok(ConnectionPermit {
87 _permit: permit,
88 db_path: self.db_path.clone(),
89 })
90 }
91
92 pub fn available_connections(&self) -> usize {
103 self.semaphore.available_permits()
104 }
105
106 pub async fn try_acquire(&self) -> Option<ConnectionPermit> {
129 self.semaphore
130 .clone()
131 .try_acquire_owned()
132 .ok()
133 .map(|permit| ConnectionPermit {
134 _permit: permit,
135 db_path: self.db_path.clone(),
136 })
137 }
138}
139
140pub struct ConnectionPermit {
144 _permit: tokio::sync::OwnedSemaphorePermit,
145 db_path: std::path::PathBuf,
146}
147
148impl std::fmt::Debug for ConnectionPermit {
149 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150 f.debug_struct("ConnectionPermit")
151 .field("db_path", &self.db_path)
152 .finish()
153 }
154}
155
156impl ConnectionPermit {
157 pub fn db_path(&self) -> &std::path::Path {
173 &self.db_path
174 }
175}
176
177impl std::fmt::Display for ConnectionPermit {
178 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179 write!(f, "ConnectionPermit({})", self.db_path.display())
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186
187 #[tokio::test]
188 async fn test_pool_creation() {
189 let pool = ConnectionPool::new("/tmp/test.db", 5);
190 assert_eq!(pool.max_connections, 5);
191 assert_eq!(pool.available_connections(), 5);
192 }
193
194 #[tokio::test]
195 async fn test_pool_acquire() {
196 let pool = ConnectionPool::new("/tmp/test.db", 2);
197
198 let permit1 = pool.acquire().await.unwrap();
199 assert_eq!(pool.available_connections(), 1);
200
201 let permit2 = pool.acquire().await.unwrap();
202 assert_eq!(pool.available_connections(), 0);
203
204 drop(permit1);
206 assert_eq!(pool.available_connections(), 1);
207
208 drop(permit2);
209 assert_eq!(pool.available_connections(), 2);
210 }
211
212 #[tokio::test]
213 async fn test_pool_try_acquire() {
214 let pool = ConnectionPool::new("/tmp/test.db", 1);
215
216 let permit1 = pool.try_acquire().await;
217 assert!(permit1.is_some());
218 assert_eq!(pool.available_connections(), 0);
219
220 let permit2 = pool.try_acquire().await;
222 assert!(permit2.is_none());
223
224 drop(permit1);
225 assert_eq!(pool.available_connections(), 1);
226 }
227
228 #[tokio::test]
229 async fn test_pool_db_path() {
230 let pool = ConnectionPool::new("/tmp/test.db", 5);
231 assert_eq!(pool.db_path, std::path::PathBuf::from("/tmp/test.db"));
232
233 let permit = pool.acquire().await.unwrap();
234 assert_eq!(permit.db_path(), std::path::Path::new("/tmp/test.db"));
235 }
236
237 #[tokio::test]
238 async fn test_pool_concurrent_acquires() {
239 use tokio::sync::Barrier;
240
241 let pool = Arc::new(ConnectionPool::new("/tmp/test.db", 5));
242 let barrier = Arc::new(Barrier::new(10));
243 let mut handles = vec![];
244
245 for _i in 0..10 {
247 let pool_clone = pool.clone();
248 let barrier_clone = barrier.clone();
249 handles.push(tokio::spawn(async move {
250 barrier_clone.wait().await; let _permit = pool_clone.acquire().await.unwrap();
252 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
254 }));
256 }
257
258 for handle in handles {
260 handle.await.unwrap();
261 }
262
263 assert_eq!(pool.available_connections(), 5);
265 }
266
267 #[tokio::test]
268 async fn test_pool_timeout_behavior() {
269 use tokio::time::{timeout, Duration};
270
271 let pool = ConnectionPool::new("/tmp/test.db", 1);
272
273 let _permit1 = pool.acquire().await.unwrap();
275 assert_eq!(pool.available_connections(), 0);
276
277 let start = std::time::Instant::now();
279 let result = timeout(Duration::from_millis(100), pool.acquire()).await;
280
281 let elapsed = start.elapsed();
282
283 assert!(result.is_err());
285 assert!(elapsed >= Duration::from_millis(90));
287 assert!(elapsed < Duration::from_millis(200));
288 }
289
290 #[tokio::test]
291 async fn test_pool_permit_drop_returns() {
292 let pool = ConnectionPool::new("/tmp/test.db", 3);
293 assert_eq!(pool.available_connections(), 3);
294
295 let permit = pool.acquire().await.unwrap();
297 assert_eq!(pool.available_connections(), 2);
298
299 drop(permit);
301 assert_eq!(pool.available_connections(), 3);
302 }
303
304 #[tokio::test]
305 async fn test_pool_stress() {
306 let pool = ConnectionPool::new("/tmp/test.db", 10);
307
308 for _ in 0..100 {
310 let permit = pool.acquire().await.unwrap();
311 assert_eq!(permit.db_path(), std::path::Path::new("/tmp/test.db"));
313 drop(permit);
314 }
315
316 assert_eq!(pool.available_connections(), 10);
318 }
319
320 #[tokio::test]
321 async fn test_pool_all_permits_acquired() {
322 let pool = ConnectionPool::new("/tmp/test.db", 3);
323
324 let permit1 = pool.acquire().await.unwrap();
326 let permit2 = pool.acquire().await.unwrap();
327 let permit3 = pool.acquire().await.unwrap();
328
329 assert_eq!(pool.available_connections(), 0);
331
332 let permit4 = pool.try_acquire().await;
334 assert!(permit4.is_none());
335
336 drop(permit1);
338
339 let permit5 = pool.try_acquire().await;
341 assert!(permit5.is_some());
342
343 drop(permit2);
345 drop(permit3);
346 drop(permit5);
347 }
348
349 #[tokio::test]
350 async fn test_pool_available_count() {
351 let pool = ConnectionPool::new("/tmp/test.db", 5);
352
353 assert_eq!(pool.available_connections(), 5);
355
356 let permit1 = pool.acquire().await.unwrap();
358 assert_eq!(pool.available_connections(), 4);
359
360 let permit2 = pool.acquire().await.unwrap();
361 assert_eq!(pool.available_connections(), 3);
362
363 let permit3 = pool.acquire().await.unwrap();
364 assert_eq!(pool.available_connections(), 2);
365
366 drop(permit1);
368 assert_eq!(pool.available_connections(), 3);
369
370 drop(permit2);
371 assert_eq!(pool.available_connections(), 4);
372
373 drop(permit3);
374 assert_eq!(pool.available_connections(), 5);
375 }
376}