Skip to main content

capsule_core/config/
database.rs

1use rusqlite::{Connection, Error as SqliteError, Row};
2use std::fmt;
3use std::sync::{Arc, Mutex};
4
5#[derive(Debug)]
6pub enum DatabaseError {
7    SqliteError(String),
8    FsError(String),
9    InvalidQuery(String),
10    LockError(String),
11}
12
13impl fmt::Display for DatabaseError {
14    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
15        match self {
16            DatabaseError::SqliteError(msg) => write!(f, "SQLite error > {}", msg),
17            DatabaseError::FsError(msg) => write!(f, "File system error > {}", msg),
18            DatabaseError::InvalidQuery(msg) => write!(f, "Invalid query > {}", msg),
19            DatabaseError::LockError(msg) => write!(f, "Lock error > {}", msg),
20        }
21    }
22}
23
24impl std::error::Error for DatabaseError {}
25
26impl From<SqliteError> for DatabaseError {
27    fn from(err: SqliteError) -> Self {
28        DatabaseError::SqliteError(err.to_string())
29    }
30}
31
32impl From<std::sync::PoisonError<std::sync::MutexGuard<'_, Connection>>> for DatabaseError {
33    fn from(err: std::sync::PoisonError<std::sync::MutexGuard<'_, Connection>>) -> Self {
34        DatabaseError::LockError(err.to_string())
35    }
36}
37
38impl From<serde_json::Error> for DatabaseError {
39    fn from(err: serde_json::Error) -> Self {
40        DatabaseError::FsError(err.to_string())
41    }
42}
43
44impl From<std::io::Error> for DatabaseError {
45    fn from(err: std::io::Error) -> Self {
46        DatabaseError::FsError(err.to_string())
47    }
48}
49
50#[derive(Clone)]
51pub struct Database {
52    pub conn: Arc<Mutex<Connection>>,
53}
54
55impl Database {
56    pub fn new(path: Option<&str>, database_name: &str) -> Result<Self, DatabaseError> {
57        let conn = match path {
58            Some(path) => {
59                let database_path = &format!("{}/{}", path, database_name);
60
61                std::fs::create_dir_all(path)?;
62
63                Connection::open(database_path)?
64            }
65            None => Connection::open(":memory:")?,
66        };
67
68        conn.execute_batch(
69            "
70            PRAGMA journal_mode = WAL;
71            PRAGMA synchronous = NORMAL;
72            PRAGMA cache_size = -64000;
73            PRAGMA foreign_keys = ON;
74            PRAGMA temp_store = MEMORY;
75            PRAGMA mmap_size = 30000000000;
76        ",
77        )?;
78
79        let db = Self {
80            conn: Arc::new(Mutex::new(conn)),
81        };
82
83        Ok(db)
84    }
85
86    pub fn create_table(
87        &self,
88        table: &str,
89        columns: &[&str],
90        constraints: &[&str],
91    ) -> Result<(), DatabaseError> {
92        self.validate_table_name(table)?;
93
94        let conn = self.conn.lock()?;
95
96        let mut all_columns = vec!["id TEXT PRIMARY KEY"];
97
98        all_columns.extend_from_slice(columns);
99        all_columns.push("created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP");
100        all_columns.push("updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP");
101
102        let mut definitions = all_columns
103            .iter()
104            .map(|s| s.to_string())
105            .collect::<Vec<_>>();
106
107        definitions.extend(constraints.iter().map(|s| s.to_string()));
108
109        let sql = format!(
110            "CREATE TABLE IF NOT EXISTS {} (\n                {}\n            )",
111            table,
112            definitions.join(",\n                ")
113        );
114
115        conn.execute(&sql, [])?;
116
117        Ok(())
118    }
119
120    pub fn execute<P>(&self, query: &str, params: P) -> Result<usize, DatabaseError>
121    where
122        P: rusqlite::Params,
123    {
124        let conn = self.conn.lock()?;
125
126        let result = conn.execute(query, params)?;
127
128        Ok(result)
129    }
130
131    pub fn query<P, F, T>(
132        &self,
133        query: &str,
134        params: P,
135        mut mapper: F,
136    ) -> Result<Vec<T>, DatabaseError>
137    where
138        P: rusqlite::Params,
139        F: FnMut(&Row) -> Result<T, DatabaseError>,
140    {
141        let conn = self.conn.lock()?;
142
143        let mut stmt = conn.prepare(query)?;
144
145        let rows = stmt.query_map(params, |row| {
146            mapper(row).map_err(|_| SqliteError::InvalidQuery)
147        })?;
148
149        let mut results = Vec::new();
150        for row in rows {
151            let value = row?;
152            results.push(value);
153        }
154
155        Ok(results)
156    }
157
158    pub fn table_exists(&self, table: &str) -> Result<bool, DatabaseError> {
159        let conn = self.conn.lock()?;
160
161        let count: i64 = conn.query_row(
162            "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?",
163            [table],
164            |row| row.get(0),
165        )?;
166
167        Ok(count > 0)
168    }
169
170    pub fn validate_table_name(&self, table: &str) -> Result<(), DatabaseError> {
171        if table.is_empty() || table.len() > 64 {
172            return Err(DatabaseError::InvalidQuery(
173                "Table name must be between 1-64 characters".to_string(),
174            ));
175        }
176
177        if !table.chars().all(|c| c.is_alphanumeric() || c == '_') {
178            return Err(DatabaseError::InvalidQuery(format!(
179                "Table name can only contain alphanumeric characters and underscores for {}",
180                table
181            )));
182        }
183
184        if table
185            .chars()
186            .next()
187            .expect("Table name is empty")
188            .is_numeric()
189        {
190            return Err(DatabaseError::InvalidQuery(
191                "Table name cannot start with a number".to_string(),
192            ));
193        }
194
195        Ok(())
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202
203    mod creation {
204        use super::*;
205
206        #[test]
207        fn test_new_database_creates_in_memory_database() {
208            let db = Database::new(None, "trace.db-wal");
209            assert!(db.is_ok(), "Failed to create in-memory database");
210        }
211
212        #[test]
213        fn test_create_table() {
214            let db = Database::new(None, "trace.db-wal").expect("Failed to create database");
215            db.create_table(
216                "test_table",
217                &["name TEXT NOT NULL", "path TEXT NOT NULL"],
218                &[],
219            )
220            .expect("Failed to create test table");
221
222            let conn = db.conn.lock().unwrap();
223
224            let mut test_table_stmt = conn
225                .prepare("SELECT name FROM sqlite_master WHERE type='table' AND name='test_table'")
226                .expect("Failed to prepare query");
227
228            let test_table_exists: bool = test_table_stmt
229                .exists([])
230                .expect("Failed to check if table exists");
231
232            assert!(test_table_exists, "test table was not created");
233        }
234
235        #[test]
236        fn test_create_table_with_constraints() {
237            let db = Database::new(None, "trace.db-wal").expect("Failed to create database");
238            db.create_table(
239                "test_table",
240                &["name TEXT NOT NULL", "path TEXT NOT NULL"],
241                &["UNIQUE (path)"],
242            )
243            .expect("Failed to create test table");
244
245            let conn = db.conn.lock().unwrap();
246
247            conn.execute(
248                "INSERT INTO test_table (name, path) VALUES (?, ?)",
249                ["test", "test"],
250            )
251            .expect("Failed to insert test");
252            assert!(
253                conn.execute(
254                    "INSERT INTO test_table (name, path) VALUES (?, ?)",
255                    ["test", "test"]
256                )
257                .is_err(),
258                "Second insert should have failed"
259            );
260        }
261    }
262
263    mod execution {
264        use super::*;
265
266        #[test]
267        fn test_execute() {
268            let db = Database::new(None, "trace.db-wal").expect("Failed to create database");
269
270            {
271                let conn = db.conn.lock().unwrap();
272                conn.execute(
273                    "CREATE TABLE IF NOT EXISTS test_table (name TEXT NOT NULL, description TEXT NOT NULL)",
274                    [],
275                )
276                .expect("Failed to create test table");
277            }
278
279            let result = db
280                .execute(
281                    "INSERT INTO test_table (name, description) VALUES (?, ?)",
282                    ["test", "test"],
283                )
284                .expect("Failed to execute test");
285            assert!(result > 0, "Failed to execute test");
286        }
287    }
288
289    mod queries {
290        use super::*;
291
292        #[test]
293        fn test_query() {
294            let db = Database::new(None, "trace.db-wal").expect("Failed to create database");
295
296            {
297                let conn = db.conn.lock().unwrap();
298                conn.execute(
299                    "CREATE TABLE IF NOT EXISTS test_table (name TEXT NOT NULL, description TEXT NOT NULL)",
300                    [],
301                )
302                .expect("Failed to create test table");
303
304                conn.execute(
305                    "INSERT INTO test_table (name, description) VALUES (?, ?)",
306                    ["test1", "desc1"],
307                )
308                .expect("Failed to insert test");
309            }
310
311            let result = db
312                .query("SELECT * FROM test_table", [], |row| {
313                    Ok((row.get::<_, String>(0)?,))
314                })
315                .expect("Failed to query test");
316
317            assert!(!result.is_empty(), "Failed to execute test");
318        }
319
320        #[test]
321        fn test_query_multiple_rows() {
322            let db = Database::new(None, "trace.db-wal").expect("Failed to create database");
323
324            {
325                let conn = db.conn.lock().unwrap();
326                conn.execute(
327                    "CREATE TABLE IF NOT EXISTS test_table (name TEXT NOT NULL, description TEXT NOT NULL)",
328                    [],
329                )
330                .expect("Failed to create test table");
331
332                conn.execute(
333                    "INSERT INTO test_table (name, description) VALUES (?, ?)",
334                    ["test1", "desc1"],
335                )
336                .expect("Failed to insert test");
337                conn.execute(
338                    "INSERT INTO test_table (name, description) VALUES (?, ?)",
339                    ["test2", "desc2"],
340                )
341                .expect("Failed to insert test");
342                conn.execute(
343                    "INSERT INTO test_table (name, description) VALUES (?, ?)",
344                    ["test3", "desc3"],
345                )
346                .expect("Failed to insert test");
347            }
348
349            let result = db
350                .query("SELECT * FROM test_table", [], |row| {
351                    Ok((row.get::<_, String>(0)?,))
352                })
353                .expect("Failed to query test");
354
355            assert!(result.len() == 3, "Failed to execute test");
356        }
357    }
358
359    mod utilities {
360        use super::super::*;
361
362        #[test]
363        fn test_table_exists() {
364            let db = Database::new(None, "trace.db-wal").expect("Failed to create database");
365
366            {
367                let conn = db.conn.lock().unwrap();
368                conn.execute(
369                    "CREATE TABLE IF NOT EXISTS test_table (name TEXT NOT NULL, description TEXT NOT NULL)",
370                    [],
371                )
372                .expect("Failed to create test table");
373            }
374
375            let table_exists: bool = db
376                .table_exists("test_table")
377                .expect("Failed to check if table exists");
378
379            assert!(table_exists, "test table was not created");
380        }
381
382        #[test]
383        fn test_validate_table_name() {
384            let db = Database::new(None, "trace.db-wal").expect("Failed to create database");
385
386            let test_simple_name = db.validate_table_name("test_table");
387            assert!(test_simple_name.is_ok(), "Failed to validate table name");
388
389            let test_start_by_number = db.validate_table_name("123test_table");
390            assert!(
391                test_start_by_number.is_err(),
392                "Failed to validate table name"
393            );
394
395            let test_start_by_number_and_special_characters = db.validate_table_name("test; dede");
396            assert!(
397                test_start_by_number_and_special_characters.is_err(),
398                "Failed to validate table name"
399            );
400
401            let test_long_name = db.validate_table_name(
402                "A_long_naaaaaaaaaaaaaaammmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmme",
403            );
404            assert!(test_long_name.is_err(), "Failed to validate table name");
405        }
406    }
407}