1use rig::OneOrMany;
2use rig::embeddings::{Embedding, EmbeddingModel};
3use rig::vector_store::request::{FilterError, SearchFilter, VectorSearchRequest};
4use rig::vector_store::{VectorStoreError, VectorStoreIndex};
5use rusqlite::types::Value;
6use serde::Deserialize;
7use std::marker::PhantomData;
8use std::ops::RangeInclusive;
9use tokio_rusqlite::Connection;
10use tracing::{debug, info};
11use zerocopy::IntoBytes;
12
13#[derive(Debug)]
14pub enum SqliteError {
15 DatabaseError(Box<dyn std::error::Error + Send + Sync>),
16 SerializationError(Box<dyn std::error::Error + Send + Sync>),
17 InvalidColumnType(String),
18}
19
20pub trait ColumnValue: Send + Sync {
21 fn to_sql_string(&self) -> String;
22 fn column_type(&self) -> &'static str;
23}
24
25pub struct Column {
26 name: &'static str,
27 col_type: &'static str,
28 indexed: bool,
29}
30
31impl Column {
32 pub fn new(name: &'static str, col_type: &'static str) -> Self {
33 Self {
34 name,
35 col_type,
36 indexed: false,
37 }
38 }
39
40 pub fn indexed(mut self) -> Self {
41 self.indexed = true;
42 self
43 }
44}
45
46pub trait SqliteVectorStoreTable: Send + Sync + Clone {
84 fn name() -> &'static str;
85 fn schema() -> Vec<Column>;
86 fn id(&self) -> String;
87 fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)>;
88}
89
90#[derive(Clone)]
91pub struct SqliteVectorStore<E, T>
92where
93 E: EmbeddingModel + 'static,
94 T: SqliteVectorStoreTable + 'static,
95{
96 conn: Connection,
97 _phantom: PhantomData<(E, T)>,
98}
99
100impl<E, T> SqliteVectorStore<E, T>
101where
102 E: EmbeddingModel + 'static,
103 T: SqliteVectorStoreTable + 'static,
104{
105 pub async fn new(conn: Connection, embedding_model: &E) -> Result<Self, VectorStoreError> {
106 let dims = embedding_model.ndims();
107 let table_name = T::name();
108 let schema = T::schema();
109
110 let mut create_table = format!("CREATE TABLE IF NOT EXISTS {table_name} (");
112
113 let mut first = true;
115 for column in &schema {
116 if !first {
117 create_table.push(',');
118 }
119 create_table.push_str(&format!("\n {} {}", column.name, column.col_type));
120 first = false;
121 }
122
123 create_table.push_str("\n)");
124
125 let mut create_indexes = vec![format!(
127 "CREATE INDEX IF NOT EXISTS idx_{}_id ON {}(id)",
128 table_name, table_name
129 )];
130
131 for column in schema {
133 if column.indexed {
134 create_indexes.push(format!(
135 "CREATE INDEX IF NOT EXISTS idx_{}_{} ON {}({})",
136 table_name, column.name, table_name, column.name
137 ));
138 }
139 }
140
141 conn.call(move |conn| {
142 conn.execute_batch("BEGIN")?;
143
144 conn.execute_batch(&create_table)?;
146
147 for index_stmt in create_indexes {
149 conn.execute_batch(&index_stmt)?;
150 }
151
152 conn.execute_batch(&format!(
154 "CREATE VIRTUAL TABLE IF NOT EXISTS {table_name}_embeddings USING vec0(embedding float[{dims}])"
155 ))?;
156
157 conn.execute_batch("COMMIT")?;
158 Ok(())
159 })
160 .await
161 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
162
163 Ok(Self {
164 conn,
165 _phantom: PhantomData,
166 })
167 }
168
169 pub fn index(self, model: E) -> SqliteVectorIndex<E, T> {
170 SqliteVectorIndex::new(model, self)
171 }
172
173 pub fn add_rows_with_txn(
174 &self,
175 txn: &rusqlite::Transaction<'_>,
176 documents: Vec<(T, OneOrMany<Embedding>)>,
177 ) -> Result<i64, tokio_rusqlite::Error> {
178 info!("Adding {} documents to store", documents.len());
179 let table_name = T::name();
180 let mut last_id = 0;
181
182 for (doc, embeddings) in &documents {
183 debug!("Storing document with id {}", doc.id());
184
185 let values = doc.column_values();
186 let columns = values.iter().map(|(col, _)| *col).collect::<Vec<_>>();
187
188 let placeholders = (1..=values.len())
189 .map(|i| format!("?{i}"))
190 .collect::<Vec<_>>();
191
192 let insert_sql = format!(
193 "INSERT OR REPLACE INTO {} ({}) VALUES ({})",
194 table_name,
195 columns.join(", "),
196 placeholders.join(", ")
197 );
198
199 txn.execute(
200 &insert_sql,
201 rusqlite::params_from_iter(values.iter().map(|(_, val)| val.to_sql_string())),
202 )?;
203 last_id = txn.last_insert_rowid();
204
205 let embeddings_sql =
206 format!("INSERT INTO {table_name}_embeddings (rowid, embedding) VALUES (?1, ?2)");
207
208 let mut stmt = txn.prepare(&embeddings_sql)?;
209 for (i, embedding) in embeddings.iter().enumerate() {
210 let vec = serialize_embedding(embedding);
211 debug!(
212 "Storing embedding {} of {} (size: {} bytes)",
213 i + 1,
214 embeddings.len(),
215 vec.len() * 4
216 );
217 let blob = rusqlite::types::Value::Blob(vec.as_bytes().to_vec());
218 stmt.execute(rusqlite::params![last_id, blob])?;
219 }
220 }
221
222 Ok(last_id)
223 }
224
225 pub async fn add_rows(
226 &self,
227 documents: Vec<(T, OneOrMany<Embedding>)>,
228 ) -> Result<i64, VectorStoreError> {
229 let documents = documents.clone();
230 let this = self.clone();
231
232 self.conn
233 .call(move |conn| {
234 let tx = conn.transaction().map_err(tokio_rusqlite::Error::from)?;
235 let result = this.add_rows_with_txn(&tx, documents)?;
236 tx.commit().map_err(tokio_rusqlite::Error::from)?;
237 Ok(result)
238 })
239 .await
240 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))
241 }
242}
243
244#[derive(Clone, Default)]
245pub struct SqliteSearchFilter {
246 condition: String,
247 params: Vec<serde_json::Value>,
248}
249
250impl SearchFilter for SqliteSearchFilter {
251 type Value = serde_json::Value;
252
253 fn eq(key: String, value: Self::Value) -> Self {
254 Self {
255 condition: format!("{key} = ?"),
256 params: vec![value],
257 }
258 }
259
260 fn gt(key: String, value: Self::Value) -> Self {
261 Self {
262 condition: format!("{key} > ?"),
263 params: vec![value],
264 }
265 }
266
267 fn lt(key: String, value: Self::Value) -> Self {
268 Self {
269 condition: format!("{key} < ?"),
270 params: vec![value],
271 }
272 }
273
274 fn and(self, rhs: Self) -> Self {
275 Self {
276 condition: format!("({}) AND ({})", self.condition, rhs.condition),
277 params: self.params.into_iter().chain(rhs.params).collect(),
278 }
279 }
280
281 fn or(self, rhs: Self) -> Self {
282 Self {
283 condition: format!("({}) OR ({})", self.condition, rhs.condition),
284 params: self.params.into_iter().chain(rhs.params).collect(),
285 }
286 }
287}
288
289impl SqliteSearchFilter {
290 #[allow(clippy::should_implement_trait)]
291 pub fn not(self) -> Self {
292 Self {
293 condition: format!("NOT ({})", self.condition),
294 ..self
295 }
296 }
297
298 pub fn between<N>(key: String, range: RangeInclusive<N>) -> Self
300 where
301 N: Ord + rusqlite::ToSql + std::fmt::Display,
302 {
303 let lo = range.start();
304 let hi = range.end();
305
306 Self {
307 condition: format!("{key} between {lo} and {hi}"),
308 ..Default::default()
309 }
310 }
311
312 pub fn is_null(key: String) -> Self {
314 Self {
315 condition: format!("{key} is null"),
316 ..Default::default()
317 }
318 }
319
320 pub fn is_not_null(key: String) -> Self {
321 Self {
322 condition: format!("{key} is not null"),
323 ..Default::default()
324 }
325 }
326
327 pub fn glob<'a, S>(key: String, pattern: S) -> Self
331 where
332 S: AsRef<&'a str>,
333 {
334 Self {
335 condition: format!("{key} glob {}", pattern.as_ref()),
336 ..Default::default()
337 }
338 }
339
340 pub fn like<'a, S>(key: String, pattern: S) -> Self
343 where
344 S: AsRef<&'a str>,
345 {
346 Self {
347 condition: format!("{key} like {}", pattern.as_ref()),
348 ..Default::default()
349 }
350 }
351}
352
353impl SqliteSearchFilter {
354 fn compile_params(self) -> Result<Vec<Value>, FilterError> {
355 let mut params = Vec::with_capacity(self.params.len());
356
357 fn convert(value: serde_json::Value) -> Result<Value, FilterError> {
358 use serde_json::Value::*;
359
360 match value {
361 Null => Ok(Value::Null),
362 Bool(b) => Ok(Value::Integer(b as i64)),
363 String(s) => Ok(Value::Text(s)),
364 Number(n) => Ok(if let Some(float) = n.as_f64() {
365 Value::Real(float)
366 } else if let Some(int) = n.as_i64() {
367 Value::Integer(int)
368 } else {
369 unreachable!()
370 }),
371 Array(arr) => {
372 let blob = serde_json::to_vec(&arr)
373 .map_err(|e| FilterError::Serialization(e.to_string()))?;
374
375 Ok(Value::Blob(blob))
376 }
377 Object(obj) => {
378 let blob = serde_json::to_vec(&obj)
379 .map_err(|e| FilterError::Serialization(e.to_string()))?;
380
381 Ok(Value::Blob(blob))
382 }
383 }
384 }
385
386 for param in self.params.into_iter() {
387 params.push(convert(param)?)
388 }
389
390 Ok(params)
391 }
392}
393
394pub struct SqliteVectorIndex<E, T>
477where
478 E: EmbeddingModel + 'static,
479 T: SqliteVectorStoreTable + 'static,
480{
481 store: SqliteVectorStore<E, T>,
482 embedding_model: E,
483}
484
485impl<E, T> SqliteVectorIndex<E, T>
486where
487 E: EmbeddingModel + 'static,
488 T: SqliteVectorStoreTable,
489{
490 pub fn new(embedding_model: E, store: SqliteVectorStore<E, T>) -> Self {
491 Self {
492 store,
493 embedding_model,
494 }
495 }
496}
497
498fn build_where_clause(
499 req: &VectorSearchRequest<SqliteSearchFilter>,
500 query_vec: Vec<f32>,
501) -> Result<(String, Vec<Value>), FilterError> {
502 let thresh = req.threshold().unwrap_or(0.);
503 let thresh = SqliteSearchFilter::gt("e.distance".into(), thresh.into());
504
505 let filter = req
506 .filter()
507 .as_ref()
508 .cloned()
509 .map(|filter| thresh.clone().and(filter))
510 .unwrap_or(thresh);
511
512 let where_clause = format!(
513 "WHERE e.embedding MATCH ? AND k = ? AND {}",
514 filter.condition
515 );
516
517 let query_vec = query_vec.into_iter().flat_map(f32::to_le_bytes).collect();
518 let query_vec = Value::Blob(query_vec);
519 let samples = req.samples() as u32;
520
521 let mut params = vec![query_vec, samples.into()];
522 let filter_params = filter.clone().compile_params()?;
523 params.extend(filter_params);
524
525 Ok((where_clause, params))
526}
527
528impl<E: EmbeddingModel + std::marker::Sync, T: SqliteVectorStoreTable> VectorStoreIndex
529 for SqliteVectorIndex<E, T>
530{
531 type Filter = SqliteSearchFilter;
532
533 async fn top_n<D>(
534 &self,
535 req: VectorSearchRequest<SqliteSearchFilter>,
536 ) -> Result<Vec<(f64, String, D)>, VectorStoreError>
537 where
538 D: for<'de> Deserialize<'de>,
539 {
540 tracing::debug!("Finding top {} matches for query", req.samples() as usize);
541 let embedding = self.embedding_model.embed_text(req.query()).await?;
542 let query_vec: Vec<f32> = serialize_embedding(&embedding);
543 let table_name = T::name();
544
545 let columns = T::schema();
547 let column_names: Vec<&str> = columns.iter().map(|column| column.name).collect();
548
549 let select_cols = column_names.join(", ");
551
552 let (where_clause, params) = build_where_clause(&req, query_vec)?;
553
554 let rows = self
555 .store
556 .conn
557 .call(move |conn| {
558 let mut stmt = conn.prepare(&format!(
559 "SELECT d.{select_cols}, e.distance
560 FROM {table_name}_embeddings e
561 JOIN {table_name} d ON e.rowid = d.rowid
562 {where_clause}
563 ORDER BY e.distance"
564 ))?;
565
566 dbg!(&stmt);
567
568 let rows = stmt
569 .query_map(rusqlite::params_from_iter(params), |row| {
570 let mut map = serde_json::Map::new();
572 for (i, col_name) in column_names.iter().enumerate() {
573 let value: String = row.get(i)?;
574 map.insert(col_name.to_string(), serde_json::Value::String(value));
575 }
576 let distance: f64 = row.get(column_names.len())?;
577 let id: String = row.get(0)?; Ok((id, serde_json::Value::Object(map), distance))
580 })?
581 .collect::<Result<Vec<_>, _>>()?;
582 Ok(rows)
583 })
584 .await
585 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
586
587 debug!("Found {} potential matches", rows.len());
588 let mut top_n = Vec::new();
589 for (id, doc_value, distance) in rows {
590 match serde_json::from_value::<D>(doc_value) {
591 Ok(doc) => {
592 top_n.push((distance, id, doc));
593 }
594 Err(e) => {
595 debug!("Failed to deserialize document {}: {}", id, e);
596 continue;
597 }
598 }
599 }
600
601 debug!("Returning {} matches", top_n.len());
602 Ok(top_n)
603 }
604
605 async fn top_n_ids(
606 &self,
607 req: VectorSearchRequest<SqliteSearchFilter>,
608 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
609 tracing::debug!(
610 "Finding top {} document IDs for query",
611 req.samples() as usize
612 );
613 let embedding = self.embedding_model.embed_text(req.query()).await?;
614 let query_vec = serialize_embedding(&embedding);
615 let table_name = T::name();
616
617 let (where_clause, params) = build_where_clause(&req, query_vec)?;
618
619 let results = self
620 .store
621 .conn
622 .call(move |conn| {
623 let mut stmt = conn.prepare(&format!(
624 "SELECT d.id, e.distance
625 FROM {table_name}_embeddings e
626 JOIN {table_name} d ON e.rowid = d.rowid
627 {where_clause}
628 ORDER BY e.distance"
629 ))?;
630
631 dbg!(&stmt);
632
633 let results = stmt
634 .query_map(rusqlite::params_from_iter(params), |row| {
635 Ok((row.get::<_, f64>(1)?, row.get::<_, String>(0)?))
636 })?
637 .collect::<Result<Vec<_>, _>>()?;
638 Ok(results)
639 })
640 .await
641 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
642
643 debug!("Found {} matching document IDs", results.len());
644 Ok(results)
645 }
646}
647
648fn serialize_embedding(embedding: &Embedding) -> Vec<f32> {
649 embedding.vec.iter().map(|x| *x as f32).collect()
650}
651
652impl ColumnValue for String {
653 fn to_sql_string(&self) -> String {
654 self.clone()
655 }
656
657 fn column_type(&self) -> &'static str {
658 "TEXT"
659 }
660}