use std::path::Path;
use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use rusqlite::types::{Value as SqlValue, ValueRef};
use rusqlite::Connection;
use serde_json::{Map, Value};
use crate::sql::{translate_placeholders, ConnectorError, Dialect, SqlConnector, TranslatedSql};
pub struct SqliteConnector {
conn: Arc<Mutex<Connection>>,
}
impl SqliteConnector {
pub fn open(path: &Path) -> Result<Self, ConnectorError> {
let conn = Connection::open(path).map_err(|e| ConnectorError::Connection(e.to_string()))?;
Ok(Self {
conn: Arc::new(Mutex::new(conn)),
})
}
pub fn open_in_memory() -> Result<Self, ConnectorError> {
let conn =
Connection::open_in_memory().map_err(|e| ConnectorError::Connection(e.to_string()))?;
Ok(Self {
conn: Arc::new(Mutex::new(conn)),
})
}
pub async fn execute_batch(&self, sql: &str) -> Result<(), ConnectorError> {
let conn = Arc::clone(&self.conn);
let sql = sql.to_string();
tokio::task::spawn_blocking(move || -> Result<(), ConnectorError> {
let guard = conn
.lock()
.map_err(|_| ConnectorError::Driver("mutex poisoned".into()))?;
guard
.execute_batch(&sql)
.map_err(|e| ConnectorError::Query(e.to_string()))?;
Ok(())
})
.await
.map_err(|e| ConnectorError::Driver(format!("join error: {e}")))?
}
}
fn json_to_sql(v: &Value) -> SqlValue {
match v {
Value::Null => SqlValue::Null,
Value::Bool(b) => SqlValue::Integer(i64::from(*b)),
Value::Number(n) => {
if let Some(i) = n.as_i64() {
SqlValue::Integer(i)
} else if let Some(f) = n.as_f64() {
SqlValue::Real(f)
} else {
SqlValue::Null
}
},
Value::String(s) => SqlValue::Text(s.clone()),
_ => SqlValue::Text(v.to_string()),
}
}
fn sql_to_json(v: ValueRef<'_>) -> Value {
match v {
ValueRef::Null => Value::Null,
ValueRef::Integer(i) => Value::Number(i.into()),
ValueRef::Real(f) => serde_json::Number::from_f64(f)
.map(Value::Number)
.unwrap_or(Value::Null),
ValueRef::Text(t) => Value::String(String::from_utf8_lossy(t).into_owned()),
ValueRef::Blob(_) => Value::String("<blob>".into()),
}
}
fn bind_params(
stmt: &mut rusqlite::Statement<'_>,
ordered_params: &[String],
named_params: &[(String, Value)],
) -> Result<(), ConnectorError> {
for name in ordered_params {
let Some((_, val)) = named_params.iter().find(|(n, _)| n == name) else {
continue;
};
let bind_name = format!(":{name}");
let idx = stmt
.parameter_index(&bind_name)
.map_err(|e| ConnectorError::ParameterBind {
name: name.clone(),
reason: e.to_string(),
})?;
if let Some(idx) = idx {
stmt.raw_bind_parameter(idx, json_to_sql(val))
.map_err(|e| ConnectorError::ParameterBind {
name: name.clone(),
reason: e.to_string(),
})?;
}
}
Ok(())
}
fn collect_rows(stmt: &mut rusqlite::Statement<'_>) -> Result<Vec<Value>, ConnectorError> {
let cols: Vec<String> = stmt
.column_names()
.iter()
.map(|c| (*c).to_string())
.collect();
let mut rows = stmt.raw_query();
let mut out = Vec::new();
while let Some(row) = rows
.next()
.map_err(|e| ConnectorError::Query(e.to_string()))?
{
let mut obj = Map::new();
for (i, col) in cols.iter().enumerate() {
let vr = row
.get_ref(i)
.map_err(|e| ConnectorError::Query(e.to_string()))?;
obj.insert(col.clone(), sql_to_json(vr));
}
out.push(Value::Object(obj));
}
Ok(out)
}
#[async_trait]
impl SqlConnector for SqliteConnector {
fn dialect(&self) -> Dialect {
Dialect::Sqlite
}
async fn execute(
&self,
sql: &str,
params: &[(String, Value)],
) -> Result<Vec<Value>, ConnectorError> {
let conn = Arc::clone(&self.conn);
let sql = sql.to_string();
let params = params.to_vec();
tokio::task::spawn_blocking(move || -> Result<Vec<Value>, ConnectorError> {
let TranslatedSql {
sql: translated,
ordered_params,
} = translate_placeholders(&sql, Dialect::Sqlite);
let guard = conn
.lock()
.map_err(|_| ConnectorError::Driver("mutex poisoned".into()))?;
let mut stmt = guard
.prepare(&translated)
.map_err(|e| ConnectorError::Query(e.to_string()))?;
bind_params(&mut stmt, &ordered_params, ¶ms)?;
collect_rows(&mut stmt)
})
.await
.map_err(|e| ConnectorError::Driver(format!("join error: {e}")))?
}
async fn schema_text(&self) -> Result<String, ConnectorError> {
let conn = Arc::clone(&self.conn);
tokio::task::spawn_blocking(move || -> Result<String, ConnectorError> {
let guard = conn
.lock()
.map_err(|_| ConnectorError::Driver("mutex poisoned".into()))?;
let mut stmt = guard
.prepare(
"SELECT name, sql FROM sqlite_master \
WHERE type IN ('table', 'view') AND sql IS NOT NULL \
ORDER BY name",
)
.map_err(|e| ConnectorError::Schema(e.to_string()))?;
let mut rows = stmt
.query([])
.map_err(|e| ConnectorError::Schema(e.to_string()))?;
let mut out = String::new();
while let Some(row) = rows
.next()
.map_err(|e| ConnectorError::Schema(e.to_string()))?
{
let ddl: String = row
.get(1)
.map_err(|e| ConnectorError::Schema(e.to_string()))?;
out.push_str(&ddl);
out.push_str(";\n");
}
Ok(out)
})
.await
.map_err(|e| ConnectorError::Driver(format!("join error: {e}")))?
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[tokio::test]
async fn test_open_in_memory_succeeds() {
let conn = SqliteConnector::open_in_memory();
assert!(conn.is_ok(), "open_in_memory must succeed");
}
#[tokio::test]
async fn test_dialect_returns_sqlite() {
let conn = SqliteConnector::open_in_memory().unwrap();
assert_eq!(conn.dialect(), Dialect::Sqlite);
}
#[tokio::test]
async fn test_execute_no_params() {
let conn = SqliteConnector::open_in_memory().unwrap();
let rows = conn.execute("SELECT 1 AS x", &[]).await.unwrap();
assert_eq!(rows, vec![json!({ "x": 1 })]);
}
#[tokio::test]
async fn test_execute_with_named_param() {
let conn = SqliteConnector::open_in_memory().unwrap();
let rows = conn
.execute("SELECT :v AS x", &[("v".into(), json!(42))])
.await
.unwrap();
assert_eq!(rows, vec![json!({ "x": 42 })]);
}
#[tokio::test]
async fn test_schema_text_returns_ddl() {
let conn = SqliteConnector::open_in_memory().unwrap();
conn.execute(
"CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)",
&[],
)
.await
.unwrap();
let schema = conn.schema_text().await.unwrap();
assert!(
schema.contains("CREATE TABLE users"),
"schema_text must echo sqlite_master DDL verbatim; got: {schema:?}"
);
}
#[tokio::test]
async fn test_execute_after_insert_returns_rows() {
let conn = SqliteConnector::open_in_memory().unwrap();
conn.execute("CREATE TABLE users (id INTEGER, name TEXT)", &[])
.await
.unwrap();
conn.execute("INSERT INTO users VALUES (1, 'Ada')", &[])
.await
.unwrap();
let rows = conn.execute("SELECT name FROM users", &[]).await.unwrap();
assert_eq!(rows, vec![json!({ "name": "Ada" })]);
}
#[tokio::test]
async fn test_execute_batch_seeds_multiple_tables() {
let conn = SqliteConnector::open_in_memory().unwrap();
conn.execute_batch(
"CREATE TABLE artists (id INTEGER, name TEXT);
CREATE TABLE albums (id INTEGER, title TEXT);
INSERT INTO artists VALUES (1, 'AC-DC');
INSERT INTO albums VALUES (1, 'For Those About To Rock');
INSERT INTO albums VALUES (2, 'Let There Be Rock');",
)
.await
.unwrap();
let artists = conn.execute("SELECT name FROM artists", &[]).await.unwrap();
assert_eq!(artists, vec![json!({ "name": "AC-DC" })]);
let albums = conn
.execute("SELECT COUNT(*) AS c FROM albums", &[])
.await
.unwrap();
assert_eq!(albums, vec![json!({ "c": 2 })]);
}
#[tokio::test]
async fn test_execute_batch_invalid_statement_returns_query_error() {
let conn = SqliteConnector::open_in_memory().unwrap();
let err = conn
.execute_batch("CREATE TABLE ok (id INTEGER); NOT VALID SQL;")
.await
.expect_err("a syntactically-invalid batch statement must return Err, not panic");
assert!(
matches!(err, ConnectorError::Query(_)),
"expected ConnectorError::Query, got: {err:?}"
);
}
#[tokio::test]
async fn test_execute_batch_idempotent_second_run_leaves_seeded_rows() {
let conn = SqliteConnector::open_in_memory().unwrap();
let bootstrap = "CREATE TABLE IF NOT EXISTS t (id INTEGER PRIMARY KEY);
INSERT OR IGNORE INTO t VALUES (1);
INSERT OR IGNORE INTO t VALUES (2);";
conn.execute_batch(bootstrap)
.await
.expect("first bootstrap run succeeds");
conn.execute_batch(bootstrap)
.await
.expect("second bootstrap run against persisted DB succeeds (idempotent)");
let rows = conn
.execute("SELECT COUNT(*) AS c FROM t", &[])
.await
.unwrap();
assert_eq!(
rows,
vec![json!({ "c": 2 })],
"idempotent batch must leave exactly the seeded rows after a second run"
);
}
#[tokio::test]
async fn test_concurrent_executes_serialize_via_mutex() {
let conn = Arc::new(SqliteConnector::open_in_memory().unwrap());
let a = {
let conn = Arc::clone(&conn);
tokio::spawn(async move { conn.execute("SELECT 1 AS x", &[]).await })
};
let b = {
let conn = Arc::clone(&conn);
tokio::spawn(async move { conn.execute("SELECT 2 AS x", &[]).await })
};
let (ra, rb) = tokio::join!(a, b);
assert_eq!(ra.unwrap().unwrap(), vec![json!({ "x": 1 })]);
assert_eq!(rb.unwrap().unwrap(), vec![json!({ "x": 2 })]);
}
}