1use rig::embeddings::{Embedding, EmbeddingModel};
2use rig::vector_store::{VectorStoreError, VectorStoreIndex};
3use rig::OneOrMany;
4use serde::Deserialize;
5use std::marker::PhantomData;
6use tokio_rusqlite::Connection;
7use tracing::{debug, info};
8use zerocopy::IntoBytes;
9
10#[derive(Debug)]
11pub enum SqliteError {
12 DatabaseError(Box<dyn std::error::Error + Send + Sync>),
13 SerializationError(Box<dyn std::error::Error + Send + Sync>),
14 InvalidColumnType(String),
15}
16
17pub trait ColumnValue: Send + Sync {
18 fn to_sql_string(&self) -> String;
19 fn column_type(&self) -> &'static str;
20}
21
22pub struct Column {
23 name: &'static str,
24 col_type: &'static str,
25 indexed: bool,
26}
27
28impl Column {
29 pub fn new(name: &'static str, col_type: &'static str) -> Self {
30 Self {
31 name,
32 col_type,
33 indexed: false,
34 }
35 }
36
37 pub fn indexed(mut self) -> Self {
38 self.indexed = true;
39 self
40 }
41}
42
43pub trait SqliteVectorStoreTable: Send + Sync + Clone {
81 fn name() -> &'static str;
82 fn schema() -> Vec<Column>;
83 fn id(&self) -> String;
84 fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)>;
85}
86
87#[derive(Clone)]
88pub struct SqliteVectorStore<E: EmbeddingModel + 'static, T: SqliteVectorStoreTable + 'static> {
89 conn: Connection,
90 _phantom: PhantomData<(E, T)>,
91}
92
93impl<E: EmbeddingModel + 'static, T: SqliteVectorStoreTable + 'static> SqliteVectorStore<E, T> {
94 pub async fn new(conn: Connection, embedding_model: &E) -> Result<Self, VectorStoreError> {
95 let dims = embedding_model.ndims();
96 let table_name = T::name();
97 let schema = T::schema();
98
99 let mut create_table = format!("CREATE TABLE IF NOT EXISTS {} (", table_name);
101
102 let mut first = true;
104 for column in &schema {
105 if !first {
106 create_table.push(',');
107 }
108 create_table.push_str(&format!("\n {} {}", column.name, column.col_type));
109 first = false;
110 }
111
112 create_table.push_str("\n)");
113
114 let mut create_indexes = vec![format!(
116 "CREATE INDEX IF NOT EXISTS idx_{}_id ON {}(id)",
117 table_name, table_name
118 )];
119
120 for column in schema {
122 if column.indexed {
123 create_indexes.push(format!(
124 "CREATE INDEX IF NOT EXISTS idx_{}_{} ON {}({})",
125 table_name, column.name, table_name, column.name
126 ));
127 }
128 }
129
130 conn.call(move |conn| {
131 conn.execute_batch("BEGIN")?;
132
133 conn.execute_batch(&create_table)?;
135
136 for index_stmt in create_indexes {
138 conn.execute_batch(&index_stmt)?;
139 }
140
141 conn.execute_batch(&format!(
143 "CREATE VIRTUAL TABLE IF NOT EXISTS {}_embeddings USING vec0(embedding float[{}])",
144 table_name, dims
145 ))?;
146
147 conn.execute_batch("COMMIT")?;
148 Ok(())
149 })
150 .await
151 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
152
153 Ok(Self {
154 conn,
155 _phantom: PhantomData,
156 })
157 }
158
159 pub fn index(self, model: E) -> SqliteVectorIndex<E, T> {
160 SqliteVectorIndex::new(model, self)
161 }
162
163 pub fn add_rows_with_txn(
164 &self,
165 txn: &rusqlite::Transaction<'_>,
166 documents: Vec<(T, OneOrMany<Embedding>)>,
167 ) -> Result<i64, tokio_rusqlite::Error> {
168 info!("Adding {} documents to store", documents.len());
169 let table_name = T::name();
170 let mut last_id = 0;
171
172 for (doc, embeddings) in &documents {
173 debug!("Storing document with id {}", doc.id());
174
175 let values = doc.column_values();
176 let columns = values.iter().map(|(col, _)| *col).collect::<Vec<_>>();
177
178 let placeholders = (1..=values.len())
179 .map(|i| format!("?{}", i))
180 .collect::<Vec<_>>();
181
182 let insert_sql = format!(
183 "INSERT OR REPLACE INTO {} ({}) VALUES ({})",
184 table_name,
185 columns.join(", "),
186 placeholders.join(", ")
187 );
188
189 txn.execute(
190 &insert_sql,
191 rusqlite::params_from_iter(values.iter().map(|(_, val)| val.to_sql_string())),
192 )?;
193 last_id = txn.last_insert_rowid();
194
195 let embeddings_sql = format!(
196 "INSERT INTO {}_embeddings (rowid, embedding) VALUES (?1, ?2)",
197 table_name
198 );
199
200 let mut stmt = txn.prepare(&embeddings_sql)?;
201 for (i, embedding) in embeddings.iter().enumerate() {
202 let vec = serialize_embedding(embedding);
203 debug!(
204 "Storing embedding {} of {} (size: {} bytes)",
205 i + 1,
206 embeddings.len(),
207 vec.len() * 4
208 );
209 let blob = rusqlite::types::Value::Blob(vec.as_bytes().to_vec());
210 stmt.execute(rusqlite::params![last_id, blob])?;
211 }
212 }
213
214 Ok(last_id)
215 }
216
217 pub async fn add_rows(
218 &self,
219 documents: Vec<(T, OneOrMany<Embedding>)>,
220 ) -> Result<i64, VectorStoreError> {
221 let documents = documents.clone();
222 let this = self.clone();
223
224 self.conn
225 .call(move |conn| {
226 let tx = conn.transaction().map_err(tokio_rusqlite::Error::from)?;
227 let result = this.add_rows_with_txn(&tx, documents)?;
228 tx.commit().map_err(tokio_rusqlite::Error::from)?;
229 Ok(result)
230 })
231 .await
232 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))
233 }
234}
235
236pub struct SqliteVectorIndex<E: EmbeddingModel + 'static, T: SqliteVectorStoreTable + 'static> {
319 store: SqliteVectorStore<E, T>,
320 embedding_model: E,
321}
322
323impl<E: EmbeddingModel + 'static, T: SqliteVectorStoreTable> SqliteVectorIndex<E, T> {
324 pub fn new(embedding_model: E, store: SqliteVectorStore<E, T>) -> Self {
325 Self {
326 store,
327 embedding_model,
328 }
329 }
330}
331
332impl<E: EmbeddingModel + std::marker::Sync, T: SqliteVectorStoreTable> VectorStoreIndex
333 for SqliteVectorIndex<E, T>
334{
335 async fn top_n<D: for<'a> Deserialize<'a>>(
336 &self,
337 query: &str,
338 n: usize,
339 ) -> Result<Vec<(f64, String, D)>, VectorStoreError> {
340 debug!("Finding top {} matches for query", n);
341 let embedding = self.embedding_model.embed_text(query).await?;
342 let query_vec: Vec<f32> = serialize_embedding(&embedding);
343 let table_name = T::name();
344
345 let columns = T::schema();
347 let column_names: Vec<&str> = columns.iter().map(|column| column.name).collect();
348
349 let rows = self
350 .store
351 .conn
352 .call(move |conn| {
353 let select_cols = column_names.join(", ");
355 let mut stmt = conn.prepare(&format!(
356 "SELECT d.{}, e.distance
357 FROM {}_embeddings e
358 JOIN {} d ON e.rowid = d.rowid
359 WHERE e.embedding MATCH ?1 AND k = ?2
360 ORDER BY e.distance",
361 select_cols, table_name, table_name
362 ))?;
363
364 let rows = stmt
365 .query_map(rusqlite::params![query_vec.as_bytes().to_vec(), n], |row| {
366 let mut map = serde_json::Map::new();
368 for (i, col_name) in column_names.iter().enumerate() {
369 let value: String = row.get(i)?;
370 map.insert(col_name.to_string(), serde_json::Value::String(value));
371 }
372 let distance: f64 = row.get(column_names.len())?;
373 let id: String = row.get(0)?; Ok((id, serde_json::Value::Object(map), distance))
376 })?
377 .collect::<Result<Vec<_>, _>>()?;
378 Ok(rows)
379 })
380 .await
381 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
382
383 debug!("Found {} potential matches", rows.len());
384 let mut top_n = Vec::new();
385 for (id, doc_value, distance) in rows {
386 match serde_json::from_value::<D>(doc_value) {
387 Ok(doc) => {
388 top_n.push((distance, id, doc));
389 }
390 Err(e) => {
391 debug!("Failed to deserialize document {}: {}", id, e);
392 continue;
393 }
394 }
395 }
396
397 debug!("Returning {} matches", top_n.len());
398 Ok(top_n)
399 }
400
401 async fn top_n_ids(
402 &self,
403 query: &str,
404 n: usize,
405 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
406 debug!("Finding top {} document IDs for query", n);
407 let embedding = self.embedding_model.embed_text(query).await?;
408 let query_vec = serialize_embedding(&embedding);
409 let table_name = T::name();
410
411 let results = self
412 .store
413 .conn
414 .call(move |conn| {
415 let mut stmt = conn.prepare(&format!(
416 "SELECT d.id, e.distance
417 FROM {0}_embeddings e
418 JOIN {0} d ON e.rowid = d.rowid
419 WHERE e.embedding MATCH ?1 AND k = ?2
420 ORDER BY e.distance",
421 table_name
422 ))?;
423
424 let results = stmt
425 .query_map(
426 rusqlite::params![
427 query_vec
428 .iter()
429 .flat_map(|x| x.to_le_bytes())
430 .collect::<Vec<u8>>(),
431 n
432 ],
433 |row| Ok((row.get::<_, f64>(1)?, row.get::<_, String>(0)?)),
434 )?
435 .collect::<Result<Vec<_>, _>>()?;
436 Ok(results)
437 })
438 .await
439 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
440
441 debug!("Found {} matching document IDs", results.len());
442 Ok(results)
443 }
444}
445
446fn serialize_embedding(embedding: &Embedding) -> Vec<f32> {
447 embedding.vec.iter().map(|x| *x as f32).collect()
448}
449
450impl ColumnValue for String {
451 fn to_sql_string(&self) -> String {
452 self.clone()
453 }
454
455 fn column_type(&self) -> &'static str {
456 "TEXT"
457 }
458}