1use rig::OneOrMany;
2use rig::embeddings::{Embedding, EmbeddingModel};
3use rig::vector_store::request::VectorSearchRequest;
4use rig::vector_store::{VectorStoreError, VectorStoreIndex};
5use serde::Deserialize;
6use std::marker::PhantomData;
7use tokio_rusqlite::Connection;
8use tracing::{debug, info};
9use zerocopy::IntoBytes;
10
11#[derive(Debug)]
12pub enum SqliteError {
13 DatabaseError(Box<dyn std::error::Error + Send + Sync>),
14 SerializationError(Box<dyn std::error::Error + Send + Sync>),
15 InvalidColumnType(String),
16}
17
18pub trait ColumnValue: Send + Sync {
19 fn to_sql_string(&self) -> String;
20 fn column_type(&self) -> &'static str;
21}
22
23pub struct Column {
24 name: &'static str,
25 col_type: &'static str,
26 indexed: bool,
27}
28
29impl Column {
30 pub fn new(name: &'static str, col_type: &'static str) -> Self {
31 Self {
32 name,
33 col_type,
34 indexed: false,
35 }
36 }
37
38 pub fn indexed(mut self) -> Self {
39 self.indexed = true;
40 self
41 }
42}
43
44pub trait SqliteVectorStoreTable: Send + Sync + Clone {
82 fn name() -> &'static str;
83 fn schema() -> Vec<Column>;
84 fn id(&self) -> String;
85 fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)>;
86}
87
88#[derive(Clone)]
89pub struct SqliteVectorStore<E: EmbeddingModel + 'static, T: SqliteVectorStoreTable + 'static> {
90 conn: Connection,
91 _phantom: PhantomData<(E, T)>,
92}
93
94impl<E: EmbeddingModel + 'static, T: SqliteVectorStoreTable + 'static> SqliteVectorStore<E, T> {
95 pub async fn new(conn: Connection, embedding_model: &E) -> Result<Self, VectorStoreError> {
96 let dims = embedding_model.ndims();
97 let table_name = T::name();
98 let schema = T::schema();
99
100 let mut create_table = format!("CREATE TABLE IF NOT EXISTS {table_name} (");
102
103 let mut first = true;
105 for column in &schema {
106 if !first {
107 create_table.push(',');
108 }
109 create_table.push_str(&format!("\n {} {}", column.name, column.col_type));
110 first = false;
111 }
112
113 create_table.push_str("\n)");
114
115 let mut create_indexes = vec![format!(
117 "CREATE INDEX IF NOT EXISTS idx_{}_id ON {}(id)",
118 table_name, table_name
119 )];
120
121 for column in schema {
123 if column.indexed {
124 create_indexes.push(format!(
125 "CREATE INDEX IF NOT EXISTS idx_{}_{} ON {}({})",
126 table_name, column.name, table_name, column.name
127 ));
128 }
129 }
130
131 conn.call(move |conn| {
132 conn.execute_batch("BEGIN")?;
133
134 conn.execute_batch(&create_table)?;
136
137 for index_stmt in create_indexes {
139 conn.execute_batch(&index_stmt)?;
140 }
141
142 conn.execute_batch(&format!(
144 "CREATE VIRTUAL TABLE IF NOT EXISTS {table_name}_embeddings USING vec0(embedding float[{dims}])"
145 ))?;
146
147 conn.execute_batch("COMMIT")?;
148 Ok(())
149 })
150 .await
151 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
152
153 Ok(Self {
154 conn,
155 _phantom: PhantomData,
156 })
157 }
158
159 pub fn index(self, model: E) -> SqliteVectorIndex<E, T> {
160 SqliteVectorIndex::new(model, self)
161 }
162
163 pub fn add_rows_with_txn(
164 &self,
165 txn: &rusqlite::Transaction<'_>,
166 documents: Vec<(T, OneOrMany<Embedding>)>,
167 ) -> Result<i64, tokio_rusqlite::Error> {
168 info!("Adding {} documents to store", documents.len());
169 let table_name = T::name();
170 let mut last_id = 0;
171
172 for (doc, embeddings) in &documents {
173 debug!("Storing document with id {}", doc.id());
174
175 let values = doc.column_values();
176 let columns = values.iter().map(|(col, _)| *col).collect::<Vec<_>>();
177
178 let placeholders = (1..=values.len())
179 .map(|i| format!("?{i}"))
180 .collect::<Vec<_>>();
181
182 let insert_sql = format!(
183 "INSERT OR REPLACE INTO {} ({}) VALUES ({})",
184 table_name,
185 columns.join(", "),
186 placeholders.join(", ")
187 );
188
189 txn.execute(
190 &insert_sql,
191 rusqlite::params_from_iter(values.iter().map(|(_, val)| val.to_sql_string())),
192 )?;
193 last_id = txn.last_insert_rowid();
194
195 let embeddings_sql =
196 format!("INSERT INTO {table_name}_embeddings (rowid, embedding) VALUES (?1, ?2)");
197
198 let mut stmt = txn.prepare(&embeddings_sql)?;
199 for (i, embedding) in embeddings.iter().enumerate() {
200 let vec = serialize_embedding(embedding);
201 debug!(
202 "Storing embedding {} of {} (size: {} bytes)",
203 i + 1,
204 embeddings.len(),
205 vec.len() * 4
206 );
207 let blob = rusqlite::types::Value::Blob(vec.as_bytes().to_vec());
208 stmt.execute(rusqlite::params![last_id, blob])?;
209 }
210 }
211
212 Ok(last_id)
213 }
214
215 pub async fn add_rows(
216 &self,
217 documents: Vec<(T, OneOrMany<Embedding>)>,
218 ) -> Result<i64, VectorStoreError> {
219 let documents = documents.clone();
220 let this = self.clone();
221
222 self.conn
223 .call(move |conn| {
224 let tx = conn.transaction().map_err(tokio_rusqlite::Error::from)?;
225 let result = this.add_rows_with_txn(&tx, documents)?;
226 tx.commit().map_err(tokio_rusqlite::Error::from)?;
227 Ok(result)
228 })
229 .await
230 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))
231 }
232}
233
234pub struct SqliteVectorIndex<E: EmbeddingModel + 'static, T: SqliteVectorStoreTable + 'static> {
317 store: SqliteVectorStore<E, T>,
318 embedding_model: E,
319}
320
321impl<E: EmbeddingModel + 'static, T: SqliteVectorStoreTable> SqliteVectorIndex<E, T> {
322 pub fn new(embedding_model: E, store: SqliteVectorStore<E, T>) -> Self {
323 Self {
324 store,
325 embedding_model,
326 }
327 }
328}
329
330impl<E: EmbeddingModel + std::marker::Sync, T: SqliteVectorStoreTable> VectorStoreIndex
331 for SqliteVectorIndex<E, T>
332{
333 async fn top_n<D: for<'a> Deserialize<'a>>(
334 &self,
335 req: VectorSearchRequest,
336 ) -> Result<Vec<(f64, String, D)>, VectorStoreError> {
337 tracing::debug!("Finding top {} matches for query", req.samples() as usize);
338 let embedding = self.embedding_model.embed_text(req.query()).await?;
339 let query_vec: Vec<f32> = serialize_embedding(&embedding);
340 let table_name = T::name();
341
342 let columns = T::schema();
344 let column_names: Vec<&str> = columns.iter().map(|column| column.name).collect();
345
346 let rows = self
347 .store
348 .conn
349 .call(move |conn| {
350 let select_cols = column_names.join(", ");
352 let mut stmt = conn.prepare(&format!(
353 "SELECT d.{select_cols}, e.distance
354 FROM {table_name}_embeddings e
355 JOIN {table_name} d ON e.rowid = d.rowid
356 WHERE e.embedding MATCH ?1 AND k = ?2 AND e.distance >= ?3
357 ORDER BY e.distance"
358 ))?;
359
360 let rows = stmt
361 .query_map(
362 rusqlite::params![
363 query_vec.as_bytes().to_vec(),
364 req.samples() as usize,
365 req.threshold().unwrap_or(0.)
366 ],
367 |row| {
368 let mut map = serde_json::Map::new();
370 for (i, col_name) in column_names.iter().enumerate() {
371 let value: String = row.get(i)?;
372 map.insert(col_name.to_string(), serde_json::Value::String(value));
373 }
374 let distance: f64 = row.get(column_names.len())?;
375 let id: String = row.get(0)?; Ok((id, serde_json::Value::Object(map), distance))
378 },
379 )?
380 .collect::<Result<Vec<_>, _>>()?;
381 Ok(rows)
382 })
383 .await
384 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
385
386 debug!("Found {} potential matches", rows.len());
387 let mut top_n = Vec::new();
388 for (id, doc_value, distance) in rows {
389 match serde_json::from_value::<D>(doc_value) {
390 Ok(doc) => {
391 top_n.push((distance, id, doc));
392 }
393 Err(e) => {
394 debug!("Failed to deserialize document {}: {}", id, e);
395 continue;
396 }
397 }
398 }
399
400 debug!("Returning {} matches", top_n.len());
401 Ok(top_n)
402 }
403
404 async fn top_n_ids(
405 &self,
406 req: VectorSearchRequest,
407 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
408 tracing::debug!(
409 "Finding top {} document IDs for query",
410 req.samples() as usize
411 );
412 let embedding = self.embedding_model.embed_text(req.query()).await?;
413 let query_vec = serialize_embedding(&embedding);
414 let table_name = T::name();
415
416 let results = self
417 .store
418 .conn
419 .call(move |conn| {
420 let mut stmt = conn.prepare(&format!(
421 "SELECT d.id, e.distance
422 FROM {table_name}_embeddings e
423 JOIN {table_name} d ON e.rowid = d.rowid
424 WHERE e.embedding MATCH ?1 AND k = ?2 AND e.distance >= ?3
425 ORDER BY e.distance"
426 ))?;
427
428 let results = stmt
429 .query_map(
430 rusqlite::params![
431 query_vec
432 .iter()
433 .flat_map(|x| x.to_le_bytes())
434 .collect::<Vec<u8>>(),
435 req.samples() as usize,
436 req.threshold().unwrap_or(0.)
437 ],
438 |row| Ok((row.get::<_, f64>(1)?, row.get::<_, String>(0)?)),
439 )?
440 .collect::<Result<Vec<_>, _>>()?;
441 Ok(results)
442 })
443 .await
444 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
445
446 debug!("Found {} matching document IDs", results.len());
447 Ok(results)
448 }
449}
450
451fn serialize_embedding(embedding: &Embedding) -> Vec<f32> {
452 embedding.vec.iter().map(|x| *x as f32).collect()
453}
454
455impl ColumnValue for String {
456 fn to_sql_string(&self) -> String {
457 self.clone()
458 }
459
460 fn column_type(&self) -> &'static str {
461 "TEXT"
462 }
463}