1use rig::OneOrMany;
2use rig::embeddings::{Embedding, EmbeddingModel};
3use rig::vector_store::{VectorStoreError, VectorStoreIndex};
4use serde::Deserialize;
5use std::marker::PhantomData;
6use tokio_rusqlite::Connection;
7use tracing::{debug, info};
8use zerocopy::IntoBytes;
9
10#[derive(Debug)]
11pub enum SqliteError {
12 DatabaseError(Box<dyn std::error::Error + Send + Sync>),
13 SerializationError(Box<dyn std::error::Error + Send + Sync>),
14 InvalidColumnType(String),
15}
16
17pub trait ColumnValue: Send + Sync {
18 fn to_sql_string(&self) -> String;
19 fn column_type(&self) -> &'static str;
20}
21
22pub struct Column {
23 name: &'static str,
24 col_type: &'static str,
25 indexed: bool,
26}
27
28impl Column {
29 pub fn new(name: &'static str, col_type: &'static str) -> Self {
30 Self {
31 name,
32 col_type,
33 indexed: false,
34 }
35 }
36
37 pub fn indexed(mut self) -> Self {
38 self.indexed = true;
39 self
40 }
41}
42
43pub trait SqliteVectorStoreTable: Send + Sync + Clone {
81 fn name() -> &'static str;
82 fn schema() -> Vec<Column>;
83 fn id(&self) -> String;
84 fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)>;
85}
86
87#[derive(Clone)]
88pub struct SqliteVectorStore<E: EmbeddingModel + 'static, T: SqliteVectorStoreTable + 'static> {
89 conn: Connection,
90 _phantom: PhantomData<(E, T)>,
91}
92
93impl<E: EmbeddingModel + 'static, T: SqliteVectorStoreTable + 'static> SqliteVectorStore<E, T> {
94 pub async fn new(conn: Connection, embedding_model: &E) -> Result<Self, VectorStoreError> {
95 let dims = embedding_model.ndims();
96 let table_name = T::name();
97 let schema = T::schema();
98
99 let mut create_table = format!("CREATE TABLE IF NOT EXISTS {table_name} (");
101
102 let mut first = true;
104 for column in &schema {
105 if !first {
106 create_table.push(',');
107 }
108 create_table.push_str(&format!("\n {} {}", column.name, column.col_type));
109 first = false;
110 }
111
112 create_table.push_str("\n)");
113
114 let mut create_indexes = vec![format!(
116 "CREATE INDEX IF NOT EXISTS idx_{}_id ON {}(id)",
117 table_name, table_name
118 )];
119
120 for column in schema {
122 if column.indexed {
123 create_indexes.push(format!(
124 "CREATE INDEX IF NOT EXISTS idx_{}_{} ON {}({})",
125 table_name, column.name, table_name, column.name
126 ));
127 }
128 }
129
130 conn.call(move |conn| {
131 conn.execute_batch("BEGIN")?;
132
133 conn.execute_batch(&create_table)?;
135
136 for index_stmt in create_indexes {
138 conn.execute_batch(&index_stmt)?;
139 }
140
141 conn.execute_batch(&format!(
143 "CREATE VIRTUAL TABLE IF NOT EXISTS {table_name}_embeddings USING vec0(embedding float[{dims}])"
144 ))?;
145
146 conn.execute_batch("COMMIT")?;
147 Ok(())
148 })
149 .await
150 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
151
152 Ok(Self {
153 conn,
154 _phantom: PhantomData,
155 })
156 }
157
158 pub fn index(self, model: E) -> SqliteVectorIndex<E, T> {
159 SqliteVectorIndex::new(model, self)
160 }
161
162 pub fn add_rows_with_txn(
163 &self,
164 txn: &rusqlite::Transaction<'_>,
165 documents: Vec<(T, OneOrMany<Embedding>)>,
166 ) -> Result<i64, tokio_rusqlite::Error> {
167 info!("Adding {} documents to store", documents.len());
168 let table_name = T::name();
169 let mut last_id = 0;
170
171 for (doc, embeddings) in &documents {
172 debug!("Storing document with id {}", doc.id());
173
174 let values = doc.column_values();
175 let columns = values.iter().map(|(col, _)| *col).collect::<Vec<_>>();
176
177 let placeholders = (1..=values.len())
178 .map(|i| format!("?{i}"))
179 .collect::<Vec<_>>();
180
181 let insert_sql = format!(
182 "INSERT OR REPLACE INTO {} ({}) VALUES ({})",
183 table_name,
184 columns.join(", "),
185 placeholders.join(", ")
186 );
187
188 txn.execute(
189 &insert_sql,
190 rusqlite::params_from_iter(values.iter().map(|(_, val)| val.to_sql_string())),
191 )?;
192 last_id = txn.last_insert_rowid();
193
194 let embeddings_sql =
195 format!("INSERT INTO {table_name}_embeddings (rowid, embedding) VALUES (?1, ?2)");
196
197 let mut stmt = txn.prepare(&embeddings_sql)?;
198 for (i, embedding) in embeddings.iter().enumerate() {
199 let vec = serialize_embedding(embedding);
200 debug!(
201 "Storing embedding {} of {} (size: {} bytes)",
202 i + 1,
203 embeddings.len(),
204 vec.len() * 4
205 );
206 let blob = rusqlite::types::Value::Blob(vec.as_bytes().to_vec());
207 stmt.execute(rusqlite::params![last_id, blob])?;
208 }
209 }
210
211 Ok(last_id)
212 }
213
214 pub async fn add_rows(
215 &self,
216 documents: Vec<(T, OneOrMany<Embedding>)>,
217 ) -> Result<i64, VectorStoreError> {
218 let documents = documents.clone();
219 let this = self.clone();
220
221 self.conn
222 .call(move |conn| {
223 let tx = conn.transaction().map_err(tokio_rusqlite::Error::from)?;
224 let result = this.add_rows_with_txn(&tx, documents)?;
225 tx.commit().map_err(tokio_rusqlite::Error::from)?;
226 Ok(result)
227 })
228 .await
229 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))
230 }
231}
232
233pub struct SqliteVectorIndex<E: EmbeddingModel + 'static, T: SqliteVectorStoreTable + 'static> {
316 store: SqliteVectorStore<E, T>,
317 embedding_model: E,
318}
319
320impl<E: EmbeddingModel + 'static, T: SqliteVectorStoreTable> SqliteVectorIndex<E, T> {
321 pub fn new(embedding_model: E, store: SqliteVectorStore<E, T>) -> Self {
322 Self {
323 store,
324 embedding_model,
325 }
326 }
327}
328
329impl<E: EmbeddingModel + std::marker::Sync, T: SqliteVectorStoreTable> VectorStoreIndex
330 for SqliteVectorIndex<E, T>
331{
332 async fn top_n<D: for<'a> Deserialize<'a>>(
333 &self,
334 query: &str,
335 n: usize,
336 ) -> Result<Vec<(f64, String, D)>, VectorStoreError> {
337 debug!("Finding top {} matches for query", n);
338 let embedding = self.embedding_model.embed_text(query).await?;
339 let query_vec: Vec<f32> = serialize_embedding(&embedding);
340 let table_name = T::name();
341
342 let columns = T::schema();
344 let column_names: Vec<&str> = columns.iter().map(|column| column.name).collect();
345
346 let rows = self
347 .store
348 .conn
349 .call(move |conn| {
350 let select_cols = column_names.join(", ");
352 let mut stmt = conn.prepare(&format!(
353 "SELECT d.{select_cols}, e.distance
354 FROM {table_name}_embeddings e
355 JOIN {table_name} d ON e.rowid = d.rowid
356 WHERE e.embedding MATCH ?1 AND k = ?2
357 ORDER BY e.distance"
358 ))?;
359
360 let rows = stmt
361 .query_map(rusqlite::params![query_vec.as_bytes().to_vec(), n], |row| {
362 let mut map = serde_json::Map::new();
364 for (i, col_name) in column_names.iter().enumerate() {
365 let value: String = row.get(i)?;
366 map.insert(col_name.to_string(), serde_json::Value::String(value));
367 }
368 let distance: f64 = row.get(column_names.len())?;
369 let id: String = row.get(0)?; Ok((id, serde_json::Value::Object(map), distance))
372 })?
373 .collect::<Result<Vec<_>, _>>()?;
374 Ok(rows)
375 })
376 .await
377 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
378
379 debug!("Found {} potential matches", rows.len());
380 let mut top_n = Vec::new();
381 for (id, doc_value, distance) in rows {
382 match serde_json::from_value::<D>(doc_value) {
383 Ok(doc) => {
384 top_n.push((distance, id, doc));
385 }
386 Err(e) => {
387 debug!("Failed to deserialize document {}: {}", id, e);
388 continue;
389 }
390 }
391 }
392
393 debug!("Returning {} matches", top_n.len());
394 Ok(top_n)
395 }
396
397 async fn top_n_ids(
398 &self,
399 query: &str,
400 n: usize,
401 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
402 debug!("Finding top {} document IDs for query", n);
403 let embedding = self.embedding_model.embed_text(query).await?;
404 let query_vec = serialize_embedding(&embedding);
405 let table_name = T::name();
406
407 let results = self
408 .store
409 .conn
410 .call(move |conn| {
411 let mut stmt = conn.prepare(&format!(
412 "SELECT d.id, e.distance
413 FROM {table_name}_embeddings e
414 JOIN {table_name} d ON e.rowid = d.rowid
415 WHERE e.embedding MATCH ?1 AND k = ?2
416 ORDER BY e.distance"
417 ))?;
418
419 let results = stmt
420 .query_map(
421 rusqlite::params![
422 query_vec
423 .iter()
424 .flat_map(|x| x.to_le_bytes())
425 .collect::<Vec<u8>>(),
426 n
427 ],
428 |row| Ok((row.get::<_, f64>(1)?, row.get::<_, String>(0)?)),
429 )?
430 .collect::<Result<Vec<_>, _>>()?;
431 Ok(results)
432 })
433 .await
434 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
435
436 debug!("Found {} matching document IDs", results.len());
437 Ok(results)
438 }
439}
440
441fn serialize_embedding(embedding: &Embedding) -> Vec<f32> {
442 embedding.vec.iter().map(|x| *x as f32).collect()
443}
444
445impl ColumnValue for String {
446 fn to_sql_string(&self) -> String {
447 self.clone()
448 }
449
450 fn column_type(&self) -> &'static str {
451 "TEXT"
452 }
453}