1use crate::{Error, GraphOperations, QueryBuilder, Result, Storage};
4use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions};
5use std::path::{Path, PathBuf};
6use std::str::FromStr;
7use tracing::{debug, info};
8
9#[derive(Clone)]
15pub struct Database {
16 pool: SqlitePool,
17}
18
19impl Database {
20 pub async fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
38 let path = Self::expand_path(path)?;
39 info!("Opening database at: {}", path.display());
40
41 if let Some(parent) = path.parent() {
43 std::fs::create_dir_all(parent)?;
44 }
45
46 let options = SqliteConnectOptions::from_str(&format!("sqlite://{}", path.display()))?
48 .create_if_missing(true)
49 .foreign_keys(true) .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal); let pool = SqlitePoolOptions::new()
54 .max_connections(5)
55 .connect_with(options)
56 .await?;
57
58 let db = Self { pool };
59
60 db.migrate().await?;
62
63 Ok(db)
64 }
65
66 pub async fn open_default() -> Result<Self> {
68 let path = Self::default_path()?;
69 Self::open(path).await
70 }
71
72 pub fn default_path() -> Result<PathBuf> {
74 let home = std::env::var("HOME")
75 .map_err(|_| Error::Other("HOME environment variable not set".to_string()))?;
76 Ok(PathBuf::from(home).join(".niwa").join("graph.db"))
77 }
78
79 async fn migrate(&self) -> Result<()> {
81 info!("Running database migrations");
82
83 let migrations_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("migrations");
87
88 sqlx::migrate::Migrator::new(migrations_path)
89 .await
90 .map_err(|e| Error::Migration(e.to_string()))?
91 .run(&self.pool)
92 .await
93 .map_err(|e| Error::Migration(e.to_string()))?;
94
95 debug!("Migrations completed successfully");
96 Ok(())
97 }
98
99 pub fn storage(&self) -> Storage {
101 Storage::new(self.pool.clone())
102 }
103
104 pub fn query(&self) -> QueryBuilder {
106 QueryBuilder::new(self.pool.clone())
107 }
108
109 pub fn graph(&self) -> GraphOperations {
111 GraphOperations::new(self.pool.clone())
112 }
113
114 pub fn pool(&self) -> &SqlitePool {
116 &self.pool
117 }
118
119 pub async fn close(self) {
121 self.pool.close().await;
122 }
123
124 fn expand_path<P: AsRef<Path>>(path: P) -> Result<PathBuf> {
126 let path = path.as_ref();
127 let path_str = path
128 .to_str()
129 .ok_or_else(|| Error::Other(format!("Invalid path: {}", path.display())))?;
130
131 if let Some(stripped) = path_str.strip_prefix("~/") {
132 let home = std::env::var("HOME")
133 .map_err(|_| Error::Other("HOME environment variable not set".to_string()))?;
134 Ok(PathBuf::from(home).join(stripped))
135 } else {
136 Ok(path.to_path_buf())
137 }
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use super::*;
144 use tempfile::TempDir;
145
146 #[tokio::test]
147 async fn test_open_database() {
148 let temp_dir = TempDir::new().unwrap();
149 let db_path = temp_dir.path().join("test.db");
150
151 let db = Database::open(&db_path).await.unwrap();
152 assert!(db_path.exists());
153
154 db.close().await;
155 }
156
157 #[tokio::test]
158 async fn test_migrations_run() {
159 let temp_dir = TempDir::new().unwrap();
160 let db_path = temp_dir.path().join("test.db");
161
162 let db = Database::open(&db_path).await.unwrap();
163
164 let result: (i64,) = sqlx::query_as(
166 "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='expertises'",
167 )
168 .fetch_one(db.pool())
169 .await
170 .unwrap();
171
172 assert_eq!(result.0, 1, "expertises table should exist");
173
174 db.close().await;
175 }
176
177 #[test]
178 fn test_expand_path() {
179 let expanded = Database::expand_path("~/test/path").unwrap();
180 assert!(!expanded.to_str().unwrap().starts_with("~"));
181
182 let normal = Database::expand_path("/absolute/path").unwrap();
183 assert_eq!(normal.to_str().unwrap(), "/absolute/path");
184 }
185}