Skip to main content

niwa_core/
db.rs

1//! Database connection management
2
3use crate::{Error, GraphOperations, QueryBuilder, Result, Storage};
4use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions};
5use std::path::{Path, PathBuf};
6use std::str::FromStr;
7use tracing::{debug, info};
8
9/// Database handle
10///
11/// This is the main entry point for all database operations.
12/// It manages the SQLite connection pool and provides access to
13/// storage, query, and graph operations.
14#[derive(Clone)]
15pub struct Database {
16    pool: SqlitePool,
17}
18
19impl Database {
20    /// Open or create a database at the given path
21    ///
22    /// # Arguments
23    ///
24    /// * `path` - Path to the SQLite database file
25    ///
26    /// # Example
27    ///
28    /// ```no_run
29    /// use niwa_core::Database;
30    ///
31    /// #[tokio::main]
32    /// async fn main() -> anyhow::Result<()> {
33    ///     let db = Database::open("~/.niwa/graph.db").await?;
34    ///     Ok(())
35    /// }
36    /// ```
37    pub async fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
38        let path = Self::expand_path(path)?;
39        info!("Opening database at: {}", path.display());
40
41        // Ensure parent directory exists
42        if let Some(parent) = path.parent() {
43            std::fs::create_dir_all(parent)?;
44        }
45
46        // Configure SQLite connection
47        let options = SqliteConnectOptions::from_str(&format!("sqlite://{}", path.display()))?
48            .create_if_missing(true)
49            .foreign_keys(true) // Enable foreign key constraints
50            .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal); // Use WAL mode for better concurrency
51
52        // Create connection pool
53        let pool = SqlitePoolOptions::new()
54            .max_connections(5)
55            .connect_with(options)
56            .await?;
57
58        let db = Self { pool };
59
60        // Run migrations
61        db.migrate().await?;
62
63        Ok(db)
64    }
65
66    /// Open database at the default location (~/.niwa/graph.db)
67    pub async fn open_default() -> Result<Self> {
68        let path = Self::default_path()?;
69        Self::open(path).await
70    }
71
72    /// Get the default database path
73    pub fn default_path() -> Result<PathBuf> {
74        let home = std::env::var("HOME")
75            .map_err(|_| Error::Other("HOME environment variable not set".to_string()))?;
76        Ok(PathBuf::from(home).join(".niwa").join("graph.db"))
77    }
78
79    /// Run database migrations
80    async fn migrate(&self) -> Result<()> {
81        info!("Running database migrations");
82
83        // Use runtime migration loading instead of compile-time macro
84        // This is essential for CLI/Desktop apps where migrations can be added
85        // after the binary is built
86        let migrations_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("migrations");
87
88        sqlx::migrate::Migrator::new(migrations_path)
89            .await
90            .map_err(|e| Error::Migration(e.to_string()))?
91            .run(&self.pool)
92            .await
93            .map_err(|e| Error::Migration(e.to_string()))?;
94
95        debug!("Migrations completed successfully");
96        Ok(())
97    }
98
99    /// Get a reference to the storage operations
100    pub fn storage(&self) -> Storage {
101        Storage::new(self.pool.clone())
102    }
103
104    /// Get a query builder
105    pub fn query(&self) -> QueryBuilder {
106        QueryBuilder::new(self.pool.clone())
107    }
108
109    /// Get a reference to the graph operations
110    pub fn graph(&self) -> GraphOperations {
111        GraphOperations::new(self.pool.clone())
112    }
113
114    /// Get the underlying pool (for advanced usage)
115    pub fn pool(&self) -> &SqlitePool {
116        &self.pool
117    }
118
119    /// Close the database connection
120    pub async fn close(self) {
121        self.pool.close().await;
122    }
123
124    /// Expand tilde in path
125    fn expand_path<P: AsRef<Path>>(path: P) -> Result<PathBuf> {
126        let path = path.as_ref();
127        let path_str = path
128            .to_str()
129            .ok_or_else(|| Error::Other(format!("Invalid path: {}", path.display())))?;
130
131        if let Some(stripped) = path_str.strip_prefix("~/") {
132            let home = std::env::var("HOME")
133                .map_err(|_| Error::Other("HOME environment variable not set".to_string()))?;
134            Ok(PathBuf::from(home).join(stripped))
135        } else {
136            Ok(path.to_path_buf())
137        }
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144    use tempfile::TempDir;
145
146    #[tokio::test]
147    async fn test_open_database() {
148        let temp_dir = TempDir::new().unwrap();
149        let db_path = temp_dir.path().join("test.db");
150
151        let db = Database::open(&db_path).await.unwrap();
152        assert!(db_path.exists());
153
154        db.close().await;
155    }
156
157    #[tokio::test]
158    async fn test_migrations_run() {
159        let temp_dir = TempDir::new().unwrap();
160        let db_path = temp_dir.path().join("test.db");
161
162        let db = Database::open(&db_path).await.unwrap();
163
164        // Verify tables exist
165        let result: (i64,) = sqlx::query_as(
166            "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='expertises'",
167        )
168        .fetch_one(db.pool())
169        .await
170        .unwrap();
171
172        assert_eq!(result.0, 1, "expertises table should exist");
173
174        db.close().await;
175    }
176
177    #[test]
178    fn test_expand_path() {
179        let expanded = Database::expand_path("~/test/path").unwrap();
180        assert!(!expanded.to_str().unwrap().starts_with("~"));
181
182        let normal = Database::expand_path("/absolute/path").unwrap();
183        assert_eq!(normal.to_str().unwrap(), "/absolute/path");
184    }
185}