1use rig::OneOrMany;
2use rig::embeddings::{Embedding, EmbeddingModel};
3use rig::vector_store::request::{FilterError, SearchFilter, VectorSearchRequest};
4use rig::vector_store::{VectorStoreError, VectorStoreIndex};
5use rusqlite::types::Value;
6use serde::Deserialize;
7use std::marker::PhantomData;
8use std::ops::RangeInclusive;
9use tokio_rusqlite::Connection;
10use tracing::{debug, info};
11use zerocopy::IntoBytes;
12
13#[derive(Debug)]
14pub enum SqliteError {
15 DatabaseError(Box<dyn std::error::Error + Send + Sync>),
16 SerializationError(Box<dyn std::error::Error + Send + Sync>),
17 InvalidColumnType(String),
18}
19
20pub trait ColumnValue: Send + Sync {
21 fn to_sql_string(&self) -> String;
22 fn column_type(&self) -> &'static str;
23}
24
25pub struct Column {
26 name: &'static str,
27 col_type: &'static str,
28 indexed: bool,
29}
30
31impl Column {
32 pub fn new(name: &'static str, col_type: &'static str) -> Self {
33 Self {
34 name,
35 col_type,
36 indexed: false,
37 }
38 }
39
40 pub fn indexed(mut self) -> Self {
41 self.indexed = true;
42 self
43 }
44}
45
46pub trait SqliteVectorStoreTable: Send + Sync + Clone {
84 fn name() -> &'static str;
85 fn schema() -> Vec<Column>;
86 fn id(&self) -> String;
87 fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)>;
88}
89
90#[derive(Clone)]
91pub struct SqliteVectorStore<E, T>
92where
93 E: EmbeddingModel + 'static,
94 T: SqliteVectorStoreTable + 'static,
95{
96 conn: Connection,
97 _phantom: PhantomData<(E, T)>,
98}
99
100impl<E, T> SqliteVectorStore<E, T>
101where
102 E: EmbeddingModel + Clone + 'static,
103 T: SqliteVectorStoreTable + 'static,
104{
105 pub async fn new(conn: Connection, embedding_model: &E) -> Result<Self, VectorStoreError> {
106 let dims = embedding_model.ndims();
107 let table_name = T::name();
108 let schema = T::schema();
109
110 let mut create_table = format!("CREATE TABLE IF NOT EXISTS {table_name} (");
112
113 let mut first = true;
115 for column in &schema {
116 if !first {
117 create_table.push(',');
118 }
119 create_table.push_str(&format!("\n {} {}", column.name, column.col_type));
120 first = false;
121 }
122
123 create_table.push_str("\n)");
124
125 let mut create_indexes = vec![format!(
127 "CREATE INDEX IF NOT EXISTS idx_{}_id ON {}(id)",
128 table_name, table_name
129 )];
130
131 for column in schema {
133 if column.indexed {
134 create_indexes.push(format!(
135 "CREATE INDEX IF NOT EXISTS idx_{}_{} ON {}({})",
136 table_name, column.name, table_name, column.name
137 ));
138 }
139 }
140
141 conn.call(move |conn| {
142 conn.execute_batch("BEGIN")?;
143
144 conn.execute_batch(&create_table)?;
146
147 for index_stmt in create_indexes {
149 conn.execute_batch(&index_stmt)?;
150 }
151
152 conn.execute_batch(&format!(
154 "CREATE VIRTUAL TABLE IF NOT EXISTS {table_name}_embeddings USING vec0(embedding float[{dims}])"
155 ))?;
156
157 conn.execute_batch("COMMIT")?;
158 Ok(())
159 })
160 .await
161 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
162
163 Ok(Self {
164 conn,
165 _phantom: PhantomData,
166 })
167 }
168
169 pub fn index(self, model: E) -> SqliteVectorIndex<E, T> {
170 SqliteVectorIndex::new(model, self)
171 }
172
173 pub fn add_rows_with_txn(
174 &self,
175 txn: &rusqlite::Transaction<'_>,
176 documents: Vec<(T, OneOrMany<Embedding>)>,
177 ) -> Result<i64, tokio_rusqlite::Error> {
178 info!("Adding {} documents to store", documents.len());
179 let table_name = T::name();
180 let mut last_id = 0;
181
182 for (doc, embeddings) in &documents {
183 debug!("Storing document with id {}", doc.id());
184
185 let values = doc.column_values();
186 let columns = values.iter().map(|(col, _)| *col).collect::<Vec<_>>();
187
188 let placeholders = (1..=values.len())
189 .map(|i| format!("?{i}"))
190 .collect::<Vec<_>>();
191
192 let insert_sql = format!(
193 "INSERT OR REPLACE INTO {} ({}) VALUES ({})",
194 table_name,
195 columns.join(", "),
196 placeholders.join(", ")
197 );
198
199 txn.execute(
200 &insert_sql,
201 rusqlite::params_from_iter(values.iter().map(|(_, val)| val.to_sql_string())),
202 )?;
203 last_id = txn.last_insert_rowid();
204
205 let embeddings_sql =
206 format!("INSERT INTO {table_name}_embeddings (rowid, embedding) VALUES (?1, ?2)");
207
208 let mut stmt = txn.prepare(&embeddings_sql)?;
209 for (i, embedding) in embeddings.iter().enumerate() {
210 let vec = serialize_embedding(embedding);
211 debug!(
212 "Storing embedding {} of {} (size: {} bytes)",
213 i + 1,
214 embeddings.len(),
215 vec.len() * 4
216 );
217 let blob = rusqlite::types::Value::Blob(vec.as_bytes().to_vec());
218 stmt.execute(rusqlite::params![last_id, blob])?;
219 }
220 }
221
222 Ok(last_id)
223 }
224
225 pub async fn add_rows(
226 &self,
227 documents: Vec<(T, OneOrMany<Embedding>)>,
228 ) -> Result<i64, VectorStoreError>
229 where
230 T: 'static,
231 Self: 'static,
232 {
233 let cloned = self.clone();
234
235 self.conn
236 .call(move |conn| {
237 let tx = conn.transaction()?;
238 let result = cloned.add_rows_with_txn(&tx, documents)?;
239 tx.commit()?;
240
241 Ok(result)
242 })
243 .await
244 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))
245 }
246}
247
248#[derive(Clone, Default)]
249pub struct SqliteSearchFilter {
250 condition: String,
251 params: Vec<serde_json::Value>,
252}
253
254impl SearchFilter for SqliteSearchFilter {
255 type Value = serde_json::Value;
256
257 fn eq(key: String, value: Self::Value) -> Self {
258 Self {
259 condition: format!("{key} = ?"),
260 params: vec![value],
261 }
262 }
263
264 fn gt(key: String, value: Self::Value) -> Self {
265 Self {
266 condition: format!("{key} > ?"),
267 params: vec![value],
268 }
269 }
270
271 fn lt(key: String, value: Self::Value) -> Self {
272 Self {
273 condition: format!("{key} < ?"),
274 params: vec![value],
275 }
276 }
277
278 fn and(self, rhs: Self) -> Self {
279 Self {
280 condition: format!("({}) AND ({})", self.condition, rhs.condition),
281 params: self.params.into_iter().chain(rhs.params).collect(),
282 }
283 }
284
285 fn or(self, rhs: Self) -> Self {
286 Self {
287 condition: format!("({}) OR ({})", self.condition, rhs.condition),
288 params: self.params.into_iter().chain(rhs.params).collect(),
289 }
290 }
291}
292
293impl SqliteSearchFilter {
294 #[allow(clippy::should_implement_trait)]
295 pub fn not(self) -> Self {
296 Self {
297 condition: format!("NOT ({})", self.condition),
298 ..self
299 }
300 }
301
302 pub fn between<N>(key: String, range: RangeInclusive<N>) -> Self
304 where
305 N: Ord + rusqlite::ToSql + std::fmt::Display,
306 {
307 let lo = range.start();
308 let hi = range.end();
309
310 Self {
311 condition: format!("{key} between {lo} and {hi}"),
312 ..Default::default()
313 }
314 }
315
316 pub fn is_null(key: String) -> Self {
318 Self {
319 condition: format!("{key} is null"),
320 ..Default::default()
321 }
322 }
323
324 pub fn is_not_null(key: String) -> Self {
325 Self {
326 condition: format!("{key} is not null"),
327 ..Default::default()
328 }
329 }
330
331 pub fn glob<'a, S>(key: String, pattern: S) -> Self
335 where
336 S: AsRef<&'a str>,
337 {
338 Self {
339 condition: format!("{key} glob {}", pattern.as_ref()),
340 ..Default::default()
341 }
342 }
343
344 pub fn like<'a, S>(key: String, pattern: S) -> Self
347 where
348 S: AsRef<&'a str>,
349 {
350 Self {
351 condition: format!("{key} like {}", pattern.as_ref()),
352 ..Default::default()
353 }
354 }
355}
356
357impl SqliteSearchFilter {
358 fn compile_params(self) -> Result<Vec<Value>, FilterError> {
359 let mut params = Vec::with_capacity(self.params.len());
360
361 fn convert(value: serde_json::Value) -> Result<Value, FilterError> {
362 use serde_json::Value::*;
363
364 match value {
365 Null => Ok(Value::Null),
366 Bool(b) => Ok(Value::Integer(b as i64)),
367 String(s) => Ok(Value::Text(s)),
368 Number(n) => Ok(if let Some(float) = n.as_f64() {
369 Value::Real(float)
370 } else if let Some(int) = n.as_i64() {
371 Value::Integer(int)
372 } else {
373 unreachable!()
374 }),
375 Array(arr) => {
376 let blob = serde_json::to_vec(&arr)
377 .map_err(|e| FilterError::Serialization(e.to_string()))?;
378
379 Ok(Value::Blob(blob))
380 }
381 Object(obj) => {
382 let blob = serde_json::to_vec(&obj)
383 .map_err(|e| FilterError::Serialization(e.to_string()))?;
384
385 Ok(Value::Blob(blob))
386 }
387 }
388 }
389
390 for param in self.params.into_iter() {
391 params.push(convert(param)?)
392 }
393
394 Ok(params)
395 }
396}
397
398pub struct SqliteVectorIndex<E, T>
481where
482 E: EmbeddingModel + 'static,
483 T: SqliteVectorStoreTable + 'static,
484{
485 store: SqliteVectorStore<E, T>,
486 embedding_model: E,
487}
488
489impl<E, T> SqliteVectorIndex<E, T>
490where
491 E: EmbeddingModel + 'static,
492 T: SqliteVectorStoreTable,
493{
494 pub fn new(embedding_model: E, store: SqliteVectorStore<E, T>) -> Self {
495 Self {
496 store,
497 embedding_model,
498 }
499 }
500}
501
502fn build_where_clause(
503 req: &VectorSearchRequest<SqliteSearchFilter>,
504 query_vec: Vec<f32>,
505) -> Result<(String, Vec<Value>), FilterError> {
506 let thresh = req.threshold().unwrap_or(0.);
507 let thresh = SqliteSearchFilter::gt("e.distance".into(), thresh.into());
508
509 let filter = req
510 .filter()
511 .as_ref()
512 .cloned()
513 .map(|filter| thresh.clone().and(filter))
514 .unwrap_or(thresh);
515
516 let where_clause = format!(
517 "WHERE e.embedding MATCH ? AND k = ? AND {}",
518 filter.condition
519 );
520
521 let query_vec = query_vec.into_iter().flat_map(f32::to_le_bytes).collect();
522 let query_vec = Value::Blob(query_vec);
523 let samples = req.samples() as u32;
524
525 let mut params = vec![query_vec, samples.into()];
526 let filter_params = filter.clone().compile_params()?;
527 params.extend(filter_params);
528
529 Ok((where_clause, params))
530}
531
532impl<E: EmbeddingModel + std::marker::Sync, T: SqliteVectorStoreTable> VectorStoreIndex
533 for SqliteVectorIndex<E, T>
534{
535 type Filter = SqliteSearchFilter;
536
537 async fn top_n<D>(
538 &self,
539 req: VectorSearchRequest<SqliteSearchFilter>,
540 ) -> Result<Vec<(f64, String, D)>, VectorStoreError>
541 where
542 D: for<'de> Deserialize<'de>,
543 {
544 tracing::debug!("Finding top {} matches for query", req.samples() as usize);
545 let embedding = self.embedding_model.embed_text(req.query()).await?;
546 let query_vec: Vec<f32> = serialize_embedding(&embedding);
547 let table_name = T::name();
548
549 let columns = T::schema();
551 let column_names: Vec<&str> = columns.iter().map(|column| column.name).collect();
552
553 let select_cols = column_names.join(", ");
555
556 let (where_clause, params) = build_where_clause(&req, query_vec)?;
557
558 let rows = self
559 .store
560 .conn
561 .call(move |conn| {
562 let mut stmt = conn.prepare(&format!(
563 "SELECT d.{select_cols}, e.distance
564 FROM {table_name}_embeddings e
565 JOIN {table_name} d ON e.rowid = d.rowid
566 {where_clause}
567 ORDER BY e.distance"
568 ))?;
569
570 let rows = stmt
571 .query_map(rusqlite::params_from_iter(params), |row| {
572 let mut map = serde_json::Map::new();
574 for (i, col_name) in column_names.iter().enumerate() {
575 let value: String = row.get(i)?;
576 map.insert(col_name.to_string(), serde_json::Value::String(value));
577 }
578 let distance: f64 = row.get(column_names.len())?;
579 let id: String = row.get(0)?; Ok((id, serde_json::Value::Object(map), distance))
582 })?
583 .collect::<Result<Vec<_>, _>>()?;
584 Ok(rows)
585 })
586 .await
587 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
588
589 debug!("Found {} potential matches", rows.len());
590 let mut top_n = Vec::new();
591 for (id, doc_value, distance) in rows {
592 match serde_json::from_value::<D>(doc_value) {
593 Ok(doc) => {
594 top_n.push((distance, id, doc));
595 }
596 Err(e) => {
597 debug!("Failed to deserialize document {}: {}", id, e);
598 continue;
599 }
600 }
601 }
602
603 debug!("Returning {} matches", top_n.len());
604 Ok(top_n)
605 }
606
607 async fn top_n_ids(
608 &self,
609 req: VectorSearchRequest<SqliteSearchFilter>,
610 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
611 tracing::debug!(
612 "Finding top {} document IDs for query",
613 req.samples() as usize
614 );
615 let embedding = self.embedding_model.embed_text(req.query()).await?;
616 let query_vec = serialize_embedding(&embedding);
617 let table_name = T::name();
618
619 let (where_clause, params) = build_where_clause(&req, query_vec)?;
620
621 let results = self
622 .store
623 .conn
624 .call(move |conn| {
625 let mut stmt = conn.prepare(&format!(
626 "SELECT d.id, e.distance
627 FROM {table_name}_embeddings e
628 JOIN {table_name} d ON e.rowid = d.rowid
629 {where_clause}
630 ORDER BY e.distance"
631 ))?;
632
633 let results = stmt
634 .query_map(rusqlite::params_from_iter(params), |row| {
635 Ok((row.get::<_, f64>(1)?, row.get::<_, String>(0)?))
636 })?
637 .collect::<Result<Vec<_>, _>>()?;
638 Ok(results)
639 })
640 .await
641 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
642
643 debug!("Found {} matching document IDs", results.len());
644 Ok(results)
645 }
646}
647
648fn serialize_embedding(embedding: &Embedding) -> Vec<f32> {
649 embedding.vec.iter().map(|x| *x as f32).collect()
650}
651
652impl ColumnValue for String {
653 fn to_sql_string(&self) -> String {
654 self.clone()
655 }
656
657 fn column_type(&self) -> &'static str {
658 "TEXT"
659 }
660}