mixtape_tools/sqlite/
manager.rs1use crate::sqlite::error::SqliteToolError;
14use lazy_static::lazy_static;
15use mixtape_core::ToolError;
16use rusqlite::Connection;
17use std::collections::HashMap;
18use std::path::Path;
19use std::sync::{Arc, Mutex, RwLock};
20
21lazy_static! {
22 pub static ref DATABASE_MANAGER: DatabaseManager = DatabaseManager::new();
24}
25
26pub async fn with_connection<T, F>(db_path: Option<String>, f: F) -> Result<T, ToolError>
44where
45 T: Send + 'static,
46 F: FnOnce(&Connection) -> Result<T, SqliteToolError> + Send + 'static,
47{
48 tokio::task::spawn_blocking(move || {
49 let conn = DATABASE_MANAGER.get(db_path.as_deref())?;
50 let conn = conn.lock().unwrap();
51 f(&conn)
52 })
53 .await
54 .map_err(|e| ToolError::Custom(format!("Task join error: {}", e)))?
55 .map_err(|e| e.into())
56}
57
58pub struct DatabaseManager {
68 connections: RwLock<HashMap<String, Arc<Mutex<Connection>>>>,
70
71 default_db: RwLock<Option<String>>,
73}
74
75impl Default for DatabaseManager {
76 fn default() -> Self {
77 Self::new()
78 }
79}
80
81impl DatabaseManager {
82 pub fn new() -> Self {
84 Self {
85 connections: RwLock::new(HashMap::new()),
86 default_db: RwLock::new(None),
87 }
88 }
89
90 fn normalize_path(path: &Path) -> String {
92 path.canonicalize()
93 .unwrap_or_else(|_| path.to_path_buf())
94 .to_string_lossy()
95 .to_string()
96 }
97
98 pub fn open(&self, path: &Path, create: bool) -> Result<String, SqliteToolError> {
105 let key = Self::normalize_path(path);
106
107 {
109 let connections = self.connections.read().unwrap();
110 if connections.contains_key(&key) {
111 self.set_default_if_first(&key);
113 return Ok(key);
114 }
115 }
116
117 if !create && !path.exists() {
119 return Err(SqliteToolError::DatabaseDoesNotExist(path.to_path_buf()));
120 }
121
122 if create {
124 if let Some(parent) = path.parent() {
125 if !parent.exists() {
126 std::fs::create_dir_all(parent)?;
127 }
128 }
129 }
130
131 let conn = Connection::open(path).map_err(|e| SqliteToolError::ConnectionFailed {
133 path: path.to_path_buf(),
134 message: e.to_string(),
135 })?;
136
137 conn.execute_batch("PRAGMA foreign_keys = ON;")?;
139
140 let conn = Arc::new(Mutex::new(conn));
141
142 {
144 let mut connections = self.connections.write().unwrap();
145 connections.insert(key.clone(), conn);
146 }
147
148 self.set_default_if_first(&key);
150
151 Ok(key)
152 }
153
154 fn set_default_if_first(&self, key: &str) {
156 let mut default = self.default_db.write().unwrap();
157 if default.is_none() {
158 *default = Some(key.to_string());
159 }
160 }
161
162 pub fn close(&self, name: &str) -> Result<(), SqliteToolError> {
164 let mut connections = self.connections.write().unwrap();
165
166 let key = if connections.contains_key(name) {
168 name.to_string()
169 } else {
170 connections
172 .keys()
173 .find(|k| k.ends_with(name) || Path::new(k).file_name().is_some_and(|f| f == name))
174 .cloned()
175 .ok_or_else(|| SqliteToolError::DatabaseNotFound(name.to_string()))?
176 };
177
178 connections.remove(&key);
179
180 let mut default = self.default_db.write().unwrap();
182 if default.as_ref() == Some(&key) {
183 *default = connections.keys().next().cloned();
185 }
186
187 Ok(())
188 }
189
190 pub fn get(&self, name: Option<&str>) -> Result<Arc<Mutex<Connection>>, SqliteToolError> {
192 let connections = self.connections.read().unwrap();
193
194 let key = match name {
195 Some(n) => {
196 if connections.contains_key(n) {
198 n.to_string()
199 } else {
200 connections
202 .keys()
203 .find(|k| {
204 k.ends_with(n) || Path::new(k).file_name().is_some_and(|f| f == n)
205 })
206 .cloned()
207 .ok_or_else(|| SqliteToolError::DatabaseNotFound(n.to_string()))?
208 }
209 }
210 None => {
211 let default = self.default_db.read().unwrap();
212 default.clone().ok_or(SqliteToolError::NoDefaultDatabase)?
213 }
214 };
215
216 connections
217 .get(&key)
218 .cloned()
219 .ok_or_else(|| SqliteToolError::DatabaseNotFound(key))
220 }
221
222 pub fn set_default(&self, name: &str) -> Result<(), SqliteToolError> {
224 let connections = self.connections.read().unwrap();
225
226 let key = if connections.contains_key(name) {
228 name.to_string()
229 } else {
230 connections
231 .keys()
232 .find(|k| k.ends_with(name) || Path::new(k).file_name().is_some_and(|f| f == name))
233 .cloned()
234 .ok_or_else(|| SqliteToolError::DatabaseNotFound(name.to_string()))?
235 };
236
237 let mut default = self.default_db.write().unwrap();
238 *default = Some(key);
239
240 Ok(())
241 }
242
243 pub fn get_default(&self) -> Option<String> {
245 self.default_db.read().unwrap().clone()
246 }
247
248 pub fn list_open(&self) -> Vec<String> {
250 self.connections.read().unwrap().keys().cloned().collect()
251 }
252
253 pub fn is_open(&self, name: &str) -> bool {
255 let connections = self.connections.read().unwrap();
256 connections.contains_key(name)
257 || connections
258 .keys()
259 .any(|k| k.ends_with(name) || Path::new(k).file_name().is_some_and(|f| f == name))
260 }
261
262 pub fn close_all(&self) {
264 let mut connections = self.connections.write().unwrap();
265 connections.clear();
266
267 let mut default = self.default_db.write().unwrap();
268 *default = None;
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275 use tempfile::TempDir;
276
277 fn create_test_manager() -> DatabaseManager {
278 DatabaseManager::new()
279 }
280
281 #[test]
282 fn test_open_and_get() {
283 let temp_dir = TempDir::new().unwrap();
284 let db_path = temp_dir.path().join("test.db");
285 let manager = create_test_manager();
286
287 let key = manager.open(&db_path, true).unwrap();
289 assert!(!key.is_empty());
290
291 let conn = manager.get(None).unwrap();
293 let guard = conn.lock().unwrap();
294
295 guard
297 .execute_batch("CREATE TABLE test (id INTEGER);")
298 .unwrap();
299 }
300
301 #[test]
302 fn test_open_existing_only() {
303 let temp_dir = TempDir::new().unwrap();
304 let db_path = temp_dir.path().join("nonexistent.db");
305 let manager = create_test_manager();
306
307 let result = manager.open(&db_path, false);
309 assert!(result.is_err());
310
311 std::fs::write(&db_path, "").unwrap();
313
314 }
318
319 #[test]
320 fn test_close_database() {
321 let temp_dir = TempDir::new().unwrap();
322 let db_path = temp_dir.path().join("test.db");
323 let manager = create_test_manager();
324
325 let key = manager.open(&db_path, true).unwrap();
326 assert!(manager.is_open(&key));
327
328 manager.close(&key).unwrap();
329 assert!(!manager.is_open(&key));
330 }
331
332 #[test]
333 fn test_multiple_databases() {
334 let temp_dir = TempDir::new().unwrap();
335 let db1_path = temp_dir.path().join("db1.db");
336 let db2_path = temp_dir.path().join("db2.db");
337 let manager = create_test_manager();
338
339 let key1 = manager.open(&db1_path, true).unwrap();
340 let key2 = manager.open(&db2_path, true).unwrap();
341
342 assert_eq!(manager.get_default(), Some(key1.clone()));
344
345 assert!(manager.get(Some(&key1)).is_ok());
347 assert!(manager.get(Some(&key2)).is_ok());
348
349 let open = manager.list_open();
351 assert_eq!(open.len(), 2);
352 }
353
354 #[test]
355 fn test_set_default() {
356 let temp_dir = TempDir::new().unwrap();
357 let db1_path = temp_dir.path().join("db1.db");
358 let db2_path = temp_dir.path().join("db2.db");
359 let manager = create_test_manager();
360
361 let key1 = manager.open(&db1_path, true).unwrap();
362 let key2 = manager.open(&db2_path, true).unwrap();
363
364 assert_eq!(manager.get_default(), Some(key1.clone()));
365
366 manager.set_default(&key2).unwrap();
367 assert_eq!(manager.get_default(), Some(key2));
368 }
369
370 #[test]
371 fn test_no_default_database() {
372 let manager = create_test_manager();
373 let result = manager.get(None);
374 assert!(matches!(result, Err(SqliteToolError::NoDefaultDatabase)));
375 }
376
377 #[test]
378 fn test_close_all() {
379 let temp_dir = TempDir::new().unwrap();
380 let db1_path = temp_dir.path().join("db1.db");
381 let db2_path = temp_dir.path().join("db2.db");
382 let manager = create_test_manager();
383
384 manager.open(&db1_path, true).unwrap();
385 manager.open(&db2_path, true).unwrap();
386
387 assert_eq!(manager.list_open().len(), 2);
388
389 manager.close_all();
390
391 assert_eq!(manager.list_open().len(), 0);
392 assert!(manager.get_default().is_none());
393 }
394}