Skip to main content

database_mcp/db/
sqlite.rs

1//! `SQLite` backend implementation via sqlx.
2//!
3//! Implements [`DatabaseBackend`] for `SQLite` file-based databases.
4
5use crate::config::DatabaseConfig;
6use crate::db::backend::DatabaseBackend;
7use crate::db::identifier::validate_identifier;
8use crate::error::AppError;
9use serde_json::{Value, json};
10use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions, SqliteRow};
11use sqlx::{Row, SqlitePool};
12use sqlx_to_json::RowExt;
13use std::collections::HashMap;
14use tracing::info;
15
16/// Converts [`DatabaseConfig`] into [`SqliteConnectOptions`].
17impl From<&DatabaseConfig> for SqliteConnectOptions {
18    fn from(config: &DatabaseConfig) -> Self {
19        let name = config.name.as_deref().unwrap_or_default();
20        SqliteConnectOptions::new().filename(name)
21    }
22}
23
24/// `SQLite` file-based database backend.
25#[derive(Clone)]
26pub struct SqliteBackend {
27    pool: SqlitePool,
28    pub read_only: bool,
29}
30
31impl std::fmt::Debug for SqliteBackend {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        f.debug_struct("SqliteBackend")
34            .field("read_only", &self.read_only)
35            .finish_non_exhaustive()
36    }
37}
38
39impl SqliteBackend {
40    /// Creates a lazy in-memory backend for tests.
41    #[cfg(test)]
42    pub(crate) fn in_memory(read_only: bool) -> Self {
43        let pool = sqlx::sqlite::SqlitePoolOptions::new()
44            .max_connections(1)
45            .connect_lazy("sqlite::memory:")
46            .expect("in-memory SQLite");
47        Self { pool, read_only }
48    }
49
50    /// Creates a new `SQLite` backend from configuration.
51    ///
52    /// # Errors
53    ///
54    /// Returns [`AppError::Connection`] if the database file cannot be opened.
55    pub async fn new(config: &DatabaseConfig) -> Result<Self, AppError> {
56        let name = config.name.as_deref().unwrap_or_default();
57        let pool = SqlitePoolOptions::new()
58            .max_connections(1) // SQLite is single-writer
59            .connect_with(config.into())
60            .await
61            .map_err(|e| AppError::Connection(format!("Failed to open SQLite: {e}")))?;
62
63        info!("SQLite connection initialized: {name}");
64
65        Ok(Self {
66            pool,
67            read_only: config.read_only,
68        })
69    }
70}
71
72impl SqliteBackend {
73    /// Wraps `name` in double quotes for safe use in `SQLite` SQL statements.
74    ///
75    /// Escapes internal double quotes by doubling them.
76    fn quote_identifier(name: &str) -> String {
77        let escaped = name.replace('"', "\"\"");
78        format!("\"{escaped}\"")
79    }
80}
81
82impl DatabaseBackend for SqliteBackend {
83    #[allow(clippy::unused_async)]
84    async fn list_databases(&self) -> Result<Vec<String>, AppError> {
85        // SQLite has one database: "main"
86        Ok(vec!["main".to_string()])
87    }
88
89    async fn list_tables(&self, _database: &str) -> Result<Vec<String>, AppError> {
90        let rows: Vec<(String,)> = sqlx::query_as(
91            "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' ORDER BY name",
92        )
93        .fetch_all(&self.pool)
94        .await
95        .map_err(|e| AppError::Query(e.to_string()))?;
96        Ok(rows.into_iter().map(|r| r.0).collect())
97    }
98
99    async fn get_table_schema(&self, _database: &str, table: &str) -> Result<Value, AppError> {
100        validate_identifier(table)?;
101        let rows: Vec<SqliteRow> = sqlx::query(&format!("PRAGMA table_info({})", Self::quote_identifier(table)))
102            .fetch_all(&self.pool)
103            .await
104            .map_err(|e| AppError::Query(e.to_string()))?;
105
106        if rows.is_empty() {
107            return Err(AppError::TableNotFound(table.to_string()));
108        }
109
110        let mut schema: HashMap<String, Value> = HashMap::new();
111        for row in &rows {
112            let col_name: String = row.try_get("name").unwrap_or_default();
113            let col_type: String = row.try_get("type").unwrap_or_default();
114            let notnull: i32 = row.try_get("notnull").unwrap_or(0);
115            let default: Option<String> = row.try_get("dflt_value").ok();
116            let pk: i32 = row.try_get("pk").unwrap_or(0);
117            schema.insert(
118                col_name,
119                json!({
120                    "type": col_type,
121                    "nullable": notnull == 0,
122                    "key": if pk > 0 { "PRI" } else { "" },
123                    "default": default,
124                    "extra": Value::Null,
125                }),
126            );
127        }
128        Ok(json!(schema))
129    }
130
131    async fn get_table_schema_with_relations(&self, database: &str, table: &str) -> Result<Value, AppError> {
132        let schema = self.get_table_schema(database, table).await?;
133        let mut columns: HashMap<String, Value> = serde_json::from_value(schema).unwrap_or_default();
134
135        // Add null foreign_key to all columns
136        for col in columns.values_mut() {
137            if let Some(obj) = col.as_object_mut() {
138                obj.entry("foreign_key".to_string()).or_insert(Value::Null);
139            }
140        }
141
142        // Get FK info via PRAGMA
143        let fk_rows: Vec<SqliteRow> =
144            sqlx::query(&format!("PRAGMA foreign_key_list({})", Self::quote_identifier(table)))
145                .fetch_all(&self.pool)
146                .await
147                .map_err(|e| AppError::Query(e.to_string()))?;
148
149        for fk_row in &fk_rows {
150            let from_col: String = fk_row.try_get("from").unwrap_or_default();
151            if let Some(col_info) = columns.get_mut(&from_col)
152                && let Some(obj) = col_info.as_object_mut()
153            {
154                let ref_table: String = fk_row.try_get("table").unwrap_or_default();
155                let ref_col: String = fk_row.try_get("to").unwrap_or_default();
156                let on_update: String = fk_row.try_get("on_update").unwrap_or_default();
157                let on_delete: String = fk_row.try_get("on_delete").unwrap_or_default();
158                obj.insert(
159                    "foreign_key".to_string(),
160                    json!({
161                        "constraint_name": Value::Null,
162                        "referenced_table": ref_table,
163                        "referenced_column": ref_col,
164                        "on_update": on_update,
165                        "on_delete": on_delete,
166                    }),
167                );
168            }
169        }
170
171        Ok(json!({
172            "table_name": table,
173            "columns": columns,
174        }))
175    }
176
177    async fn execute_query(&self, sql: &str, _database: Option<&str>) -> Result<Value, AppError> {
178        let rows: Vec<SqliteRow> = sqlx::query(sql)
179            .fetch_all(&self.pool)
180            .await
181            .map_err(|e| AppError::Query(e.to_string()))?;
182        Ok(Value::Array(rows.iter().map(RowExt::to_json).collect()))
183    }
184
185    #[allow(clippy::unused_async)]
186    async fn create_database(&self, _name: &str) -> Result<Value, AppError> {
187        Ok(json!({
188            "status": "unsupported",
189            "message": "SQLite does not support creating databases. Use --db-path to specify the database file.",
190        }))
191    }
192
193    fn dialect(&self) -> Box<dyn sqlparser::dialect::Dialect> {
194        Box::new(sqlparser::dialect::SQLiteDialect {})
195    }
196
197    fn read_only(&self) -> bool {
198        self.read_only
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205    use crate::config::DatabaseBackend;
206
207    #[test]
208    fn quote_identifier_wraps_in_double_quotes() {
209        assert_eq!(SqliteBackend::quote_identifier("users"), "\"users\"");
210        assert_eq!(SqliteBackend::quote_identifier("eu-docker"), "\"eu-docker\"");
211    }
212
213    #[test]
214    fn quote_identifier_escapes_double_quotes() {
215        assert_eq!(SqliteBackend::quote_identifier("test\"db"), "\"test\"\"db\"");
216        assert_eq!(SqliteBackend::quote_identifier("a\"b\"c"), "\"a\"\"b\"\"c\"");
217    }
218
219    #[test]
220    fn try_from_sets_filename() {
221        let config = DatabaseConfig {
222            backend: DatabaseBackend::Sqlite,
223            name: Some("test.db".into()),
224            ..DatabaseConfig::default()
225        };
226        let opts = SqliteConnectOptions::from(&config);
227
228        assert_eq!(opts.get_filename().to_str().expect("valid path"), "test.db");
229    }
230
231    #[test]
232    fn try_from_empty_name_defaults() {
233        let config = DatabaseConfig {
234            backend: DatabaseBackend::Sqlite,
235            name: None,
236            ..DatabaseConfig::default()
237        };
238        let opts = SqliteConnectOptions::from(&config);
239
240        // Empty string filename — validated elsewhere by Config::validate()
241        assert_eq!(opts.get_filename().to_str().expect("valid path"), "");
242    }
243
244    // Row-to-JSON conversion tests live in crates/sqlx_to_json.
245    // These tests cover the array-level wrapping done by execute_query.
246
247    /// Helper: creates an in-memory `SQLite` pool for unit tests.
248    async fn mem_pool() -> SqlitePool {
249        SqlitePoolOptions::new()
250            .max_connections(1)
251            .connect("sqlite::memory:")
252            .await
253            .expect("in-memory SQLite")
254    }
255
256    /// Helper: runs a query and converts all rows via [`RowExt::to_json`].
257    async fn query_json(pool: &SqlitePool, sql: &str) -> Value {
258        let rows: Vec<SqliteRow> = sqlx::query(sql).fetch_all(pool).await.expect("query failed");
259        Value::Array(rows.iter().map(RowExt::to_json).collect())
260    }
261
262    #[tokio::test]
263    async fn execute_query_empty_result() {
264        let pool = mem_pool().await;
265        sqlx::query("CREATE TABLE t (v INTEGER)").execute(&pool).await.unwrap();
266
267        let rows = query_json(&pool, "SELECT v FROM t").await;
268        assert_eq!(rows, Value::Array(vec![]));
269    }
270
271    #[tokio::test]
272    async fn execute_query_multiple_rows() {
273        let pool = mem_pool().await;
274        sqlx::query("CREATE TABLE t (id INTEGER, name TEXT, score REAL)")
275            .execute(&pool)
276            .await
277            .unwrap();
278        sqlx::query("INSERT INTO t VALUES (1, 'alice', 9.5), (2, 'bob', 8.0)")
279            .execute(&pool)
280            .await
281            .unwrap();
282
283        let rows = query_json(&pool, "SELECT id, name, score FROM t ORDER BY id").await;
284        assert_eq!(rows.as_array().expect("should be array").len(), 2);
285
286        assert_eq!(rows[0]["id"], Value::Number(1.into()));
287        assert_eq!(rows[0]["name"], Value::String("alice".into()));
288        assert!(rows[0]["score"].is_number());
289
290        assert_eq!(rows[1]["id"], Value::Number(2.into()));
291        assert_eq!(rows[1]["name"], Value::String("bob".into()));
292    }
293}