1use rig::embeddings::{Embedding, EmbeddingModel};
2use rig::vector_store::request::{FilterError, SearchFilter, VectorSearchRequest};
3use rig::vector_store::{InsertDocuments, VectorStoreError, VectorStoreIndex};
4use rig::wasm_compat::{WasmCompatSend, WasmCompatSync};
5use rig::{Embed, OneOrMany};
6use rusqlite::types::Value;
7use serde::{Deserialize, Serialize};
8use std::marker::PhantomData;
9use std::ops::RangeInclusive;
10use tokio_rusqlite::Connection;
11use tracing::{debug, info};
12use zerocopy::IntoBytes;
13
14#[derive(Debug)]
15pub enum SqliteError {
16 DatabaseError(Box<dyn std::error::Error + Send + Sync>),
17 SerializationError(Box<dyn std::error::Error + Send + Sync>),
18 InvalidColumnType(String),
19}
20
21pub trait ColumnValue: Send + Sync {
22 fn to_sql_string(&self) -> String;
23 fn column_type(&self) -> &'static str;
24}
25
26pub struct Column {
27 name: &'static str,
28 col_type: &'static str,
29 indexed: bool,
30}
31
32impl Column {
33 pub fn new(name: &'static str, col_type: &'static str) -> Self {
34 Self {
35 name,
36 col_type,
37 indexed: false,
38 }
39 }
40
41 pub fn indexed(mut self) -> Self {
42 self.indexed = true;
43 self
44 }
45}
46
47pub trait SqliteVectorStoreTable: Send + Sync + Clone {
85 fn name() -> &'static str;
86 fn schema() -> Vec<Column>;
87 fn id(&self) -> String;
88 fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)>;
89}
90
91#[derive(Clone)]
92pub struct SqliteVectorStore<E, T>
93where
94 E: EmbeddingModel + 'static,
95 T: SqliteVectorStoreTable + 'static,
96{
97 conn: Connection,
98 _phantom: PhantomData<(E, T)>,
99}
100
101impl<E, T> SqliteVectorStore<E, T>
102where
103 E: EmbeddingModel + Clone + 'static,
104 T: SqliteVectorStoreTable + 'static,
105{
106 pub async fn new(conn: Connection, embedding_model: &E) -> Result<Self, VectorStoreError> {
107 let dims = embedding_model.ndims();
108 let table_name = T::name();
109 let schema = T::schema();
110
111 let mut create_table = format!("CREATE TABLE IF NOT EXISTS {table_name} (");
113
114 let mut first = true;
116 for column in &schema {
117 if !first {
118 create_table.push(',');
119 }
120 create_table.push_str(&format!("\n {} {}", column.name, column.col_type));
121 first = false;
122 }
123
124 create_table.push_str("\n)");
125
126 let mut create_indexes = vec![format!(
128 "CREATE INDEX IF NOT EXISTS idx_{}_id ON {}(id)",
129 table_name, table_name
130 )];
131
132 for column in schema {
134 if column.indexed {
135 create_indexes.push(format!(
136 "CREATE INDEX IF NOT EXISTS idx_{}_{} ON {}({})",
137 table_name, column.name, table_name, column.name
138 ));
139 }
140 }
141
142 conn.call(move |conn| {
143 conn.execute_batch("BEGIN")?;
144
145 conn.execute_batch(&create_table)?;
147
148 for index_stmt in create_indexes {
150 conn.execute_batch(&index_stmt)?;
151 }
152
153 conn.execute_batch(&format!(
155 "CREATE VIRTUAL TABLE IF NOT EXISTS {table_name}_embeddings USING vec0(embedding float[{dims}])"
156 ))?;
157
158 conn.execute_batch("COMMIT")?;
159 Ok(())
160 })
161 .await
162 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
163
164 Ok(Self {
165 conn,
166 _phantom: PhantomData,
167 })
168 }
169
170 pub fn index(self, model: E) -> SqliteVectorIndex<E, T> {
171 SqliteVectorIndex::new(model, self)
172 }
173
174 pub fn add_rows_with_txn(
175 &self,
176 txn: &rusqlite::Transaction<'_>,
177 documents: Vec<(T, OneOrMany<Embedding>)>,
178 ) -> Result<i64, tokio_rusqlite::Error> {
179 info!("Adding {} documents to store", documents.len());
180 let table_name = T::name();
181 let mut last_id = 0;
182
183 for (doc, embeddings) in &documents {
184 debug!("Storing document with id {}", doc.id());
185
186 let values = doc.column_values();
187 let columns = values.iter().map(|(col, _)| *col).collect::<Vec<_>>();
188
189 let placeholders = (1..=values.len())
190 .map(|i| format!("?{i}"))
191 .collect::<Vec<_>>();
192
193 let insert_sql = format!(
194 "INSERT OR REPLACE INTO {} ({}) VALUES ({})",
195 table_name,
196 columns.join(", "),
197 placeholders.join(", ")
198 );
199
200 txn.execute(
201 &insert_sql,
202 rusqlite::params_from_iter(values.iter().map(|(_, val)| val.to_sql_string())),
203 )?;
204 last_id = txn.last_insert_rowid();
205
206 let embeddings_sql =
207 format!("INSERT INTO {table_name}_embeddings (rowid, embedding) VALUES (?1, ?2)");
208
209 let mut stmt = txn.prepare(&embeddings_sql)?;
210 for (i, embedding) in embeddings.iter().enumerate() {
211 let vec = serialize_embedding(embedding);
212 debug!(
213 "Storing embedding {} of {} (size: {} bytes)",
214 i + 1,
215 embeddings.len(),
216 vec.len() * 4
217 );
218 let blob = rusqlite::types::Value::Blob(vec.as_bytes().to_vec());
219 stmt.execute(rusqlite::params![last_id, blob])?;
220 }
221 }
222
223 Ok(last_id)
224 }
225
226 pub async fn add_rows(
227 &self,
228 documents: Vec<(T, OneOrMany<Embedding>)>,
229 ) -> Result<i64, VectorStoreError>
230 where
231 T: 'static,
232 Self: 'static,
233 {
234 let cloned = self.clone();
235
236 self.conn
237 .call(move |conn| {
238 let tx = conn.transaction()?;
239 let result = cloned.add_rows_with_txn(&tx, documents)?;
240 tx.commit()?;
241
242 Ok(result)
243 })
244 .await
245 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))
246 }
247}
248
249impl<E, T> InsertDocuments for SqliteVectorStore<E, T>
250where
251 E: EmbeddingModel + Clone + WasmCompatSend + WasmCompatSync + 'static,
252 T: SqliteVectorStoreTable
253 + for<'de> Deserialize<'de>
254 + WasmCompatSend
255 + WasmCompatSync
256 + 'static,
257{
258 async fn insert_documents<Doc: Serialize + Embed + WasmCompatSend>(
259 &self,
260 documents: Vec<(Doc, OneOrMany<Embedding>)>,
261 ) -> Result<(), VectorStoreError> {
262 if documents.is_empty() {
263 return Ok(());
264 }
265
266 let rows = documents
267 .into_iter()
268 .map(|(document, embeddings)| {
269 let document = serde_json::to_value(document)?;
270 let row = serde_json::from_value::<T>(document)?;
271
272 Ok((row, embeddings))
273 })
274 .collect::<Result<Vec<_>, VectorStoreError>>()?;
275
276 self.add_rows(rows).await?;
277
278 Ok(())
279 }
280}
281
282#[derive(Clone, Default, Deserialize, Serialize, Debug)]
283pub struct SqliteSearchFilter {
284 condition: String,
285 params: Vec<serde_json::Value>,
286}
287
288impl SearchFilter for SqliteSearchFilter {
289 type Value = serde_json::Value;
290
291 fn eq(key: impl AsRef<str>, value: Self::Value) -> Self {
292 Self {
293 condition: format!("{} = ?", key.as_ref()),
294 params: vec![value],
295 }
296 }
297
298 fn gt(key: impl AsRef<str>, value: Self::Value) -> Self {
299 Self {
300 condition: format!("{} > ?", key.as_ref()),
301 params: vec![value],
302 }
303 }
304
305 fn lt(key: impl AsRef<str>, value: Self::Value) -> Self {
306 Self {
307 condition: format!("{} < ?", key.as_ref()),
308 params: vec![value],
309 }
310 }
311
312 fn and(self, rhs: Self) -> Self {
313 Self {
314 condition: format!("({}) AND ({})", self.condition, rhs.condition),
315 params: self.params.into_iter().chain(rhs.params).collect(),
316 }
317 }
318
319 fn or(self, rhs: Self) -> Self {
320 Self {
321 condition: format!("({}) OR ({})", self.condition, rhs.condition),
322 params: self.params.into_iter().chain(rhs.params).collect(),
323 }
324 }
325}
326
327impl SqliteSearchFilter {
328 #[allow(clippy::should_implement_trait)]
329 pub fn not(self) -> Self {
330 Self {
331 condition: format!("NOT ({})", self.condition),
332 ..self
333 }
334 }
335
336 pub fn between<N>(key: String, range: RangeInclusive<N>) -> Self
338 where
339 N: Ord + rusqlite::ToSql + std::fmt::Display,
340 {
341 let lo = range.start();
342 let hi = range.end();
343
344 Self {
345 condition: format!("{key} between {lo} and {hi}"),
346 ..Default::default()
347 }
348 }
349
350 pub fn is_null(key: String) -> Self {
352 Self {
353 condition: format!("{key} is null"),
354 ..Default::default()
355 }
356 }
357
358 pub fn is_not_null(key: String) -> Self {
359 Self {
360 condition: format!("{key} is not null"),
361 ..Default::default()
362 }
363 }
364
365 pub fn glob<'a, S>(key: String, pattern: S) -> Self
369 where
370 S: AsRef<&'a str>,
371 {
372 Self {
373 condition: format!("{key} glob {}", pattern.as_ref()),
374 ..Default::default()
375 }
376 }
377
378 pub fn like<'a, S>(key: String, pattern: S) -> Self
381 where
382 S: AsRef<&'a str>,
383 {
384 Self {
385 condition: format!("{key} like {}", pattern.as_ref()),
386 ..Default::default()
387 }
388 }
389}
390
391impl SqliteSearchFilter {
392 fn compile_params(self) -> Result<Vec<Value>, FilterError> {
393 let mut params = Vec::with_capacity(self.params.len());
394
395 fn convert(value: serde_json::Value) -> Result<Value, FilterError> {
396 use serde_json::Value::*;
397
398 match value {
399 Null => Ok(Value::Null),
400 Bool(b) => Ok(Value::Integer(b as i64)),
401 String(s) => Ok(Value::Text(s)),
402 Number(n) => Ok(if let Some(float) = n.as_f64() {
403 Value::Real(float)
404 } else if let Some(int) = n.as_i64() {
405 Value::Integer(int)
406 } else {
407 unreachable!()
408 }),
409 Array(arr) => {
410 let blob = serde_json::to_vec(&arr)
411 .map_err(|e| FilterError::Serialization(e.to_string()))?;
412
413 Ok(Value::Blob(blob))
414 }
415 Object(obj) => {
416 let blob = serde_json::to_vec(&obj)
417 .map_err(|e| FilterError::Serialization(e.to_string()))?;
418
419 Ok(Value::Blob(blob))
420 }
421 }
422 }
423
424 for param in self.params.into_iter() {
425 params.push(convert(param)?)
426 }
427
428 Ok(params)
429 }
430}
431
432pub struct SqliteVectorIndex<E, T>
518where
519 E: EmbeddingModel + 'static,
520 T: SqliteVectorStoreTable + 'static,
521{
522 store: SqliteVectorStore<E, T>,
523 embedding_model: E,
524}
525
526impl<E, T> SqliteVectorIndex<E, T>
527where
528 E: EmbeddingModel + 'static,
529 T: SqliteVectorStoreTable,
530{
531 pub fn new(embedding_model: E, store: SqliteVectorStore<E, T>) -> Self {
532 Self {
533 store,
534 embedding_model,
535 }
536 }
537}
538
539fn build_where_clause(
540 req: &VectorSearchRequest<SqliteSearchFilter>,
541 query_vec: Vec<f32>,
542) -> Result<(String, Vec<Value>), FilterError> {
543 let thresh = req.threshold().unwrap_or(0.);
544 let thresh = SqliteSearchFilter::gt("distance", thresh.into());
545
546 let filter = req
547 .filter()
548 .as_ref()
549 .cloned()
550 .map(|filter| thresh.clone().and(filter))
551 .unwrap_or(thresh);
552
553 let where_clause = format!(
554 "WHERE e.embedding MATCH ? AND k = ? AND {}",
555 filter.condition
556 );
557
558 let query_vec = query_vec.into_iter().flat_map(f32::to_le_bytes).collect();
559 let query_vec = Value::Blob(query_vec);
560 let samples = req.samples() as u32;
561
562 let mut params = vec![query_vec.clone(), query_vec, samples.into()];
563 let filter_params = filter.clone().compile_params()?;
564 params.extend(filter_params);
565
566 Ok((where_clause, params))
567}
568
569impl<E: EmbeddingModel + std::marker::Sync, T: SqliteVectorStoreTable> VectorStoreIndex
570 for SqliteVectorIndex<E, T>
571{
572 type Filter = SqliteSearchFilter;
573
574 async fn top_n<D>(
575 &self,
576 req: VectorSearchRequest<SqliteSearchFilter>,
577 ) -> Result<Vec<(f64, String, D)>, VectorStoreError>
578 where
579 D: for<'de> Deserialize<'de>,
580 {
581 tracing::debug!("Finding top {} matches for query", req.samples() as usize);
582 let embedding = self.embedding_model.embed_text(req.query()).await?;
583 let query_vec: Vec<f32> = serialize_embedding(&embedding);
584 let table_name = T::name();
585
586 let columns = T::schema();
588 let column_names: Vec<&str> = columns.iter().map(|column| column.name).collect();
589
590 let select_cols = column_names.join(", ");
592
593 let (where_clause, params) = build_where_clause(&req, query_vec)?;
594
595 let rows = self
596 .store
597 .conn
598 .call(move |conn| {
599 let mut stmt = conn.prepare(&format!(
600 "SELECT d.{select_cols}, (1 - vec_distance_cosine(?, e.embedding)) as distance
601 FROM {table_name}_embeddings e
602 JOIN {table_name} d ON e.rowid = d.rowid
603 {where_clause}
604 ORDER BY distance"
605 ))?;
606
607 let rows = stmt
608 .query_map(rusqlite::params_from_iter(params), |row| {
609 let mut map = serde_json::Map::new();
611 for (i, col_name) in column_names.iter().enumerate() {
612 let value: String = row.get(i)?;
613 map.insert(col_name.to_string(), serde_json::Value::String(value));
614 }
615 let distance: f64 = row.get(column_names.len())?;
616 let id: String = row.get(0)?; Ok((id, serde_json::Value::Object(map), distance))
619 })?
620 .collect::<Result<Vec<_>, _>>()?;
621 Ok(rows)
622 })
623 .await
624 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
625
626 debug!("Found {} potential matches", rows.len());
627 let mut top_n = Vec::new();
628 for (id, doc_value, distance) in rows {
629 match serde_json::from_value::<D>(doc_value) {
630 Ok(doc) => {
631 top_n.push((distance, id, doc));
632 }
633 Err(e) => {
634 debug!("Failed to deserialize document {}: {}", id, e);
635 continue;
636 }
637 }
638 }
639
640 debug!("Returning {} matches", top_n.len());
641 Ok(top_n)
642 }
643
644 async fn top_n_ids(
645 &self,
646 req: VectorSearchRequest<SqliteSearchFilter>,
647 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
648 tracing::debug!(
649 "Finding top {} document IDs for query",
650 req.samples() as usize
651 );
652 let embedding = self.embedding_model.embed_text(req.query()).await?;
653 let query_vec = serialize_embedding(&embedding);
654 let table_name = T::name();
655
656 let (where_clause, params) = build_where_clause(&req, query_vec)?;
657
658 let results = self
659 .store
660 .conn
661 .call(move |conn| {
662 let mut stmt = conn.prepare(&format!(
663 "SELECT d.id, (1 - vec_distance_cosine(?1, e.embedding)) as distance
664 FROM {table_name}_embeddings e
665 JOIN {table_name} d ON e.rowid = d.rowid
666 {where_clause}
667 ORDER BY distance"
668 ))?;
669
670 let results = stmt
671 .query_map(rusqlite::params_from_iter(params), |row| {
672 Ok((row.get::<_, f64>(1)?, row.get::<_, String>(0)?))
673 })?
674 .collect::<Result<Vec<_>, _>>()?;
675 Ok(results)
676 })
677 .await
678 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
679
680 debug!("Found {} matching document IDs", results.len());
681 Ok(results)
682 }
683}
684
685fn serialize_embedding(embedding: &Embedding) -> Vec<f32> {
686 embedding.vec.iter().map(|x| *x as f32).collect()
687}
688
689impl ColumnValue for String {
690 fn to_sql_string(&self) -> String {
691 self.clone()
692 }
693
694 fn column_type(&self) -> &'static str {
695 "TEXT"
696 }
697}