use crate::sqlite::error::SqliteToolError;
use lazy_static::lazy_static;
use mixtape_core::ToolError;
use rusqlite::Connection;
use std::collections::HashMap;
use std::path::Path;
use std::sync::{Arc, Mutex, RwLock};
lazy_static! {
pub static ref DATABASE_MANAGER: DatabaseManager = DatabaseManager::new();
}
pub async fn with_connection<T, F>(db_path: Option<String>, f: F) -> Result<T, ToolError>
where
T: Send + 'static,
F: FnOnce(&Connection) -> Result<T, SqliteToolError> + Send + 'static,
{
tokio::task::spawn_blocking(move || {
let conn = DATABASE_MANAGER.get(db_path.as_deref())?;
let conn = conn.lock().unwrap();
f(&conn)
})
.await
.map_err(|e| ToolError::Custom(format!("Task join error: {}", e)))?
.map_err(|e| e.into())
}
pub struct DatabaseManager {
connections: RwLock<HashMap<String, Arc<Mutex<Connection>>>>,
default_db: RwLock<Option<String>>,
}
impl Default for DatabaseManager {
fn default() -> Self {
Self::new()
}
}
impl DatabaseManager {
pub fn new() -> Self {
Self {
connections: RwLock::new(HashMap::new()),
default_db: RwLock::new(None),
}
}
fn normalize_path(path: &Path) -> String {
path.canonicalize()
.unwrap_or_else(|_| path.to_path_buf())
.to_string_lossy()
.to_string()
}
pub fn open(&self, path: &Path, create: bool) -> Result<String, SqliteToolError> {
let key = Self::normalize_path(path);
{
let connections = self.connections.read().unwrap();
if connections.contains_key(&key) {
self.set_default_if_first(&key);
return Ok(key);
}
}
if !create && !path.exists() {
return Err(SqliteToolError::DatabaseDoesNotExist(path.to_path_buf()));
}
if create {
if let Some(parent) = path.parent() {
if !parent.exists() {
std::fs::create_dir_all(parent)?;
}
}
}
let conn = Connection::open(path).map_err(|e| SqliteToolError::ConnectionFailed {
path: path.to_path_buf(),
message: e.to_string(),
})?;
conn.execute_batch("PRAGMA foreign_keys = ON;")?;
let conn = Arc::new(Mutex::new(conn));
{
let mut connections = self.connections.write().unwrap();
connections.insert(key.clone(), conn);
}
self.set_default_if_first(&key);
Ok(key)
}
fn set_default_if_first(&self, key: &str) {
let mut default = self.default_db.write().unwrap();
if default.is_none() {
*default = Some(key.to_string());
}
}
pub fn close(&self, name: &str) -> Result<(), SqliteToolError> {
let mut connections = self.connections.write().unwrap();
let key = if connections.contains_key(name) {
name.to_string()
} else {
connections
.keys()
.find(|k| k.ends_with(name) || Path::new(k).file_name().is_some_and(|f| f == name))
.cloned()
.ok_or_else(|| SqliteToolError::DatabaseNotFound(name.to_string()))?
};
connections.remove(&key);
let mut default = self.default_db.write().unwrap();
if default.as_ref() == Some(&key) {
*default = connections.keys().next().cloned();
}
Ok(())
}
pub fn get(&self, name: Option<&str>) -> Result<Arc<Mutex<Connection>>, SqliteToolError> {
let connections = self.connections.read().unwrap();
let key = match name {
Some(n) => {
if connections.contains_key(n) {
n.to_string()
} else {
connections
.keys()
.find(|k| {
k.ends_with(n) || Path::new(k).file_name().is_some_and(|f| f == n)
})
.cloned()
.ok_or_else(|| SqliteToolError::DatabaseNotFound(n.to_string()))?
}
}
None => {
let default = self.default_db.read().unwrap();
default.clone().ok_or(SqliteToolError::NoDefaultDatabase)?
}
};
connections
.get(&key)
.cloned()
.ok_or_else(|| SqliteToolError::DatabaseNotFound(key))
}
pub fn set_default(&self, name: &str) -> Result<(), SqliteToolError> {
let connections = self.connections.read().unwrap();
let key = if connections.contains_key(name) {
name.to_string()
} else {
connections
.keys()
.find(|k| k.ends_with(name) || Path::new(k).file_name().is_some_and(|f| f == name))
.cloned()
.ok_or_else(|| SqliteToolError::DatabaseNotFound(name.to_string()))?
};
let mut default = self.default_db.write().unwrap();
*default = Some(key);
Ok(())
}
pub fn get_default(&self) -> Option<String> {
self.default_db.read().unwrap().clone()
}
pub fn list_open(&self) -> Vec<String> {
self.connections.read().unwrap().keys().cloned().collect()
}
pub fn is_open(&self, name: &str) -> bool {
let connections = self.connections.read().unwrap();
connections.contains_key(name)
|| connections
.keys()
.any(|k| k.ends_with(name) || Path::new(k).file_name().is_some_and(|f| f == name))
}
pub fn close_all(&self) {
let mut connections = self.connections.write().unwrap();
connections.clear();
let mut default = self.default_db.write().unwrap();
*default = None;
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn create_test_manager() -> DatabaseManager {
DatabaseManager::new()
}
#[test]
fn test_open_and_get() {
let temp_dir = TempDir::new().unwrap();
let db_path = temp_dir.path().join("test.db");
let manager = create_test_manager();
let key = manager.open(&db_path, true).unwrap();
assert!(!key.is_empty());
let conn = manager.get(None).unwrap();
let guard = conn.lock().unwrap();
guard
.execute_batch("CREATE TABLE test (id INTEGER);")
.unwrap();
}
#[test]
fn test_open_existing_only() {
let temp_dir = TempDir::new().unwrap();
let db_path = temp_dir.path().join("nonexistent.db");
let manager = create_test_manager();
let result = manager.open(&db_path, false);
assert!(result.is_err());
std::fs::write(&db_path, "").unwrap();
}
#[test]
fn test_close_database() {
let temp_dir = TempDir::new().unwrap();
let db_path = temp_dir.path().join("test.db");
let manager = create_test_manager();
let key = manager.open(&db_path, true).unwrap();
assert!(manager.is_open(&key));
manager.close(&key).unwrap();
assert!(!manager.is_open(&key));
}
#[test]
fn test_multiple_databases() {
let temp_dir = TempDir::new().unwrap();
let db1_path = temp_dir.path().join("db1.db");
let db2_path = temp_dir.path().join("db2.db");
let manager = create_test_manager();
let key1 = manager.open(&db1_path, true).unwrap();
let key2 = manager.open(&db2_path, true).unwrap();
assert_eq!(manager.get_default(), Some(key1.clone()));
assert!(manager.get(Some(&key1)).is_ok());
assert!(manager.get(Some(&key2)).is_ok());
let open = manager.list_open();
assert_eq!(open.len(), 2);
}
#[test]
fn test_set_default() {
let temp_dir = TempDir::new().unwrap();
let db1_path = temp_dir.path().join("db1.db");
let db2_path = temp_dir.path().join("db2.db");
let manager = create_test_manager();
let key1 = manager.open(&db1_path, true).unwrap();
let key2 = manager.open(&db2_path, true).unwrap();
assert_eq!(manager.get_default(), Some(key1.clone()));
manager.set_default(&key2).unwrap();
assert_eq!(manager.get_default(), Some(key2));
}
#[test]
fn test_no_default_database() {
let manager = create_test_manager();
let result = manager.get(None);
assert!(matches!(result, Err(SqliteToolError::NoDefaultDatabase)));
}
#[test]
fn test_close_all() {
let temp_dir = TempDir::new().unwrap();
let db1_path = temp_dir.path().join("db1.db");
let db2_path = temp_dir.path().join("db2.db");
let manager = create_test_manager();
manager.open(&db1_path, true).unwrap();
manager.open(&db2_path, true).unwrap();
assert_eq!(manager.list_open().len(), 2);
manager.close_all();
assert_eq!(manager.list_open().len(), 0);
assert!(manager.get_default().is_none());
}
}