1use rig::OneOrMany;
2use rig::embeddings::{Embedding, EmbeddingModel};
3use rig::vector_store::request::{FilterError, SearchFilter, VectorSearchRequest};
4use rig::vector_store::{VectorStoreError, VectorStoreIndex};
5use rusqlite::types::Value;
6use serde::Deserialize;
7use std::marker::PhantomData;
8use tokio_rusqlite::Connection;
9use tracing::{debug, info};
10use zerocopy::IntoBytes;
11
12#[derive(Debug)]
13pub enum SqliteError {
14 DatabaseError(Box<dyn std::error::Error + Send + Sync>),
15 SerializationError(Box<dyn std::error::Error + Send + Sync>),
16 InvalidColumnType(String),
17}
18
19pub trait ColumnValue: Send + Sync {
20 fn to_sql_string(&self) -> String;
21 fn column_type(&self) -> &'static str;
22}
23
24pub struct Column {
25 name: &'static str,
26 col_type: &'static str,
27 indexed: bool,
28}
29
30impl Column {
31 pub fn new(name: &'static str, col_type: &'static str) -> Self {
32 Self {
33 name,
34 col_type,
35 indexed: false,
36 }
37 }
38
39 pub fn indexed(mut self) -> Self {
40 self.indexed = true;
41 self
42 }
43}
44
45pub trait SqliteVectorStoreTable: Send + Sync + Clone {
83 fn name() -> &'static str;
84 fn schema() -> Vec<Column>;
85 fn id(&self) -> String;
86 fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)>;
87}
88
89#[derive(Clone)]
90pub struct SqliteVectorStore<E, T>
91where
92 E: EmbeddingModel + 'static,
93 T: SqliteVectorStoreTable + 'static,
94{
95 conn: Connection,
96 _phantom: PhantomData<(E, T)>,
97}
98
99impl<E, T> SqliteVectorStore<E, T>
100where
101 E: EmbeddingModel + 'static,
102 T: SqliteVectorStoreTable + 'static,
103{
104 pub async fn new(conn: Connection, embedding_model: &E) -> Result<Self, VectorStoreError> {
105 let dims = embedding_model.ndims();
106 let table_name = T::name();
107 let schema = T::schema();
108
109 let mut create_table = format!("CREATE TABLE IF NOT EXISTS {table_name} (");
111
112 let mut first = true;
114 for column in &schema {
115 if !first {
116 create_table.push(',');
117 }
118 create_table.push_str(&format!("\n {} {}", column.name, column.col_type));
119 first = false;
120 }
121
122 create_table.push_str("\n)");
123
124 let mut create_indexes = vec![format!(
126 "CREATE INDEX IF NOT EXISTS idx_{}_id ON {}(id)",
127 table_name, table_name
128 )];
129
130 for column in schema {
132 if column.indexed {
133 create_indexes.push(format!(
134 "CREATE INDEX IF NOT EXISTS idx_{}_{} ON {}({})",
135 table_name, column.name, table_name, column.name
136 ));
137 }
138 }
139
140 conn.call(move |conn| {
141 conn.execute_batch("BEGIN")?;
142
143 conn.execute_batch(&create_table)?;
145
146 for index_stmt in create_indexes {
148 conn.execute_batch(&index_stmt)?;
149 }
150
151 conn.execute_batch(&format!(
153 "CREATE VIRTUAL TABLE IF NOT EXISTS {table_name}_embeddings USING vec0(embedding float[{dims}])"
154 ))?;
155
156 conn.execute_batch("COMMIT")?;
157 Ok(())
158 })
159 .await
160 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
161
162 Ok(Self {
163 conn,
164 _phantom: PhantomData,
165 })
166 }
167
168 pub fn index(self, model: E) -> SqliteVectorIndex<E, T> {
169 SqliteVectorIndex::new(model, self)
170 }
171
172 pub fn add_rows_with_txn(
173 &self,
174 txn: &rusqlite::Transaction<'_>,
175 documents: Vec<(T, OneOrMany<Embedding>)>,
176 ) -> Result<i64, tokio_rusqlite::Error> {
177 info!("Adding {} documents to store", documents.len());
178 let table_name = T::name();
179 let mut last_id = 0;
180
181 for (doc, embeddings) in &documents {
182 debug!("Storing document with id {}", doc.id());
183
184 let values = doc.column_values();
185 let columns = values.iter().map(|(col, _)| *col).collect::<Vec<_>>();
186
187 let placeholders = (1..=values.len())
188 .map(|i| format!("?{i}"))
189 .collect::<Vec<_>>();
190
191 let insert_sql = format!(
192 "INSERT OR REPLACE INTO {} ({}) VALUES ({})",
193 table_name,
194 columns.join(", "),
195 placeholders.join(", ")
196 );
197
198 txn.execute(
199 &insert_sql,
200 rusqlite::params_from_iter(values.iter().map(|(_, val)| val.to_sql_string())),
201 )?;
202 last_id = txn.last_insert_rowid();
203
204 let embeddings_sql =
205 format!("INSERT INTO {table_name}_embeddings (rowid, embedding) VALUES (?1, ?2)");
206
207 let mut stmt = txn.prepare(&embeddings_sql)?;
208 for (i, embedding) in embeddings.iter().enumerate() {
209 let vec = serialize_embedding(embedding);
210 debug!(
211 "Storing embedding {} of {} (size: {} bytes)",
212 i + 1,
213 embeddings.len(),
214 vec.len() * 4
215 );
216 let blob = rusqlite::types::Value::Blob(vec.as_bytes().to_vec());
217 stmt.execute(rusqlite::params![last_id, blob])?;
218 }
219 }
220
221 Ok(last_id)
222 }
223
224 pub async fn add_rows(
225 &self,
226 documents: Vec<(T, OneOrMany<Embedding>)>,
227 ) -> Result<i64, VectorStoreError> {
228 let documents = documents.clone();
229 let this = self.clone();
230
231 self.conn
232 .call(move |conn| {
233 let tx = conn.transaction().map_err(tokio_rusqlite::Error::from)?;
234 let result = this.add_rows_with_txn(&tx, documents)?;
235 tx.commit().map_err(tokio_rusqlite::Error::from)?;
236 Ok(result)
237 })
238 .await
239 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))
240 }
241}
242
243#[derive(Clone)]
244pub struct SqliteSearchFilter {
245 condition: String,
246 params: Vec<serde_json::Value>,
247}
248
249impl SearchFilter for SqliteSearchFilter {
250 type Value = serde_json::Value;
251
252 fn eq(key: String, value: Self::Value) -> Self {
253 Self {
254 condition: format!("{key} = ?"),
255 params: vec![value],
256 }
257 }
258
259 fn gt(key: String, value: Self::Value) -> Self {
260 Self {
261 condition: format!("{key} > ?"),
262 params: vec![value],
263 }
264 }
265
266 fn lt(key: String, value: Self::Value) -> Self {
267 Self {
268 condition: format!("{key} < ?"),
269 params: vec![value],
270 }
271 }
272
273 fn and(self, rhs: Self) -> Self {
274 Self {
275 condition: format!("({}) AND ({})", self.condition, rhs.condition),
276 params: self.params.into_iter().chain(rhs.params).collect(),
277 }
278 }
279
280 fn or(self, rhs: Self) -> Self {
281 Self {
282 condition: format!("({}) OR ({})", self.condition, rhs.condition),
283 params: self.params.into_iter().chain(rhs.params).collect(),
284 }
285 }
286}
287
288impl SqliteSearchFilter {
289 #[allow(clippy::should_implement_trait)]
290 pub fn not(self) -> Self {
291 Self {
292 condition: format!("NOT ({})", self.condition),
293 ..self
294 }
295 }
296}
297
298impl SqliteSearchFilter {
299 fn compile_params(self) -> Result<Vec<Value>, FilterError> {
300 let mut params = Vec::with_capacity(self.params.len());
301
302 fn convert(value: serde_json::Value) -> Result<Value, FilterError> {
303 use serde_json::Value::*;
304
305 match value {
306 Null => Ok(Value::Null),
307 Bool(b) => Ok(Value::Integer(b as i64)),
308 String(s) => Ok(Value::Text(s)),
309 Number(n) => Ok(if let Some(float) = n.as_f64() {
310 Value::Real(float)
311 } else if let Some(int) = n.as_i64() {
312 Value::Integer(int)
313 } else {
314 unreachable!()
315 }),
316 Array(arr) => {
317 let blob = serde_json::to_vec(&arr)
318 .map_err(|e| FilterError::Serialization(e.to_string()))?;
319
320 Ok(Value::Blob(blob))
321 }
322 Object(obj) => {
323 let blob = serde_json::to_vec(&obj)
324 .map_err(|e| FilterError::Serialization(e.to_string()))?;
325
326 Ok(Value::Blob(blob))
327 }
328 }
329 }
330
331 for param in self.params.into_iter() {
332 params.push(convert(param)?)
333 }
334
335 Ok(params)
336 }
337}
338
339pub struct SqliteVectorIndex<E, T>
422where
423 E: EmbeddingModel + 'static,
424 T: SqliteVectorStoreTable + 'static,
425{
426 store: SqliteVectorStore<E, T>,
427 embedding_model: E,
428}
429
430impl<E, T> SqliteVectorIndex<E, T>
431where
432 E: EmbeddingModel + 'static,
433 T: SqliteVectorStoreTable,
434{
435 pub fn new(embedding_model: E, store: SqliteVectorStore<E, T>) -> Self {
436 Self {
437 store,
438 embedding_model,
439 }
440 }
441}
442
443fn build_where_clause(
444 req: &VectorSearchRequest<SqliteSearchFilter>,
445 query_vec: Vec<f32>,
446) -> Result<(String, Vec<Value>), FilterError> {
447 let thresh = req.threshold().unwrap_or(0.);
448 let thresh = SqliteSearchFilter::gt("e.distance".into(), thresh.into());
449
450 let filter = req
451 .filter()
452 .as_ref()
453 .cloned()
454 .map(|filter| thresh.clone().and(filter))
455 .unwrap_or(thresh);
456
457 let where_clause = format!(
458 "WHERE e.embedding MATCH ? AND k = ? AND {}",
459 filter.condition
460 );
461
462 let query_vec = query_vec.into_iter().flat_map(f32::to_le_bytes).collect();
463 let query_vec = Value::Blob(query_vec);
464 let samples = req.samples() as u32;
465
466 let mut params = vec![query_vec, samples.into()];
467 let filter_params = filter.clone().compile_params()?;
468 params.extend(filter_params);
469
470 Ok((where_clause, params))
471}
472
473impl<E: EmbeddingModel + std::marker::Sync, T: SqliteVectorStoreTable> VectorStoreIndex
474 for SqliteVectorIndex<E, T>
475{
476 type Filter = SqliteSearchFilter;
477
478 async fn top_n<D>(
479 &self,
480 req: VectorSearchRequest<SqliteSearchFilter>,
481 ) -> Result<Vec<(f64, String, D)>, VectorStoreError>
482 where
483 D: for<'de> Deserialize<'de>,
484 {
485 tracing::debug!("Finding top {} matches for query", req.samples() as usize);
486 let embedding = self.embedding_model.embed_text(req.query()).await?;
487 let query_vec: Vec<f32> = serialize_embedding(&embedding);
488 let table_name = T::name();
489
490 let columns = T::schema();
492 let column_names: Vec<&str> = columns.iter().map(|column| column.name).collect();
493
494 let select_cols = column_names.join(", ");
496
497 let (where_clause, params) = build_where_clause(&req, query_vec)?;
498
499 let rows = self
500 .store
501 .conn
502 .call(move |conn| {
503 let mut stmt = conn.prepare(&format!(
504 "SELECT d.{select_cols}, e.distance
505 FROM {table_name}_embeddings e
506 JOIN {table_name} d ON e.rowid = d.rowid
507 {where_clause}
508 ORDER BY e.distance"
509 ))?;
510
511 dbg!(&stmt);
512
513 let rows = stmt
514 .query_map(rusqlite::params_from_iter(params), |row| {
515 let mut map = serde_json::Map::new();
517 for (i, col_name) in column_names.iter().enumerate() {
518 let value: String = row.get(i)?;
519 map.insert(col_name.to_string(), serde_json::Value::String(value));
520 }
521 let distance: f64 = row.get(column_names.len())?;
522 let id: String = row.get(0)?; Ok((id, serde_json::Value::Object(map), distance))
525 })?
526 .collect::<Result<Vec<_>, _>>()?;
527 Ok(rows)
528 })
529 .await
530 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
531
532 debug!("Found {} potential matches", rows.len());
533 let mut top_n = Vec::new();
534 for (id, doc_value, distance) in rows {
535 match serde_json::from_value::<D>(doc_value) {
536 Ok(doc) => {
537 top_n.push((distance, id, doc));
538 }
539 Err(e) => {
540 debug!("Failed to deserialize document {}: {}", id, e);
541 continue;
542 }
543 }
544 }
545
546 debug!("Returning {} matches", top_n.len());
547 Ok(top_n)
548 }
549
550 async fn top_n_ids(
551 &self,
552 req: VectorSearchRequest<SqliteSearchFilter>,
553 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
554 tracing::debug!(
555 "Finding top {} document IDs for query",
556 req.samples() as usize
557 );
558 let embedding = self.embedding_model.embed_text(req.query()).await?;
559 let query_vec = serialize_embedding(&embedding);
560 let table_name = T::name();
561
562 let (where_clause, params) = build_where_clause(&req, query_vec)?;
563
564 let results = self
565 .store
566 .conn
567 .call(move |conn| {
568 let mut stmt = conn.prepare(&format!(
569 "SELECT d.id, e.distance
570 FROM {table_name}_embeddings e
571 JOIN {table_name} d ON e.rowid = d.rowid
572 {where_clause}
573 ORDER BY e.distance"
574 ))?;
575
576 dbg!(&stmt);
577
578 let results = stmt
579 .query_map(rusqlite::params_from_iter(params), |row| {
580 Ok((row.get::<_, f64>(1)?, row.get::<_, String>(0)?))
581 })?
582 .collect::<Result<Vec<_>, _>>()?;
583 Ok(results)
584 })
585 .await
586 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
587
588 debug!("Found {} matching document IDs", results.len());
589 Ok(results)
590 }
591}
592
593fn serialize_embedding(embedding: &Embedding) -> Vec<f32> {
594 embedding.vec.iter().map(|x| *x as f32).collect()
595}
596
597impl ColumnValue for String {
598 fn to_sql_string(&self) -> String {
599 self.clone()
600 }
601
602 fn column_type(&self) -> &'static str {
603 "TEXT"
604 }
605}