rig_sqlite/
lib.rs

1use rig::OneOrMany;
2use rig::embeddings::{Embedding, EmbeddingModel};
3use rig::vector_store::request::{FilterError, SearchFilter, VectorSearchRequest};
4use rig::vector_store::{VectorStoreError, VectorStoreIndex};
5use rusqlite::types::Value;
6use serde::Deserialize;
7use std::marker::PhantomData;
8use tokio_rusqlite::Connection;
9use tracing::{debug, info};
10use zerocopy::IntoBytes;
11
12#[derive(Debug)]
13pub enum SqliteError {
14    DatabaseError(Box<dyn std::error::Error + Send + Sync>),
15    SerializationError(Box<dyn std::error::Error + Send + Sync>),
16    InvalidColumnType(String),
17}
18
19pub trait ColumnValue: Send + Sync {
20    fn to_sql_string(&self) -> String;
21    fn column_type(&self) -> &'static str;
22}
23
24pub struct Column {
25    name: &'static str,
26    col_type: &'static str,
27    indexed: bool,
28}
29
30impl Column {
31    pub fn new(name: &'static str, col_type: &'static str) -> Self {
32        Self {
33            name,
34            col_type,
35            indexed: false,
36        }
37    }
38
39    pub fn indexed(mut self) -> Self {
40        self.indexed = true;
41        self
42    }
43}
44
45/// Example of a document type that can be used with SqliteVectorStore
46/// ```rust
47/// use rig::Embed;
48/// use serde::Deserialize;
49/// use rig_sqlite::{Column, ColumnValue, SqliteVectorStoreTable};
50///
51/// #[derive(Embed, Clone, Debug, Deserialize)]
52/// struct Document {
53///     id: String,
54///     #[embed]
55///     content: String,
56/// }
57///
58/// impl SqliteVectorStoreTable for Document {
59///     fn name() -> &'static str {
60///         "documents"
61///     }
62///
63///     fn schema() -> Vec<Column> {
64///         vec![
65///             Column::new("id", "TEXT PRIMARY KEY"),
66///             Column::new("content", "TEXT"),
67///         ]
68///     }
69///
70///     fn id(&self) -> String {
71///         self.id.clone()
72///     }
73///
74///     fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)> {
75///         vec![
76///             ("id", Box::new(self.id.clone())),
77///             ("content", Box::new(self.content.clone())),
78///         ]
79///     }
80/// }
81/// ```
82pub trait SqliteVectorStoreTable: Send + Sync + Clone {
83    fn name() -> &'static str;
84    fn schema() -> Vec<Column>;
85    fn id(&self) -> String;
86    fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)>;
87}
88
89#[derive(Clone)]
90pub struct SqliteVectorStore<E, T>
91where
92    E: EmbeddingModel + 'static,
93    T: SqliteVectorStoreTable + 'static,
94{
95    conn: Connection,
96    _phantom: PhantomData<(E, T)>,
97}
98
99impl<E, T> SqliteVectorStore<E, T>
100where
101    E: EmbeddingModel + 'static,
102    T: SqliteVectorStoreTable + 'static,
103{
104    pub async fn new(conn: Connection, embedding_model: &E) -> Result<Self, VectorStoreError> {
105        let dims = embedding_model.ndims();
106        let table_name = T::name();
107        let schema = T::schema();
108
109        // Build the table schema
110        let mut create_table = format!("CREATE TABLE IF NOT EXISTS {table_name} (");
111
112        // Add columns
113        let mut first = true;
114        for column in &schema {
115            if !first {
116                create_table.push(',');
117            }
118            create_table.push_str(&format!("\n    {} {}", column.name, column.col_type));
119            first = false;
120        }
121
122        create_table.push_str("\n)");
123
124        // Build index creation statements
125        let mut create_indexes = vec![format!(
126            "CREATE INDEX IF NOT EXISTS idx_{}_id ON {}(id)",
127            table_name, table_name
128        )];
129
130        // Add indexes for marked columns
131        for column in schema {
132            if column.indexed {
133                create_indexes.push(format!(
134                    "CREATE INDEX IF NOT EXISTS idx_{}_{} ON {}({})",
135                    table_name, column.name, table_name, column.name
136                ));
137            }
138        }
139
140        conn.call(move |conn| {
141            conn.execute_batch("BEGIN")?;
142
143            // Create document table
144            conn.execute_batch(&create_table)?;
145
146            // Create indexes
147            for index_stmt in create_indexes {
148                conn.execute_batch(&index_stmt)?;
149            }
150
151            // Create embeddings table
152            conn.execute_batch(&format!(
153                "CREATE VIRTUAL TABLE IF NOT EXISTS {table_name}_embeddings USING vec0(embedding float[{dims}])"
154            ))?;
155
156            conn.execute_batch("COMMIT")?;
157            Ok(())
158        })
159        .await
160        .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
161
162        Ok(Self {
163            conn,
164            _phantom: PhantomData,
165        })
166    }
167
168    pub fn index(self, model: E) -> SqliteVectorIndex<E, T> {
169        SqliteVectorIndex::new(model, self)
170    }
171
172    pub fn add_rows_with_txn(
173        &self,
174        txn: &rusqlite::Transaction<'_>,
175        documents: Vec<(T, OneOrMany<Embedding>)>,
176    ) -> Result<i64, tokio_rusqlite::Error> {
177        info!("Adding {} documents to store", documents.len());
178        let table_name = T::name();
179        let mut last_id = 0;
180
181        for (doc, embeddings) in &documents {
182            debug!("Storing document with id {}", doc.id());
183
184            let values = doc.column_values();
185            let columns = values.iter().map(|(col, _)| *col).collect::<Vec<_>>();
186
187            let placeholders = (1..=values.len())
188                .map(|i| format!("?{i}"))
189                .collect::<Vec<_>>();
190
191            let insert_sql = format!(
192                "INSERT OR REPLACE INTO {} ({}) VALUES ({})",
193                table_name,
194                columns.join(", "),
195                placeholders.join(", ")
196            );
197
198            txn.execute(
199                &insert_sql,
200                rusqlite::params_from_iter(values.iter().map(|(_, val)| val.to_sql_string())),
201            )?;
202            last_id = txn.last_insert_rowid();
203
204            let embeddings_sql =
205                format!("INSERT INTO {table_name}_embeddings (rowid, embedding) VALUES (?1, ?2)");
206
207            let mut stmt = txn.prepare(&embeddings_sql)?;
208            for (i, embedding) in embeddings.iter().enumerate() {
209                let vec = serialize_embedding(embedding);
210                debug!(
211                    "Storing embedding {} of {} (size: {} bytes)",
212                    i + 1,
213                    embeddings.len(),
214                    vec.len() * 4
215                );
216                let blob = rusqlite::types::Value::Blob(vec.as_bytes().to_vec());
217                stmt.execute(rusqlite::params![last_id, blob])?;
218            }
219        }
220
221        Ok(last_id)
222    }
223
224    pub async fn add_rows(
225        &self,
226        documents: Vec<(T, OneOrMany<Embedding>)>,
227    ) -> Result<i64, VectorStoreError> {
228        let documents = documents.clone();
229        let this = self.clone();
230
231        self.conn
232            .call(move |conn| {
233                let tx = conn.transaction().map_err(tokio_rusqlite::Error::from)?;
234                let result = this.add_rows_with_txn(&tx, documents)?;
235                tx.commit().map_err(tokio_rusqlite::Error::from)?;
236                Ok(result)
237            })
238            .await
239            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))
240    }
241}
242
243#[derive(Clone)]
244pub struct SqliteSearchFilter {
245    condition: String,
246    params: Vec<serde_json::Value>,
247}
248
249impl SearchFilter for SqliteSearchFilter {
250    type Value = serde_json::Value;
251
252    fn eq(key: String, value: Self::Value) -> Self {
253        Self {
254            condition: format!("{key} = ?"),
255            params: vec![value],
256        }
257    }
258
259    fn gt(key: String, value: Self::Value) -> Self {
260        Self {
261            condition: format!("{key} > ?"),
262            params: vec![value],
263        }
264    }
265
266    fn lt(key: String, value: Self::Value) -> Self {
267        Self {
268            condition: format!("{key} < ?"),
269            params: vec![value],
270        }
271    }
272
273    fn and(self, rhs: Self) -> Self {
274        Self {
275            condition: format!("({}) AND ({})", self.condition, rhs.condition),
276            params: self.params.into_iter().chain(rhs.params).collect(),
277        }
278    }
279
280    fn or(self, rhs: Self) -> Self {
281        Self {
282            condition: format!("({}) OR ({})", self.condition, rhs.condition),
283            params: self.params.into_iter().chain(rhs.params).collect(),
284        }
285    }
286}
287
288impl SqliteSearchFilter {
289    #[allow(clippy::should_implement_trait)]
290    pub fn not(self) -> Self {
291        Self {
292            condition: format!("NOT ({})", self.condition),
293            ..self
294        }
295    }
296}
297
298impl SqliteSearchFilter {
299    fn compile_params(self) -> Result<Vec<Value>, FilterError> {
300        let mut params = Vec::with_capacity(self.params.len());
301
302        fn convert(value: serde_json::Value) -> Result<Value, FilterError> {
303            use serde_json::Value::*;
304
305            match value {
306                Null => Ok(Value::Null),
307                Bool(b) => Ok(Value::Integer(b as i64)),
308                String(s) => Ok(Value::Text(s)),
309                Number(n) => Ok(if let Some(float) = n.as_f64() {
310                    Value::Real(float)
311                } else if let Some(int) = n.as_i64() {
312                    Value::Integer(int)
313                } else {
314                    unreachable!()
315                }),
316                Array(arr) => {
317                    let blob = serde_json::to_vec(&arr)
318                        .map_err(|e| FilterError::Serialization(e.to_string()))?;
319
320                    Ok(Value::Blob(blob))
321                }
322                Object(obj) => {
323                    let blob = serde_json::to_vec(&obj)
324                        .map_err(|e| FilterError::Serialization(e.to_string()))?;
325
326                    Ok(Value::Blob(blob))
327                }
328            }
329        }
330
331        for param in self.params.into_iter() {
332            params.push(convert(param)?)
333        }
334
335        Ok(params)
336    }
337}
338
339/// SQLite vector store implementation for Rig.
340///
341/// This crate provides a SQLite-based vector store implementation that can be used with Rig.
342/// It uses the `sqlite-vec` extension to enable vector similarity search capabilities.
343///
344/// # Example
345/// ```rust
346/// use rig::{
347///     embeddings::EmbeddingsBuilder,
348///     providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
349///     vector_store::VectorStoreIndex,
350///     Embed,
351/// };
352/// use rig_sqlite::{Column, ColumnValue, SqliteVectorStore, SqliteVectorStoreTable};
353/// use serde::Deserialize;
354/// use tokio_rusqlite::Connection;
355///
356/// #[derive(Embed, Clone, Debug, Deserialize)]
357/// struct Document {
358///     id: String,
359///     #[embed]
360///     content: String,
361/// }
362///
363/// impl SqliteVectorStoreTable for Document {
364///     fn name() -> &'static str {
365///         "documents"
366///     }
367///
368///     fn schema() -> Vec<Column> {
369///         vec![
370///             Column::new("id", "TEXT PRIMARY KEY"),
371///             Column::new("content", "TEXT"),
372///         ]
373///     }
374///
375///     fn id(&self) -> String {
376///         self.id.clone()
377///     }
378///
379///     fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)> {
380///         vec![
381///             ("id", Box::new(self.id.clone())),
382///             ("content", Box::new(self.content.clone())),
383///         ]
384///     }
385/// }
386///
387/// let conn = Connection::open("vector_store.db").await?;
388/// let openai_client = Client::new("YOUR_API_KEY");
389/// let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);
390///
391/// // Initialize vector store
392/// let vector_store = SqliteVectorStore::new(conn, &model).await?;
393///
394/// // Create documents
395/// let documents = vec![
396///     Document {
397///         id: "doc1".to_string(),
398///         content: "Example document 1".to_string(),
399///     },
400///     Document {
401///         id: "doc2".to_string(),
402///         content: "Example document 2".to_string(),
403///     },
404/// ];
405///
406/// // Generate embeddings
407/// let embeddings = EmbeddingsBuilder::new(model.clone())
408///     .documents(documents)?
409///     .build()
410///     .await?;
411///
412/// // Add to vector store
413/// vector_store.add_rows(embeddings).await?;
414///
415/// // Create index and search
416/// let index = vector_store.index(model);
417/// let results = index
418///     .top_n::<Document>("Example query", 2)
419///     .await?;
420/// ```
421pub struct SqliteVectorIndex<E, T>
422where
423    E: EmbeddingModel + 'static,
424    T: SqliteVectorStoreTable + 'static,
425{
426    store: SqliteVectorStore<E, T>,
427    embedding_model: E,
428}
429
430impl<E, T> SqliteVectorIndex<E, T>
431where
432    E: EmbeddingModel + 'static,
433    T: SqliteVectorStoreTable,
434{
435    pub fn new(embedding_model: E, store: SqliteVectorStore<E, T>) -> Self {
436        Self {
437            store,
438            embedding_model,
439        }
440    }
441}
442
443fn build_where_clause(
444    req: &VectorSearchRequest<SqliteSearchFilter>,
445    query_vec: Vec<f32>,
446) -> Result<(String, Vec<Value>), FilterError> {
447    let thresh = req.threshold().unwrap_or(0.);
448    let thresh = SqliteSearchFilter::gt("e.distance".into(), thresh.into());
449
450    let filter = req
451        .filter()
452        .as_ref()
453        .cloned()
454        .map(|filter| thresh.clone().and(filter))
455        .unwrap_or(thresh);
456
457    let where_clause = format!(
458        "WHERE e.embedding MATCH ? AND k = ? AND {}",
459        filter.condition
460    );
461
462    let query_vec = query_vec.into_iter().flat_map(f32::to_le_bytes).collect();
463    let query_vec = Value::Blob(query_vec);
464    let samples = req.samples() as u32;
465
466    let mut params = vec![query_vec, samples.into()];
467    let filter_params = filter.clone().compile_params()?;
468    params.extend(filter_params);
469
470    Ok((where_clause, params))
471}
472
473impl<E: EmbeddingModel + std::marker::Sync, T: SqliteVectorStoreTable> VectorStoreIndex
474    for SqliteVectorIndex<E, T>
475{
476    type Filter = SqliteSearchFilter;
477
478    async fn top_n<D>(
479        &self,
480        req: VectorSearchRequest<SqliteSearchFilter>,
481    ) -> Result<Vec<(f64, String, D)>, VectorStoreError>
482    where
483        D: for<'de> Deserialize<'de>,
484    {
485        tracing::debug!("Finding top {} matches for query", req.samples() as usize);
486        let embedding = self.embedding_model.embed_text(req.query()).await?;
487        let query_vec: Vec<f32> = serialize_embedding(&embedding);
488        let table_name = T::name();
489
490        // Get all column names from SqliteVectorStoreTable
491        let columns = T::schema();
492        let column_names: Vec<&str> = columns.iter().map(|column| column.name).collect();
493
494        // Build SELECT statement with all columns
495        let select_cols = column_names.join(", ");
496
497        let (where_clause, params) = build_where_clause(&req, query_vec)?;
498
499        let rows = self
500            .store
501            .conn
502            .call(move |conn| {
503                let mut stmt = conn.prepare(&format!(
504                    "SELECT d.{select_cols}, e.distance
505                    FROM {table_name}_embeddings e
506                    JOIN {table_name} d ON e.rowid = d.rowid
507                    {where_clause}
508                    ORDER BY e.distance"
509                ))?;
510
511                dbg!(&stmt);
512
513                let rows = stmt
514                    .query_map(rusqlite::params_from_iter(params), |row| {
515                        // Create a map of column names to values
516                        let mut map = serde_json::Map::new();
517                        for (i, col_name) in column_names.iter().enumerate() {
518                            let value: String = row.get(i)?;
519                            map.insert(col_name.to_string(), serde_json::Value::String(value));
520                        }
521                        let distance: f64 = row.get(column_names.len())?;
522                        let id: String = row.get(0)?; // Assuming id is always first column
523
524                        Ok((id, serde_json::Value::Object(map), distance))
525                    })?
526                    .collect::<Result<Vec<_>, _>>()?;
527                Ok(rows)
528            })
529            .await
530            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
531
532        debug!("Found {} potential matches", rows.len());
533        let mut top_n = Vec::new();
534        for (id, doc_value, distance) in rows {
535            match serde_json::from_value::<D>(doc_value) {
536                Ok(doc) => {
537                    top_n.push((distance, id, doc));
538                }
539                Err(e) => {
540                    debug!("Failed to deserialize document {}: {}", id, e);
541                    continue;
542                }
543            }
544        }
545
546        debug!("Returning {} matches", top_n.len());
547        Ok(top_n)
548    }
549
550    async fn top_n_ids(
551        &self,
552        req: VectorSearchRequest<SqliteSearchFilter>,
553    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
554        tracing::debug!(
555            "Finding top {} document IDs for query",
556            req.samples() as usize
557        );
558        let embedding = self.embedding_model.embed_text(req.query()).await?;
559        let query_vec = serialize_embedding(&embedding);
560        let table_name = T::name();
561
562        let (where_clause, params) = build_where_clause(&req, query_vec)?;
563
564        let results = self
565            .store
566            .conn
567            .call(move |conn| {
568                let mut stmt = conn.prepare(&format!(
569                    "SELECT d.id, e.distance
570                     FROM {table_name}_embeddings e
571                     JOIN {table_name} d ON e.rowid = d.rowid
572                     {where_clause}
573                     ORDER BY e.distance"
574                ))?;
575
576                dbg!(&stmt);
577
578                let results = stmt
579                    .query_map(rusqlite::params_from_iter(params), |row| {
580                        Ok((row.get::<_, f64>(1)?, row.get::<_, String>(0)?))
581                    })?
582                    .collect::<Result<Vec<_>, _>>()?;
583                Ok(results)
584            })
585            .await
586            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
587
588        debug!("Found {} matching document IDs", results.len());
589        Ok(results)
590    }
591}
592
593fn serialize_embedding(embedding: &Embedding) -> Vec<f32> {
594    embedding.vec.iter().map(|x| *x as f32).collect()
595}
596
597impl ColumnValue for String {
598    fn to_sql_string(&self) -> String {
599        self.clone()
600    }
601
602    fn column_type(&self) -> &'static str {
603        "TEXT"
604    }
605}