1use std::path::Path;
44use std::sync::{Arc, Mutex};
45
46use async_trait::async_trait;
47use rusqlite::types::{Value as SqlValue, ValueRef};
48use rusqlite::Connection;
49use serde_json::{Map, Value};
50
51use crate::sql::{translate_placeholders, ConnectorError, Dialect, SqlConnector, TranslatedSql};
52
53pub struct SqliteConnector {
60 conn: Arc<Mutex<Connection>>,
61}
62
63impl SqliteConnector {
64 pub fn open(path: &Path) -> Result<Self, ConnectorError> {
73 let conn = Connection::open(path).map_err(|e| ConnectorError::Connection(e.to_string()))?;
74 Ok(Self {
75 conn: Arc::new(Mutex::new(conn)),
76 })
77 }
78
79 pub fn open_in_memory() -> Result<Self, ConnectorError> {
89 let conn =
90 Connection::open_in_memory().map_err(|e| ConnectorError::Connection(e.to_string()))?;
91 Ok(Self {
92 conn: Arc::new(Mutex::new(conn)),
93 })
94 }
95
96 pub async fn execute_batch(&self, sql: &str) -> Result<(), ConnectorError> {
140 let conn = Arc::clone(&self.conn);
141 let sql = sql.to_string();
142 tokio::task::spawn_blocking(move || -> Result<(), ConnectorError> {
143 let guard = conn
144 .lock()
145 .map_err(|_| ConnectorError::Driver("mutex poisoned".into()))?;
146 guard
147 .execute_batch(&sql)
148 .map_err(|e| ConnectorError::Query(e.to_string()))?;
149 Ok(())
150 })
151 .await
152 .map_err(|e| ConnectorError::Driver(format!("join error: {e}")))?
153 }
154}
155
156fn json_to_sql(v: &Value) -> SqlValue {
162 match v {
163 Value::Null => SqlValue::Null,
164 Value::Bool(b) => SqlValue::Integer(i64::from(*b)),
165 Value::Number(n) => {
166 if let Some(i) = n.as_i64() {
167 SqlValue::Integer(i)
168 } else if let Some(f) = n.as_f64() {
169 SqlValue::Real(f)
170 } else {
171 SqlValue::Null
172 }
173 },
174 Value::String(s) => SqlValue::Text(s.clone()),
175 _ => SqlValue::Text(v.to_string()),
176 }
177}
178
179fn sql_to_json(v: ValueRef<'_>) -> Value {
185 match v {
186 ValueRef::Null => Value::Null,
187 ValueRef::Integer(i) => Value::Number(i.into()),
188 ValueRef::Real(f) => serde_json::Number::from_f64(f)
189 .map(Value::Number)
190 .unwrap_or(Value::Null),
191 ValueRef::Text(t) => Value::String(String::from_utf8_lossy(t).into_owned()),
192 ValueRef::Blob(_) => Value::String("<blob>".into()),
193 }
194}
195
196fn bind_params(
203 stmt: &mut rusqlite::Statement<'_>,
204 ordered_params: &[String],
205 named_params: &[(String, Value)],
206) -> Result<(), ConnectorError> {
207 for name in ordered_params {
208 let Some((_, val)) = named_params.iter().find(|(n, _)| n == name) else {
209 continue;
210 };
211 let bind_name = format!(":{name}");
212 let idx = stmt
213 .parameter_index(&bind_name)
214 .map_err(|e| ConnectorError::ParameterBind {
215 name: name.clone(),
216 reason: e.to_string(),
217 })?;
218 if let Some(idx) = idx {
219 stmt.raw_bind_parameter(idx, json_to_sql(val))
220 .map_err(|e| ConnectorError::ParameterBind {
221 name: name.clone(),
222 reason: e.to_string(),
223 })?;
224 }
225 }
226 Ok(())
227}
228
229fn collect_rows(stmt: &mut rusqlite::Statement<'_>) -> Result<Vec<Value>, ConnectorError> {
232 let cols: Vec<String> = stmt
233 .column_names()
234 .iter()
235 .map(|c| (*c).to_string())
236 .collect();
237 let mut rows = stmt.raw_query();
238 let mut out = Vec::new();
239 while let Some(row) = rows
240 .next()
241 .map_err(|e| ConnectorError::Query(e.to_string()))?
242 {
243 let mut obj = Map::new();
244 for (i, col) in cols.iter().enumerate() {
245 let vr = row
246 .get_ref(i)
247 .map_err(|e| ConnectorError::Query(e.to_string()))?;
248 obj.insert(col.clone(), sql_to_json(vr));
249 }
250 out.push(Value::Object(obj));
251 }
252 Ok(out)
253}
254
255#[async_trait]
256impl SqlConnector for SqliteConnector {
257 fn dialect(&self) -> Dialect {
258 Dialect::Sqlite
259 }
260
261 async fn execute(
262 &self,
263 sql: &str,
264 params: &[(String, Value)],
265 ) -> Result<Vec<Value>, ConnectorError> {
266 let conn = Arc::clone(&self.conn);
267 let sql = sql.to_string();
268 let params = params.to_vec();
269 tokio::task::spawn_blocking(move || -> Result<Vec<Value>, ConnectorError> {
270 let TranslatedSql {
271 sql: translated,
272 ordered_params,
273 } = translate_placeholders(&sql, Dialect::Sqlite);
274 let guard = conn
277 .lock()
278 .map_err(|_| ConnectorError::Driver("mutex poisoned".into()))?;
279 let mut stmt = guard
280 .prepare(&translated)
281 .map_err(|e| ConnectorError::Query(e.to_string()))?;
282 bind_params(&mut stmt, &ordered_params, ¶ms)?;
283 collect_rows(&mut stmt)
284 })
285 .await
286 .map_err(|e| ConnectorError::Driver(format!("join error: {e}")))?
287 }
288
289 async fn schema_text(&self) -> Result<String, ConnectorError> {
290 let conn = Arc::clone(&self.conn);
291 tokio::task::spawn_blocking(move || -> Result<String, ConnectorError> {
292 let guard = conn
293 .lock()
294 .map_err(|_| ConnectorError::Driver("mutex poisoned".into()))?;
295 let mut stmt = guard
296 .prepare(
297 "SELECT name, sql FROM sqlite_master \
298 WHERE type IN ('table', 'view') AND sql IS NOT NULL \
299 ORDER BY name",
300 )
301 .map_err(|e| ConnectorError::Schema(e.to_string()))?;
302 let mut rows = stmt
303 .query([])
304 .map_err(|e| ConnectorError::Schema(e.to_string()))?;
305 let mut out = String::new();
306 while let Some(row) = rows
307 .next()
308 .map_err(|e| ConnectorError::Schema(e.to_string()))?
309 {
310 let ddl: String = row
311 .get(1)
312 .map_err(|e| ConnectorError::Schema(e.to_string()))?;
313 out.push_str(&ddl);
314 out.push_str(";\n");
315 }
316 Ok(out)
317 })
318 .await
319 .map_err(|e| ConnectorError::Driver(format!("join error: {e}")))?
320 }
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326 use serde_json::json;
327
328 #[tokio::test]
329 async fn test_open_in_memory_succeeds() {
330 let conn = SqliteConnector::open_in_memory();
331 assert!(conn.is_ok(), "open_in_memory must succeed");
332 }
333
334 #[tokio::test]
335 async fn test_dialect_returns_sqlite() {
336 let conn = SqliteConnector::open_in_memory().unwrap();
337 assert_eq!(conn.dialect(), Dialect::Sqlite);
338 }
339
340 #[tokio::test]
341 async fn test_execute_no_params() {
342 let conn = SqliteConnector::open_in_memory().unwrap();
343 let rows = conn.execute("SELECT 1 AS x", &[]).await.unwrap();
344 assert_eq!(rows, vec![json!({ "x": 1 })]);
345 }
346
347 #[tokio::test]
348 async fn test_execute_with_named_param() {
349 let conn = SqliteConnector::open_in_memory().unwrap();
350 let rows = conn
351 .execute("SELECT :v AS x", &[("v".into(), json!(42))])
352 .await
353 .unwrap();
354 assert_eq!(rows, vec![json!({ "x": 42 })]);
355 }
356
357 #[tokio::test]
358 async fn test_schema_text_returns_ddl() {
359 let conn = SqliteConnector::open_in_memory().unwrap();
360 conn.execute(
361 "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)",
362 &[],
363 )
364 .await
365 .unwrap();
366 let schema = conn.schema_text().await.unwrap();
367 assert!(
368 schema.contains("CREATE TABLE users"),
369 "schema_text must echo sqlite_master DDL verbatim; got: {schema:?}"
370 );
371 }
372
373 #[tokio::test]
374 async fn test_execute_after_insert_returns_rows() {
375 let conn = SqliteConnector::open_in_memory().unwrap();
376 conn.execute("CREATE TABLE users (id INTEGER, name TEXT)", &[])
377 .await
378 .unwrap();
379 conn.execute("INSERT INTO users VALUES (1, 'Ada')", &[])
380 .await
381 .unwrap();
382 let rows = conn.execute("SELECT name FROM users", &[]).await.unwrap();
383 assert_eq!(rows, vec![json!({ "name": "Ada" })]);
384 }
385
386 #[tokio::test]
387 async fn test_execute_batch_seeds_multiple_tables() {
388 let conn = SqliteConnector::open_in_memory().unwrap();
389 conn.execute_batch(
390 "CREATE TABLE artists (id INTEGER, name TEXT);
391 CREATE TABLE albums (id INTEGER, title TEXT);
392 INSERT INTO artists VALUES (1, 'AC-DC');
393 INSERT INTO albums VALUES (1, 'For Those About To Rock');
394 INSERT INTO albums VALUES (2, 'Let There Be Rock');",
395 )
396 .await
397 .unwrap();
398
399 let artists = conn.execute("SELECT name FROM artists", &[]).await.unwrap();
400 assert_eq!(artists, vec![json!({ "name": "AC-DC" })]);
401
402 let albums = conn
403 .execute("SELECT COUNT(*) AS c FROM albums", &[])
404 .await
405 .unwrap();
406 assert_eq!(albums, vec![json!({ "c": 2 })]);
407 }
408
409 #[tokio::test]
410 async fn test_execute_batch_invalid_statement_returns_query_error() {
411 let conn = SqliteConnector::open_in_memory().unwrap();
412 let err = conn
413 .execute_batch("CREATE TABLE ok (id INTEGER); NOT VALID SQL;")
414 .await
415 .expect_err("a syntactically-invalid batch statement must return Err, not panic");
416 assert!(
417 matches!(err, ConnectorError::Query(_)),
418 "expected ConnectorError::Query, got: {err:?}"
419 );
420 }
421
422 #[tokio::test]
423 async fn test_execute_batch_idempotent_second_run_leaves_seeded_rows() {
424 let conn = SqliteConnector::open_in_memory().unwrap();
425 let bootstrap = "CREATE TABLE IF NOT EXISTS t (id INTEGER PRIMARY KEY);
426 INSERT OR IGNORE INTO t VALUES (1);
427 INSERT OR IGNORE INTO t VALUES (2);";
428
429 conn.execute_batch(bootstrap)
430 .await
431 .expect("first bootstrap run succeeds");
432 conn.execute_batch(bootstrap)
433 .await
434 .expect("second bootstrap run against persisted DB succeeds (idempotent)");
435
436 let rows = conn
437 .execute("SELECT COUNT(*) AS c FROM t", &[])
438 .await
439 .unwrap();
440 assert_eq!(
441 rows,
442 vec![json!({ "c": 2 })],
443 "idempotent batch must leave exactly the seeded rows after a second run"
444 );
445 }
446
447 #[tokio::test]
448 async fn test_concurrent_executes_serialize_via_mutex() {
449 let conn = Arc::new(SqliteConnector::open_in_memory().unwrap());
450 let a = {
451 let conn = Arc::clone(&conn);
452 tokio::spawn(async move { conn.execute("SELECT 1 AS x", &[]).await })
453 };
454 let b = {
455 let conn = Arc::clone(&conn);
456 tokio::spawn(async move { conn.execute("SELECT 2 AS x", &[]).await })
457 };
458 let (ra, rb) = tokio::join!(a, b);
459 assert_eq!(ra.unwrap().unwrap(), vec![json!({ "x": 1 })]);
460 assert_eq!(rb.unwrap().unwrap(), vec![json!({ "x": 2 })]);
461 }
462}