Skip to main content

victauri_plugin/
database.rs

1#[cfg(feature = "sqlite")]
2use std::path::{Path, PathBuf};
3
4#[cfg(feature = "sqlite")]
5const MAX_ROWS_DEFAULT: usize = 100;
6#[cfg(feature = "sqlite")]
7const MAX_ROWS_LIMIT: usize = 10_000;
8
9#[cfg(feature = "sqlite")]
10static READ_ONLY_PREFIXES: &[&str] = &["select", "pragma", "explain", "with"];
11
12#[cfg(feature = "sqlite")]
13fn is_read_only(sql: &str) -> bool {
14    let trimmed = sql.trim_start().to_lowercase();
15    READ_ONLY_PREFIXES
16        .iter()
17        .any(|prefix| trimmed.starts_with(prefix))
18}
19
20/// Discover `SQLite` database files in a directory (non-recursive, max depth 2).
21#[cfg(feature = "sqlite")]
22#[must_use]
23pub fn discover_databases(dir: &Path) -> Vec<PathBuf> {
24    let mut results = Vec::new();
25    discover_recursive(dir, 0, 2, &mut results);
26    results
27}
28
29#[cfg(feature = "sqlite")]
30fn discover_recursive(dir: &Path, depth: u32, max_depth: u32, results: &mut Vec<PathBuf>) {
31    let Ok(entries) = std::fs::read_dir(dir) else {
32        return;
33    };
34    for entry in entries.flatten() {
35        let path = entry.path();
36        if path.is_symlink() {
37            continue;
38        }
39        if path.is_file() {
40            if let Some(ext) = path.extension().and_then(|e| e.to_str())
41                && matches!(ext, "sqlite" | "sqlite3" | "db" | "sdb")
42            {
43                results.push(path);
44            }
45        } else if path.is_dir() && depth < max_depth {
46            discover_recursive(&path, depth + 1, max_depth, results);
47        }
48    }
49}
50
51/// Execute a read-only SQL query against a `SQLite` database.
52///
53/// # Errors
54///
55/// Returns an error if the query is not read-only, the database cannot be opened,
56/// or the query fails.
57#[cfg(feature = "sqlite")]
58pub fn query(
59    db_path: &Path,
60    sql: &str,
61    params: &[serde_json::Value],
62    max_rows: Option<usize>,
63) -> Result<serde_json::Value, String> {
64    if !is_read_only(sql) {
65        return Err(
66            "only SELECT, PRAGMA, EXPLAIN, and WITH queries are allowed (read-only access)"
67                .to_string(),
68        );
69    }
70
71    let max_rows = max_rows.unwrap_or(MAX_ROWS_DEFAULT).min(MAX_ROWS_LIMIT);
72
73    let conn = rusqlite::Connection::open_with_flags(
74        db_path,
75        rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY | rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX,
76    )
77    .map_err(|e| format!("failed to open database: {e}"))?;
78
79    // 5 second query timeout
80    conn.busy_timeout(std::time::Duration::from_secs(5))
81        .map_err(|e| format!("failed to set timeout: {e}"))?;
82
83    let mut stmt = conn
84        .prepare(sql)
85        .map_err(|e| format!("failed to prepare query: {e}"))?;
86
87    let column_names: Vec<String> = stmt
88        .column_names()
89        .iter()
90        .map(|s| (*s).to_string())
91        .collect();
92    let column_count = column_names.len();
93
94    let sqlite_params: Vec<Box<dyn rusqlite::types::ToSql>> =
95        params.iter().map(json_to_sql).collect();
96    let param_refs: Vec<&dyn rusqlite::types::ToSql> = sqlite_params.iter().map(|b| &**b).collect();
97
98    let mut rows_out: Vec<serde_json::Value> = Vec::new();
99    let mut rows = stmt
100        .query(param_refs.as_slice())
101        .map_err(|e| format!("query execution failed: {e}"))?;
102
103    while let Some(row) = rows.next().map_err(|e| format!("row read failed: {e}"))? {
104        if rows_out.len() >= max_rows {
105            break;
106        }
107        let mut obj = serde_json::Map::new();
108        for (i, col_name) in column_names.iter().enumerate().take(column_count) {
109            let value = row_value_to_json(row, i);
110            obj.insert(col_name.clone(), value);
111        }
112        rows_out.push(serde_json::Value::Object(obj));
113    }
114
115    let truncated = rows_out.len() == max_rows;
116
117    Ok(serde_json::json!({
118        "columns": column_names,
119        "rows": rows_out,
120        "row_count": rows_out.len(),
121        "truncated": truncated,
122        "max_rows": max_rows,
123    }))
124}
125
126#[cfg(feature = "sqlite")]
127fn json_to_sql(val: &serde_json::Value) -> Box<dyn rusqlite::types::ToSql> {
128    match val {
129        serde_json::Value::Null => Box::new(rusqlite::types::Null),
130        serde_json::Value::Bool(b) => Box::new(*b),
131        serde_json::Value::Number(n) => {
132            if let Some(i) = n.as_i64() {
133                Box::new(i)
134            } else if let Some(f) = n.as_f64() {
135                Box::new(f)
136            } else {
137                Box::new(n.to_string())
138            }
139        }
140        serde_json::Value::String(s) => Box::new(s.clone()),
141        other => Box::new(other.to_string()),
142    }
143}
144
145#[cfg(feature = "sqlite")]
146fn row_value_to_json(row: &rusqlite::Row, idx: usize) -> serde_json::Value {
147    use rusqlite::types::ValueRef;
148    match row.get_ref(idx) {
149        Ok(ValueRef::Null) => serde_json::Value::Null,
150        Ok(ValueRef::Integer(i)) => serde_json::json!(i),
151        Ok(ValueRef::Real(f)) => serde_json::json!(f),
152        Ok(ValueRef::Text(t)) => {
153            let s = String::from_utf8_lossy(t);
154            if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&s)
155                && (parsed.is_object() || parsed.is_array())
156            {
157                return parsed;
158            }
159            serde_json::Value::String(s.into_owned())
160        }
161        Ok(ValueRef::Blob(b)) => {
162            use base64::Engine;
163            serde_json::json!({
164                "__blob": true,
165                "size": b.len(),
166                "base64": base64::engine::general_purpose::STANDARD.encode(b),
167            })
168        }
169        Err(_) => serde_json::Value::Null,
170    }
171}
172
173#[cfg(all(test, feature = "sqlite"))]
174mod tests {
175    use super::*;
176
177    fn create_test_db() -> (tempfile::NamedTempFile, PathBuf) {
178        let file = tempfile::NamedTempFile::with_suffix(".sqlite").unwrap();
179        let path = file.path().to_path_buf();
180        let conn = rusqlite::Connection::open(&path).unwrap();
181        conn.execute_batch(
182            "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT, score REAL);
183             INSERT INTO users VALUES (1, 'Alice', 95.5);
184             INSERT INTO users VALUES (2, 'Bob', 87.0);
185             INSERT INTO users VALUES (3, 'Charlie', 92.3);",
186        )
187        .unwrap();
188        (file, path)
189    }
190
191    #[test]
192    fn select_all_rows() {
193        let (_f, path) = create_test_db();
194        let result = query(&path, "SELECT * FROM users", &[], None).unwrap();
195        assert_eq!(result["row_count"], 3);
196        assert_eq!(
197            result["columns"],
198            serde_json::json!(["id", "name", "score"])
199        );
200        assert_eq!(result["rows"][0]["name"], "Alice");
201        assert_eq!(result["rows"][1]["name"], "Bob");
202    }
203
204    #[test]
205    fn select_with_params() {
206        let (_f, path) = create_test_db();
207        let result = query(
208            &path,
209            "SELECT name FROM users WHERE score > ?",
210            &[serde_json::json!(90.0)],
211            None,
212        )
213        .unwrap();
214        assert_eq!(result["row_count"], 2);
215    }
216
217    #[test]
218    fn max_rows_truncation() {
219        let (_f, path) = create_test_db();
220        let result = query(&path, "SELECT * FROM users", &[], Some(2)).unwrap();
221        assert_eq!(result["row_count"], 2);
222        assert_eq!(result["truncated"], true);
223    }
224
225    #[test]
226    fn rejects_insert() {
227        let (_f, path) = create_test_db();
228        let err = query(
229            &path,
230            "INSERT INTO users VALUES (4, 'Eve', 99.0)",
231            &[],
232            None,
233        )
234        .unwrap_err();
235        assert!(err.contains("read-only"));
236    }
237
238    #[test]
239    fn rejects_delete() {
240        let (_f, path) = create_test_db();
241        let err = query(&path, "DELETE FROM users", &[], None).unwrap_err();
242        assert!(err.contains("read-only"));
243    }
244
245    #[test]
246    fn rejects_drop() {
247        let (_f, path) = create_test_db();
248        let err = query(&path, "DROP TABLE users", &[], None).unwrap_err();
249        assert!(err.contains("read-only"));
250    }
251
252    #[test]
253    fn rejects_update() {
254        let (_f, path) = create_test_db();
255        let err = query(&path, "UPDATE users SET name = 'X'", &[], None).unwrap_err();
256        assert!(err.contains("read-only"));
257    }
258
259    #[test]
260    fn pragma_works() {
261        let (_f, path) = create_test_db();
262        let result = query(&path, "PRAGMA table_info(users)", &[], None).unwrap();
263        assert!(result["row_count"].as_u64().unwrap() >= 3);
264    }
265
266    #[test]
267    fn with_cte_works() {
268        let (_f, path) = create_test_db();
269        let result = query(
270            &path,
271            "WITH top AS (SELECT * FROM users WHERE score > 90) SELECT name FROM top",
272            &[],
273            None,
274        )
275        .unwrap();
276        assert_eq!(result["row_count"], 2);
277    }
278
279    #[test]
280    fn nonexistent_db_fails() {
281        let err = query(Path::new("/nonexistent/db.sqlite"), "SELECT 1", &[], None).unwrap_err();
282        assert!(err.contains("failed to open"));
283    }
284
285    #[test]
286    fn json_column_parsed() {
287        let file = tempfile::NamedTempFile::with_suffix(".sqlite").unwrap();
288        let path = file.path().to_path_buf();
289        let conn = rusqlite::Connection::open(&path).unwrap();
290        conn.execute_batch(
291            r#"CREATE TABLE config (key TEXT, value TEXT);
292               INSERT INTO config VALUES ('settings', '{"theme":"dark","lang":"en"}');"#,
293        )
294        .unwrap();
295        let result = query(&path, "SELECT * FROM config", &[], None).unwrap();
296        assert!(result["rows"][0]["value"].is_object());
297        assert_eq!(result["rows"][0]["value"]["theme"], "dark");
298    }
299
300    #[test]
301    fn discover_finds_sqlite_files() {
302        let dir = tempfile::tempdir().unwrap();
303        std::fs::File::create(dir.path().join("app.sqlite")).unwrap();
304        std::fs::File::create(dir.path().join("cache.db")).unwrap();
305        std::fs::File::create(dir.path().join("readme.txt")).unwrap();
306        let sub = dir.path().join("subdir");
307        std::fs::create_dir(&sub).unwrap();
308        std::fs::File::create(sub.join("deep.sqlite3")).unwrap();
309
310        let dbs = discover_databases(dir.path());
311        assert_eq!(dbs.len(), 3);
312    }
313
314    #[test]
315    fn blob_column_base64() {
316        let file = tempfile::NamedTempFile::with_suffix(".sqlite").unwrap();
317        let path = file.path().to_path_buf();
318        let conn = rusqlite::Connection::open(&path).unwrap();
319        conn.execute_batch("CREATE TABLE blobs (id INTEGER, data BLOB)")
320            .unwrap();
321        conn.execute("INSERT INTO blobs VALUES (1, X'DEADBEEF')", [])
322            .unwrap();
323        let result = query(&path, "SELECT * FROM blobs", &[], None).unwrap();
324        assert!(result["rows"][0]["data"]["__blob"].as_bool().unwrap());
325        assert_eq!(result["rows"][0]["data"]["size"], 4);
326    }
327}