tempo_cli/db/
pool.rs

1use anyhow::Result;
2use rusqlite::{Connection, OpenFlags};
3use std::path::{Path, PathBuf};
4use std::sync::{Arc, Mutex};
5use std::time::{Duration, Instant};
6use std::collections::VecDeque;
7
8/// A database connection with metadata
9#[derive(Debug)]
10pub struct PooledConnection {
11    pub connection: Connection,
12    created_at: Instant,
13    last_used: Instant,
14    use_count: usize,
15}
16
17impl PooledConnection {
18    fn new(connection: Connection) -> Self {
19        let now = Instant::now();
20        Self {
21            connection,
22            created_at: now,
23            last_used: now,
24            use_count: 0,
25        }
26    }
27
28    fn mark_used(&mut self) {
29        self.last_used = Instant::now();
30        self.use_count += 1;
31    }
32
33    fn is_expired(&self, max_lifetime: Duration) -> bool {
34        self.created_at.elapsed() > max_lifetime
35    }
36
37    fn is_idle_too_long(&self, max_idle: Duration) -> bool {
38        self.last_used.elapsed() > max_idle
39    }
40}
41
42/// Configuration for the database pool
43#[derive(Debug, Clone)]
44pub struct PoolConfig {
45    pub max_connections: usize,
46    pub min_connections: usize,
47    pub max_lifetime: Duration,
48    pub max_idle_time: Duration,
49    pub connection_timeout: Duration,
50}
51
52impl Default for PoolConfig {
53    fn default() -> Self {
54        Self {
55            max_connections: 10,
56            min_connections: 2,
57            max_lifetime: Duration::from_secs(3600), // 1 hour
58            max_idle_time: Duration::from_secs(600), // 10 minutes
59            connection_timeout: Duration::from_secs(30),
60        }
61    }
62}
63
64/// A connection pool for SQLite databases
65pub struct DatabasePool {
66    db_path: PathBuf,
67    pool: Arc<Mutex<VecDeque<PooledConnection>>>,
68    config: PoolConfig,
69    stats: Arc<Mutex<PoolStats>>,
70}
71
72#[derive(Debug, Default)]
73pub struct PoolStats {
74    pub total_connections_created: usize,
75    pub active_connections: usize,
76    pub connections_in_pool: usize,
77    pub connection_requests: usize,
78    pub connection_timeouts: usize,
79}
80
81impl DatabasePool {
82    /// Create a new database pool
83    pub fn new<P: AsRef<Path>>(db_path: P, config: PoolConfig) -> Result<Self> {
84        let db_path = db_path.as_ref().to_path_buf();
85        
86        // Create parent directory if it doesn't exist
87        if let Some(parent) = db_path.parent() {
88            std::fs::create_dir_all(parent)?;
89        }
90
91        let pool = Self {
92            db_path,
93            pool: Arc::new(Mutex::new(VecDeque::new())),
94            config,
95            stats: Arc::new(Mutex::new(PoolStats::default())),
96        };
97
98        // Pre-populate with minimum connections
99        pool.ensure_min_connections()?;
100
101        Ok(pool)
102    }
103
104    /// Create a new database pool with default configuration
105    pub fn new_with_defaults<P: AsRef<Path>>(db_path: P) -> Result<Self> {
106        Self::new(db_path, PoolConfig::default())
107    }
108
109    /// Get a connection from the pool
110    pub async fn get_connection(&self) -> Result<PooledConnectionGuard> {
111        let start = Instant::now();
112        
113        // Update stats
114        {
115            let mut stats = self.stats.lock()
116                .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
117            stats.connection_requests += 1;
118        }
119
120        loop {
121            // Try to get a connection from the pool
122            if let Some(mut conn) = self.try_get_from_pool()? {
123                conn.mark_used();
124                
125                // Update stats
126                {
127                    let mut stats = self.stats.lock()
128                        .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
129                    stats.active_connections += 1;
130                    stats.connections_in_pool = stats.connections_in_pool.saturating_sub(1);
131                }
132
133                return Ok(PooledConnectionGuard::new(conn, self.pool.clone(), self.stats.clone()));
134            }
135
136            // If no connection available, try to create a new one
137            if self.can_create_new_connection()? {
138                let conn = self.create_connection()?;
139                
140                // Update stats
141                {
142                    let mut stats = self.stats.lock()
143                        .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
144                    stats.total_connections_created += 1;
145                    stats.active_connections += 1;
146                }
147
148                return Ok(PooledConnectionGuard::new(conn, self.pool.clone(), self.stats.clone()));
149            }
150
151            // Check for timeout
152            if start.elapsed() > self.config.connection_timeout {
153                let mut stats = self.stats.lock()
154                    .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
155                stats.connection_timeouts += 1;
156                return Err(anyhow::anyhow!("Connection timeout after {:?}", self.config.connection_timeout));
157            }
158
159            // Wait a bit before retrying  
160            tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
161        }
162    }
163
164    /// Try to get a connection from the existing pool
165    fn try_get_from_pool(&self) -> Result<Option<PooledConnection>> {
166        let mut pool = self.pool.lock()
167            .map_err(|e| anyhow::anyhow!("Failed to acquire pool lock: {}", e))?;
168
169        // Clean up expired/idle connections first
170        self.cleanup_connections(&mut pool)?;
171
172        // Try to get a connection
173        Ok(pool.pop_front())
174    }
175
176    /// Check if we can create a new connection
177    fn can_create_new_connection(&self) -> Result<bool> {
178        let stats = self.stats.lock()
179            .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
180        Ok(stats.active_connections + stats.connections_in_pool < self.config.max_connections)
181    }
182
183    /// Create a new database connection
184    fn create_connection(&self) -> Result<PooledConnection> {
185        let connection = Connection::open_with_flags(
186            &self.db_path,
187            OpenFlags::SQLITE_OPEN_READ_WRITE
188                | OpenFlags::SQLITE_OPEN_CREATE
189                | OpenFlags::SQLITE_OPEN_NO_MUTEX,
190        )?;
191
192        // Configure the connection
193        connection.pragma_update(None, "foreign_keys", "ON")?;
194        connection.pragma_update(None, "journal_mode", "WAL")?;
195        connection.pragma_update(None, "synchronous", "NORMAL")?;
196        connection.pragma_update(None, "cache_size", "-64000")?;
197
198        // Run migrations
199        crate::db::migrations::run_migrations(&connection)?;
200
201        Ok(PooledConnection::new(connection))
202    }
203
204    /// Clean up expired and idle connections
205    fn cleanup_connections(&self, pool: &mut VecDeque<PooledConnection>) -> Result<()> {
206        let mut to_remove = Vec::new();
207        
208        for (index, conn) in pool.iter().enumerate() {
209            if conn.is_expired(self.config.max_lifetime) || 
210               conn.is_idle_too_long(self.config.max_idle_time) {
211                to_remove.push(index);
212            }
213        }
214
215        // Remove connections in reverse order to maintain indices
216        for index in to_remove.iter().rev() {
217            pool.remove(*index);
218        }
219
220        Ok(())
221    }
222
223    /// Ensure minimum number of connections are available
224    fn ensure_min_connections(&self) -> Result<()> {
225        let mut pool = self.pool.lock()
226            .map_err(|e| anyhow::anyhow!("Failed to acquire pool lock: {}", e))?;
227
228        while pool.len() < self.config.min_connections {
229            let conn = self.create_connection()?;
230            pool.push_back(conn);
231            
232            // Update stats
233            let mut stats = self.stats.lock()
234                .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
235            stats.total_connections_created += 1;
236            stats.connections_in_pool += 1;
237        }
238
239        Ok(())
240    }
241
242    /// Return a connection to the pool
243    fn return_connection(&self, conn: PooledConnection) -> Result<()> {
244        let mut pool = self.pool.lock()
245            .map_err(|e| anyhow::anyhow!("Failed to acquire pool lock: {}", e))?;
246
247        // Check if we should keep this connection
248        if !conn.is_expired(self.config.max_lifetime) && 
249           pool.len() < self.config.max_connections {
250            pool.push_back(conn);
251            
252            // Update stats
253            let mut stats = self.stats.lock()
254                .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
255            stats.connections_in_pool += 1;
256            stats.active_connections = stats.active_connections.saturating_sub(1);
257        } else {
258            // Update stats - connection is being dropped
259            let mut stats = self.stats.lock()
260                .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
261            stats.active_connections = stats.active_connections.saturating_sub(1);
262        }
263
264        Ok(())
265    }
266
267    /// Get current pool statistics
268    pub fn stats(&self) -> Result<PoolStats> {
269        let stats = self.stats.lock()
270            .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
271        Ok(PoolStats {
272            total_connections_created: stats.total_connections_created,
273            active_connections: stats.active_connections,
274            connections_in_pool: stats.connections_in_pool,
275            connection_requests: stats.connection_requests,
276            connection_timeouts: stats.connection_timeouts,
277        })
278    }
279
280    /// Close all connections in the pool
281    pub fn close(&self) -> Result<()> {
282        let mut pool = self.pool.lock()
283            .map_err(|e| anyhow::anyhow!("Failed to acquire pool lock: {}", e))?;
284        pool.clear();
285        
286        let mut stats = self.stats.lock()
287            .map_err(|e| anyhow::anyhow!("Failed to acquire stats lock: {}", e))?;
288        stats.connections_in_pool = 0;
289        
290        Ok(())
291    }
292}
293
294/// A guard that automatically returns connections to the pool when dropped
295pub struct PooledConnectionGuard {
296    connection: Option<PooledConnection>,
297    pool: Arc<Mutex<VecDeque<PooledConnection>>>,
298    stats: Arc<Mutex<PoolStats>>,
299}
300
301impl PooledConnectionGuard {
302    fn new(
303        connection: PooledConnection,
304        pool: Arc<Mutex<VecDeque<PooledConnection>>>,
305        stats: Arc<Mutex<PoolStats>>,
306    ) -> Self {
307        Self {
308            connection: Some(connection),
309            pool,
310            stats,
311        }
312    }
313
314    /// Get a reference to the underlying connection
315    pub fn connection(&self) -> &Connection {
316        &self.connection.as_ref().unwrap().connection
317    }
318}
319
320impl Drop for PooledConnectionGuard {
321    fn drop(&mut self) {
322        if let Some(conn) = self.connection.take() {
323            // Try to return connection to pool
324            let mut pool = match self.pool.lock() {
325                Ok(pool) => pool,
326                Err(_) => {
327                    // Pool lock is poisoned, just update stats
328                    if let Ok(mut stats) = self.stats.lock() {
329                        stats.active_connections = stats.active_connections.saturating_sub(1);
330                    }
331                    return;
332                }
333            };
334
335            // Check if we should keep this connection
336            if !conn.is_expired(Duration::from_secs(3600)) && pool.len() < 10 {
337                pool.push_back(conn);
338                if let Ok(mut stats) = self.stats.lock() {
339                    stats.connections_in_pool += 1;
340                    stats.active_connections = stats.active_connections.saturating_sub(1);
341                }
342            } else {
343                // Connection is being dropped
344                if let Ok(mut stats) = self.stats.lock() {
345                    stats.active_connections = stats.active_connections.saturating_sub(1);
346                }
347            }
348        }
349    }
350}
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355    use tempfile::tempdir;
356
357    #[test]
358    fn test_pool_creation() {
359        let temp_dir = tempdir().unwrap();
360        let db_path = temp_dir.path().join("test.db");
361        
362        let pool = DatabasePool::new_with_defaults(&db_path).unwrap();
363        let stats = pool.stats().unwrap();
364        
365        // Should have minimum connections created
366        assert!(stats.total_connections_created >= 2);
367        assert_eq!(stats.connections_in_pool, 2);
368    }
369
370    #[tokio::test]
371    async fn test_get_connection() {
372        let temp_dir = tempdir().unwrap();
373        let db_path = temp_dir.path().join("test.db");
374        
375        let pool = DatabasePool::new_with_defaults(&db_path).unwrap();
376        let conn = pool.get_connection().await.unwrap();
377        
378        // Should be able to use the connection
379        conn.connection().execute("CREATE TABLE test (id INTEGER)", []).unwrap();
380        
381        let stats = pool.stats().unwrap();
382        assert_eq!(stats.active_connections, 1);
383    }
384
385    #[tokio::test]
386    async fn test_connection_return() {
387        let temp_dir = tempdir().unwrap();
388        let db_path = temp_dir.path().join("test.db");
389        
390        let pool = DatabasePool::new_with_defaults(&db_path).unwrap();
391        
392        {
393            let _conn = pool.get_connection().await.unwrap();
394            let stats = pool.stats().unwrap();
395            assert_eq!(stats.active_connections, 1);
396        }
397        
398        // Connection should be returned to pool
399        let stats = pool.stats().unwrap();
400        assert_eq!(stats.active_connections, 0);
401        assert!(stats.connections_in_pool > 0);
402    }
403}