Skip to main content

mixtape_tools/sqlite/query/
read.rs

1//! Read query tool
2
3use crate::prelude::*;
4use crate::sqlite::error::SqliteToolError;
5use crate::sqlite::manager::with_connection;
6use crate::sqlite::types::{json_to_sql, QueryResult};
7use rusqlite::types::ValueRef;
8
9/// Input for read query execution
10#[derive(Debug, Deserialize, JsonSchema)]
11pub struct ReadQueryInput {
12    /// SQL query to execute (SELECT, PRAGMA, or EXPLAIN only)
13    pub query: String,
14
15    /// Query parameters for prepared statements
16    #[serde(default)]
17    pub params: Vec<serde_json::Value>,
18
19    /// Database file path. If not specified, uses the default database.
20    #[serde(default)]
21    pub db_path: Option<String>,
22
23    /// Maximum number of rows to return (default: 1000)
24    #[serde(default = "default_limit")]
25    pub limit: usize,
26
27    /// Number of rows to skip (default: 0)
28    #[serde(default)]
29    pub offset: usize,
30}
31
32impl ReadQueryInput {
33    /// Creates a new ReadQueryInput with the given query and default values.
34    pub fn new(query: impl Into<String>) -> Self {
35        Self {
36            query: query.into(),
37            params: vec![],
38            db_path: None,
39            limit: 1000,
40            offset: 0,
41        }
42    }
43
44    /// Sets the database path.
45    pub fn db_path(mut self, path: impl Into<String>) -> Self {
46        self.db_path = Some(path.into());
47        self
48    }
49
50    /// Sets the query parameters.
51    pub fn params(mut self, params: Vec<serde_json::Value>) -> Self {
52        self.params = params;
53        self
54    }
55}
56
57fn default_limit() -> usize {
58    1000
59}
60
61/// Tool for executing read-only queries (SAFE)
62///
63/// Executes SELECT, PRAGMA, and EXPLAIN queries.
64/// Other query types will be rejected for safety.
65pub struct ReadQueryTool;
66
67impl ReadQueryTool {
68    /// Validates that a query is read-only
69    fn is_read_only(sql: &str) -> bool {
70        let normalized = sql.trim().to_uppercase();
71
72        // Check for allowed prefixes
73        let allowed_prefixes = ["SELECT", "PRAGMA", "EXPLAIN", "WITH"];
74
75        // WITH queries should eventually lead to SELECT
76        if normalized.starts_with("WITH") {
77            // Basic check - could be more sophisticated
78            return normalized.contains("SELECT");
79        }
80
81        allowed_prefixes
82            .iter()
83            .any(|prefix| normalized.starts_with(prefix))
84    }
85}
86
87impl Tool for ReadQueryTool {
88    type Input = ReadQueryInput;
89
90    fn name(&self) -> &str {
91        "sqlite_read_query"
92    }
93
94    fn description(&self) -> &str {
95        "Execute a read-only SQL query (SELECT, PRAGMA, EXPLAIN). Returns the query results with column names and row data."
96    }
97
98    async fn execute(&self, input: Self::Input) -> Result<ToolResult, ToolError> {
99        // Validate query is read-only
100        if !Self::is_read_only(&input.query) {
101            return Err(SqliteToolError::InvalidQuery(
102                "Only SELECT, PRAGMA, EXPLAIN, and WITH...SELECT queries are allowed. Use sqlite_write_query for modifications.".to_string()
103            ).into());
104        }
105
106        let query = input.query;
107        let params = input.params;
108        let limit = input.limit;
109        let offset = input.offset;
110
111        let result = with_connection(input.db_path, move |conn| {
112            let mut stmt = conn.prepare(&query)?;
113
114            // Get column names
115            let columns: Vec<String> = stmt.column_names().iter().map(|s| s.to_string()).collect();
116
117            // Convert params to rusqlite values
118            let params_ref: Vec<Box<dyn rusqlite::ToSql>> =
119                params.iter().map(|v| json_to_sql(v)).collect();
120
121            let params_slice: Vec<&dyn rusqlite::ToSql> =
122                params_ref.iter().map(|b| b.as_ref()).collect();
123
124            // Execute query and collect rows
125            let mut rows_result = stmt.query(params_slice.as_slice())?;
126            let mut rows: Vec<Vec<serde_json::Value>> = Vec::new();
127            let mut skipped = 0;
128
129            while let Some(row) = rows_result.next()? {
130                // Handle offset
131                if skipped < offset {
132                    skipped += 1;
133                    continue;
134                }
135
136                // Handle limit
137                if rows.len() >= limit {
138                    break;
139                }
140
141                let mut row_data: Vec<serde_json::Value> = Vec::new();
142                for i in 0..columns.len() {
143                    let value = row.get_ref(i)?;
144                    row_data.push(sql_to_json(value));
145                }
146                rows.push(row_data);
147            }
148
149            Ok(QueryResult {
150                row_count: rows.len(),
151                columns,
152                rows,
153                rows_affected: None,
154            })
155        })
156        .await?;
157
158        Ok(ToolResult::Json(serde_json::to_value(result)?))
159    }
160}
161
162/// Convert a rusqlite value to JSON
163fn sql_to_json(value: ValueRef) -> serde_json::Value {
164    match value {
165        ValueRef::Null => serde_json::Value::Null,
166        ValueRef::Integer(i) => serde_json::Value::Number(i.into()),
167        ValueRef::Real(f) => serde_json::Number::from_f64(f)
168            .map(serde_json::Value::Number)
169            .unwrap_or(serde_json::Value::Null),
170        ValueRef::Text(s) => serde_json::Value::String(String::from_utf8_lossy(s).to_string()),
171        ValueRef::Blob(b) => {
172            // Return as base64-encoded string
173            use base64::Engine;
174            serde_json::Value::String(base64::engine::general_purpose::STANDARD.encode(b))
175        }
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182    use crate::sqlite::test_utils::{unwrap_json, TestDatabase};
183
184    #[tokio::test]
185    async fn test_read_query() {
186        let db = TestDatabase::with_schema(
187            "CREATE TABLE users (id INTEGER, name TEXT);
188             INSERT INTO users VALUES (1, 'Alice');
189             INSERT INTO users VALUES (2, 'Bob');",
190        )
191        .await;
192
193        let result = ReadQueryTool
194            .execute(ReadQueryInput::new("SELECT * FROM users ORDER BY id").db_path(db.key()))
195            .await
196            .unwrap();
197
198        let json = unwrap_json(result);
199        assert_eq!(json["row_count"].as_i64().unwrap(), 2);
200        assert_eq!(json["columns"].as_array().unwrap().len(), 2);
201    }
202
203    #[tokio::test]
204    async fn test_reject_write_query() {
205        let db = TestDatabase::new().await;
206
207        let result = ReadQueryTool
208            .execute(ReadQueryInput::new("INSERT INTO users VALUES (1, 'test')").db_path(db.key()))
209            .await;
210        assert!(result.is_err());
211    }
212
213    #[test]
214    fn test_is_read_only() {
215        assert!(ReadQueryTool::is_read_only("SELECT * FROM users"));
216        assert!(ReadQueryTool::is_read_only("  SELECT * FROM users"));
217        assert!(ReadQueryTool::is_read_only("PRAGMA table_info(users)"));
218        assert!(ReadQueryTool::is_read_only("EXPLAIN SELECT * FROM users"));
219        assert!(ReadQueryTool::is_read_only(
220            "WITH cte AS (SELECT 1) SELECT * FROM cte"
221        ));
222
223        assert!(!ReadQueryTool::is_read_only("INSERT INTO users VALUES (1)"));
224        assert!(!ReadQueryTool::is_read_only("UPDATE users SET name = 'x'"));
225        assert!(!ReadQueryTool::is_read_only("DELETE FROM users"));
226        assert!(!ReadQueryTool::is_read_only("DROP TABLE users"));
227    }
228
229    #[test]
230    fn test_tool_metadata() {
231        let tool = ReadQueryTool;
232        assert_eq!(tool.name(), "sqlite_read_query");
233        assert!(!tool.description().is_empty());
234    }
235
236    #[tokio::test]
237    async fn test_parameterized_query_with_types() {
238        let db = TestDatabase::with_schema(
239            "CREATE TABLE data (id INTEGER, name TEXT, score REAL, active INTEGER);
240             INSERT INTO data VALUES (1, 'Alice', 95.5, 1);
241             INSERT INTO data VALUES (2, 'Bob', 87.0, 0);
242             INSERT INTO data VALUES (3, NULL, 72.5, 1);",
243        )
244        .await;
245
246        // Test with integer parameter
247        let result = ReadQueryTool
248            .execute(ReadQueryInput {
249                query: "SELECT * FROM data WHERE id = ?".to_string(),
250                params: vec![serde_json::json!(2)],
251                db_path: Some(db.key()),
252                limit: 1000,
253                offset: 0,
254            })
255            .await
256            .unwrap();
257
258        let json = unwrap_json(result);
259        assert_eq!(json["row_count"], 1);
260        assert_eq!(json["rows"][0][1], "Bob");
261
262        // Test with string parameter
263        let json = unwrap_json(
264            ReadQueryTool
265                .execute(ReadQueryInput {
266                    query: "SELECT * FROM data WHERE name = ?".to_string(),
267                    params: vec![serde_json::json!("Alice")],
268                    db_path: Some(db.key()),
269                    limit: 1000,
270                    offset: 0,
271                })
272                .await
273                .unwrap(),
274        );
275        assert_eq!(json["row_count"], 1);
276
277        // Test with float parameter
278        let json = unwrap_json(
279            ReadQueryTool
280                .execute(ReadQueryInput {
281                    query: "SELECT * FROM data WHERE score > ?".to_string(),
282                    params: vec![serde_json::json!(90.0)],
283                    db_path: Some(db.key()),
284                    limit: 1000,
285                    offset: 0,
286                })
287                .await
288                .unwrap(),
289        );
290        assert_eq!(json["row_count"], 1);
291
292        // Test with boolean parameter (converts to 1/0)
293        let json = unwrap_json(
294            ReadQueryTool
295                .execute(ReadQueryInput {
296                    query: "SELECT * FROM data WHERE active = ?".to_string(),
297                    params: vec![serde_json::json!(true)],
298                    db_path: Some(db.key()),
299                    limit: 1000,
300                    offset: 0,
301                })
302                .await
303                .unwrap(),
304        );
305        assert_eq!(json["row_count"], 2);
306
307        // Test with multiple parameters
308        let json = unwrap_json(
309            ReadQueryTool
310                .execute(ReadQueryInput {
311                    query: "SELECT * FROM data WHERE id > ? AND score < ?".to_string(),
312                    params: vec![serde_json::json!(1), serde_json::json!(80.0)],
313                    db_path: Some(db.key()),
314                    limit: 1000,
315                    offset: 0,
316                })
317                .await
318                .unwrap(),
319        );
320        assert_eq!(json["row_count"], 1);
321        assert_eq!(json["rows"][0][0], 3);
322    }
323
324    #[tokio::test]
325    async fn test_null_parameter() {
326        let db = TestDatabase::with_schema(
327            "CREATE TABLE data (id INTEGER, name TEXT);
328             INSERT INTO data VALUES (1, 'Alice');
329             INSERT INTO data VALUES (2, NULL);",
330        )
331        .await;
332
333        let result = ReadQueryTool
334            .execute(ReadQueryInput {
335                query: "SELECT * FROM data WHERE name IS ?".to_string(),
336                params: vec![serde_json::Value::Null],
337                db_path: Some(db.key()),
338                limit: 1000,
339                offset: 0,
340            })
341            .await
342            .unwrap();
343
344        let json = unwrap_json(result);
345        assert_eq!(json["row_count"], 1);
346        assert_eq!(json["rows"][0][0], 2);
347    }
348
349    #[tokio::test]
350    async fn test_blob_data_base64() {
351        let db = TestDatabase::with_schema("CREATE TABLE files (id INTEGER, data BLOB);").await;
352        // Insert raw bytes
353        db.execute("INSERT INTO files VALUES (1, X'48656C6C6F')"); // "Hello" in hex
354        db.execute("INSERT INTO files VALUES (2, X'0001020304')");
355
356        let result = ReadQueryTool
357            .execute(ReadQueryInput {
358                query: "SELECT * FROM files ORDER BY id".to_string(),
359                params: vec![],
360                db_path: Some(db.key()),
361                limit: 1000,
362                offset: 0,
363            })
364            .await
365            .unwrap();
366
367        let json = unwrap_json(result);
368
369        // Blobs should be returned as base64-encoded strings
370        use base64::Engine;
371        let expected_hello = base64::engine::general_purpose::STANDARD.encode(b"Hello");
372        let expected_bytes = base64::engine::general_purpose::STANDARD.encode([0, 1, 2, 3, 4]);
373
374        assert_eq!(json["rows"][0][1], expected_hello);
375        assert_eq!(json["rows"][1][1], expected_bytes);
376    }
377
378    #[tokio::test]
379    async fn test_limit_parameter() {
380        let db = TestDatabase::with_schema(
381            "CREATE TABLE numbers (n INTEGER);
382             INSERT INTO numbers VALUES (1), (2), (3), (4), (5);",
383        )
384        .await;
385
386        let result = ReadQueryTool
387            .execute(ReadQueryInput {
388                query: "SELECT * FROM numbers ORDER BY n".to_string(),
389                params: vec![],
390                db_path: Some(db.key()),
391                limit: 2,
392                offset: 0,
393            })
394            .await
395            .unwrap();
396
397        let json = unwrap_json(result);
398        assert_eq!(json["row_count"], 2);
399        assert_eq!(json["rows"][0][0], 1);
400        assert_eq!(json["rows"][1][0], 2);
401    }
402
403    #[tokio::test]
404    async fn test_offset_parameter() {
405        let db = TestDatabase::with_schema(
406            "CREATE TABLE numbers (n INTEGER);
407             INSERT INTO numbers VALUES (1), (2), (3), (4), (5);",
408        )
409        .await;
410
411        let result = ReadQueryTool
412            .execute(ReadQueryInput {
413                query: "SELECT * FROM numbers ORDER BY n".to_string(),
414                params: vec![],
415                db_path: Some(db.key()),
416                limit: 1000,
417                offset: 2,
418            })
419            .await
420            .unwrap();
421
422        let json = unwrap_json(result);
423        assert_eq!(json["row_count"], 3);
424        assert_eq!(json["rows"][0][0], 3);
425        assert_eq!(json["rows"][1][0], 4);
426        assert_eq!(json["rows"][2][0], 5);
427    }
428
429    #[tokio::test]
430    async fn test_limit_and_offset_combined() {
431        let db = TestDatabase::with_schema(
432            "CREATE TABLE numbers (n INTEGER);
433             INSERT INTO numbers VALUES (1), (2), (3), (4), (5);",
434        )
435        .await;
436
437        let result = ReadQueryTool
438            .execute(ReadQueryInput {
439                query: "SELECT * FROM numbers ORDER BY n".to_string(),
440                params: vec![],
441                db_path: Some(db.key()),
442                limit: 2,
443                offset: 1,
444            })
445            .await
446            .unwrap();
447
448        let json = unwrap_json(result);
449        assert_eq!(json["row_count"], 2);
450        assert_eq!(json["rows"][0][0], 2);
451        assert_eq!(json["rows"][1][0], 3);
452    }
453
454    #[tokio::test]
455    async fn test_with_select_query() {
456        let db = TestDatabase::with_schema(
457            "CREATE TABLE employees (id INTEGER, manager_id INTEGER, name TEXT);
458             INSERT INTO employees VALUES (1, NULL, 'CEO');
459             INSERT INTO employees VALUES (2, 1, 'VP');
460             INSERT INTO employees VALUES (3, 2, 'Manager');",
461        )
462        .await;
463
464        let result = ReadQueryTool
465            .execute(ReadQueryInput {
466                query: "WITH managers AS (SELECT * FROM employees WHERE manager_id IS NOT NULL) SELECT * FROM managers".to_string(),
467                params: vec![],
468                db_path: Some(db.key()),
469                limit: 1000,
470                offset: 0,
471            })
472            .await
473            .unwrap();
474
475        let json = unwrap_json(result);
476        assert_eq!(json["row_count"], 2);
477    }
478
479    #[tokio::test]
480    async fn test_pragma_query() {
481        let db =
482            TestDatabase::with_schema("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT);")
483                .await;
484
485        let result = ReadQueryTool
486            .execute(ReadQueryInput {
487                query: "PRAGMA table_info(users)".to_string(),
488                params: vec![],
489                db_path: Some(db.key()),
490                limit: 1000,
491                offset: 0,
492            })
493            .await
494            .unwrap();
495
496        let json = unwrap_json(result);
497        assert_eq!(json["row_count"], 2);
498    }
499
500    #[tokio::test]
501    async fn test_null_in_results() {
502        let db = TestDatabase::with_schema(
503            "CREATE TABLE data (id INTEGER, value TEXT);
504             INSERT INTO data VALUES (1, NULL);",
505        )
506        .await;
507
508        let result = ReadQueryTool
509            .execute(ReadQueryInput {
510                query: "SELECT * FROM data".to_string(),
511                params: vec![],
512                db_path: Some(db.key()),
513                limit: 1000,
514                offset: 0,
515            })
516            .await
517            .unwrap();
518
519        let json = unwrap_json(result);
520        assert!(json["rows"][0][1].is_null());
521    }
522}