1use anyhow::Result;
2use rusqlite::{Connection, OpenFlags};
3use std::collections::VecDeque;
4use std::path::{Path, PathBuf};
5use std::sync::{Arc, Mutex};
6use std::time::{Duration, Instant};
7
8#[derive(Debug)]
10pub struct PooledConnection {
11 pub connection: Connection,
12 created_at: Instant,
13 last_used: Instant,
14 use_count: usize,
15}
16
17impl PooledConnection {
18 fn new(connection: Connection) -> Self {
19 let now = Instant::now();
20 Self {
21 connection,
22 created_at: now,
23 last_used: now,
24 use_count: 0,
25 }
26 }
27
28 fn mark_used(&mut self) {
29 self.last_used = Instant::now();
30 self.use_count += 1;
31 }
32
33 fn is_expired(&self, max_lifetime: Duration) -> bool {
34 self.created_at.elapsed() > max_lifetime
35 }
36
37 fn is_idle_too_long(&self, max_idle: Duration) -> bool {
38 self.last_used.elapsed() > max_idle
39 }
40}
41
42#[derive(Debug, Clone)]
44pub struct PoolConfig {
45 pub max_connections: usize,
46 pub min_connections: usize,
47 pub max_lifetime: Duration,
48 pub max_idle_time: Duration,
49 pub connection_timeout: Duration,
50}
51
52impl Default for PoolConfig {
53 fn default() -> Self {
54 Self {
55 max_connections: 10,
56 min_connections: 2,
57 max_lifetime: Duration::from_secs(3600), max_idle_time: Duration::from_secs(600), connection_timeout: Duration::from_secs(30),
60 }
61 }
62}
63
64pub struct DatabasePool {
66 db_path: PathBuf,
67 pool: Arc<Mutex<VecDeque<PooledConnection>>>,
68 config: PoolConfig,
69 stats: Arc<Mutex<PoolStats>>,
70}
71
72#[derive(Debug, Default)]
73pub struct PoolStats {
74 pub total_connections_created: usize,
75 pub active_connections: usize,
76 pub connections_in_pool: usize,
77 pub connection_requests: usize,
78 pub connection_timeouts: usize,
79}
80
81impl DatabasePool {
82 pub fn new<P: AsRef<Path>>(db_path: P, config: PoolConfig) -> Result<Self> {
84 let db_path = db_path.as_ref().to_path_buf();
85
86 if let Some(parent) = db_path.parent() {
88 std::fs::create_dir_all(parent)?;
89 }
90
91 let pool = Self {
92 db_path,
93 pool: Arc::new(Mutex::new(VecDeque::new())),
94 config,
95 stats: Arc::new(Mutex::new(PoolStats::default())),
96 };
97
98 pool.ensure_min_connections()?;
100
101 Ok(pool)
102 }
103
104 pub fn new_with_defaults<P: AsRef<Path>>(db_path: P) -> Result<Self> {
106 Self::new(db_path, PoolConfig::default())
107 }
108
109 pub async fn get_connection(&self) -> Result<PooledConnectionGuard> {
111 let start = Instant::now();
112
113 {
115 let mut stats = self
116 .stats
117 .lock()
118 .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
119 stats.connection_requests += 1;
120 }
121
122 loop {
123 if let Some(mut conn) = self.try_get_from_pool()? {
125 conn.mark_used();
126
127 {
129 let mut stats = self
130 .stats
131 .lock()
132 .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
133 stats.active_connections += 1;
134 stats.connections_in_pool = stats.connections_in_pool.saturating_sub(1);
135 }
136
137 return Ok(PooledConnectionGuard::new(
138 conn,
139 self.pool.clone(),
140 self.stats.clone(),
141 ));
142 }
143
144 if self.can_create_new_connection()? {
146 let conn = self.create_connection()?;
147
148 {
150 let mut stats = self
151 .stats
152 .lock()
153 .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
154 stats.total_connections_created += 1;
155 stats.active_connections += 1;
156 }
157
158 return Ok(PooledConnectionGuard::new(
159 conn,
160 self.pool.clone(),
161 self.stats.clone(),
162 ));
163 }
164
165 if start.elapsed() > self.config.connection_timeout {
167 let mut stats = self
168 .stats
169 .lock()
170 .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
171 stats.connection_timeouts += 1;
172 return Err(anyhow::anyhow!(
173 "Connection timeout after {:?}",
174 self.config.connection_timeout
175 ));
176 }
177
178 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
180 }
181 }
182
183 fn try_get_from_pool(&self) -> Result<Option<PooledConnection>> {
185 let mut pool = self
186 .pool
187 .lock()
188 .map_err(|e| anyhow::anyhow!("Failed to acquire pool lock: {}", e))?;
189
190 self.cleanup_connections(&mut pool)?;
192
193 Ok(pool.pop_front())
195 }
196
197 fn can_create_new_connection(&self) -> Result<bool> {
199 let stats = self
200 .stats
201 .lock()
202 .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
203 Ok(stats.active_connections + stats.connections_in_pool < self.config.max_connections)
204 }
205
206 fn create_connection(&self) -> Result<PooledConnection> {
208 let connection = Connection::open_with_flags(
209 &self.db_path,
210 OpenFlags::SQLITE_OPEN_READ_WRITE
211 | OpenFlags::SQLITE_OPEN_CREATE
212 | OpenFlags::SQLITE_OPEN_NO_MUTEX,
213 )?;
214
215 connection.pragma_update(None, "foreign_keys", "ON")?;
217 connection.pragma_update(None, "journal_mode", "WAL")?;
218 connection.pragma_update(None, "synchronous", "NORMAL")?;
219 connection.pragma_update(None, "cache_size", "-64000")?;
220
221 crate::db::migrations::run_migrations(&connection)?;
223
224 Ok(PooledConnection::new(connection))
225 }
226
227 fn cleanup_connections(&self, pool: &mut VecDeque<PooledConnection>) -> Result<()> {
229 let mut to_remove = Vec::new();
230
231 for (index, conn) in pool.iter().enumerate() {
232 if conn.is_expired(self.config.max_lifetime)
233 || conn.is_idle_too_long(self.config.max_idle_time)
234 {
235 to_remove.push(index);
236 }
237 }
238
239 for index in to_remove.iter().rev() {
241 pool.remove(*index);
242 }
243
244 Ok(())
245 }
246
247 fn ensure_min_connections(&self) -> Result<()> {
249 let mut pool = self
250 .pool
251 .lock()
252 .map_err(|e| anyhow::anyhow!("Failed to acquire pool lock: {}", e))?;
253
254 while pool.len() < self.config.min_connections {
255 let conn = self.create_connection()?;
256 pool.push_back(conn);
257
258 let mut stats = self
260 .stats
261 .lock()
262 .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
263 stats.total_connections_created += 1;
264 stats.connections_in_pool += 1;
265 }
266
267 Ok(())
268 }
269
270 #[allow(dead_code)]
272 fn return_connection(&self, conn: PooledConnection) -> Result<()> {
273 let mut pool = self
274 .pool
275 .lock()
276 .map_err(|e| anyhow::anyhow!("Failed to acquire pool lock: {}", e))?;
277
278 if !conn.is_expired(self.config.max_lifetime) && pool.len() < self.config.max_connections {
280 pool.push_back(conn);
281
282 let mut stats = self
284 .stats
285 .lock()
286 .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
287 stats.connections_in_pool += 1;
288 stats.active_connections = stats.active_connections.saturating_sub(1);
289 } else {
290 let mut stats = self
292 .stats
293 .lock()
294 .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
295 stats.active_connections = stats.active_connections.saturating_sub(1);
296 }
297
298 Ok(())
299 }
300
301 pub fn stats(&self) -> Result<PoolStats> {
303 let stats = self
304 .stats
305 .lock()
306 .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
307 Ok(PoolStats {
308 total_connections_created: stats.total_connections_created,
309 active_connections: stats.active_connections,
310 connections_in_pool: stats.connections_in_pool,
311 connection_requests: stats.connection_requests,
312 connection_timeouts: stats.connection_timeouts,
313 })
314 }
315
316 pub fn close(&self) -> Result<()> {
318 let mut pool = self
319 .pool
320 .lock()
321 .map_err(|e| anyhow::anyhow!("Failed to acquire pool lock: {}", e))?;
322 pool.clear();
323
324 let mut stats = self
325 .stats
326 .lock()
327 .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
328 stats.connections_in_pool = 0;
329
330 Ok(())
331 }
332}
333
334pub struct PooledConnectionGuard {
336 connection: Option<PooledConnection>,
337 pool: Arc<Mutex<VecDeque<PooledConnection>>>,
338 stats: Arc<Mutex<PoolStats>>,
339}
340
341impl PooledConnectionGuard {
342 fn new(
343 connection: PooledConnection,
344 pool: Arc<Mutex<VecDeque<PooledConnection>>>,
345 stats: Arc<Mutex<PoolStats>>,
346 ) -> Self {
347 Self {
348 connection: Some(connection),
349 pool,
350 stats,
351 }
352 }
353
354 pub fn connection(&self) -> &Connection {
356 &self.connection.as_ref().unwrap().connection
357 }
358}
359
360impl Drop for PooledConnectionGuard {
361 fn drop(&mut self) {
362 if let Some(conn) = self.connection.take() {
363 let mut pool = match self.pool.lock() {
365 Ok(pool) => pool,
366 Err(_) => {
367 if let Ok(mut stats) = self.stats.lock() {
369 stats.active_connections = stats.active_connections.saturating_sub(1);
370 }
371 return;
372 }
373 };
374
375 if !conn.is_expired(Duration::from_secs(3600)) && pool.len() < 10 {
377 pool.push_back(conn);
378 if let Ok(mut stats) = self.stats.lock() {
379 stats.connections_in_pool += 1;
380 stats.active_connections = stats.active_connections.saturating_sub(1);
381 }
382 } else {
383 if let Ok(mut stats) = self.stats.lock() {
385 stats.active_connections = stats.active_connections.saturating_sub(1);
386 }
387 }
388 }
389 }
390}
391
392#[cfg(test)]
393mod tests {
394 use super::*;
395 use tempfile::tempdir;
396
397 #[test]
398 fn test_pool_creation() {
399 let temp_dir = tempdir().unwrap();
400 let db_path = temp_dir.path().join("test.db");
401
402 let pool = DatabasePool::new_with_defaults(&db_path).unwrap();
403 let stats = pool.stats().unwrap();
404
405 assert!(stats.total_connections_created >= 2);
407 assert_eq!(stats.connections_in_pool, 2);
408 }
409
410 #[tokio::test]
411 async fn test_get_connection() {
412 let temp_dir = tempdir().unwrap();
413 let db_path = temp_dir.path().join("test.db");
414
415 let pool = DatabasePool::new_with_defaults(&db_path).unwrap();
416 let conn = pool.get_connection().await.unwrap();
417
418 conn.connection()
420 .execute("CREATE TABLE test (id INTEGER)", [])
421 .unwrap();
422
423 let stats = pool.stats().unwrap();
424 assert_eq!(stats.active_connections, 1);
425 }
426
427 #[tokio::test]
428 async fn test_connection_return() {
429 let temp_dir = tempdir().unwrap();
430 let db_path = temp_dir.path().join("test.db");
431
432 let pool = DatabasePool::new_with_defaults(&db_path).unwrap();
433
434 {
435 let _conn = pool.get_connection().await.unwrap();
436 let stats = pool.stats().unwrap();
437 assert_eq!(stats.active_connections, 1);
438 }
439
440 let stats = pool.stats().unwrap();
442 assert_eq!(stats.active_connections, 0);
443 assert!(stats.connections_in_pool > 0);
444 }
445}