1use rig_core::embeddings::{Embedding, EmbeddingModel};
11use rig_core::vector_store::request::{FilterError, SearchFilter, VectorSearchRequest};
12use rig_core::vector_store::{InsertDocuments, VectorStoreError, VectorStoreIndex};
13use rig_core::wasm_compat::{WasmCompatSend, WasmCompatSync};
14use rig_core::{Embed, OneOrMany};
15use rusqlite::types::Value;
16use serde::{Deserialize, Serialize};
17use std::marker::PhantomData;
18use std::ops::RangeInclusive;
19use tokio_rusqlite::Connection;
20use tracing::{debug, info};
21use zerocopy::IntoBytes;
22
23#[derive(Debug)]
24pub enum SqliteError {
25 DatabaseError(Box<dyn std::error::Error + Send + Sync>),
26 SerializationError(Box<dyn std::error::Error + Send + Sync>),
27 InvalidColumnType(String),
28}
29
30pub trait ColumnValue: Send + Sync {
31 fn to_sql_string(&self) -> String;
32 fn column_type(&self) -> &'static str;
33}
34
35pub struct Column {
36 name: &'static str,
37 col_type: &'static str,
38 indexed: bool,
39}
40
41impl Column {
42 pub fn new(name: &'static str, col_type: &'static str) -> Self {
43 Self {
44 name,
45 col_type,
46 indexed: false,
47 }
48 }
49
50 pub fn indexed(mut self) -> Self {
51 self.indexed = true;
52 self
53 }
54}
55
56pub trait SqliteVectorStoreTable: Send + Sync + Clone {
94 fn name() -> &'static str;
95 fn schema() -> Vec<Column>;
96 fn id(&self) -> String;
97 fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)>;
98}
99
100#[derive(Clone)]
101pub struct SqliteVectorStore<E, T>
102where
103 E: EmbeddingModel + 'static,
104 T: SqliteVectorStoreTable + 'static,
105{
106 conn: Connection,
107 _phantom: PhantomData<(E, T)>,
108}
109
110impl<E, T> SqliteVectorStore<E, T>
111where
112 E: EmbeddingModel + Clone + 'static,
113 T: SqliteVectorStoreTable + 'static,
114{
115 pub async fn new(conn: Connection, embedding_model: &E) -> Result<Self, VectorStoreError> {
116 let dims = embedding_model.ndims();
117 let table_name = T::name();
118 let schema = T::schema();
119
120 let mut create_table = format!("CREATE TABLE IF NOT EXISTS {table_name} (");
122
123 let mut first = true;
125 for column in &schema {
126 if !first {
127 create_table.push(',');
128 }
129 create_table.push_str(&format!("\n {} {}", column.name, column.col_type));
130 first = false;
131 }
132
133 create_table.push_str("\n)");
134
135 let mut create_indexes = vec![format!(
137 "CREATE INDEX IF NOT EXISTS idx_{}_id ON {}(id)",
138 table_name, table_name
139 )];
140
141 for column in schema {
143 if column.indexed {
144 create_indexes.push(format!(
145 "CREATE INDEX IF NOT EXISTS idx_{}_{} ON {}({})",
146 table_name, column.name, table_name, column.name
147 ));
148 }
149 }
150
151 conn.call(move |conn| {
152 conn.execute_batch("BEGIN")?;
153
154 conn.execute_batch(&create_table)?;
156
157 for index_stmt in create_indexes {
159 conn.execute_batch(&index_stmt)?;
160 }
161
162 conn.execute_batch(&format!(
164 "CREATE VIRTUAL TABLE IF NOT EXISTS {table_name}_embeddings USING vec0(embedding float[{dims}])"
165 ))?;
166
167 conn.execute_batch("COMMIT")?;
168 Ok(())
169 })
170 .await
171 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
172
173 Ok(Self {
174 conn,
175 _phantom: PhantomData,
176 })
177 }
178
179 pub fn index(self, model: E) -> SqliteVectorIndex<E, T> {
180 SqliteVectorIndex::new(model, self)
181 }
182
183 pub fn add_rows_with_txn(
184 &self,
185 txn: &rusqlite::Transaction<'_>,
186 documents: Vec<(T, OneOrMany<Embedding>)>,
187 ) -> Result<i64, tokio_rusqlite::Error> {
188 info!("Adding {} documents to store", documents.len());
189 let table_name = T::name();
190 let mut last_id = 0;
191
192 for (doc, embeddings) in &documents {
193 debug!("Storing document with id {}", doc.id());
194
195 let values = doc.column_values();
196 let columns = values.iter().map(|(col, _)| *col).collect::<Vec<_>>();
197
198 let placeholders = (1..=values.len())
199 .map(|i| format!("?{i}"))
200 .collect::<Vec<_>>();
201
202 let insert_sql = format!(
203 "INSERT OR REPLACE INTO {} ({}) VALUES ({})",
204 table_name,
205 columns.join(", "),
206 placeholders.join(", ")
207 );
208
209 txn.execute(
210 &insert_sql,
211 rusqlite::params_from_iter(values.iter().map(|(_, val)| val.to_sql_string())),
212 )?;
213 last_id = txn.last_insert_rowid();
214
215 let embeddings_sql =
216 format!("INSERT INTO {table_name}_embeddings (rowid, embedding) VALUES (?1, ?2)");
217
218 let mut stmt = txn.prepare(&embeddings_sql)?;
219 for (i, embedding) in embeddings.iter().enumerate() {
220 let vec = serialize_embedding(embedding);
221 debug!(
222 "Storing embedding {} of {} (size: {} bytes)",
223 i + 1,
224 embeddings.len(),
225 vec.len() * 4
226 );
227 let blob = rusqlite::types::Value::Blob(vec.as_bytes().to_vec());
228 stmt.execute(rusqlite::params![last_id, blob])?;
229 }
230 }
231
232 Ok(last_id)
233 }
234
235 pub async fn add_rows(
236 &self,
237 documents: Vec<(T, OneOrMany<Embedding>)>,
238 ) -> Result<i64, VectorStoreError>
239 where
240 T: 'static,
241 Self: 'static,
242 {
243 let cloned = self.clone();
244
245 self.conn
246 .call(move |conn| {
247 let tx = conn.transaction()?;
248 let result = cloned.add_rows_with_txn(&tx, documents)?;
249 tx.commit()?;
250
251 Ok(result)
252 })
253 .await
254 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))
255 }
256}
257
258impl<E, T> InsertDocuments for SqliteVectorStore<E, T>
259where
260 E: EmbeddingModel + Clone + WasmCompatSend + WasmCompatSync + 'static,
261 T: SqliteVectorStoreTable
262 + for<'de> Deserialize<'de>
263 + WasmCompatSend
264 + WasmCompatSync
265 + 'static,
266{
267 async fn insert_documents<Doc: Serialize + Embed + WasmCompatSend>(
268 &self,
269 documents: Vec<(Doc, OneOrMany<Embedding>)>,
270 ) -> Result<(), VectorStoreError> {
271 if documents.is_empty() {
272 return Ok(());
273 }
274
275 let rows = documents
276 .into_iter()
277 .map(|(document, embeddings)| {
278 let document = serde_json::to_value(document)?;
279 let row = serde_json::from_value::<T>(document)?;
280
281 Ok((row, embeddings))
282 })
283 .collect::<Result<Vec<_>, VectorStoreError>>()?;
284
285 self.add_rows(rows).await?;
286
287 Ok(())
288 }
289}
290
291#[derive(Clone, Default, Deserialize, Serialize, Debug)]
292pub struct SqliteSearchFilter {
293 condition: String,
294 params: Vec<serde_json::Value>,
295}
296
297impl SearchFilter for SqliteSearchFilter {
298 type Value = serde_json::Value;
299
300 fn eq(key: impl AsRef<str>, value: Self::Value) -> Self {
301 Self {
302 condition: format!("{} = ?", key.as_ref()),
303 params: vec![value],
304 }
305 }
306
307 fn gt(key: impl AsRef<str>, value: Self::Value) -> Self {
308 Self {
309 condition: format!("{} > ?", key.as_ref()),
310 params: vec![value],
311 }
312 }
313
314 fn lt(key: impl AsRef<str>, value: Self::Value) -> Self {
315 Self {
316 condition: format!("{} < ?", key.as_ref()),
317 params: vec![value],
318 }
319 }
320
321 fn and(self, rhs: Self) -> Self {
322 Self {
323 condition: format!("({}) AND ({})", self.condition, rhs.condition),
324 params: self.params.into_iter().chain(rhs.params).collect(),
325 }
326 }
327
328 fn or(self, rhs: Self) -> Self {
329 Self {
330 condition: format!("({}) OR ({})", self.condition, rhs.condition),
331 params: self.params.into_iter().chain(rhs.params).collect(),
332 }
333 }
334}
335
336impl SqliteSearchFilter {
337 #[allow(clippy::should_implement_trait)]
338 pub fn not(self) -> Self {
339 Self {
340 condition: format!("NOT ({})", self.condition),
341 ..self
342 }
343 }
344
345 pub fn between<N>(key: String, range: RangeInclusive<N>) -> Self
347 where
348 N: Ord + rusqlite::ToSql + std::fmt::Display,
349 {
350 let lo = range.start();
351 let hi = range.end();
352
353 Self {
354 condition: format!("{key} between {lo} and {hi}"),
355 ..Default::default()
356 }
357 }
358
359 pub fn is_null(key: String) -> Self {
361 Self {
362 condition: format!("{key} is null"),
363 ..Default::default()
364 }
365 }
366
367 pub fn is_not_null(key: String) -> Self {
368 Self {
369 condition: format!("{key} is not null"),
370 ..Default::default()
371 }
372 }
373
374 pub fn glob<'a, S>(key: String, pattern: S) -> Self
378 where
379 S: AsRef<&'a str>,
380 {
381 Self {
382 condition: format!("{key} glob {}", pattern.as_ref()),
383 ..Default::default()
384 }
385 }
386
387 pub fn like<'a, S>(key: String, pattern: S) -> Self
390 where
391 S: AsRef<&'a str>,
392 {
393 Self {
394 condition: format!("{key} like {}", pattern.as_ref()),
395 ..Default::default()
396 }
397 }
398}
399
400impl SqliteSearchFilter {
401 fn compile_params(self) -> Result<Vec<Value>, FilterError> {
402 let mut params = Vec::with_capacity(self.params.len());
403
404 fn convert(value: serde_json::Value) -> Result<Value, FilterError> {
405 use serde_json::Value::*;
406
407 match value {
408 Null => Ok(Value::Null),
409 Bool(b) => Ok(Value::Integer(b as i64)),
410 String(s) => Ok(Value::Text(s)),
411 Number(n) => Ok(if let Some(float) = n.as_f64() {
412 Value::Real(float)
413 } else if let Some(int) = n.as_i64() {
414 Value::Integer(int)
415 } else if let Some(int) = n.as_u64() {
416 Value::Integer(int as i64)
417 } else {
418 Value::Text(n.to_string())
419 }),
420 Array(arr) => {
421 let blob = serde_json::to_vec(&arr)
422 .map_err(|e| FilterError::Serialization(e.to_string()))?;
423
424 Ok(Value::Blob(blob))
425 }
426 Object(obj) => {
427 let blob = serde_json::to_vec(&obj)
428 .map_err(|e| FilterError::Serialization(e.to_string()))?;
429
430 Ok(Value::Blob(blob))
431 }
432 }
433 }
434
435 for param in self.params.into_iter() {
436 params.push(convert(param)?)
437 }
438
439 Ok(params)
440 }
441}
442
443pub struct SqliteVectorIndex<E, T>
535where
536 E: EmbeddingModel + 'static,
537 T: SqliteVectorStoreTable + 'static,
538{
539 store: SqliteVectorStore<E, T>,
540 embedding_model: E,
541}
542
543impl<E, T> SqliteVectorIndex<E, T>
544where
545 E: EmbeddingModel + 'static,
546 T: SqliteVectorStoreTable,
547{
548 pub fn new(embedding_model: E, store: SqliteVectorStore<E, T>) -> Self {
549 Self {
550 store,
551 embedding_model,
552 }
553 }
554}
555
556fn build_where_clause(
557 req: &VectorSearchRequest<SqliteSearchFilter>,
558 query_vec: Vec<f32>,
559) -> Result<(String, Vec<Value>), FilterError> {
560 let thresh = req.threshold().unwrap_or(0.);
561 let thresh = SqliteSearchFilter::gt("distance", thresh.into());
562
563 let filter = req
564 .filter()
565 .as_ref()
566 .cloned()
567 .map(|filter| thresh.clone().and(filter))
568 .unwrap_or(thresh);
569
570 let where_clause = format!(
571 "WHERE e.embedding MATCH ? AND k = ? AND {}",
572 filter.condition
573 );
574
575 let query_vec = query_vec.into_iter().flat_map(f32::to_le_bytes).collect();
576 let query_vec = Value::Blob(query_vec);
577 let samples = req.samples() as u32;
578
579 let mut params = vec![query_vec.clone(), query_vec, samples.into()];
580 let filter_params = filter.clone().compile_params()?;
581 params.extend(filter_params);
582
583 Ok((where_clause, params))
584}
585
586impl<E: EmbeddingModel + std::marker::Sync, T: SqliteVectorStoreTable> VectorStoreIndex
587 for SqliteVectorIndex<E, T>
588{
589 type Filter = SqliteSearchFilter;
590
591 async fn top_n<D>(
592 &self,
593 req: VectorSearchRequest<SqliteSearchFilter>,
594 ) -> Result<Vec<(f64, String, D)>, VectorStoreError>
595 where
596 D: for<'de> Deserialize<'de>,
597 {
598 tracing::debug!("Finding top {} matches for query", req.samples() as usize);
599 let embedding = self.embedding_model.embed_text(req.query()).await?;
600 let query_vec: Vec<f32> = serialize_embedding(&embedding);
601 let table_name = T::name();
602
603 let columns = T::schema();
605 let column_names: Vec<&str> = columns.iter().map(|column| column.name).collect();
606
607 let select_cols = column_names.join(", ");
609
610 let (where_clause, params) = build_where_clause(&req, query_vec)?;
611
612 let rows = self
613 .store
614 .conn
615 .call(move |conn| {
616 let mut stmt = conn.prepare(&format!(
617 "SELECT d.{select_cols}, (1 - vec_distance_cosine(?, e.embedding)) as distance
618 FROM {table_name}_embeddings e
619 JOIN {table_name} d ON e.rowid = d.rowid
620 {where_clause}
621 ORDER BY distance"
622 ))?;
623
624 let rows = stmt
625 .query_map(rusqlite::params_from_iter(params), |row| {
626 let mut map = serde_json::Map::new();
628 for (i, col_name) in column_names.iter().enumerate() {
629 let value: String = row.get(i)?;
630 map.insert(col_name.to_string(), serde_json::Value::String(value));
631 }
632 let distance: f64 = row.get(column_names.len())?;
633 let id: String = row.get(0)?; Ok((id, serde_json::Value::Object(map), distance))
636 })?
637 .collect::<Result<Vec<_>, _>>()?;
638 Ok(rows)
639 })
640 .await
641 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
642
643 debug!("Found {} potential matches", rows.len());
644 let mut top_n = Vec::new();
645 for (id, doc_value, distance) in rows {
646 match serde_json::from_value::<D>(doc_value) {
647 Ok(doc) => {
648 top_n.push((distance, id, doc));
649 }
650 Err(e) => {
651 debug!("Failed to deserialize document {}: {}", id, e);
652 continue;
653 }
654 }
655 }
656
657 debug!("Returning {} matches", top_n.len());
658 Ok(top_n)
659 }
660
661 async fn top_n_ids(
662 &self,
663 req: VectorSearchRequest<SqliteSearchFilter>,
664 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
665 tracing::debug!(
666 "Finding top {} document IDs for query",
667 req.samples() as usize
668 );
669 let embedding = self.embedding_model.embed_text(req.query()).await?;
670 let query_vec = serialize_embedding(&embedding);
671 let table_name = T::name();
672
673 let (where_clause, params) = build_where_clause(&req, query_vec)?;
674
675 let results = self
676 .store
677 .conn
678 .call(move |conn| {
679 let mut stmt = conn.prepare(&format!(
680 "SELECT d.id, (1 - vec_distance_cosine(?1, e.embedding)) as distance
681 FROM {table_name}_embeddings e
682 JOIN {table_name} d ON e.rowid = d.rowid
683 {where_clause}
684 ORDER BY distance"
685 ))?;
686
687 let results = stmt
688 .query_map(rusqlite::params_from_iter(params), |row| {
689 Ok((row.get::<_, f64>(1)?, row.get::<_, String>(0)?))
690 })?
691 .collect::<Result<Vec<_>, _>>()?;
692 Ok(results)
693 })
694 .await
695 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
696
697 debug!("Found {} matching document IDs", results.len());
698 Ok(results)
699 }
700}
701
702fn serialize_embedding(embedding: &Embedding) -> Vec<f32> {
703 embedding.vec.iter().map(|x| *x as f32).collect()
704}
705
706impl ColumnValue for String {
707 fn to_sql_string(&self) -> String {
708 self.clone()
709 }
710
711 fn column_type(&self) -> &'static str {
712 "TEXT"
713 }
714}