1use rig::embeddings::{Embedding, EmbeddingModel};
2use rig::vector_store::request::{FilterError, SearchFilter, VectorSearchRequest};
3use rig::vector_store::{InsertDocuments, VectorStoreError, VectorStoreIndex};
4use rig::wasm_compat::{WasmCompatSend, WasmCompatSync};
5use rig::{Embed, OneOrMany};
6use rusqlite::types::Value;
7use serde::{Deserialize, Serialize};
8use std::marker::PhantomData;
9use std::ops::RangeInclusive;
10use tokio_rusqlite::Connection;
11use tracing::{debug, info};
12use zerocopy::IntoBytes;
13
14#[derive(Debug)]
15pub enum SqliteError {
16 DatabaseError(Box<dyn std::error::Error + Send + Sync>),
17 SerializationError(Box<dyn std::error::Error + Send + Sync>),
18 InvalidColumnType(String),
19}
20
21pub trait ColumnValue: Send + Sync {
22 fn to_sql_string(&self) -> String;
23 fn column_type(&self) -> &'static str;
24}
25
26pub struct Column {
27 name: &'static str,
28 col_type: &'static str,
29 indexed: bool,
30}
31
32impl Column {
33 pub fn new(name: &'static str, col_type: &'static str) -> Self {
34 Self {
35 name,
36 col_type,
37 indexed: false,
38 }
39 }
40
41 pub fn indexed(mut self) -> Self {
42 self.indexed = true;
43 self
44 }
45}
46
47pub trait SqliteVectorStoreTable: Send + Sync + Clone {
85 fn name() -> &'static str;
86 fn schema() -> Vec<Column>;
87 fn id(&self) -> String;
88 fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)>;
89}
90
91#[derive(Clone)]
92pub struct SqliteVectorStore<E, T>
93where
94 E: EmbeddingModel + 'static,
95 T: SqliteVectorStoreTable + 'static,
96{
97 conn: Connection,
98 _phantom: PhantomData<(E, T)>,
99}
100
101impl<E, T> SqliteVectorStore<E, T>
102where
103 E: EmbeddingModel + Clone + 'static,
104 T: SqliteVectorStoreTable + 'static,
105{
106 pub async fn new(conn: Connection, embedding_model: &E) -> Result<Self, VectorStoreError> {
107 let dims = embedding_model.ndims();
108 let table_name = T::name();
109 let schema = T::schema();
110
111 let mut create_table = format!("CREATE TABLE IF NOT EXISTS {table_name} (");
113
114 let mut first = true;
116 for column in &schema {
117 if !first {
118 create_table.push(',');
119 }
120 create_table.push_str(&format!("\n {} {}", column.name, column.col_type));
121 first = false;
122 }
123
124 create_table.push_str("\n)");
125
126 let mut create_indexes = vec![format!(
128 "CREATE INDEX IF NOT EXISTS idx_{}_id ON {}(id)",
129 table_name, table_name
130 )];
131
132 for column in schema {
134 if column.indexed {
135 create_indexes.push(format!(
136 "CREATE INDEX IF NOT EXISTS idx_{}_{} ON {}({})",
137 table_name, column.name, table_name, column.name
138 ));
139 }
140 }
141
142 conn.call(move |conn| {
143 conn.execute_batch("BEGIN")?;
144
145 conn.execute_batch(&create_table)?;
147
148 for index_stmt in create_indexes {
150 conn.execute_batch(&index_stmt)?;
151 }
152
153 conn.execute_batch(&format!(
155 "CREATE VIRTUAL TABLE IF NOT EXISTS {table_name}_embeddings USING vec0(embedding float[{dims}])"
156 ))?;
157
158 conn.execute_batch("COMMIT")?;
159 Ok(())
160 })
161 .await
162 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
163
164 Ok(Self {
165 conn,
166 _phantom: PhantomData,
167 })
168 }
169
170 pub fn index(self, model: E) -> SqliteVectorIndex<E, T> {
171 SqliteVectorIndex::new(model, self)
172 }
173
174 pub fn add_rows_with_txn(
175 &self,
176 txn: &rusqlite::Transaction<'_>,
177 documents: Vec<(T, OneOrMany<Embedding>)>,
178 ) -> Result<i64, tokio_rusqlite::Error> {
179 info!("Adding {} documents to store", documents.len());
180 let table_name = T::name();
181 let mut last_id = 0;
182
183 for (doc, embeddings) in &documents {
184 debug!("Storing document with id {}", doc.id());
185
186 let values = doc.column_values();
187 let columns = values.iter().map(|(col, _)| *col).collect::<Vec<_>>();
188
189 let placeholders = (1..=values.len())
190 .map(|i| format!("?{i}"))
191 .collect::<Vec<_>>();
192
193 let insert_sql = format!(
194 "INSERT OR REPLACE INTO {} ({}) VALUES ({})",
195 table_name,
196 columns.join(", "),
197 placeholders.join(", ")
198 );
199
200 txn.execute(
201 &insert_sql,
202 rusqlite::params_from_iter(values.iter().map(|(_, val)| val.to_sql_string())),
203 )?;
204 last_id = txn.last_insert_rowid();
205
206 let embeddings_sql =
207 format!("INSERT INTO {table_name}_embeddings (rowid, embedding) VALUES (?1, ?2)");
208
209 let mut stmt = txn.prepare(&embeddings_sql)?;
210 for (i, embedding) in embeddings.iter().enumerate() {
211 let vec = serialize_embedding(embedding);
212 debug!(
213 "Storing embedding {} of {} (size: {} bytes)",
214 i + 1,
215 embeddings.len(),
216 vec.len() * 4
217 );
218 let blob = rusqlite::types::Value::Blob(vec.as_bytes().to_vec());
219 stmt.execute(rusqlite::params![last_id, blob])?;
220 }
221 }
222
223 Ok(last_id)
224 }
225
226 pub async fn add_rows(
227 &self,
228 documents: Vec<(T, OneOrMany<Embedding>)>,
229 ) -> Result<i64, VectorStoreError>
230 where
231 T: 'static,
232 Self: 'static,
233 {
234 let cloned = self.clone();
235
236 self.conn
237 .call(move |conn| {
238 let tx = conn.transaction()?;
239 let result = cloned.add_rows_with_txn(&tx, documents)?;
240 tx.commit()?;
241
242 Ok(result)
243 })
244 .await
245 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))
246 }
247}
248
249impl<E, T> InsertDocuments for SqliteVectorStore<E, T>
250where
251 E: EmbeddingModel + Clone + WasmCompatSend + WasmCompatSync + 'static,
252 T: SqliteVectorStoreTable
253 + for<'de> Deserialize<'de>
254 + WasmCompatSend
255 + WasmCompatSync
256 + 'static,
257{
258 async fn insert_documents<Doc: Serialize + Embed + WasmCompatSend>(
259 &self,
260 documents: Vec<(Doc, OneOrMany<Embedding>)>,
261 ) -> Result<(), VectorStoreError> {
262 if documents.is_empty() {
263 return Ok(());
264 }
265
266 let rows = documents
267 .into_iter()
268 .map(|(document, embeddings)| {
269 let document = serde_json::to_value(document)?;
270 let row = serde_json::from_value::<T>(document)?;
271
272 Ok((row, embeddings))
273 })
274 .collect::<Result<Vec<_>, VectorStoreError>>()?;
275
276 self.add_rows(rows).await?;
277
278 Ok(())
279 }
280}
281
282#[derive(Clone, Default, Deserialize, Serialize, Debug)]
283pub struct SqliteSearchFilter {
284 condition: String,
285 params: Vec<serde_json::Value>,
286}
287
288impl SearchFilter for SqliteSearchFilter {
289 type Value = serde_json::Value;
290
291 fn eq(key: impl AsRef<str>, value: Self::Value) -> Self {
292 Self {
293 condition: format!("{} = ?", key.as_ref()),
294 params: vec![value],
295 }
296 }
297
298 fn gt(key: impl AsRef<str>, value: Self::Value) -> Self {
299 Self {
300 condition: format!("{} > ?", key.as_ref()),
301 params: vec![value],
302 }
303 }
304
305 fn lt(key: impl AsRef<str>, value: Self::Value) -> Self {
306 Self {
307 condition: format!("{} < ?", key.as_ref()),
308 params: vec![value],
309 }
310 }
311
312 fn and(self, rhs: Self) -> Self {
313 Self {
314 condition: format!("({}) AND ({})", self.condition, rhs.condition),
315 params: self.params.into_iter().chain(rhs.params).collect(),
316 }
317 }
318
319 fn or(self, rhs: Self) -> Self {
320 Self {
321 condition: format!("({}) OR ({})", self.condition, rhs.condition),
322 params: self.params.into_iter().chain(rhs.params).collect(),
323 }
324 }
325}
326
327impl SqliteSearchFilter {
328 #[allow(clippy::should_implement_trait)]
329 pub fn not(self) -> Self {
330 Self {
331 condition: format!("NOT ({})", self.condition),
332 ..self
333 }
334 }
335
336 pub fn between<N>(key: String, range: RangeInclusive<N>) -> Self
338 where
339 N: Ord + rusqlite::ToSql + std::fmt::Display,
340 {
341 let lo = range.start();
342 let hi = range.end();
343
344 Self {
345 condition: format!("{key} between {lo} and {hi}"),
346 ..Default::default()
347 }
348 }
349
350 pub fn is_null(key: String) -> Self {
352 Self {
353 condition: format!("{key} is null"),
354 ..Default::default()
355 }
356 }
357
358 pub fn is_not_null(key: String) -> Self {
359 Self {
360 condition: format!("{key} is not null"),
361 ..Default::default()
362 }
363 }
364
365 pub fn glob<'a, S>(key: String, pattern: S) -> Self
369 where
370 S: AsRef<&'a str>,
371 {
372 Self {
373 condition: format!("{key} glob {}", pattern.as_ref()),
374 ..Default::default()
375 }
376 }
377
378 pub fn like<'a, S>(key: String, pattern: S) -> Self
381 where
382 S: AsRef<&'a str>,
383 {
384 Self {
385 condition: format!("{key} like {}", pattern.as_ref()),
386 ..Default::default()
387 }
388 }
389}
390
391impl SqliteSearchFilter {
392 fn compile_params(self) -> Result<Vec<Value>, FilterError> {
393 let mut params = Vec::with_capacity(self.params.len());
394
395 fn convert(value: serde_json::Value) -> Result<Value, FilterError> {
396 use serde_json::Value::*;
397
398 match value {
399 Null => Ok(Value::Null),
400 Bool(b) => Ok(Value::Integer(b as i64)),
401 String(s) => Ok(Value::Text(s)),
402 Number(n) => Ok(if let Some(float) = n.as_f64() {
403 Value::Real(float)
404 } else if let Some(int) = n.as_i64() {
405 Value::Integer(int)
406 } else if let Some(int) = n.as_u64() {
407 Value::Integer(int as i64)
408 } else {
409 Value::Text(n.to_string())
410 }),
411 Array(arr) => {
412 let blob = serde_json::to_vec(&arr)
413 .map_err(|e| FilterError::Serialization(e.to_string()))?;
414
415 Ok(Value::Blob(blob))
416 }
417 Object(obj) => {
418 let blob = serde_json::to_vec(&obj)
419 .map_err(|e| FilterError::Serialization(e.to_string()))?;
420
421 Ok(Value::Blob(blob))
422 }
423 }
424 }
425
426 for param in self.params.into_iter() {
427 params.push(convert(param)?)
428 }
429
430 Ok(params)
431 }
432}
433
434pub struct SqliteVectorIndex<E, T>
526where
527 E: EmbeddingModel + 'static,
528 T: SqliteVectorStoreTable + 'static,
529{
530 store: SqliteVectorStore<E, T>,
531 embedding_model: E,
532}
533
534impl<E, T> SqliteVectorIndex<E, T>
535where
536 E: EmbeddingModel + 'static,
537 T: SqliteVectorStoreTable,
538{
539 pub fn new(embedding_model: E, store: SqliteVectorStore<E, T>) -> Self {
540 Self {
541 store,
542 embedding_model,
543 }
544 }
545}
546
547fn build_where_clause(
548 req: &VectorSearchRequest<SqliteSearchFilter>,
549 query_vec: Vec<f32>,
550) -> Result<(String, Vec<Value>), FilterError> {
551 let thresh = req.threshold().unwrap_or(0.);
552 let thresh = SqliteSearchFilter::gt("distance", thresh.into());
553
554 let filter = req
555 .filter()
556 .as_ref()
557 .cloned()
558 .map(|filter| thresh.clone().and(filter))
559 .unwrap_or(thresh);
560
561 let where_clause = format!(
562 "WHERE e.embedding MATCH ? AND k = ? AND {}",
563 filter.condition
564 );
565
566 let query_vec = query_vec.into_iter().flat_map(f32::to_le_bytes).collect();
567 let query_vec = Value::Blob(query_vec);
568 let samples = req.samples() as u32;
569
570 let mut params = vec![query_vec.clone(), query_vec, samples.into()];
571 let filter_params = filter.clone().compile_params()?;
572 params.extend(filter_params);
573
574 Ok((where_clause, params))
575}
576
577impl<E: EmbeddingModel + std::marker::Sync, T: SqliteVectorStoreTable> VectorStoreIndex
578 for SqliteVectorIndex<E, T>
579{
580 type Filter = SqliteSearchFilter;
581
582 async fn top_n<D>(
583 &self,
584 req: VectorSearchRequest<SqliteSearchFilter>,
585 ) -> Result<Vec<(f64, String, D)>, VectorStoreError>
586 where
587 D: for<'de> Deserialize<'de>,
588 {
589 tracing::debug!("Finding top {} matches for query", req.samples() as usize);
590 let embedding = self.embedding_model.embed_text(req.query()).await?;
591 let query_vec: Vec<f32> = serialize_embedding(&embedding);
592 let table_name = T::name();
593
594 let columns = T::schema();
596 let column_names: Vec<&str> = columns.iter().map(|column| column.name).collect();
597
598 let select_cols = column_names.join(", ");
600
601 let (where_clause, params) = build_where_clause(&req, query_vec)?;
602
603 let rows = self
604 .store
605 .conn
606 .call(move |conn| {
607 let mut stmt = conn.prepare(&format!(
608 "SELECT d.{select_cols}, (1 - vec_distance_cosine(?, e.embedding)) as distance
609 FROM {table_name}_embeddings e
610 JOIN {table_name} d ON e.rowid = d.rowid
611 {where_clause}
612 ORDER BY distance"
613 ))?;
614
615 let rows = stmt
616 .query_map(rusqlite::params_from_iter(params), |row| {
617 let mut map = serde_json::Map::new();
619 for (i, col_name) in column_names.iter().enumerate() {
620 let value: String = row.get(i)?;
621 map.insert(col_name.to_string(), serde_json::Value::String(value));
622 }
623 let distance: f64 = row.get(column_names.len())?;
624 let id: String = row.get(0)?; Ok((id, serde_json::Value::Object(map), distance))
627 })?
628 .collect::<Result<Vec<_>, _>>()?;
629 Ok(rows)
630 })
631 .await
632 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
633
634 debug!("Found {} potential matches", rows.len());
635 let mut top_n = Vec::new();
636 for (id, doc_value, distance) in rows {
637 match serde_json::from_value::<D>(doc_value) {
638 Ok(doc) => {
639 top_n.push((distance, id, doc));
640 }
641 Err(e) => {
642 debug!("Failed to deserialize document {}: {}", id, e);
643 continue;
644 }
645 }
646 }
647
648 debug!("Returning {} matches", top_n.len());
649 Ok(top_n)
650 }
651
652 async fn top_n_ids(
653 &self,
654 req: VectorSearchRequest<SqliteSearchFilter>,
655 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
656 tracing::debug!(
657 "Finding top {} document IDs for query",
658 req.samples() as usize
659 );
660 let embedding = self.embedding_model.embed_text(req.query()).await?;
661 let query_vec = serialize_embedding(&embedding);
662 let table_name = T::name();
663
664 let (where_clause, params) = build_where_clause(&req, query_vec)?;
665
666 let results = self
667 .store
668 .conn
669 .call(move |conn| {
670 let mut stmt = conn.prepare(&format!(
671 "SELECT d.id, (1 - vec_distance_cosine(?1, e.embedding)) as distance
672 FROM {table_name}_embeddings e
673 JOIN {table_name} d ON e.rowid = d.rowid
674 {where_clause}
675 ORDER BY distance"
676 ))?;
677
678 let results = stmt
679 .query_map(rusqlite::params_from_iter(params), |row| {
680 Ok((row.get::<_, f64>(1)?, row.get::<_, String>(0)?))
681 })?
682 .collect::<Result<Vec<_>, _>>()?;
683 Ok(results)
684 })
685 .await
686 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
687
688 debug!("Found {} matching document IDs", results.len());
689 Ok(results)
690 }
691}
692
693fn serialize_embedding(embedding: &Embedding) -> Vec<f32> {
694 embedding.vec.iter().map(|x| *x as f32).collect()
695}
696
697impl ColumnValue for String {
698 fn to_sql_string(&self) -> String {
699 self.clone()
700 }
701
702 fn column_type(&self) -> &'static str {
703 "TEXT"
704 }
705}