1use anyhow::Result;
2use rusqlite::{Connection, OpenFlags};
3use std::collections::VecDeque;
4use std::path::{Path, PathBuf};
5use std::sync::{Arc, Mutex};
6use std::time::{Duration, Instant};
7
8#[derive(Debug)]
10pub struct PooledConnection {
11 pub connection: Connection,
12 created_at: Instant,
13 last_used: Instant,
14 use_count: usize,
15}
16
17impl PooledConnection {
18 fn new(connection: Connection) -> Self {
19 let now = Instant::now();
20 Self {
21 connection,
22 created_at: now,
23 last_used: now,
24 use_count: 0,
25 }
26 }
27
28 fn mark_used(&mut self) {
29 self.last_used = Instant::now();
30 self.use_count += 1;
31 }
32
33 fn is_expired(&self, max_lifetime: Duration) -> bool {
34 self.created_at.elapsed() > max_lifetime
35 }
36
37 fn is_idle_too_long(&self, max_idle: Duration) -> bool {
38 self.last_used.elapsed() > max_idle
39 }
40}
41
42#[derive(Debug, Clone)]
44pub struct PoolConfig {
45 pub max_connections: usize,
46 pub min_connections: usize,
47 pub max_lifetime: Duration,
48 pub max_idle_time: Duration,
49 pub connection_timeout: Duration,
50}
51
52impl Default for PoolConfig {
53 fn default() -> Self {
54 Self {
55 max_connections: 10,
56 min_connections: 2,
57 max_lifetime: Duration::from_secs(3600), max_idle_time: Duration::from_secs(600), connection_timeout: Duration::from_secs(30),
60 }
61 }
62}
63
64pub struct DatabasePool {
66 db_path: PathBuf,
67 pool: Arc<Mutex<VecDeque<PooledConnection>>>,
68 config: PoolConfig,
69 stats: Arc<Mutex<PoolStats>>,
70}
71
72#[derive(Debug, Default)]
73pub struct PoolStats {
74 pub total_connections_created: usize,
75 pub active_connections: usize,
76 pub connections_in_pool: usize,
77 pub connection_requests: usize,
78 pub connection_timeouts: usize,
79}
80
81impl DatabasePool {
82 pub fn new<P: AsRef<Path>>(db_path: P, config: PoolConfig) -> Result<Self> {
84 let db_path = db_path.as_ref().to_path_buf();
85
86 if let Some(parent) = db_path.parent() {
88 std::fs::create_dir_all(parent)?;
89 }
90
91 let pool = Self {
92 db_path,
93 pool: Arc::new(Mutex::new(VecDeque::new())),
94 config,
95 stats: Arc::new(Mutex::new(PoolStats::default())),
96 };
97
98 pool.ensure_min_connections()?;
100
101 Ok(pool)
102 }
103
104 pub fn new_with_defaults<P: AsRef<Path>>(db_path: P) -> Result<Self> {
106 Self::new(db_path, PoolConfig::default())
107 }
108
109 pub async fn get_connection(&self) -> Result<PooledConnectionGuard> {
111 let start = Instant::now();
112
113 {
115 let mut stats = self
116 .stats
117 .lock()
118 .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
119 stats.connection_requests += 1;
120 }
121
122 loop {
123 if let Some(mut conn) = self.try_get_from_pool()? {
125 conn.mark_used();
126
127 {
129 let mut stats = self
130 .stats
131 .lock()
132 .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
133 stats.active_connections += 1;
134 stats.connections_in_pool = stats.connections_in_pool.saturating_sub(1);
135 }
136
137 return Ok(PooledConnectionGuard::new(
138 conn,
139 self.pool.clone(),
140 self.stats.clone(),
141 ));
142 }
143
144 if self.can_create_new_connection()? {
146 let conn = self.create_connection()?;
147
148 {
150 let mut stats = self
151 .stats
152 .lock()
153 .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
154 stats.total_connections_created += 1;
155 stats.active_connections += 1;
156 }
157
158 return Ok(PooledConnectionGuard::new(
159 conn,
160 self.pool.clone(),
161 self.stats.clone(),
162 ));
163 }
164
165 if start.elapsed() > self.config.connection_timeout {
167 let mut stats = self
168 .stats
169 .lock()
170 .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
171 stats.connection_timeouts += 1;
172 return Err(anyhow::anyhow!(
173 "Connection timeout after {:?}",
174 self.config.connection_timeout
175 ));
176 }
177
178 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
180 }
181 }
182
183 fn try_get_from_pool(&self) -> Result<Option<PooledConnection>> {
185 let mut pool = self
186 .pool
187 .lock()
188 .map_err(|e| anyhow::anyhow!("Failed to acquire pool lock: {}", e))?;
189
190 self.cleanup_connections(&mut pool)?;
192
193 Ok(pool.pop_front())
195 }
196
197 fn can_create_new_connection(&self) -> Result<bool> {
199 let stats = self
200 .stats
201 .lock()
202 .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
203 Ok(stats.active_connections + stats.connections_in_pool < self.config.max_connections)
204 }
205
206 fn create_connection(&self) -> Result<PooledConnection> {
208 let connection = Connection::open_with_flags(
209 &self.db_path,
210 OpenFlags::SQLITE_OPEN_READ_WRITE
211 | OpenFlags::SQLITE_OPEN_CREATE
212 | OpenFlags::SQLITE_OPEN_NO_MUTEX,
213 )?;
214
215 connection.pragma_update(None, "foreign_keys", "ON")?;
217 connection.pragma_update(None, "journal_mode", "WAL")?;
218 connection.pragma_update(None, "synchronous", "NORMAL")?;
219 connection.pragma_update(None, "cache_size", "-64000")?;
220
221 crate::db::migrations::run_migrations(&connection)?;
223
224 Ok(PooledConnection::new(connection))
225 }
226
227 fn cleanup_connections(&self, pool: &mut VecDeque<PooledConnection>) -> Result<()> {
229 let mut to_remove = Vec::new();
230
231 for (index, conn) in pool.iter().enumerate() {
232 if conn.is_expired(self.config.max_lifetime)
233 || conn.is_idle_too_long(self.config.max_idle_time)
234 {
235 to_remove.push(index);
236 }
237 }
238
239 for index in to_remove.iter().rev() {
241 pool.remove(*index);
242 }
243
244 Ok(())
245 }
246
247 fn ensure_min_connections(&self) -> Result<()> {
249 let mut pool = self
250 .pool
251 .lock()
252 .map_err(|e| anyhow::anyhow!("Failed to acquire pool lock: {}", e))?;
253
254 while pool.len() < self.config.min_connections {
255 let conn = self.create_connection()?;
256 pool.push_back(conn);
257
258 let mut stats = self
260 .stats
261 .lock()
262 .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
263 stats.total_connections_created += 1;
264 stats.connections_in_pool += 1;
265 }
266
267 Ok(())
268 }
269
270 fn return_connection(&self, conn: PooledConnection) -> Result<()> {
272 let mut pool = self
273 .pool
274 .lock()
275 .map_err(|e| anyhow::anyhow!("Failed to acquire pool lock: {}", e))?;
276
277 if !conn.is_expired(self.config.max_lifetime) && pool.len() < self.config.max_connections {
279 pool.push_back(conn);
280
281 let mut stats = self
283 .stats
284 .lock()
285 .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
286 stats.connections_in_pool += 1;
287 stats.active_connections = stats.active_connections.saturating_sub(1);
288 } else {
289 let mut stats = self
291 .stats
292 .lock()
293 .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
294 stats.active_connections = stats.active_connections.saturating_sub(1);
295 }
296
297 Ok(())
298 }
299
300 pub fn stats(&self) -> Result<PoolStats> {
302 let stats = self
303 .stats
304 .lock()
305 .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
306 Ok(PoolStats {
307 total_connections_created: stats.total_connections_created,
308 active_connections: stats.active_connections,
309 connections_in_pool: stats.connections_in_pool,
310 connection_requests: stats.connection_requests,
311 connection_timeouts: stats.connection_timeouts,
312 })
313 }
314
315 pub fn close(&self) -> Result<()> {
317 let mut pool = self
318 .pool
319 .lock()
320 .map_err(|e| anyhow::anyhow!("Failed to acquire pool lock: {}", e))?;
321 pool.clear();
322
323 let mut stats = self
324 .stats
325 .lock()
326 .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
327 stats.connections_in_pool = 0;
328
329 Ok(())
330 }
331}
332
333pub struct PooledConnectionGuard {
335 connection: Option<PooledConnection>,
336 pool: Arc<Mutex<VecDeque<PooledConnection>>>,
337 stats: Arc<Mutex<PoolStats>>,
338}
339
340impl PooledConnectionGuard {
341 fn new(
342 connection: PooledConnection,
343 pool: Arc<Mutex<VecDeque<PooledConnection>>>,
344 stats: Arc<Mutex<PoolStats>>,
345 ) -> Self {
346 Self {
347 connection: Some(connection),
348 pool,
349 stats,
350 }
351 }
352
353 pub fn connection(&self) -> &Connection {
355 &self.connection.as_ref().unwrap().connection
356 }
357}
358
359impl Drop for PooledConnectionGuard {
360 fn drop(&mut self) {
361 if let Some(conn) = self.connection.take() {
362 let mut pool = match self.pool.lock() {
364 Ok(pool) => pool,
365 Err(_) => {
366 if let Ok(mut stats) = self.stats.lock() {
368 stats.active_connections = stats.active_connections.saturating_sub(1);
369 }
370 return;
371 }
372 };
373
374 if !conn.is_expired(Duration::from_secs(3600)) && pool.len() < 10 {
376 pool.push_back(conn);
377 if let Ok(mut stats) = self.stats.lock() {
378 stats.connections_in_pool += 1;
379 stats.active_connections = stats.active_connections.saturating_sub(1);
380 }
381 } else {
382 if let Ok(mut stats) = self.stats.lock() {
384 stats.active_connections = stats.active_connections.saturating_sub(1);
385 }
386 }
387 }
388 }
389}
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394 use tempfile::tempdir;
395
396 #[test]
397 fn test_pool_creation() {
398 let temp_dir = tempdir().unwrap();
399 let db_path = temp_dir.path().join("test.db");
400
401 let pool = DatabasePool::new_with_defaults(&db_path).unwrap();
402 let stats = pool.stats().unwrap();
403
404 assert!(stats.total_connections_created >= 2);
406 assert_eq!(stats.connections_in_pool, 2);
407 }
408
409 #[tokio::test]
410 async fn test_get_connection() {
411 let temp_dir = tempdir().unwrap();
412 let db_path = temp_dir.path().join("test.db");
413
414 let pool = DatabasePool::new_with_defaults(&db_path).unwrap();
415 let conn = pool.get_connection().await.unwrap();
416
417 conn.connection()
419 .execute("CREATE TABLE test (id INTEGER)", [])
420 .unwrap();
421
422 let stats = pool.stats().unwrap();
423 assert_eq!(stats.active_connections, 1);
424 }
425
426 #[tokio::test]
427 async fn test_connection_return() {
428 let temp_dir = tempdir().unwrap();
429 let db_path = temp_dir.path().join("test.db");
430
431 let pool = DatabasePool::new_with_defaults(&db_path).unwrap();
432
433 {
434 let _conn = pool.get_connection().await.unwrap();
435 let stats = pool.stats().unwrap();
436 assert_eq!(stats.active_connections, 1);
437 }
438
439 let stats = pool.stats().unwrap();
441 assert_eq!(stats.active_connections, 0);
442 assert!(stats.connections_in_pool > 0);
443 }
444}