Skip to main content

mixtape_tools/sqlite/
manager.rs

1//! Database connection manager for SQLite tools
2//!
3//! Provides a singleton pattern for managing multiple database connections
4//! across tool invocations.
5//!
6//! # Test Isolation
7//!
8//! For test isolation, create local `DatabaseManager` instances instead of using
9//! the global `DATABASE_MANAGER`. Each instance has its own connection pool and
10//! default database setting. Tool tests that use the global should call
11//! `close_all()` in cleanup to reset state.
12
13use crate::sqlite::error::SqliteToolError;
14use lazy_static::lazy_static;
15use mixtape_core::ToolError;
16use rusqlite::Connection;
17use std::collections::HashMap;
18use std::path::Path;
19use std::sync::{Arc, Mutex, RwLock};
20
21lazy_static! {
22    /// Global database manager instance
23    pub static ref DATABASE_MANAGER: DatabaseManager = DatabaseManager::new();
24}
25
26/// Executes a closure with a database connection in a blocking task.
27///
28/// This helper abstracts the common pattern of:
29/// 1. Spawning a blocking task for SQLite operations
30/// 2. Acquiring a connection from the manager
31/// 3. Locking the connection mutex
32/// 4. Mapping errors to ToolError
33///
34/// # Example
35///
36/// ```ignore
37/// let tables = with_connection(input.db_path, |conn| {
38///     let mut stmt = conn.prepare("SELECT name FROM sqlite_master")?;
39///     // ... use the connection
40///     Ok(result)
41/// }).await?;
42/// ```
43pub async fn with_connection<T, F>(db_path: Option<String>, f: F) -> Result<T, ToolError>
44where
45    T: Send + 'static,
46    F: FnOnce(&Connection) -> Result<T, SqliteToolError> + Send + 'static,
47{
48    tokio::task::spawn_blocking(move || {
49        let conn = DATABASE_MANAGER.get(db_path.as_deref())?;
50        let conn = conn.lock().unwrap();
51        f(&conn)
52    })
53    .await
54    .map_err(|e| ToolError::Custom(format!("Task join error: {}", e)))?
55    .map_err(|e| e.into())
56}
57
58/// Manages SQLite database connections
59///
60/// Supports multiple simultaneous database connections, each identified
61/// by a unique name derived from the file path.
62///
63/// # Test Isolation
64///
65/// - Create new `DatabaseManager` instances for isolated tests
66/// - The global `DATABASE_MANAGER` is shared; use `close_all()` for cleanup
67pub struct DatabaseManager {
68    /// Open database connections keyed by normalized path
69    connections: RwLock<HashMap<String, Arc<Mutex<Connection>>>>,
70
71    /// The default database to use when none is specified
72    default_db: RwLock<Option<String>>,
73}
74
75impl Default for DatabaseManager {
76    fn default() -> Self {
77        Self::new()
78    }
79}
80
81impl DatabaseManager {
82    /// Creates a new database manager
83    pub fn new() -> Self {
84        Self {
85            connections: RwLock::new(HashMap::new()),
86            default_db: RwLock::new(None),
87        }
88    }
89
90    /// Normalizes a path to a consistent string key
91    fn normalize_path(path: &Path) -> String {
92        path.canonicalize()
93            .unwrap_or_else(|_| path.to_path_buf())
94            .to_string_lossy()
95            .to_string()
96    }
97
98    /// Opens or creates a database connection
99    ///
100    /// If `create` is false and the database doesn't exist, returns an error.
101    /// If the database is already open, returns the existing connection.
102    ///
103    /// Returns the database identifier (normalized path) for future reference.
104    pub fn open(&self, path: &Path, create: bool) -> Result<String, SqliteToolError> {
105        let key = Self::normalize_path(path);
106
107        // Check if already open
108        {
109            let connections = self.connections.read().unwrap();
110            if connections.contains_key(&key) {
111                // Set as default if it's the first/only database
112                self.set_default_if_first(&key);
113                return Ok(key);
114            }
115        }
116
117        // Check if file exists when create=false
118        if !create && !path.exists() {
119            return Err(SqliteToolError::DatabaseDoesNotExist(path.to_path_buf()));
120        }
121
122        // Ensure parent directory exists for new databases
123        if create {
124            if let Some(parent) = path.parent() {
125                if !parent.exists() {
126                    std::fs::create_dir_all(parent)?;
127                }
128            }
129        }
130
131        // Open the connection
132        let conn = Connection::open(path).map_err(|e| SqliteToolError::ConnectionFailed {
133            path: path.to_path_buf(),
134            message: e.to_string(),
135        })?;
136
137        // Enable foreign keys by default
138        conn.execute_batch("PRAGMA foreign_keys = ON;")?;
139
140        let conn = Arc::new(Mutex::new(conn));
141
142        // Store the connection
143        {
144            let mut connections = self.connections.write().unwrap();
145            connections.insert(key.clone(), conn);
146        }
147
148        // Set as default if first database
149        self.set_default_if_first(&key);
150
151        Ok(key)
152    }
153
154    /// Sets a database as default if no default is set
155    fn set_default_if_first(&self, key: &str) {
156        let mut default = self.default_db.write().unwrap();
157        if default.is_none() {
158            *default = Some(key.to_string());
159        }
160    }
161
162    /// Closes a database connection
163    pub fn close(&self, name: &str) -> Result<(), SqliteToolError> {
164        let mut connections = self.connections.write().unwrap();
165
166        // Try to find by exact key or by filename
167        let key = if connections.contains_key(name) {
168            name.to_string()
169        } else {
170            // Search for matching filename
171            connections
172                .keys()
173                .find(|k| k.ends_with(name) || Path::new(k).file_name().is_some_and(|f| f == name))
174                .cloned()
175                .ok_or_else(|| SqliteToolError::DatabaseNotFound(name.to_string()))?
176        };
177
178        connections.remove(&key);
179
180        // Clear default if it was this database
181        let mut default = self.default_db.write().unwrap();
182        if default.as_ref() == Some(&key) {
183            // Set to another open database or None
184            *default = connections.keys().next().cloned();
185        }
186
187        Ok(())
188    }
189
190    /// Gets a connection by name, or the default connection if name is None
191    pub fn get(&self, name: Option<&str>) -> Result<Arc<Mutex<Connection>>, SqliteToolError> {
192        let connections = self.connections.read().unwrap();
193
194        let key = match name {
195            Some(n) => {
196                // Try exact match first
197                if connections.contains_key(n) {
198                    n.to_string()
199                } else {
200                    // Search for matching filename
201                    connections
202                        .keys()
203                        .find(|k| {
204                            k.ends_with(n) || Path::new(k).file_name().is_some_and(|f| f == n)
205                        })
206                        .cloned()
207                        .ok_or_else(|| SqliteToolError::DatabaseNotFound(n.to_string()))?
208                }
209            }
210            None => {
211                let default = self.default_db.read().unwrap();
212                default.clone().ok_or(SqliteToolError::NoDefaultDatabase)?
213            }
214        };
215
216        connections
217            .get(&key)
218            .cloned()
219            .ok_or_else(|| SqliteToolError::DatabaseNotFound(key))
220    }
221
222    /// Sets the default database (thread-local)
223    pub fn set_default(&self, name: &str) -> Result<(), SqliteToolError> {
224        let connections = self.connections.read().unwrap();
225
226        // Verify the database exists
227        let key = if connections.contains_key(name) {
228            name.to_string()
229        } else {
230            connections
231                .keys()
232                .find(|k| k.ends_with(name) || Path::new(k).file_name().is_some_and(|f| f == name))
233                .cloned()
234                .ok_or_else(|| SqliteToolError::DatabaseNotFound(name.to_string()))?
235        };
236
237        let mut default = self.default_db.write().unwrap();
238        *default = Some(key);
239
240        Ok(())
241    }
242
243    /// Returns the current default database name
244    pub fn get_default(&self) -> Option<String> {
245        self.default_db.read().unwrap().clone()
246    }
247
248    /// Lists all open database connections
249    pub fn list_open(&self) -> Vec<String> {
250        self.connections.read().unwrap().keys().cloned().collect()
251    }
252
253    /// Checks if a database is open
254    pub fn is_open(&self, name: &str) -> bool {
255        let connections = self.connections.read().unwrap();
256        connections.contains_key(name)
257            || connections
258                .keys()
259                .any(|k| k.ends_with(name) || Path::new(k).file_name().is_some_and(|f| f == name))
260    }
261
262    /// Closes all database connections and clears the default
263    pub fn close_all(&self) {
264        let mut connections = self.connections.write().unwrap();
265        connections.clear();
266
267        let mut default = self.default_db.write().unwrap();
268        *default = None;
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275    use tempfile::TempDir;
276
277    fn create_test_manager() -> DatabaseManager {
278        DatabaseManager::new()
279    }
280
281    #[test]
282    fn test_open_and_get() {
283        let temp_dir = TempDir::new().unwrap();
284        let db_path = temp_dir.path().join("test.db");
285        let manager = create_test_manager();
286
287        // Open database
288        let key = manager.open(&db_path, true).unwrap();
289        assert!(!key.is_empty());
290
291        // Get connection
292        let conn = manager.get(None).unwrap();
293        let guard = conn.lock().unwrap();
294
295        // Verify it works
296        guard
297            .execute_batch("CREATE TABLE test (id INTEGER);")
298            .unwrap();
299    }
300
301    #[test]
302    fn test_open_existing_only() {
303        let temp_dir = TempDir::new().unwrap();
304        let db_path = temp_dir.path().join("nonexistent.db");
305        let manager = create_test_manager();
306
307        // Should fail when create=false and file doesn't exist
308        let result = manager.open(&db_path, false);
309        assert!(result.is_err());
310
311        // Create the file first
312        std::fs::write(&db_path, "").unwrap();
313
314        // Now it should succeed (though this isn't a valid SQLite file,
315        // rusqlite will handle it)
316        // For a real test, we'd create it properly first
317    }
318
319    #[test]
320    fn test_close_database() {
321        let temp_dir = TempDir::new().unwrap();
322        let db_path = temp_dir.path().join("test.db");
323        let manager = create_test_manager();
324
325        let key = manager.open(&db_path, true).unwrap();
326        assert!(manager.is_open(&key));
327
328        manager.close(&key).unwrap();
329        assert!(!manager.is_open(&key));
330    }
331
332    #[test]
333    fn test_multiple_databases() {
334        let temp_dir = TempDir::new().unwrap();
335        let db1_path = temp_dir.path().join("db1.db");
336        let db2_path = temp_dir.path().join("db2.db");
337        let manager = create_test_manager();
338
339        let key1 = manager.open(&db1_path, true).unwrap();
340        let key2 = manager.open(&db2_path, true).unwrap();
341
342        // First opened should be default
343        assert_eq!(manager.get_default(), Some(key1.clone()));
344
345        // Can get both
346        assert!(manager.get(Some(&key1)).is_ok());
347        assert!(manager.get(Some(&key2)).is_ok());
348
349        // List all
350        let open = manager.list_open();
351        assert_eq!(open.len(), 2);
352    }
353
354    #[test]
355    fn test_set_default() {
356        let temp_dir = TempDir::new().unwrap();
357        let db1_path = temp_dir.path().join("db1.db");
358        let db2_path = temp_dir.path().join("db2.db");
359        let manager = create_test_manager();
360
361        let key1 = manager.open(&db1_path, true).unwrap();
362        let key2 = manager.open(&db2_path, true).unwrap();
363
364        assert_eq!(manager.get_default(), Some(key1.clone()));
365
366        manager.set_default(&key2).unwrap();
367        assert_eq!(manager.get_default(), Some(key2));
368    }
369
370    #[test]
371    fn test_no_default_database() {
372        let manager = create_test_manager();
373        let result = manager.get(None);
374        assert!(matches!(result, Err(SqliteToolError::NoDefaultDatabase)));
375    }
376
377    #[test]
378    fn test_close_all() {
379        let temp_dir = TempDir::new().unwrap();
380        let db1_path = temp_dir.path().join("db1.db");
381        let db2_path = temp_dir.path().join("db2.db");
382        let manager = create_test_manager();
383
384        manager.open(&db1_path, true).unwrap();
385        manager.open(&db2_path, true).unwrap();
386
387        assert_eq!(manager.list_open().len(), 2);
388
389        manager.close_all();
390
391        assert_eq!(manager.list_open().len(), 0);
392        assert!(manager.get_default().is_none());
393    }
394}