Skip to main content

forgekit_core/
pool.rs

1//! Connection pool for concurrent database access.
2//!
3//! This module provides a semaphore-based connection pool
4//! for limiting concurrent database connections.
5
6use std::sync::Arc;
7use tokio::sync::Semaphore;
8
9/// Connection pool for database connections.
10///
11/// The pool limits the number of concurrent connections
12/// using a semaphore permit system.
13///
14/// # Examples
15///
16/// ```no_run
17/// use forgekit_core::pool::ConnectionPool;
18///
19/// # #[tokio::main]
20/// # async fn main() -> anyhow::Result<()> {
21/// let pool = ConnectionPool::new("/path/to/db.sqlite", 10);
22///
23/// // Acquire a connection
24/// let _permit = pool.acquire().await?;
25/// // Use connection here
26/// #     Ok(())
27/// # }
28/// ```
29#[derive(Clone, Debug)]
30pub struct ConnectionPool {
31    /// Path to the database file.
32    pub db_path: std::path::PathBuf,
33    /// Semaphore for limiting connections.
34    semaphore: Arc<Semaphore>,
35    /// Maximum number of connections.
36    pub max_connections: usize,
37}
38
39impl ConnectionPool {
40    /// Creates a new connection pool.
41    ///
42    /// # Arguments
43    ///
44    /// * `db_path` - Path to the SQLite database file
45    /// * `max_connections` - Maximum number of concurrent connections
46    ///
47    /// # Examples
48    ///
49    /// ```no_run
50    /// use forgekit_core::pool::ConnectionPool;
51    ///
52    /// let pool = ConnectionPool::new("./db.sqlite", 10);
53    /// ```
54    pub fn new(db_path: impl AsRef<std::path::Path>, max_connections: usize) -> Self {
55        Self {
56            db_path: db_path.as_ref().to_path_buf(),
57            semaphore: Arc::new(Semaphore::new(max_connections)),
58            max_connections,
59        }
60    }
61
62    /// Acquires a permit from the pool.
63    ///
64    /// This will wait until a connection is available.
65    /// The permit is released when dropped.
66    ///
67    /// # Returns
68    ///
69    /// A `ConnectionPermit` that represents the acquired connection.
70    ///
71    /// # Examples
72    ///
73    /// ```no_run
74    /// # use forgekit_core::pool::ConnectionPool;
75    /// # #[tokio::main]
76    /// # async fn main() -> anyhow::Result<()> {
77    /// # let pool = ConnectionPool::new("./db.sqlite", 10);
78    /// let permit = pool.acquire().await?;
79    /// // Use connection
80    /// drop(permit); // Release back to pool
81    /// #     Ok(())
82    /// # }
83    /// ```
84    pub async fn acquire(&self) -> anyhow::Result<ConnectionPermit> {
85        let permit = self.semaphore.clone().acquire_owned().await?;
86        Ok(ConnectionPermit {
87            _permit: permit,
88            db_path: self.db_path.clone(),
89        })
90    }
91
92    /// Returns the current number of available connections.
93    ///
94    /// # Examples
95    ///
96    /// ```no_run
97    /// # use forgekit_core::pool::ConnectionPool;
98    /// # let pool = ConnectionPool::new("./db.sqlite", 10);
99    /// let available = pool.available_connections();
100    /// println!("Available connections: {}", available);
101    /// ```
102    pub fn available_connections(&self) -> usize {
103        self.semaphore.available_permits()
104    }
105
106    /// Tries to acquire a permit without waiting.
107    ///
108    /// # Returns
109    ///
110    /// - `Some(permit)` if a connection is immediately available
111    /// - `None` if all connections are in use
112    ///
113    /// # Examples
114    ///
115    /// ```no_run
116    /// # use forgekit_core::pool::ConnectionPool;
117    /// # #[tokio::main]
118    /// # async fn main() -> anyhow::Result<()> {
119    /// # let pool = ConnectionPool::new("./db.sqlite", 10);
120    /// if let Some(permit) = pool.try_acquire().await {
121    ///     // Use connection
122    /// } else {
123    ///     // No connection available
124    /// }
125    /// #     Ok(())
126    /// # }
127    /// ```
128    pub async fn try_acquire(&self) -> Option<ConnectionPermit> {
129        self.semaphore
130            .clone()
131            .try_acquire_owned()
132            .ok()
133            .map(|permit| ConnectionPermit {
134                _permit: permit,
135                db_path: self.db_path.clone(),
136            })
137    }
138}
139
140/// A permit representing an acquired connection.
141///
142/// When dropped, the connection is returned to the pool.
143pub struct ConnectionPermit {
144    _permit: tokio::sync::OwnedSemaphorePermit,
145    db_path: std::path::PathBuf,
146}
147
148impl std::fmt::Debug for ConnectionPermit {
149    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150        f.debug_struct("ConnectionPermit")
151            .field("db_path", &self.db_path)
152            .finish()
153    }
154}
155
156impl ConnectionPermit {
157    /// Returns the path to the database file.
158    ///
159    /// # Examples
160    ///
161    /// ```no_run
162    /// # use forgekit_core::pool::ConnectionPool;
163    /// # #[tokio::main]
164    /// # async fn main() -> anyhow::Result<()> {
165    /// # let pool = ConnectionPool::new("./db.sqlite", 10);
166    /// # let permit = pool.acquire().await?;
167    /// let db_path = permit.db_path();
168    /// println!("Connected to: {:?}", db_path);
169    /// #     Ok(())
170    /// # }
171    /// ```
172    pub fn db_path(&self) -> &std::path::Path {
173        &self.db_path
174    }
175}
176
177impl std::fmt::Display for ConnectionPermit {
178    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179        write!(f, "ConnectionPermit({})", self.db_path.display())
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186
187    #[tokio::test]
188    async fn test_pool_creation() {
189        let pool = ConnectionPool::new("/tmp/test.db", 5);
190        assert_eq!(pool.max_connections, 5);
191        assert_eq!(pool.available_connections(), 5);
192    }
193
194    #[tokio::test]
195    async fn test_pool_acquire() {
196        let pool = ConnectionPool::new("/tmp/test.db", 2);
197
198        let permit1 = pool.acquire().await.unwrap();
199        assert_eq!(pool.available_connections(), 1);
200
201        let permit2 = pool.acquire().await.unwrap();
202        assert_eq!(pool.available_connections(), 0);
203
204        // Dropping permit returns it to pool
205        drop(permit1);
206        assert_eq!(pool.available_connections(), 1);
207
208        drop(permit2);
209        assert_eq!(pool.available_connections(), 2);
210    }
211
212    #[tokio::test]
213    async fn test_pool_try_acquire() {
214        let pool = ConnectionPool::new("/tmp/test.db", 1);
215
216        let permit1 = pool.try_acquire().await;
217        assert!(permit1.is_some());
218        assert_eq!(pool.available_connections(), 0);
219
220        // Second acquire fails
221        let permit2 = pool.try_acquire().await;
222        assert!(permit2.is_none());
223
224        drop(permit1);
225        assert_eq!(pool.available_connections(), 1);
226    }
227
228    #[tokio::test]
229    async fn test_pool_db_path() {
230        let pool = ConnectionPool::new("/tmp/test.db", 5);
231        assert_eq!(pool.db_path, std::path::PathBuf::from("/tmp/test.db"));
232
233        let permit = pool.acquire().await.unwrap();
234        assert_eq!(permit.db_path(), std::path::Path::new("/tmp/test.db"));
235    }
236
237    #[tokio::test]
238    async fn test_pool_concurrent_acquires() {
239        use tokio::sync::Barrier;
240
241        let pool = Arc::new(ConnectionPool::new("/tmp/test.db", 5));
242        let barrier = Arc::new(Barrier::new(10));
243        let mut handles = vec![];
244
245        // Spawn 10 tasks trying to acquire
246        for _i in 0..10 {
247            let pool_clone = pool.clone();
248            let barrier_clone = barrier.clone();
249            handles.push(tokio::spawn(async move {
250                barrier_clone.wait().await; // Coordinate start
251                let _permit = pool_clone.acquire().await.unwrap();
252                // Hold permit briefly
253                tokio::time::sleep(std::time::Duration::from_millis(50)).await;
254                // Permit releases here when dropped
255            }));
256        }
257
258        // Wait for all to complete
259        for handle in handles {
260            handle.await.unwrap();
261        }
262
263        // All 10 should have completed eventually
264        assert_eq!(pool.available_connections(), 5);
265    }
266
267    #[tokio::test]
268    async fn test_pool_timeout_behavior() {
269        use tokio::time::{timeout, Duration};
270
271        let pool = ConnectionPool::new("/tmp/test.db", 1);
272
273        // Acquire 1 permit
274        let _permit1 = pool.acquire().await.unwrap();
275        assert_eq!(pool.available_connections(), 0);
276
277        // Try to acquire another - should timeout
278        let start = std::time::Instant::now();
279        let result = timeout(Duration::from_millis(100), pool.acquire()).await;
280
281        let elapsed = start.elapsed();
282
283        // Should have timed out
284        assert!(result.is_err());
285        // Should have taken approximately the timeout duration
286        assert!(elapsed >= Duration::from_millis(90));
287        assert!(elapsed < Duration::from_millis(200));
288    }
289
290    #[tokio::test]
291    async fn test_pool_permit_drop_returns() {
292        let pool = ConnectionPool::new("/tmp/test.db", 3);
293        assert_eq!(pool.available_connections(), 3);
294
295        // Acquire permit
296        let permit = pool.acquire().await.unwrap();
297        assert_eq!(pool.available_connections(), 2);
298
299        // Drop permit
300        drop(permit);
301        assert_eq!(pool.available_connections(), 3);
302    }
303
304    #[tokio::test]
305    async fn test_pool_stress() {
306        let pool = ConnectionPool::new("/tmp/test.db", 10);
307
308        // Run 100 acquire/release cycles
309        for _ in 0..100 {
310            let permit = pool.acquire().await.unwrap();
311            // Verify no deadlocks
312            assert_eq!(permit.db_path(), std::path::Path::new("/tmp/test.db"));
313            drop(permit);
314        }
315
316        // Verify final available equals max
317        assert_eq!(pool.available_connections(), 10);
318    }
319
320    #[tokio::test]
321    async fn test_pool_all_permits_acquired() {
322        let pool = ConnectionPool::new("/tmp/test.db", 3);
323
324        // Acquire all permits up to max
325        let permit1 = pool.acquire().await.unwrap();
326        let permit2 = pool.acquire().await.unwrap();
327        let permit3 = pool.acquire().await.unwrap();
328
329        // Verify available is 0
330        assert_eq!(pool.available_connections(), 0);
331
332        // Verify try_acquire returns None
333        let permit4 = pool.try_acquire().await;
334        assert!(permit4.is_none());
335
336        // Release one permit
337        drop(permit1);
338
339        // Verify try_acquire now works
340        let permit5 = pool.try_acquire().await;
341        assert!(permit5.is_some());
342
343        // Clean up
344        drop(permit2);
345        drop(permit3);
346        drop(permit5);
347    }
348
349    #[tokio::test]
350    async fn test_pool_available_count() {
351        let pool = ConnectionPool::new("/tmp/test.db", 5);
352
353        // Initial available should be max
354        assert_eq!(pool.available_connections(), 5);
355
356        // Acquire varying number of permits
357        let permit1 = pool.acquire().await.unwrap();
358        assert_eq!(pool.available_connections(), 4);
359
360        let permit2 = pool.acquire().await.unwrap();
361        assert_eq!(pool.available_connections(), 3);
362
363        let permit3 = pool.acquire().await.unwrap();
364        assert_eq!(pool.available_connections(), 2);
365
366        // Drop permits and verify available increases
367        drop(permit1);
368        assert_eq!(pool.available_connections(), 3);
369
370        drop(permit2);
371        assert_eq!(pool.available_connections(), 4);
372
373        drop(permit3);
374        assert_eq!(pool.available_connections(), 5);
375    }
376}