1use rig::OneOrMany;
2use rig::embeddings::{Embedding, EmbeddingModel};
3use rig::vector_store::request::VectorSearchRequest;
4use rig::vector_store::{VectorStoreError, VectorStoreIndex};
5use serde::Deserialize;
6use std::marker::PhantomData;
7use tokio_rusqlite::Connection;
8use tracing::{debug, info};
9use zerocopy::IntoBytes;
10
11#[derive(Debug)]
12pub enum SqliteError {
13 DatabaseError(Box<dyn std::error::Error + Send + Sync>),
14 SerializationError(Box<dyn std::error::Error + Send + Sync>),
15 InvalidColumnType(String),
16}
17
18pub trait ColumnValue: Send + Sync {
19 fn to_sql_string(&self) -> String;
20 fn column_type(&self) -> &'static str;
21}
22
23pub struct Column {
24 name: &'static str,
25 col_type: &'static str,
26 indexed: bool,
27}
28
29impl Column {
30 pub fn new(name: &'static str, col_type: &'static str) -> Self {
31 Self {
32 name,
33 col_type,
34 indexed: false,
35 }
36 }
37
38 pub fn indexed(mut self) -> Self {
39 self.indexed = true;
40 self
41 }
42}
43
44pub trait SqliteVectorStoreTable: Send + Sync + Clone {
82 fn name() -> &'static str;
83 fn schema() -> Vec<Column>;
84 fn id(&self) -> String;
85 fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)>;
86}
87
88#[derive(Clone)]
89pub struct SqliteVectorStore<E, T>
90where
91 E: EmbeddingModel + 'static,
92 T: SqliteVectorStoreTable + 'static,
93{
94 conn: Connection,
95 _phantom: PhantomData<(E, T)>,
96}
97
98impl<E, T> SqliteVectorStore<E, T>
99where
100 E: EmbeddingModel + 'static,
101 T: SqliteVectorStoreTable + 'static,
102{
103 pub async fn new(conn: Connection, embedding_model: &E) -> Result<Self, VectorStoreError> {
104 let dims = embedding_model.ndims();
105 let table_name = T::name();
106 let schema = T::schema();
107
108 let mut create_table = format!("CREATE TABLE IF NOT EXISTS {table_name} (");
110
111 let mut first = true;
113 for column in &schema {
114 if !first {
115 create_table.push(',');
116 }
117 create_table.push_str(&format!("\n {} {}", column.name, column.col_type));
118 first = false;
119 }
120
121 create_table.push_str("\n)");
122
123 let mut create_indexes = vec![format!(
125 "CREATE INDEX IF NOT EXISTS idx_{}_id ON {}(id)",
126 table_name, table_name
127 )];
128
129 for column in schema {
131 if column.indexed {
132 create_indexes.push(format!(
133 "CREATE INDEX IF NOT EXISTS idx_{}_{} ON {}({})",
134 table_name, column.name, table_name, column.name
135 ));
136 }
137 }
138
139 conn.call(move |conn| {
140 conn.execute_batch("BEGIN")?;
141
142 conn.execute_batch(&create_table)?;
144
145 for index_stmt in create_indexes {
147 conn.execute_batch(&index_stmt)?;
148 }
149
150 conn.execute_batch(&format!(
152 "CREATE VIRTUAL TABLE IF NOT EXISTS {table_name}_embeddings USING vec0(embedding float[{dims}])"
153 ))?;
154
155 conn.execute_batch("COMMIT")?;
156 Ok(())
157 })
158 .await
159 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
160
161 Ok(Self {
162 conn,
163 _phantom: PhantomData,
164 })
165 }
166
167 pub fn index(self, model: E) -> SqliteVectorIndex<E, T> {
168 SqliteVectorIndex::new(model, self)
169 }
170
171 pub fn add_rows_with_txn(
172 &self,
173 txn: &rusqlite::Transaction<'_>,
174 documents: Vec<(T, OneOrMany<Embedding>)>,
175 ) -> Result<i64, tokio_rusqlite::Error> {
176 info!("Adding {} documents to store", documents.len());
177 let table_name = T::name();
178 let mut last_id = 0;
179
180 for (doc, embeddings) in &documents {
181 debug!("Storing document with id {}", doc.id());
182
183 let values = doc.column_values();
184 let columns = values.iter().map(|(col, _)| *col).collect::<Vec<_>>();
185
186 let placeholders = (1..=values.len())
187 .map(|i| format!("?{i}"))
188 .collect::<Vec<_>>();
189
190 let insert_sql = format!(
191 "INSERT OR REPLACE INTO {} ({}) VALUES ({})",
192 table_name,
193 columns.join(", "),
194 placeholders.join(", ")
195 );
196
197 txn.execute(
198 &insert_sql,
199 rusqlite::params_from_iter(values.iter().map(|(_, val)| val.to_sql_string())),
200 )?;
201 last_id = txn.last_insert_rowid();
202
203 let embeddings_sql =
204 format!("INSERT INTO {table_name}_embeddings (rowid, embedding) VALUES (?1, ?2)");
205
206 let mut stmt = txn.prepare(&embeddings_sql)?;
207 for (i, embedding) in embeddings.iter().enumerate() {
208 let vec = serialize_embedding(embedding);
209 debug!(
210 "Storing embedding {} of {} (size: {} bytes)",
211 i + 1,
212 embeddings.len(),
213 vec.len() * 4
214 );
215 let blob = rusqlite::types::Value::Blob(vec.as_bytes().to_vec());
216 stmt.execute(rusqlite::params![last_id, blob])?;
217 }
218 }
219
220 Ok(last_id)
221 }
222
223 pub async fn add_rows(
224 &self,
225 documents: Vec<(T, OneOrMany<Embedding>)>,
226 ) -> Result<i64, VectorStoreError> {
227 let documents = documents.clone();
228 let this = self.clone();
229
230 self.conn
231 .call(move |conn| {
232 let tx = conn.transaction().map_err(tokio_rusqlite::Error::from)?;
233 let result = this.add_rows_with_txn(&tx, documents)?;
234 tx.commit().map_err(tokio_rusqlite::Error::from)?;
235 Ok(result)
236 })
237 .await
238 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))
239 }
240}
241
242pub struct SqliteVectorIndex<E, T>
325where
326 E: EmbeddingModel + 'static,
327 T: SqliteVectorStoreTable + 'static,
328{
329 store: SqliteVectorStore<E, T>,
330 embedding_model: E,
331}
332
333impl<E, T> SqliteVectorIndex<E, T>
334where
335 E: EmbeddingModel + 'static,
336 T: SqliteVectorStoreTable,
337{
338 pub fn new(embedding_model: E, store: SqliteVectorStore<E, T>) -> Self {
339 Self {
340 store,
341 embedding_model,
342 }
343 }
344}
345
346impl<E: EmbeddingModel + std::marker::Sync, T: SqliteVectorStoreTable> VectorStoreIndex
347 for SqliteVectorIndex<E, T>
348{
349 async fn top_n<D: for<'a> Deserialize<'a>>(
350 &self,
351 req: VectorSearchRequest,
352 ) -> Result<Vec<(f64, String, D)>, VectorStoreError> {
353 tracing::debug!("Finding top {} matches for query", req.samples() as usize);
354 let embedding = self.embedding_model.embed_text(req.query()).await?;
355 let query_vec: Vec<f32> = serialize_embedding(&embedding);
356 let table_name = T::name();
357
358 let columns = T::schema();
360 let column_names: Vec<&str> = columns.iter().map(|column| column.name).collect();
361
362 let rows = self
363 .store
364 .conn
365 .call(move |conn| {
366 let select_cols = column_names.join(", ");
368 let mut stmt = conn.prepare(&format!(
369 "SELECT d.{select_cols}, e.distance
370 FROM {table_name}_embeddings e
371 JOIN {table_name} d ON e.rowid = d.rowid
372 WHERE e.embedding MATCH ?1 AND k = ?2 AND e.distance >= ?3
373 ORDER BY e.distance"
374 ))?;
375
376 let rows = stmt
377 .query_map(
378 rusqlite::params![
379 query_vec.as_bytes().to_vec(),
380 req.samples() as usize,
381 req.threshold().unwrap_or(0.)
382 ],
383 |row| {
384 let mut map = serde_json::Map::new();
386 for (i, col_name) in column_names.iter().enumerate() {
387 let value: String = row.get(i)?;
388 map.insert(col_name.to_string(), serde_json::Value::String(value));
389 }
390 let distance: f64 = row.get(column_names.len())?;
391 let id: String = row.get(0)?; Ok((id, serde_json::Value::Object(map), distance))
394 },
395 )?
396 .collect::<Result<Vec<_>, _>>()?;
397 Ok(rows)
398 })
399 .await
400 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
401
402 debug!("Found {} potential matches", rows.len());
403 let mut top_n = Vec::new();
404 for (id, doc_value, distance) in rows {
405 match serde_json::from_value::<D>(doc_value) {
406 Ok(doc) => {
407 top_n.push((distance, id, doc));
408 }
409 Err(e) => {
410 debug!("Failed to deserialize document {}: {}", id, e);
411 continue;
412 }
413 }
414 }
415
416 debug!("Returning {} matches", top_n.len());
417 Ok(top_n)
418 }
419
420 async fn top_n_ids(
421 &self,
422 req: VectorSearchRequest,
423 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
424 tracing::debug!(
425 "Finding top {} document IDs for query",
426 req.samples() as usize
427 );
428 let embedding = self.embedding_model.embed_text(req.query()).await?;
429 let query_vec = serialize_embedding(&embedding);
430 let table_name = T::name();
431
432 let results = self
433 .store
434 .conn
435 .call(move |conn| {
436 let mut stmt = conn.prepare(&format!(
437 "SELECT d.id, e.distance
438 FROM {table_name}_embeddings e
439 JOIN {table_name} d ON e.rowid = d.rowid
440 WHERE e.embedding MATCH ?1 AND k = ?2 AND e.distance >= ?3
441 ORDER BY e.distance"
442 ))?;
443
444 let results = stmt
445 .query_map(
446 rusqlite::params![
447 query_vec
448 .iter()
449 .flat_map(|x| x.to_le_bytes())
450 .collect::<Vec<u8>>(),
451 req.samples() as usize,
452 req.threshold().unwrap_or(0.)
453 ],
454 |row| Ok((row.get::<_, f64>(1)?, row.get::<_, String>(0)?)),
455 )?
456 .collect::<Result<Vec<_>, _>>()?;
457 Ok(results)
458 })
459 .await
460 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
461
462 debug!("Found {} matching document IDs", results.len());
463 Ok(results)
464 }
465}
466
467fn serialize_embedding(embedding: &Embedding) -> Vec<f32> {
468 embedding.vec.iter().map(|x| *x as f32).collect()
469}
470
471impl ColumnValue for String {
472 fn to_sql_string(&self) -> String {
473 self.clone()
474 }
475
476 fn column_type(&self) -> &'static str {
477 "TEXT"
478 }
479}