use crate::{Error, GraphOperations, QueryBuilder, Result, Storage};
use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions};
use std::path::{Path, PathBuf};
use std::str::FromStr;
use tracing::{debug, info};
#[derive(Clone)]
pub struct Database {
pool: SqlitePool,
}
impl Database {
pub async fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
let path = Self::expand_path(path)?;
info!("Opening database at: {}", path.display());
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
let options = SqliteConnectOptions::from_str(&format!("sqlite://{}", path.display()))?
.create_if_missing(true)
.foreign_keys(true) .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal);
let pool = SqlitePoolOptions::new()
.max_connections(5)
.connect_with(options)
.await?;
let db = Self { pool };
db.migrate().await?;
Ok(db)
}
pub async fn open_default() -> Result<Self> {
let path = Self::default_path()?;
Self::open(path).await
}
pub fn default_path() -> Result<PathBuf> {
let home = std::env::var("HOME")
.map_err(|_| Error::Other("HOME environment variable not set".to_string()))?;
Ok(PathBuf::from(home).join(".niwa").join("graph.db"))
}
async fn migrate(&self) -> Result<()> {
info!("Running database migrations");
let migrations_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("migrations");
sqlx::migrate::Migrator::new(migrations_path)
.await
.map_err(|e| Error::Migration(e.to_string()))?
.run(&self.pool)
.await
.map_err(|e| Error::Migration(e.to_string()))?;
debug!("Migrations completed successfully");
Ok(())
}
pub fn storage(&self) -> Storage {
Storage::new(self.pool.clone())
}
pub fn query(&self) -> QueryBuilder {
QueryBuilder::new(self.pool.clone())
}
pub fn graph(&self) -> GraphOperations {
GraphOperations::new(self.pool.clone())
}
pub fn pool(&self) -> &SqlitePool {
&self.pool
}
pub async fn close(self) {
self.pool.close().await;
}
fn expand_path<P: AsRef<Path>>(path: P) -> Result<PathBuf> {
let path = path.as_ref();
let path_str = path
.to_str()
.ok_or_else(|| Error::Other(format!("Invalid path: {}", path.display())))?;
if let Some(stripped) = path_str.strip_prefix("~/") {
let home = std::env::var("HOME")
.map_err(|_| Error::Other("HOME environment variable not set".to_string()))?;
Ok(PathBuf::from(home).join(stripped))
} else {
Ok(path.to_path_buf())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[tokio::test]
async fn test_open_database() {
let temp_dir = TempDir::new().unwrap();
let db_path = temp_dir.path().join("test.db");
let db = Database::open(&db_path).await.unwrap();
assert!(db_path.exists());
db.close().await;
}
#[tokio::test]
async fn test_migrations_run() {
let temp_dir = TempDir::new().unwrap();
let db_path = temp_dir.path().join("test.db");
let db = Database::open(&db_path).await.unwrap();
let result: (i64,) = sqlx::query_as(
"SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='expertises'",
)
.fetch_one(db.pool())
.await
.unwrap();
assert_eq!(result.0, 1, "expertises table should exist");
db.close().await;
}
#[test]
fn test_expand_path() {
let expanded = Database::expand_path("~/test/path").unwrap();
assert!(!expanded.to_str().unwrap().starts_with("~"));
let normal = Database::expand_path("/absolute/path").unwrap();
assert_eq!(normal.to_str().unwrap(), "/absolute/path");
}
}