1use std::sync::{Arc, OnceLock};
15
16use anyhow::{Context, Result};
17use rusqlite::Connection;
18use tokio::sync::Mutex;
19use tokio::task::spawn_blocking;
20
21use crate::backends::base::{BackendDialect, ColSpec};
22
23#[derive(Clone)]
24pub struct SQLiteBackend {
25 pub(crate) dsn_env: String,
26}
27
28impl SQLiteBackend {
29 pub fn new(dsn_env: String) -> Self {
30 Self { dsn_env }
31 }
32}
33
34fn register_sqlite_vec_once() {
38 static ONCE: OnceLock<()> = OnceLock::new();
39 ONCE.get_or_init(|| {
40 unsafe {
44 let _ = rusqlite::ffi::sqlite3_auto_extension(Some(std::mem::transmute(
45 sqlite_vec::sqlite3_vec_init as *const (),
46 )));
47 }
48 });
49}
50
51pub type SqliteConn = Arc<Mutex<Connection>>;
56
57impl SQLiteBackend {
58 pub async fn connect(&self) -> Result<SqliteConn> {
63 let dsn_env = self.dsn_env.clone();
64 spawn_blocking(move || -> Result<SqliteConn> {
65 register_sqlite_vec_once();
66 let path = std::env::var(&dsn_env)
67 .with_context(|| format!("DSN env var {dsn_env} not set"))?;
68 let conn = if path == ":memory:" {
69 Connection::open_in_memory().context("open :memory:")?
70 } else {
71 Connection::open(&path).with_context(|| format!("opening {path}"))?
72 };
73 let _ = conn.pragma_update(None, "journal_mode", &"WAL");
75 Ok(Arc::new(Mutex::new(conn)))
76 })
77 .await
78 .context("spawn_blocking connect")?
79 }
80
81 pub async fn table_exists(&self, conn: &SqliteConn, _db: &str, table: &str) -> Result<bool> {
84 let conn = conn.clone();
85 let table = table.to_string();
86 spawn_blocking(move || -> Result<bool> {
87 let g = conn.blocking_lock();
88 let r: Option<i32> = g
89 .query_row(
90 "SELECT 1 FROM sqlite_master WHERE type IN ('table','virtual table') AND name=?",
91 rusqlite::params![table],
92 |row| row.get(0),
93 )
94 .ok();
95 Ok(r.is_some())
96 })
97 .await
98 .context("spawn_blocking table_exists")?
99 }
100
101 pub async fn embedding_dim(
105 &self,
106 conn: &SqliteConn,
107 _db: &str,
108 table: &str,
109 ) -> Result<Option<usize>> {
110 let conn = conn.clone();
111 let vec_table = format!("{table}_vec");
112 spawn_blocking(move || -> Result<Option<usize>> {
113 let g = conn.blocking_lock();
114 let sql: Option<String> = g
115 .query_row(
116 "SELECT sql FROM sqlite_master WHERE type='table' AND name=?",
117 rusqlite::params![vec_table],
118 |row| row.get(0),
119 )
120 .ok();
121 let Some(sql) = sql else { return Ok(None) };
122 let re = regex::Regex::new(r"(?i)FLOAT\[(\d+)\]").unwrap();
123 Ok(re
124 .captures(&sql)
125 .and_then(|c| c.get(1))
126 .and_then(|m| m.as_str().parse().ok()))
127 })
128 .await
129 .context("spawn_blocking embedding_dim")?
130 }
131
132 pub async fn with_create_lock(&self, _conn: &SqliteConn, _key: &str) -> Result<()> {
134 Ok(())
135 }
136}
137
138impl BackendDialect for SQLiteBackend {
139 const NAME: &'static str = "sqlite";
140 const SUPPORTS_UPSERT: bool = true;
141
142 fn quote_ident(&self, name: &str) -> String {
143 format!("\"{}\"", name.replace('"', "\"\""))
144 }
145
146 fn fq_table(&self, _db: &str, table: &str) -> String {
147 self.quote_ident(table)
149 }
150
151 fn vector_type_ddl(&self, dim: usize) -> String {
152 format!("FLOAT[{dim}]")
153 }
154 fn json_type_ddl(&self) -> String {
155 "TEXT".to_string()
156 }
157 fn tags_array_type_ddl(&self) -> String {
158 "TEXT".to_string()
159 }
160 fn text_pk_type_ddl(&self) -> String {
161 "TEXT".to_string()
162 }
163 fn timestamp_now_default_ddl(&self) -> String {
164 "TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP".to_string()
165 }
166
167 fn vector_literal(&self, arr: &[f32]) -> String {
168 let v: Vec<f64> = arr.iter().map(|x| *x as f64).collect();
172 serde_json::to_string(&v).unwrap_or_else(|_| "[]".to_string())
173 }
174
175 fn json_literal(&self, obj: &serde_json::Value) -> String {
176 serde_json::to_string(obj).unwrap_or_else(|_| "null".to_string())
177 }
178
179 fn json_path_sql(&self, col_expr: &str, dotted_path: &str) -> String {
180 format!("json_extract({col_expr},'$.{dotted_path}')")
181 }
182
183 fn upsert_clause(&self, key_cols: &[&str], update_cols: &[&str]) -> String {
184 let keys: Vec<String> = key_cols.iter().map(|c| self.quote_ident(c)).collect();
185 let keys_sql = keys.join(", ");
186 if update_cols.is_empty() {
187 return format!("ON CONFLICT ({keys_sql}) DO NOTHING");
188 }
189 let sets: Vec<String> = update_cols
190 .iter()
191 .map(|c| format!("{q} = excluded.{q}", q = self.quote_ident(c)))
192 .collect();
193 format!("ON CONFLICT ({keys_sql}) DO UPDATE SET {}", sets.join(", "))
194 }
195
196 fn create_database_sql(&self, _name: &str) -> String {
197 "SELECT 1 -- chunkshop: SQLite has no database/schema concept".to_string()
198 }
199
200 fn add_column_if_not_exists_sql(&self, fq: &str, col: &str, type_ddl: &str) -> String {
201 format!(
204 "ALTER TABLE {fq} ADD COLUMN {} {type_ddl}",
205 self.quote_ident(col)
206 )
207 }
208
209 fn drop_table_sql(&self, fq: &str) -> String {
210 format!("DROP TABLE {fq}")
211 }
212
213 fn emit_chunks_table_ddl(
214 &self,
215 fq: &str,
216 cols: &[ColSpec],
217 _hnsw: bool,
218 dim: usize,
219 _engine: Option<&str>,
220 _vector_metric: Option<&str>,
221 ) -> Vec<String> {
222 let main_cols: Vec<&ColSpec> = cols.iter().filter(|c| c.name != "embedding").collect();
225
226 let mut col_lines: Vec<String> = Vec::with_capacity(main_cols.len());
227 let mut pk_cols: Vec<&str> = Vec::new();
228 for c in &main_cols {
229 let mut line = format!(" {} {}", self.quote_ident(c.name), c.type_ddl);
230 if let Some(default) = c.default {
231 line.push_str(&format!(" DEFAULT {default}"));
232 }
233 if !c.nullable {
234 line.push_str(" NOT NULL");
235 }
236 col_lines.push(line);
237 if c.is_primary_key {
238 pk_cols.push(c.name);
239 }
240 }
241 let mut body = col_lines.join(",\n");
242 if !pk_cols.is_empty() {
243 let pk: Vec<String> = pk_cols.iter().map(|c| self.quote_ident(c)).collect();
244 body.push_str(&format!(",\n PRIMARY KEY ({})", pk.join(", ")));
245 }
246 let create_main = format!("CREATE TABLE IF NOT EXISTS {fq} (\n{body}\n)");
247
248 let bare = fq.trim_matches('"').to_string();
250
251 let create_idx = format!(
252 "CREATE INDEX IF NOT EXISTS {} ON {fq} (\"doc_id\", \"seq_num\")",
253 self.quote_ident(&format!("{bare}_doc_seq_idx"))
254 );
255
256 let vec_fq = self.quote_ident(&format!("{bare}_vec"));
257 let create_vec = format!(
258 "CREATE VIRTUAL TABLE IF NOT EXISTS {vec_fq} USING vec0(\
259 id TEXT PRIMARY KEY, embedding FLOAT[{dim}])"
260 );
261
262 vec![create_main, create_idx, create_vec]
263 }
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269 use crate::backends::base::ColSpec;
270
271 fn backend() -> SQLiteBackend {
272 SQLiteBackend::new("UNUSED".to_string())
273 }
274
275 #[test]
276 fn quote_ident_wraps_in_double_quotes() {
277 assert_eq!(backend().quote_ident("my_table"), "\"my_table\"");
278 }
279
280 #[test]
281 fn quote_ident_doubles_embedded_quote() {
282 assert_eq!(backend().quote_ident("a\"b"), "\"a\"\"b\"");
283 }
284
285 #[test]
286 fn fq_table_returns_table_only_no_schema() {
287 assert_eq!(backend().fq_table("ignored", "my_table"), "\"my_table\"");
289 }
290
291 #[test]
292 fn vector_type_ddl_uses_float_brackets() {
293 assert_eq!(backend().vector_type_ddl(384), "FLOAT[384]");
294 }
295
296 #[test]
297 fn json_type_is_text() {
298 assert_eq!(backend().json_type_ddl(), "TEXT");
299 }
300
301 #[test]
302 fn tags_array_type_is_text() {
303 assert_eq!(backend().tags_array_type_ddl(), "TEXT");
304 }
305
306 #[test]
307 fn text_pk_type_is_text() {
308 assert_eq!(backend().text_pk_type_ddl(), "TEXT");
309 }
310
311 #[test]
312 fn timestamp_default_is_current_timestamp() {
313 assert_eq!(
314 backend().timestamp_now_default_ddl(),
315 "TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP"
316 );
317 }
318
319 #[test]
320 fn vector_literal_matches_python_json_array() {
321 let v = vec![0.1_f32, 0.2_f32, -0.3_f32];
322 let lit = backend().vector_literal(&v);
323 let parsed: serde_json::Value = serde_json::from_str(&lit).unwrap();
324 let arr = parsed.as_array().unwrap();
325 assert_eq!(arr.len(), 3);
326 assert!((arr[0].as_f64().unwrap() - 0.1).abs() < 1e-6);
327 assert!((arr[2].as_f64().unwrap() - (-0.3)).abs() < 1e-6);
328 }
329
330 #[test]
331 fn json_path_sql_uses_json_extract_with_dollar_dot() {
332 assert_eq!(
333 backend().json_path_sql("metadata", "a.b.c"),
334 "json_extract(metadata,'$.a.b.c')"
335 );
336 }
337
338 #[test]
339 fn upsert_clause_do_nothing_when_no_updates() {
340 assert_eq!(
341 backend().upsert_clause(&["id"], &[]),
342 "ON CONFLICT (\"id\") DO NOTHING"
343 );
344 }
345
346 #[test]
347 fn upsert_clause_excluded_form() {
348 assert_eq!(
349 backend().upsert_clause(&["id"], &["content", "metadata"]),
350 "ON CONFLICT (\"id\") DO UPDATE SET \"content\" = excluded.\"content\", \
351 \"metadata\" = excluded.\"metadata\""
352 );
353 }
354
355 #[test]
356 fn create_database_sql_is_noop_select() {
357 assert_eq!(
358 backend().create_database_sql("ignored"),
359 "SELECT 1 -- chunkshop: SQLite has no database/schema concept"
360 );
361 }
362
363 #[test]
364 fn add_column_lacks_if_not_exists() {
365 assert_eq!(
366 backend().add_column_if_not_exists_sql("\"chunks\"", "source", "TEXT"),
367 "ALTER TABLE \"chunks\" ADD COLUMN \"source\" TEXT"
368 );
369 }
370
371 fn canonical_cols(dim: usize) -> Vec<ColSpec> {
372 vec![
373 ColSpec {
374 name: "id",
375 type_ddl: "TEXT".into(),
376 nullable: false,
377 default: None,
378 is_primary_key: true,
379 },
380 ColSpec {
381 name: "doc_id",
382 type_ddl: "TEXT".into(),
383 nullable: false,
384 default: None,
385 is_primary_key: false,
386 },
387 ColSpec {
388 name: "seq_num",
389 type_ddl: "INTEGER".into(),
390 nullable: false,
391 default: None,
392 is_primary_key: false,
393 },
394 ColSpec {
395 name: "embedding",
396 type_ddl: format!("FLOAT[{dim}]"),
397 nullable: false,
398 default: None,
399 is_primary_key: false,
400 },
401 ]
402 }
403
404 #[test]
405 fn emit_chunks_table_ddl_returns_three_statements() {
406 let stmts = backend().emit_chunks_table_ddl(
407 "\"chunks\"",
408 &canonical_cols(384),
409 false,
410 384,
411 None,
412 None,
413 );
414 assert_eq!(stmts.len(), 3, "main table + index + vec0 virtual table");
415 assert!(stmts[0].starts_with("CREATE TABLE IF NOT EXISTS \"chunks\""));
416 assert!(stmts[0].contains("\"id\" TEXT NOT NULL"));
417 assert!(stmts[0].contains("PRIMARY KEY (\"id\")"));
418 assert!(!stmts[0].contains("\"embedding\" FLOAT"));
420 assert!(stmts[1].contains("CREATE INDEX IF NOT EXISTS \"chunks_doc_seq_idx\""));
421 assert!(stmts[2].starts_with("CREATE VIRTUAL TABLE IF NOT EXISTS \"chunks_vec\""));
422 assert!(stmts[2].contains("USING vec0("));
423 assert!(stmts[2].contains("FLOAT[384]"));
424 }
425
426 #[test]
427 fn emit_chunks_table_ddl_hnsw_does_not_change_output() {
428 let no = backend().emit_chunks_table_ddl("\"c\"", &canonical_cols(8), false, 8, None, None);
430 let yes = backend().emit_chunks_table_ddl("\"c\"", &canonical_cols(8), true, 8, None, None);
431 assert_eq!(no, yes);
432 }
433}