Skip to main content

victauri_plugin/
database.rs

1#[cfg(feature = "sqlite")]
2use std::path::{Path, PathBuf};
3
4#[cfg(feature = "sqlite")]
5const MAX_ROWS_DEFAULT: usize = 100;
6#[cfg(feature = "sqlite")]
7const MAX_ROWS_LIMIT: usize = 10_000;
8
9#[cfg(feature = "sqlite")]
10static READ_ONLY_PREFIXES: &[&str] = &["select", "pragma", "explain", "with"];
11
12#[cfg(feature = "sqlite")]
13fn strip_sql_comments(sql: &str) -> String {
14    let mut result = String::with_capacity(sql.len());
15    let bytes = sql.as_bytes();
16    let len = bytes.len();
17    let mut i = 0;
18    while i < len {
19        if i + 1 < len && bytes[i] == b'-' && bytes[i + 1] == b'-' {
20            while i < len && bytes[i] != b'\n' {
21                i += 1;
22            }
23        } else if i + 1 < len && bytes[i] == b'/' && bytes[i + 1] == b'*' {
24            i += 2;
25            while i + 1 < len && !(bytes[i] == b'*' && bytes[i + 1] == b'/') {
26                i += 1;
27            }
28            if i + 1 < len {
29                i += 2;
30            }
31            result.push(' ');
32        } else {
33            result.push(bytes[i] as char);
34            i += 1;
35        }
36    }
37    result
38}
39
40#[cfg(feature = "sqlite")]
41fn is_read_only(sql: &str) -> bool {
42    let cleaned = strip_sql_comments(sql);
43    let trimmed = cleaned.trim_start().to_lowercase();
44    if trimmed.is_empty() {
45        return false;
46    }
47    READ_ONLY_PREFIXES
48        .iter()
49        .any(|prefix| trimmed.starts_with(prefix))
50}
51
52/// Discover `SQLite` database files in a directory (non-recursive, max depth 2).
53#[cfg(feature = "sqlite")]
54#[must_use]
55pub fn discover_databases(dir: &Path) -> Vec<PathBuf> {
56    let mut results = Vec::new();
57    discover_recursive(dir, 0, 2, &mut results);
58    results
59}
60
61#[cfg(feature = "sqlite")]
62fn discover_recursive(dir: &Path, depth: u32, max_depth: u32, results: &mut Vec<PathBuf>) {
63    let Ok(entries) = std::fs::read_dir(dir) else {
64        return;
65    };
66    for entry in entries.flatten() {
67        let path = entry.path();
68        if path.is_symlink() {
69            continue;
70        }
71        if path.is_file() {
72            if let Some(ext) = path.extension().and_then(|e| e.to_str())
73                && matches!(ext, "sqlite" | "sqlite3" | "db" | "sdb")
74            {
75                results.push(path);
76            }
77        } else if path.is_dir() && depth < max_depth {
78            discover_recursive(&path, depth + 1, max_depth, results);
79        }
80    }
81}
82
83/// Execute a read-only SQL query against a `SQLite` database.
84///
85/// # Errors
86///
87/// Returns an error if the query is not read-only, the database cannot be opened,
88/// or the query fails.
89#[cfg(feature = "sqlite")]
90pub fn query(
91    db_path: &Path,
92    sql: &str,
93    params: &[serde_json::Value],
94    max_rows: Option<usize>,
95) -> Result<serde_json::Value, String> {
96    if !is_read_only(sql) {
97        return Err(
98            "only SELECT, PRAGMA, EXPLAIN, and WITH queries are allowed (read-only access)"
99                .to_string(),
100        );
101    }
102
103    let cleaned = strip_sql_comments(sql);
104    if cleaned.contains(';') {
105        let parts: Vec<&str> = cleaned
106            .split(';')
107            .filter(|s| !s.trim().is_empty())
108            .collect();
109        if parts.len() > 1 {
110            return Err(
111                "stacked queries (multiple statements separated by ;) are not allowed".to_string(),
112            );
113        }
114    }
115
116    let max_rows = max_rows.unwrap_or(MAX_ROWS_DEFAULT).min(MAX_ROWS_LIMIT);
117
118    let conn = rusqlite::Connection::open_with_flags(
119        db_path,
120        rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY | rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX,
121    )
122    .map_err(|e| format!("failed to open database: {e}"))?;
123
124    // 5 second query timeout
125    conn.busy_timeout(std::time::Duration::from_secs(5))
126        .map_err(|e| format!("failed to set timeout: {e}"))?;
127
128    let mut stmt = conn
129        .prepare(sql)
130        .map_err(|e| format!("failed to prepare query: {e}"))?;
131
132    let column_names: Vec<String> = stmt
133        .column_names()
134        .iter()
135        .map(|s| (*s).to_string())
136        .collect();
137    let column_count = column_names.len();
138
139    let sqlite_params: Vec<Box<dyn rusqlite::types::ToSql>> =
140        params.iter().map(json_to_sql).collect();
141    let param_refs: Vec<&dyn rusqlite::types::ToSql> = sqlite_params.iter().map(|b| &**b).collect();
142
143    let mut rows_out: Vec<serde_json::Value> = Vec::new();
144    let mut rows = stmt
145        .query(param_refs.as_slice())
146        .map_err(|e| format!("query execution failed: {e}"))?;
147
148    while let Some(row) = rows.next().map_err(|e| format!("row read failed: {e}"))? {
149        if rows_out.len() >= max_rows {
150            break;
151        }
152        let mut obj = serde_json::Map::new();
153        for (i, col_name) in column_names.iter().enumerate().take(column_count) {
154            let value = row_value_to_json(row, i);
155            obj.insert(col_name.clone(), value);
156        }
157        rows_out.push(serde_json::Value::Object(obj));
158    }
159
160    let truncated = rows_out.len() == max_rows;
161
162    Ok(serde_json::json!({
163        "columns": column_names,
164        "rows": rows_out,
165        "row_count": rows_out.len(),
166        "truncated": truncated,
167        "max_rows": max_rows,
168    }))
169}
170
171#[cfg(feature = "sqlite")]
172fn json_to_sql(val: &serde_json::Value) -> Box<dyn rusqlite::types::ToSql> {
173    match val {
174        serde_json::Value::Null => Box::new(rusqlite::types::Null),
175        serde_json::Value::Bool(b) => Box::new(*b),
176        serde_json::Value::Number(n) => {
177            if let Some(i) = n.as_i64() {
178                Box::new(i)
179            } else if let Some(f) = n.as_f64() {
180                Box::new(f)
181            } else {
182                Box::new(n.to_string())
183            }
184        }
185        serde_json::Value::String(s) => Box::new(s.clone()),
186        other => Box::new(other.to_string()),
187    }
188}
189
190#[cfg(feature = "sqlite")]
191fn row_value_to_json(row: &rusqlite::Row, idx: usize) -> serde_json::Value {
192    use rusqlite::types::ValueRef;
193    match row.get_ref(idx) {
194        Ok(ValueRef::Null) => serde_json::Value::Null,
195        Ok(ValueRef::Integer(i)) => serde_json::json!(i),
196        Ok(ValueRef::Real(f)) => serde_json::json!(f),
197        Ok(ValueRef::Text(t)) => {
198            let s = String::from_utf8_lossy(t);
199            if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&s)
200                && (parsed.is_object() || parsed.is_array())
201            {
202                return parsed;
203            }
204            serde_json::Value::String(s.into_owned())
205        }
206        Ok(ValueRef::Blob(b)) => {
207            use base64::Engine;
208            serde_json::json!({
209                "__blob": true,
210                "size": b.len(),
211                "base64": base64::engine::general_purpose::STANDARD.encode(b),
212            })
213        }
214        Err(_) => serde_json::Value::Null,
215    }
216}
217
218#[cfg(all(test, feature = "sqlite"))]
219mod tests {
220    use super::*;
221
222    fn create_test_db() -> (tempfile::NamedTempFile, PathBuf) {
223        let file = tempfile::NamedTempFile::with_suffix(".sqlite").unwrap();
224        let path = file.path().to_path_buf();
225        let conn = rusqlite::Connection::open(&path).unwrap();
226        conn.execute_batch(
227            "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT, score REAL);
228             INSERT INTO users VALUES (1, 'Alice', 95.5);
229             INSERT INTO users VALUES (2, 'Bob', 87.0);
230             INSERT INTO users VALUES (3, 'Charlie', 92.3);",
231        )
232        .unwrap();
233        (file, path)
234    }
235
236    #[test]
237    fn select_all_rows() {
238        let (_f, path) = create_test_db();
239        let result = query(&path, "SELECT * FROM users", &[], None).unwrap();
240        assert_eq!(result["row_count"], 3);
241        assert_eq!(
242            result["columns"],
243            serde_json::json!(["id", "name", "score"])
244        );
245        assert_eq!(result["rows"][0]["name"], "Alice");
246        assert_eq!(result["rows"][1]["name"], "Bob");
247    }
248
249    #[test]
250    fn select_with_params() {
251        let (_f, path) = create_test_db();
252        let result = query(
253            &path,
254            "SELECT name FROM users WHERE score > ?",
255            &[serde_json::json!(90.0)],
256            None,
257        )
258        .unwrap();
259        assert_eq!(result["row_count"], 2);
260    }
261
262    #[test]
263    fn max_rows_truncation() {
264        let (_f, path) = create_test_db();
265        let result = query(&path, "SELECT * FROM users", &[], Some(2)).unwrap();
266        assert_eq!(result["row_count"], 2);
267        assert_eq!(result["truncated"], true);
268    }
269
270    #[test]
271    fn rejects_insert() {
272        let (_f, path) = create_test_db();
273        let err = query(
274            &path,
275            "INSERT INTO users VALUES (4, 'Eve', 99.0)",
276            &[],
277            None,
278        )
279        .unwrap_err();
280        assert!(err.contains("read-only"));
281    }
282
283    #[test]
284    fn rejects_delete() {
285        let (_f, path) = create_test_db();
286        let err = query(&path, "DELETE FROM users", &[], None).unwrap_err();
287        assert!(err.contains("read-only"));
288    }
289
290    #[test]
291    fn rejects_drop() {
292        let (_f, path) = create_test_db();
293        let err = query(&path, "DROP TABLE users", &[], None).unwrap_err();
294        assert!(err.contains("read-only"));
295    }
296
297    #[test]
298    fn rejects_update() {
299        let (_f, path) = create_test_db();
300        let err = query(&path, "UPDATE users SET name = 'X'", &[], None).unwrap_err();
301        assert!(err.contains("read-only"));
302    }
303
304    #[test]
305    fn pragma_works() {
306        let (_f, path) = create_test_db();
307        let result = query(&path, "PRAGMA table_info(users)", &[], None).unwrap();
308        assert!(result["row_count"].as_u64().unwrap() >= 3);
309    }
310
311    #[test]
312    fn with_cte_works() {
313        let (_f, path) = create_test_db();
314        let result = query(
315            &path,
316            "WITH top AS (SELECT * FROM users WHERE score > 90) SELECT name FROM top",
317            &[],
318            None,
319        )
320        .unwrap();
321        assert_eq!(result["row_count"], 2);
322    }
323
324    #[test]
325    fn nonexistent_db_fails() {
326        let err = query(Path::new("/nonexistent/db.sqlite"), "SELECT 1", &[], None).unwrap_err();
327        assert!(err.contains("failed to open"));
328    }
329
330    #[test]
331    fn json_column_parsed() {
332        let file = tempfile::NamedTempFile::with_suffix(".sqlite").unwrap();
333        let path = file.path().to_path_buf();
334        let conn = rusqlite::Connection::open(&path).unwrap();
335        conn.execute_batch(
336            r#"CREATE TABLE config (key TEXT, value TEXT);
337               INSERT INTO config VALUES ('settings', '{"theme":"dark","lang":"en"}');"#,
338        )
339        .unwrap();
340        let result = query(&path, "SELECT * FROM config", &[], None).unwrap();
341        assert!(result["rows"][0]["value"].is_object());
342        assert_eq!(result["rows"][0]["value"]["theme"], "dark");
343    }
344
345    #[test]
346    fn discover_finds_sqlite_files() {
347        let dir = tempfile::tempdir().unwrap();
348        std::fs::File::create(dir.path().join("app.sqlite")).unwrap();
349        std::fs::File::create(dir.path().join("cache.db")).unwrap();
350        std::fs::File::create(dir.path().join("readme.txt")).unwrap();
351        let sub = dir.path().join("subdir");
352        std::fs::create_dir(&sub).unwrap();
353        std::fs::File::create(sub.join("deep.sqlite3")).unwrap();
354
355        let dbs = discover_databases(dir.path());
356        assert_eq!(dbs.len(), 3);
357    }
358
359    #[test]
360    fn rejects_comment_bypass_block() {
361        let (_f, path) = create_test_db();
362        let err = query(&path, "/* sneaky */DELETE FROM users", &[], None).unwrap_err();
363        assert!(err.contains("read-only"));
364    }
365
366    #[test]
367    fn rejects_line_comment_bypass() {
368        let (_f, path) = create_test_db();
369        let err = query(&path, "-- comment\nDELETE FROM users", &[], None).unwrap_err();
370        assert!(err.contains("read-only"));
371    }
372
373    #[test]
374    fn rejects_stacked_queries() {
375        let (_f, path) = create_test_db();
376        let err = query(&path, "SELECT 1; DROP TABLE users", &[], None).unwrap_err();
377        assert!(err.contains("stacked queries"));
378    }
379
380    #[test]
381    fn allows_trailing_semicolon() {
382        let (_f, path) = create_test_db();
383        let result = query(&path, "SELECT * FROM users;", &[], None).unwrap();
384        assert_eq!(result["row_count"], 3);
385    }
386
387    #[test]
388    fn allows_select_with_block_comment() {
389        let (_f, path) = create_test_db();
390        let result = query(
391            &path,
392            "/* filter */ SELECT name FROM users WHERE id = 1",
393            &[],
394            None,
395        )
396        .unwrap();
397        assert_eq!(result["row_count"], 1);
398        assert_eq!(result["rows"][0]["name"], "Alice");
399    }
400
401    #[test]
402    fn rejects_empty_query() {
403        let (_f, path) = create_test_db();
404        let err = query(&path, "", &[], None).unwrap_err();
405        assert!(err.contains("read-only"));
406    }
407
408    #[test]
409    fn rejects_comment_only_query() {
410        let (_f, path) = create_test_db();
411        let err = query(&path, "/* just a comment */", &[], None).unwrap_err();
412        assert!(err.contains("read-only"));
413    }
414
415    #[test]
416    fn rejects_nested_comment_bypass() {
417        let (_f, path) = create_test_db();
418        let err = query(
419            &path,
420            "/* outer /* inner */ still comment */ DROP TABLE users",
421            &[],
422            None,
423        )
424        .unwrap_err();
425        assert!(err.contains("read-only"));
426    }
427
428    #[test]
429    fn blob_column_base64() {
430        let file = tempfile::NamedTempFile::with_suffix(".sqlite").unwrap();
431        let path = file.path().to_path_buf();
432        let conn = rusqlite::Connection::open(&path).unwrap();
433        conn.execute_batch("CREATE TABLE blobs (id INTEGER, data BLOB)")
434            .unwrap();
435        conn.execute("INSERT INTO blobs VALUES (1, X'DEADBEEF')", [])
436            .unwrap();
437        let result = query(&path, "SELECT * FROM blobs", &[], None).unwrap();
438        assert!(result["rows"][0]["data"]["__blob"].as_bool().unwrap());
439        assert_eq!(result["rows"][0]["data"]["size"], 4);
440    }
441}