Skip to main content

rig_sqlite/
lib.rs

1//! SQLite vector store integration for Rig.
2//!
3//! This crate provides [`SqliteVectorStore`] and [`SqliteVectorIndex`] for
4//! storing embedded documents in SQLite with the `sqlite-vec` extension. Define
5//! document table schemas by implementing [`SqliteVectorStoreTable`].
6//!
7//! The root `rig` facade re-exports this crate as `rig::sqlite` when the
8//! `sqlite` feature is enabled.
9
10use rig_core::embeddings::{Embedding, EmbeddingModel};
11use rig_core::vector_store::request::{FilterError, SearchFilter, VectorSearchRequest};
12use rig_core::vector_store::{InsertDocuments, VectorStoreError, VectorStoreIndex};
13use rig_core::wasm_compat::{WasmCompatSend, WasmCompatSync};
14use rig_core::{Embed, OneOrMany};
15use rusqlite::types::Value;
16use serde::{Deserialize, Serialize};
17use std::marker::PhantomData;
18use std::ops::RangeInclusive;
19use tokio_rusqlite::Connection;
20use tracing::{debug, info};
21use zerocopy::IntoBytes;
22
23#[derive(Debug)]
24pub enum SqliteError {
25    DatabaseError(Box<dyn std::error::Error + Send + Sync>),
26    SerializationError(Box<dyn std::error::Error + Send + Sync>),
27    InvalidColumnType(String),
28}
29
30pub trait ColumnValue: Send + Sync {
31    fn to_sql_string(&self) -> String;
32    fn column_type(&self) -> &'static str;
33}
34
35pub struct Column {
36    name: &'static str,
37    col_type: &'static str,
38    indexed: bool,
39}
40
41impl Column {
42    pub fn new(name: &'static str, col_type: &'static str) -> Self {
43        Self {
44            name,
45            col_type,
46            indexed: false,
47        }
48    }
49
50    pub fn indexed(mut self) -> Self {
51        self.indexed = true;
52        self
53    }
54}
55
56/// Example of a document type that can be used with SqliteVectorStore
57/// ```rust
58/// use rig_core::Embed;
59/// use serde::{Deserialize, Serialize};
60/// use rig_sqlite::{Column, ColumnValue, SqliteVectorStoreTable};
61///
62/// #[derive(Embed, Clone, Debug, Deserialize, Serialize)]
63/// struct Document {
64///     id: String,
65///     #[embed]
66///     content: String,
67/// }
68///
69/// impl SqliteVectorStoreTable for Document {
70///     fn name() -> &'static str {
71///         "documents"
72///     }
73///
74///     fn schema() -> Vec<Column> {
75///         vec![
76///             Column::new("id", "TEXT PRIMARY KEY"),
77///             Column::new("content", "TEXT"),
78///         ]
79///     }
80///
81///     fn id(&self) -> String {
82///         self.id.clone()
83///     }
84///
85///     fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)> {
86///         vec![
87///             ("id", Box::new(self.id.clone())),
88///             ("content", Box::new(self.content.clone())),
89///         ]
90///     }
91/// }
92/// ```
93pub trait SqliteVectorStoreTable: Send + Sync + Clone {
94    fn name() -> &'static str;
95    fn schema() -> Vec<Column>;
96    fn id(&self) -> String;
97    fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)>;
98}
99
100#[derive(Clone)]
101pub struct SqliteVectorStore<E, T>
102where
103    E: EmbeddingModel + 'static,
104    T: SqliteVectorStoreTable + 'static,
105{
106    conn: Connection,
107    _phantom: PhantomData<(E, T)>,
108}
109
110impl<E, T> SqliteVectorStore<E, T>
111where
112    E: EmbeddingModel + Clone + 'static,
113    T: SqliteVectorStoreTable + 'static,
114{
115    pub async fn new(conn: Connection, embedding_model: &E) -> Result<Self, VectorStoreError> {
116        let dims = embedding_model.ndims();
117        let table_name = T::name();
118        let schema = T::schema();
119
120        // Build the table schema
121        let mut create_table = format!("CREATE TABLE IF NOT EXISTS {table_name} (");
122
123        // Add columns
124        let mut first = true;
125        for column in &schema {
126            if !first {
127                create_table.push(',');
128            }
129            create_table.push_str(&format!("\n    {} {}", column.name, column.col_type));
130            first = false;
131        }
132
133        create_table.push_str("\n)");
134
135        // Build index creation statements
136        let mut create_indexes = vec![format!(
137            "CREATE INDEX IF NOT EXISTS idx_{}_id ON {}(id)",
138            table_name, table_name
139        )];
140
141        // Add indexes for marked columns
142        for column in schema {
143            if column.indexed {
144                create_indexes.push(format!(
145                    "CREATE INDEX IF NOT EXISTS idx_{}_{} ON {}({})",
146                    table_name, column.name, table_name, column.name
147                ));
148            }
149        }
150
151        conn.call(move |conn| {
152            conn.execute_batch("BEGIN")?;
153
154            // Create document table
155            conn.execute_batch(&create_table)?;
156
157            // Create indexes
158            for index_stmt in create_indexes {
159                conn.execute_batch(&index_stmt)?;
160            }
161
162            // Create embeddings table
163            conn.execute_batch(&format!(
164                "CREATE VIRTUAL TABLE IF NOT EXISTS {table_name}_embeddings USING vec0(embedding float[{dims}])"
165            ))?;
166
167            conn.execute_batch("COMMIT")?;
168            Ok(())
169        })
170        .await
171        .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
172
173        Ok(Self {
174            conn,
175            _phantom: PhantomData,
176        })
177    }
178
179    pub fn index(self, model: E) -> SqliteVectorIndex<E, T> {
180        SqliteVectorIndex::new(model, self)
181    }
182
183    pub fn add_rows_with_txn(
184        &self,
185        txn: &rusqlite::Transaction<'_>,
186        documents: Vec<(T, OneOrMany<Embedding>)>,
187    ) -> Result<i64, tokio_rusqlite::Error> {
188        info!("Adding {} documents to store", documents.len());
189        let table_name = T::name();
190        let mut last_id = 0;
191
192        for (doc, embeddings) in &documents {
193            debug!("Storing document with id {}", doc.id());
194
195            let values = doc.column_values();
196            let columns = values.iter().map(|(col, _)| *col).collect::<Vec<_>>();
197
198            let placeholders = (1..=values.len())
199                .map(|i| format!("?{i}"))
200                .collect::<Vec<_>>();
201
202            let insert_sql = format!(
203                "INSERT OR REPLACE INTO {} ({}) VALUES ({})",
204                table_name,
205                columns.join(", "),
206                placeholders.join(", ")
207            );
208
209            txn.execute(
210                &insert_sql,
211                rusqlite::params_from_iter(values.iter().map(|(_, val)| val.to_sql_string())),
212            )?;
213            last_id = txn.last_insert_rowid();
214
215            let embeddings_sql =
216                format!("INSERT INTO {table_name}_embeddings (rowid, embedding) VALUES (?1, ?2)");
217
218            let mut stmt = txn.prepare(&embeddings_sql)?;
219            for (i, embedding) in embeddings.iter().enumerate() {
220                let vec = serialize_embedding(embedding);
221                debug!(
222                    "Storing embedding {} of {} (size: {} bytes)",
223                    i + 1,
224                    embeddings.len(),
225                    vec.len() * 4
226                );
227                let blob = rusqlite::types::Value::Blob(vec.as_bytes().to_vec());
228                stmt.execute(rusqlite::params![last_id, blob])?;
229            }
230        }
231
232        Ok(last_id)
233    }
234
235    pub async fn add_rows(
236        &self,
237        documents: Vec<(T, OneOrMany<Embedding>)>,
238    ) -> Result<i64, VectorStoreError>
239    where
240        T: 'static,
241        Self: 'static,
242    {
243        let cloned = self.clone();
244
245        self.conn
246            .call(move |conn| {
247                let tx = conn.transaction()?;
248                let result = cloned.add_rows_with_txn(&tx, documents)?;
249                tx.commit()?;
250
251                Ok(result)
252            })
253            .await
254            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))
255    }
256}
257
258impl<E, T> InsertDocuments for SqliteVectorStore<E, T>
259where
260    E: EmbeddingModel + Clone + WasmCompatSend + WasmCompatSync + 'static,
261    T: SqliteVectorStoreTable
262        + for<'de> Deserialize<'de>
263        + WasmCompatSend
264        + WasmCompatSync
265        + 'static,
266{
267    async fn insert_documents<Doc: Serialize + Embed + WasmCompatSend>(
268        &self,
269        documents: Vec<(Doc, OneOrMany<Embedding>)>,
270    ) -> Result<(), VectorStoreError> {
271        if documents.is_empty() {
272            return Ok(());
273        }
274
275        let rows = documents
276            .into_iter()
277            .map(|(document, embeddings)| {
278                let document = serde_json::to_value(document)?;
279                let row = serde_json::from_value::<T>(document)?;
280
281                Ok((row, embeddings))
282            })
283            .collect::<Result<Vec<_>, VectorStoreError>>()?;
284
285        self.add_rows(rows).await?;
286
287        Ok(())
288    }
289}
290
291#[derive(Clone, Default, Deserialize, Serialize, Debug)]
292pub struct SqliteSearchFilter {
293    condition: String,
294    params: Vec<serde_json::Value>,
295}
296
297impl SearchFilter for SqliteSearchFilter {
298    type Value = serde_json::Value;
299
300    fn eq(key: impl AsRef<str>, value: Self::Value) -> Self {
301        Self {
302            condition: format!("{} = ?", key.as_ref()),
303            params: vec![value],
304        }
305    }
306
307    fn gt(key: impl AsRef<str>, value: Self::Value) -> Self {
308        Self {
309            condition: format!("{} > ?", key.as_ref()),
310            params: vec![value],
311        }
312    }
313
314    fn lt(key: impl AsRef<str>, value: Self::Value) -> Self {
315        Self {
316            condition: format!("{} < ?", key.as_ref()),
317            params: vec![value],
318        }
319    }
320
321    fn and(self, rhs: Self) -> Self {
322        Self {
323            condition: format!("({}) AND ({})", self.condition, rhs.condition),
324            params: self.params.into_iter().chain(rhs.params).collect(),
325        }
326    }
327
328    fn or(self, rhs: Self) -> Self {
329        Self {
330            condition: format!("({}) OR ({})", self.condition, rhs.condition),
331            params: self.params.into_iter().chain(rhs.params).collect(),
332        }
333    }
334}
335
336impl SqliteSearchFilter {
337    #[allow(clippy::should_implement_trait)]
338    pub fn not(self) -> Self {
339        Self {
340            condition: format!("NOT ({})", self.condition),
341            ..self
342        }
343    }
344
345    /// Tests whether the value at `key` is contained in the range
346    pub fn between<N>(key: String, range: RangeInclusive<N>) -> Self
347    where
348        N: Ord + rusqlite::ToSql + std::fmt::Display,
349    {
350        let lo = range.start();
351        let hi = range.end();
352
353        Self {
354            condition: format!("{key} between {lo} and {hi}"),
355            ..Default::default()
356        }
357    }
358
359    // Null checks
360    pub fn is_null(key: String) -> Self {
361        Self {
362            condition: format!("{key} is null"),
363            ..Default::default()
364        }
365    }
366
367    pub fn is_not_null(key: String) -> Self {
368        Self {
369            condition: format!("{key} is not null"),
370            ..Default::default()
371        }
372    }
373
374    // String ops
375    /// Tests whether the value at `key` satisfies the glob pattern
376    /// `pattern` should be a valid SQLite glob pattern
377    pub fn glob<'a, S>(key: String, pattern: S) -> Self
378    where
379        S: AsRef<&'a str>,
380    {
381        Self {
382            condition: format!("{key} glob {}", pattern.as_ref()),
383            ..Default::default()
384        }
385    }
386
387    /// Tests whether the value at `key` satisfies the "like" pattern
388    /// `pattern` should be a valid SQLite like pattern
389    pub fn like<'a, S>(key: String, pattern: S) -> Self
390    where
391        S: AsRef<&'a str>,
392    {
393        Self {
394            condition: format!("{key} like {}", pattern.as_ref()),
395            ..Default::default()
396        }
397    }
398}
399
400impl SqliteSearchFilter {
401    fn compile_params(self) -> Result<Vec<Value>, FilterError> {
402        let mut params = Vec::with_capacity(self.params.len());
403
404        fn convert(value: serde_json::Value) -> Result<Value, FilterError> {
405            use serde_json::Value::*;
406
407            match value {
408                Null => Ok(Value::Null),
409                Bool(b) => Ok(Value::Integer(b as i64)),
410                String(s) => Ok(Value::Text(s)),
411                Number(n) => Ok(if let Some(float) = n.as_f64() {
412                    Value::Real(float)
413                } else if let Some(int) = n.as_i64() {
414                    Value::Integer(int)
415                } else if let Some(int) = n.as_u64() {
416                    Value::Integer(int as i64)
417                } else {
418                    Value::Text(n.to_string())
419                }),
420                Array(arr) => {
421                    let blob = serde_json::to_vec(&arr)
422                        .map_err(|e| FilterError::Serialization(e.to_string()))?;
423
424                    Ok(Value::Blob(blob))
425                }
426                Object(obj) => {
427                    let blob = serde_json::to_vec(&obj)
428                        .map_err(|e| FilterError::Serialization(e.to_string()))?;
429
430                    Ok(Value::Blob(blob))
431                }
432            }
433        }
434
435        for param in self.params.into_iter() {
436            params.push(convert(param)?)
437        }
438
439        Ok(params)
440    }
441}
442
443/// SQLite vector store implementation for Rig.
444///
445/// This crate provides a SQLite-based vector store implementation that can be used with Rig.
446/// It uses the `sqlite-vec` extension to enable vector similarity search capabilities.
447///
448/// # Example
449/// ```no_run
450/// use rig_core::{
451///     client::EmbeddingsClient,
452///     embeddings::EmbeddingsBuilder,
453///     providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
454///     vector_store::{InsertDocuments, VectorStoreIndex},
455///     Embed,
456/// };
457/// use rig_sqlite::{Column, ColumnValue, SqliteVectorStore, SqliteVectorStoreTable};
458/// use rig_core::vector_store::request::VectorSearchRequest;
459/// use serde::{Deserialize, Serialize};
460/// use tokio_rusqlite::Connection;
461///
462/// # async fn example() -> anyhow::Result<()> {
463/// #[derive(Embed, Clone, Debug, Deserialize, Serialize)]
464/// struct Document {
465///     id: String,
466///     #[embed]
467///     content: String,
468/// }
469///
470/// impl SqliteVectorStoreTable for Document {
471///     fn name() -> &'static str {
472///         "documents"
473///     }
474///
475///     fn schema() -> Vec<Column> {
476///         vec![
477///             Column::new("id", "TEXT PRIMARY KEY"),
478///             Column::new("content", "TEXT"),
479///         ]
480///     }
481///
482///     fn id(&self) -> String {
483///         self.id.clone()
484///     }
485///
486///     fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)> {
487///         vec![
488///             ("id", Box::new(self.id.clone())),
489///             ("content", Box::new(self.content.clone())),
490///         ]
491///     }
492/// }
493///
494/// let conn = Connection::open("vector_store.db").await?;
495/// let openai_client = Client::new("YOUR_API_KEY")?;
496/// let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);
497///
498/// // Initialize vector store
499/// let vector_store: SqliteVectorStore<_, Document> = SqliteVectorStore::new(conn, &model).await?;
500///
501/// // Create documents
502/// let documents = vec![
503///     Document {
504///         id: "doc1".to_string(),
505///         content: "Example document 1".to_string(),
506///     },
507///     Document {
508///         id: "doc2".to_string(),
509///         content: "Example document 2".to_string(),
510///     },
511/// ];
512///
513/// // Generate embeddings
514/// let embeddings = EmbeddingsBuilder::new(model.clone())
515///     .documents(documents)?
516///     .build()
517///     .await?;
518///
519/// // Add to vector store
520/// vector_store.insert_documents(embeddings).await?;
521///
522/// // Create index and search
523/// let index = vector_store.index(model);
524/// let req = VectorSearchRequest::builder()
525///     .query("Example query")
526///     .samples(2)
527///     .build();
528/// let results = index.top_n::<Document>(req).await?;
529/// # let _ = results;
530/// # Ok(())
531/// # }
532/// # let _ = example();
533/// ```
534pub struct SqliteVectorIndex<E, T>
535where
536    E: EmbeddingModel + 'static,
537    T: SqliteVectorStoreTable + 'static,
538{
539    store: SqliteVectorStore<E, T>,
540    embedding_model: E,
541}
542
543impl<E, T> SqliteVectorIndex<E, T>
544where
545    E: EmbeddingModel + 'static,
546    T: SqliteVectorStoreTable,
547{
548    pub fn new(embedding_model: E, store: SqliteVectorStore<E, T>) -> Self {
549        Self {
550            store,
551            embedding_model,
552        }
553    }
554}
555
556fn build_where_clause(
557    req: &VectorSearchRequest<SqliteSearchFilter>,
558    query_vec: Vec<f32>,
559) -> Result<(String, Vec<Value>), FilterError> {
560    let thresh = req.threshold().unwrap_or(0.);
561    let thresh = SqliteSearchFilter::gt("distance", thresh.into());
562
563    let filter = req
564        .filter()
565        .as_ref()
566        .cloned()
567        .map(|filter| thresh.clone().and(filter))
568        .unwrap_or(thresh);
569
570    let where_clause = format!(
571        "WHERE e.embedding MATCH ? AND k = ? AND {}",
572        filter.condition
573    );
574
575    let query_vec = query_vec.into_iter().flat_map(f32::to_le_bytes).collect();
576    let query_vec = Value::Blob(query_vec);
577    let samples = req.samples() as u32;
578
579    let mut params = vec![query_vec.clone(), query_vec, samples.into()];
580    let filter_params = filter.clone().compile_params()?;
581    params.extend(filter_params);
582
583    Ok((where_clause, params))
584}
585
586impl<E: EmbeddingModel + std::marker::Sync, T: SqliteVectorStoreTable> VectorStoreIndex
587    for SqliteVectorIndex<E, T>
588{
589    type Filter = SqliteSearchFilter;
590
591    async fn top_n<D>(
592        &self,
593        req: VectorSearchRequest<SqliteSearchFilter>,
594    ) -> Result<Vec<(f64, String, D)>, VectorStoreError>
595    where
596        D: for<'de> Deserialize<'de>,
597    {
598        tracing::debug!("Finding top {} matches for query", req.samples() as usize);
599        let embedding = self.embedding_model.embed_text(req.query()).await?;
600        let query_vec: Vec<f32> = serialize_embedding(&embedding);
601        let table_name = T::name();
602
603        // Get all column names from SqliteVectorStoreTable
604        let columns = T::schema();
605        let column_names: Vec<&str> = columns.iter().map(|column| column.name).collect();
606
607        // Build SELECT statement with all columns
608        let select_cols = column_names.join(", ");
609
610        let (where_clause, params) = build_where_clause(&req, query_vec)?;
611
612        let rows = self
613            .store
614            .conn
615            .call(move |conn| {
616                let mut stmt = conn.prepare(&format!(
617                    "SELECT d.{select_cols}, (1 - vec_distance_cosine(?, e.embedding)) as distance
618                    FROM {table_name}_embeddings e
619                    JOIN {table_name} d ON e.rowid = d.rowid
620                    {where_clause}
621                    ORDER BY distance"
622                ))?;
623
624                let rows = stmt
625                    .query_map(rusqlite::params_from_iter(params), |row| {
626                        // Create a map of column names to values
627                        let mut map = serde_json::Map::new();
628                        for (i, col_name) in column_names.iter().enumerate() {
629                            let value: String = row.get(i)?;
630                            map.insert(col_name.to_string(), serde_json::Value::String(value));
631                        }
632                        let distance: f64 = row.get(column_names.len())?;
633                        let id: String = row.get(0)?; // Assuming id is always first column
634
635                        Ok((id, serde_json::Value::Object(map), distance))
636                    })?
637                    .collect::<Result<Vec<_>, _>>()?;
638                Ok(rows)
639            })
640            .await
641            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
642
643        debug!("Found {} potential matches", rows.len());
644        let mut top_n = Vec::new();
645        for (id, doc_value, distance) in rows {
646            match serde_json::from_value::<D>(doc_value) {
647                Ok(doc) => {
648                    top_n.push((distance, id, doc));
649                }
650                Err(e) => {
651                    debug!("Failed to deserialize document {}: {}", id, e);
652                    continue;
653                }
654            }
655        }
656
657        debug!("Returning {} matches", top_n.len());
658        Ok(top_n)
659    }
660
661    async fn top_n_ids(
662        &self,
663        req: VectorSearchRequest<SqliteSearchFilter>,
664    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
665        tracing::debug!(
666            "Finding top {} document IDs for query",
667            req.samples() as usize
668        );
669        let embedding = self.embedding_model.embed_text(req.query()).await?;
670        let query_vec = serialize_embedding(&embedding);
671        let table_name = T::name();
672
673        let (where_clause, params) = build_where_clause(&req, query_vec)?;
674
675        let results = self
676            .store
677            .conn
678            .call(move |conn| {
679                let mut stmt = conn.prepare(&format!(
680                    "SELECT d.id, (1 - vec_distance_cosine(?1, e.embedding)) as distance
681                     FROM {table_name}_embeddings e
682                     JOIN {table_name} d ON e.rowid = d.rowid
683                     {where_clause}
684                     ORDER BY distance"
685                ))?;
686
687                let results = stmt
688                    .query_map(rusqlite::params_from_iter(params), |row| {
689                        Ok((row.get::<_, f64>(1)?, row.get::<_, String>(0)?))
690                    })?
691                    .collect::<Result<Vec<_>, _>>()?;
692                Ok(results)
693            })
694            .await
695            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
696
697        debug!("Found {} matching document IDs", results.len());
698        Ok(results)
699    }
700}
701
702fn serialize_embedding(embedding: &Embedding) -> Vec<f32> {
703    embedding.vec.iter().map(|x| *x as f32).collect()
704}
705
706impl ColumnValue for String {
707    fn to_sql_string(&self) -> String {
708        self.clone()
709    }
710
711    fn column_type(&self) -> &'static str {
712        "TEXT"
713    }
714}