1use rusqlite::{Connection, Error as SqliteError, Row};
2use std::fmt;
3use std::sync::{Arc, Mutex};
4
5#[derive(Debug)]
6pub enum DatabaseError {
7 SqliteError(String),
8 FsError(String),
9 InvalidQuery(String),
10 LockError(String),
11}
12
13impl fmt::Display for DatabaseError {
14 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
15 match self {
16 DatabaseError::SqliteError(msg) => write!(f, "SQLite error > {}", msg),
17 DatabaseError::FsError(msg) => write!(f, "File system error > {}", msg),
18 DatabaseError::InvalidQuery(msg) => write!(f, "Invalid query > {}", msg),
19 DatabaseError::LockError(msg) => write!(f, "Lock error > {}", msg),
20 }
21 }
22}
23
24impl std::error::Error for DatabaseError {}
25
26impl From<SqliteError> for DatabaseError {
27 fn from(err: SqliteError) -> Self {
28 DatabaseError::SqliteError(err.to_string())
29 }
30}
31
32impl From<std::sync::PoisonError<std::sync::MutexGuard<'_, Connection>>> for DatabaseError {
33 fn from(err: std::sync::PoisonError<std::sync::MutexGuard<'_, Connection>>) -> Self {
34 DatabaseError::LockError(err.to_string())
35 }
36}
37
38impl From<serde_json::Error> for DatabaseError {
39 fn from(err: serde_json::Error) -> Self {
40 DatabaseError::FsError(err.to_string())
41 }
42}
43
44impl From<std::io::Error> for DatabaseError {
45 fn from(err: std::io::Error) -> Self {
46 DatabaseError::FsError(err.to_string())
47 }
48}
49
50#[derive(Clone)]
51pub struct Database {
52 pub conn: Arc<Mutex<Connection>>,
53}
54
55impl Database {
56 pub fn new(path: Option<&str>, database_name: &str) -> Result<Self, DatabaseError> {
57 let conn = match path {
58 Some(path) => {
59 let database_path = &format!("{}/{}", path, database_name);
60
61 std::fs::create_dir_all(path)?;
62
63 Connection::open(database_path)?
64 }
65 None => Connection::open(":memory:")?,
66 };
67
68 conn.execute_batch(
69 "
70 PRAGMA journal_mode = WAL;
71 PRAGMA synchronous = NORMAL;
72 PRAGMA cache_size = -64000;
73 PRAGMA foreign_keys = ON;
74 PRAGMA temp_store = MEMORY;
75 PRAGMA mmap_size = 30000000000;
76 ",
77 )?;
78
79 let db = Self {
80 conn: Arc::new(Mutex::new(conn)),
81 };
82
83 Ok(db)
84 }
85
86 pub fn create_table(
87 &self,
88 table: &str,
89 columns: &[&str],
90 constraints: &[&str],
91 ) -> Result<(), DatabaseError> {
92 self.validate_table_name(table)?;
93
94 let conn = self.conn.lock()?;
95
96 let mut all_columns = vec!["id TEXT PRIMARY KEY"];
97
98 all_columns.extend_from_slice(columns);
99 all_columns.push("created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP");
100 all_columns.push("updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP");
101
102 let mut definitions = all_columns
103 .iter()
104 .map(|s| s.to_string())
105 .collect::<Vec<_>>();
106
107 definitions.extend(constraints.iter().map(|s| s.to_string()));
108
109 let sql = format!(
110 "CREATE TABLE IF NOT EXISTS {} (\n {}\n )",
111 table,
112 definitions.join(",\n ")
113 );
114
115 conn.execute(&sql, [])?;
116
117 Ok(())
118 }
119
120 pub fn execute<P>(&self, query: &str, params: P) -> Result<usize, DatabaseError>
121 where
122 P: rusqlite::Params,
123 {
124 let conn = self.conn.lock()?;
125
126 let result = conn.execute(query, params)?;
127
128 Ok(result)
129 }
130
131 pub fn query<P, F, T>(
132 &self,
133 query: &str,
134 params: P,
135 mut mapper: F,
136 ) -> Result<Vec<T>, DatabaseError>
137 where
138 P: rusqlite::Params,
139 F: FnMut(&Row) -> Result<T, DatabaseError>,
140 {
141 let conn = self.conn.lock()?;
142
143 let mut stmt = conn.prepare(query)?;
144
145 let rows = stmt.query_map(params, |row| {
146 mapper(row).map_err(|_| SqliteError::InvalidQuery)
147 })?;
148
149 let mut results = Vec::new();
150 for row in rows {
151 let value = row?;
152 results.push(value);
153 }
154
155 Ok(results)
156 }
157
158 pub fn table_exists(&self, table: &str) -> Result<bool, DatabaseError> {
159 let conn = self.conn.lock()?;
160
161 let count: i64 = conn.query_row(
162 "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?",
163 [table],
164 |row| row.get(0),
165 )?;
166
167 Ok(count > 0)
168 }
169
170 pub fn validate_table_name(&self, table: &str) -> Result<(), DatabaseError> {
171 if table.is_empty() || table.len() > 64 {
172 return Err(DatabaseError::InvalidQuery(
173 "Table name must be between 1-64 characters".to_string(),
174 ));
175 }
176
177 if !table.chars().all(|c| c.is_alphanumeric() || c == '_') {
178 return Err(DatabaseError::InvalidQuery(format!(
179 "Table name can only contain alphanumeric characters and underscores for {}",
180 table
181 )));
182 }
183
184 if table
185 .chars()
186 .next()
187 .expect("Table name is empty")
188 .is_numeric()
189 {
190 return Err(DatabaseError::InvalidQuery(
191 "Table name cannot start with a number".to_string(),
192 ));
193 }
194
195 Ok(())
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202
203 mod creation {
204 use super::*;
205
206 #[test]
207 fn test_new_database_creates_in_memory_database() {
208 let db = Database::new(None, "trace.db-wal");
209 assert!(db.is_ok(), "Failed to create in-memory database");
210 }
211
212 #[test]
213 fn test_create_table() {
214 let db = Database::new(None, "trace.db-wal").expect("Failed to create database");
215 db.create_table(
216 "test_table",
217 &["name TEXT NOT NULL", "path TEXT NOT NULL"],
218 &[],
219 )
220 .expect("Failed to create test table");
221
222 let conn = db.conn.lock().unwrap();
223
224 let mut test_table_stmt = conn
225 .prepare("SELECT name FROM sqlite_master WHERE type='table' AND name='test_table'")
226 .expect("Failed to prepare query");
227
228 let test_table_exists: bool = test_table_stmt
229 .exists([])
230 .expect("Failed to check if table exists");
231
232 assert!(test_table_exists, "test table was not created");
233 }
234
235 #[test]
236 fn test_create_table_with_constraints() {
237 let db = Database::new(None, "trace.db-wal").expect("Failed to create database");
238 db.create_table(
239 "test_table",
240 &["name TEXT NOT NULL", "path TEXT NOT NULL"],
241 &["UNIQUE (path)"],
242 )
243 .expect("Failed to create test table");
244
245 let conn = db.conn.lock().unwrap();
246
247 conn.execute(
248 "INSERT INTO test_table (name, path) VALUES (?, ?)",
249 ["test", "test"],
250 )
251 .expect("Failed to insert test");
252 assert!(
253 conn.execute(
254 "INSERT INTO test_table (name, path) VALUES (?, ?)",
255 ["test", "test"]
256 )
257 .is_err(),
258 "Second insert should have failed"
259 );
260 }
261 }
262
263 mod execution {
264 use super::*;
265
266 #[test]
267 fn test_execute() {
268 let db = Database::new(None, "trace.db-wal").expect("Failed to create database");
269
270 {
271 let conn = db.conn.lock().unwrap();
272 conn.execute(
273 "CREATE TABLE IF NOT EXISTS test_table (name TEXT NOT NULL, description TEXT NOT NULL)",
274 [],
275 )
276 .expect("Failed to create test table");
277 }
278
279 let result = db
280 .execute(
281 "INSERT INTO test_table (name, description) VALUES (?, ?)",
282 ["test", "test"],
283 )
284 .expect("Failed to execute test");
285 assert!(result > 0, "Failed to execute test");
286 }
287 }
288
289 mod queries {
290 use super::*;
291
292 #[test]
293 fn test_query() {
294 let db = Database::new(None, "trace.db-wal").expect("Failed to create database");
295
296 {
297 let conn = db.conn.lock().unwrap();
298 conn.execute(
299 "CREATE TABLE IF NOT EXISTS test_table (name TEXT NOT NULL, description TEXT NOT NULL)",
300 [],
301 )
302 .expect("Failed to create test table");
303
304 conn.execute(
305 "INSERT INTO test_table (name, description) VALUES (?, ?)",
306 ["test1", "desc1"],
307 )
308 .expect("Failed to insert test");
309 }
310
311 let result = db
312 .query("SELECT * FROM test_table", [], |row| {
313 Ok((row.get::<_, String>(0)?,))
314 })
315 .expect("Failed to query test");
316
317 assert!(!result.is_empty(), "Failed to execute test");
318 }
319
320 #[test]
321 fn test_query_multiple_rows() {
322 let db = Database::new(None, "trace.db-wal").expect("Failed to create database");
323
324 {
325 let conn = db.conn.lock().unwrap();
326 conn.execute(
327 "CREATE TABLE IF NOT EXISTS test_table (name TEXT NOT NULL, description TEXT NOT NULL)",
328 [],
329 )
330 .expect("Failed to create test table");
331
332 conn.execute(
333 "INSERT INTO test_table (name, description) VALUES (?, ?)",
334 ["test1", "desc1"],
335 )
336 .expect("Failed to insert test");
337 conn.execute(
338 "INSERT INTO test_table (name, description) VALUES (?, ?)",
339 ["test2", "desc2"],
340 )
341 .expect("Failed to insert test");
342 conn.execute(
343 "INSERT INTO test_table (name, description) VALUES (?, ?)",
344 ["test3", "desc3"],
345 )
346 .expect("Failed to insert test");
347 }
348
349 let result = db
350 .query("SELECT * FROM test_table", [], |row| {
351 Ok((row.get::<_, String>(0)?,))
352 })
353 .expect("Failed to query test");
354
355 assert!(result.len() == 3, "Failed to execute test");
356 }
357 }
358
359 mod utilities {
360 use super::super::*;
361
362 #[test]
363 fn test_table_exists() {
364 let db = Database::new(None, "trace.db-wal").expect("Failed to create database");
365
366 {
367 let conn = db.conn.lock().unwrap();
368 conn.execute(
369 "CREATE TABLE IF NOT EXISTS test_table (name TEXT NOT NULL, description TEXT NOT NULL)",
370 [],
371 )
372 .expect("Failed to create test table");
373 }
374
375 let table_exists: bool = db
376 .table_exists("test_table")
377 .expect("Failed to check if table exists");
378
379 assert!(table_exists, "test table was not created");
380 }
381
382 #[test]
383 fn test_validate_table_name() {
384 let db = Database::new(None, "trace.db-wal").expect("Failed to create database");
385
386 let test_simple_name = db.validate_table_name("test_table");
387 assert!(test_simple_name.is_ok(), "Failed to validate table name");
388
389 let test_start_by_number = db.validate_table_name("123test_table");
390 assert!(
391 test_start_by_number.is_err(),
392 "Failed to validate table name"
393 );
394
395 let test_start_by_number_and_special_characters = db.validate_table_name("test; dede");
396 assert!(
397 test_start_by_number_and_special_characters.is_err(),
398 "Failed to validate table name"
399 );
400
401 let test_long_name = db.validate_table_name(
402 "A_long_naaaaaaaaaaaaaaammmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmme",
403 );
404 assert!(test_long_name.is_err(), "Failed to validate table name");
405 }
406 }
407}