Skip to main content

rig_sqlite/
lib.rs

1use rig::embeddings::{Embedding, EmbeddingModel};
2use rig::vector_store::request::{FilterError, SearchFilter, VectorSearchRequest};
3use rig::vector_store::{InsertDocuments, VectorStoreError, VectorStoreIndex};
4use rig::wasm_compat::{WasmCompatSend, WasmCompatSync};
5use rig::{Embed, OneOrMany};
6use rusqlite::types::Value;
7use serde::{Deserialize, Serialize};
8use std::marker::PhantomData;
9use std::ops::RangeInclusive;
10use tokio_rusqlite::Connection;
11use tracing::{debug, info};
12use zerocopy::IntoBytes;
13
14#[derive(Debug)]
15pub enum SqliteError {
16    DatabaseError(Box<dyn std::error::Error + Send + Sync>),
17    SerializationError(Box<dyn std::error::Error + Send + Sync>),
18    InvalidColumnType(String),
19}
20
21pub trait ColumnValue: Send + Sync {
22    fn to_sql_string(&self) -> String;
23    fn column_type(&self) -> &'static str;
24}
25
26pub struct Column {
27    name: &'static str,
28    col_type: &'static str,
29    indexed: bool,
30}
31
32impl Column {
33    pub fn new(name: &'static str, col_type: &'static str) -> Self {
34        Self {
35            name,
36            col_type,
37            indexed: false,
38        }
39    }
40
41    pub fn indexed(mut self) -> Self {
42        self.indexed = true;
43        self
44    }
45}
46
47/// Example of a document type that can be used with SqliteVectorStore
48/// ```rust
49/// use rig::Embed;
50/// use serde::{Deserialize, Serialize};
51/// use rig_sqlite::{Column, ColumnValue, SqliteVectorStoreTable};
52///
53/// #[derive(Embed, Clone, Debug, Deserialize, Serialize)]
54/// struct Document {
55///     id: String,
56///     #[embed]
57///     content: String,
58/// }
59///
60/// impl SqliteVectorStoreTable for Document {
61///     fn name() -> &'static str {
62///         "documents"
63///     }
64///
65///     fn schema() -> Vec<Column> {
66///         vec![
67///             Column::new("id", "TEXT PRIMARY KEY"),
68///             Column::new("content", "TEXT"),
69///         ]
70///     }
71///
72///     fn id(&self) -> String {
73///         self.id.clone()
74///     }
75///
76///     fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)> {
77///         vec![
78///             ("id", Box::new(self.id.clone())),
79///             ("content", Box::new(self.content.clone())),
80///         ]
81///     }
82/// }
83/// ```
84pub trait SqliteVectorStoreTable: Send + Sync + Clone {
85    fn name() -> &'static str;
86    fn schema() -> Vec<Column>;
87    fn id(&self) -> String;
88    fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)>;
89}
90
91#[derive(Clone)]
92pub struct SqliteVectorStore<E, T>
93where
94    E: EmbeddingModel + 'static,
95    T: SqliteVectorStoreTable + 'static,
96{
97    conn: Connection,
98    _phantom: PhantomData<(E, T)>,
99}
100
101impl<E, T> SqliteVectorStore<E, T>
102where
103    E: EmbeddingModel + Clone + 'static,
104    T: SqliteVectorStoreTable + 'static,
105{
106    pub async fn new(conn: Connection, embedding_model: &E) -> Result<Self, VectorStoreError> {
107        let dims = embedding_model.ndims();
108        let table_name = T::name();
109        let schema = T::schema();
110
111        // Build the table schema
112        let mut create_table = format!("CREATE TABLE IF NOT EXISTS {table_name} (");
113
114        // Add columns
115        let mut first = true;
116        for column in &schema {
117            if !first {
118                create_table.push(',');
119            }
120            create_table.push_str(&format!("\n    {} {}", column.name, column.col_type));
121            first = false;
122        }
123
124        create_table.push_str("\n)");
125
126        // Build index creation statements
127        let mut create_indexes = vec![format!(
128            "CREATE INDEX IF NOT EXISTS idx_{}_id ON {}(id)",
129            table_name, table_name
130        )];
131
132        // Add indexes for marked columns
133        for column in schema {
134            if column.indexed {
135                create_indexes.push(format!(
136                    "CREATE INDEX IF NOT EXISTS idx_{}_{} ON {}({})",
137                    table_name, column.name, table_name, column.name
138                ));
139            }
140        }
141
142        conn.call(move |conn| {
143            conn.execute_batch("BEGIN")?;
144
145            // Create document table
146            conn.execute_batch(&create_table)?;
147
148            // Create indexes
149            for index_stmt in create_indexes {
150                conn.execute_batch(&index_stmt)?;
151            }
152
153            // Create embeddings table
154            conn.execute_batch(&format!(
155                "CREATE VIRTUAL TABLE IF NOT EXISTS {table_name}_embeddings USING vec0(embedding float[{dims}])"
156            ))?;
157
158            conn.execute_batch("COMMIT")?;
159            Ok(())
160        })
161        .await
162        .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
163
164        Ok(Self {
165            conn,
166            _phantom: PhantomData,
167        })
168    }
169
170    pub fn index(self, model: E) -> SqliteVectorIndex<E, T> {
171        SqliteVectorIndex::new(model, self)
172    }
173
174    pub fn add_rows_with_txn(
175        &self,
176        txn: &rusqlite::Transaction<'_>,
177        documents: Vec<(T, OneOrMany<Embedding>)>,
178    ) -> Result<i64, tokio_rusqlite::Error> {
179        info!("Adding {} documents to store", documents.len());
180        let table_name = T::name();
181        let mut last_id = 0;
182
183        for (doc, embeddings) in &documents {
184            debug!("Storing document with id {}", doc.id());
185
186            let values = doc.column_values();
187            let columns = values.iter().map(|(col, _)| *col).collect::<Vec<_>>();
188
189            let placeholders = (1..=values.len())
190                .map(|i| format!("?{i}"))
191                .collect::<Vec<_>>();
192
193            let insert_sql = format!(
194                "INSERT OR REPLACE INTO {} ({}) VALUES ({})",
195                table_name,
196                columns.join(", "),
197                placeholders.join(", ")
198            );
199
200            txn.execute(
201                &insert_sql,
202                rusqlite::params_from_iter(values.iter().map(|(_, val)| val.to_sql_string())),
203            )?;
204            last_id = txn.last_insert_rowid();
205
206            let embeddings_sql =
207                format!("INSERT INTO {table_name}_embeddings (rowid, embedding) VALUES (?1, ?2)");
208
209            let mut stmt = txn.prepare(&embeddings_sql)?;
210            for (i, embedding) in embeddings.iter().enumerate() {
211                let vec = serialize_embedding(embedding);
212                debug!(
213                    "Storing embedding {} of {} (size: {} bytes)",
214                    i + 1,
215                    embeddings.len(),
216                    vec.len() * 4
217                );
218                let blob = rusqlite::types::Value::Blob(vec.as_bytes().to_vec());
219                stmt.execute(rusqlite::params![last_id, blob])?;
220            }
221        }
222
223        Ok(last_id)
224    }
225
226    pub async fn add_rows(
227        &self,
228        documents: Vec<(T, OneOrMany<Embedding>)>,
229    ) -> Result<i64, VectorStoreError>
230    where
231        T: 'static,
232        Self: 'static,
233    {
234        let cloned = self.clone();
235
236        self.conn
237            .call(move |conn| {
238                let tx = conn.transaction()?;
239                let result = cloned.add_rows_with_txn(&tx, documents)?;
240                tx.commit()?;
241
242                Ok(result)
243            })
244            .await
245            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))
246    }
247}
248
249impl<E, T> InsertDocuments for SqliteVectorStore<E, T>
250where
251    E: EmbeddingModel + Clone + WasmCompatSend + WasmCompatSync + 'static,
252    T: SqliteVectorStoreTable
253        + for<'de> Deserialize<'de>
254        + WasmCompatSend
255        + WasmCompatSync
256        + 'static,
257{
258    async fn insert_documents<Doc: Serialize + Embed + WasmCompatSend>(
259        &self,
260        documents: Vec<(Doc, OneOrMany<Embedding>)>,
261    ) -> Result<(), VectorStoreError> {
262        if documents.is_empty() {
263            return Ok(());
264        }
265
266        let rows = documents
267            .into_iter()
268            .map(|(document, embeddings)| {
269                let document = serde_json::to_value(document)?;
270                let row = serde_json::from_value::<T>(document)?;
271
272                Ok((row, embeddings))
273            })
274            .collect::<Result<Vec<_>, VectorStoreError>>()?;
275
276        self.add_rows(rows).await?;
277
278        Ok(())
279    }
280}
281
282#[derive(Clone, Default, Deserialize, Serialize, Debug)]
283pub struct SqliteSearchFilter {
284    condition: String,
285    params: Vec<serde_json::Value>,
286}
287
288impl SearchFilter for SqliteSearchFilter {
289    type Value = serde_json::Value;
290
291    fn eq(key: impl AsRef<str>, value: Self::Value) -> Self {
292        Self {
293            condition: format!("{} = ?", key.as_ref()),
294            params: vec![value],
295        }
296    }
297
298    fn gt(key: impl AsRef<str>, value: Self::Value) -> Self {
299        Self {
300            condition: format!("{} > ?", key.as_ref()),
301            params: vec![value],
302        }
303    }
304
305    fn lt(key: impl AsRef<str>, value: Self::Value) -> Self {
306        Self {
307            condition: format!("{} < ?", key.as_ref()),
308            params: vec![value],
309        }
310    }
311
312    fn and(self, rhs: Self) -> Self {
313        Self {
314            condition: format!("({}) AND ({})", self.condition, rhs.condition),
315            params: self.params.into_iter().chain(rhs.params).collect(),
316        }
317    }
318
319    fn or(self, rhs: Self) -> Self {
320        Self {
321            condition: format!("({}) OR ({})", self.condition, rhs.condition),
322            params: self.params.into_iter().chain(rhs.params).collect(),
323        }
324    }
325}
326
327impl SqliteSearchFilter {
328    #[allow(clippy::should_implement_trait)]
329    pub fn not(self) -> Self {
330        Self {
331            condition: format!("NOT ({})", self.condition),
332            ..self
333        }
334    }
335
336    /// Tests whether the value at `key` is contained in the range
337    pub fn between<N>(key: String, range: RangeInclusive<N>) -> Self
338    where
339        N: Ord + rusqlite::ToSql + std::fmt::Display,
340    {
341        let lo = range.start();
342        let hi = range.end();
343
344        Self {
345            condition: format!("{key} between {lo} and {hi}"),
346            ..Default::default()
347        }
348    }
349
350    // Null checks
351    pub fn is_null(key: String) -> Self {
352        Self {
353            condition: format!("{key} is null"),
354            ..Default::default()
355        }
356    }
357
358    pub fn is_not_null(key: String) -> Self {
359        Self {
360            condition: format!("{key} is not null"),
361            ..Default::default()
362        }
363    }
364
365    // String ops
366    /// Tests whether the value at `key` satisfies the glob pattern
367    /// `pattern` should be a valid SQLite glob pattern
368    pub fn glob<'a, S>(key: String, pattern: S) -> Self
369    where
370        S: AsRef<&'a str>,
371    {
372        Self {
373            condition: format!("{key} glob {}", pattern.as_ref()),
374            ..Default::default()
375        }
376    }
377
378    /// Tests whether the value at `key` satisfies the "like" pattern
379    /// `pattern` should be a valid SQLite like pattern
380    pub fn like<'a, S>(key: String, pattern: S) -> Self
381    where
382        S: AsRef<&'a str>,
383    {
384        Self {
385            condition: format!("{key} like {}", pattern.as_ref()),
386            ..Default::default()
387        }
388    }
389}
390
391impl SqliteSearchFilter {
392    fn compile_params(self) -> Result<Vec<Value>, FilterError> {
393        let mut params = Vec::with_capacity(self.params.len());
394
395        fn convert(value: serde_json::Value) -> Result<Value, FilterError> {
396            use serde_json::Value::*;
397
398            match value {
399                Null => Ok(Value::Null),
400                Bool(b) => Ok(Value::Integer(b as i64)),
401                String(s) => Ok(Value::Text(s)),
402                Number(n) => Ok(if let Some(float) = n.as_f64() {
403                    Value::Real(float)
404                } else if let Some(int) = n.as_i64() {
405                    Value::Integer(int)
406                } else {
407                    unreachable!()
408                }),
409                Array(arr) => {
410                    let blob = serde_json::to_vec(&arr)
411                        .map_err(|e| FilterError::Serialization(e.to_string()))?;
412
413                    Ok(Value::Blob(blob))
414                }
415                Object(obj) => {
416                    let blob = serde_json::to_vec(&obj)
417                        .map_err(|e| FilterError::Serialization(e.to_string()))?;
418
419                    Ok(Value::Blob(blob))
420                }
421            }
422        }
423
424        for param in self.params.into_iter() {
425            params.push(convert(param)?)
426        }
427
428        Ok(params)
429    }
430}
431
432/// SQLite vector store implementation for Rig.
433///
434/// This crate provides a SQLite-based vector store implementation that can be used with Rig.
435/// It uses the `sqlite-vec` extension to enable vector similarity search capabilities.
436///
437/// # Example
438/// ```rust
439/// use rig::{
440///     embeddings::EmbeddingsBuilder,
441///     providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
442///     vector_store::{InsertDocuments, VectorStoreIndex},
443///     Embed,
444/// };
445/// use rig_sqlite::{Column, ColumnValue, SqliteVectorStore, SqliteVectorStoreTable};
446/// use rig::vector_store::request::VectorSearchRequest;
447/// use serde::{Deserialize, Serialize};
448/// use tokio_rusqlite::Connection;
449///
450/// #[derive(Embed, Clone, Debug, Deserialize, Serialize)]
451/// struct Document {
452///     id: String,
453///     #[embed]
454///     content: String,
455/// }
456///
457/// impl SqliteVectorStoreTable for Document {
458///     fn name() -> &'static str {
459///         "documents"
460///     }
461///
462///     fn schema() -> Vec<Column> {
463///         vec![
464///             Column::new("id", "TEXT PRIMARY KEY"),
465///             Column::new("content", "TEXT"),
466///         ]
467///     }
468///
469///     fn id(&self) -> String {
470///         self.id.clone()
471///     }
472///
473///     fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)> {
474///         vec![
475///             ("id", Box::new(self.id.clone())),
476///             ("content", Box::new(self.content.clone())),
477///         ]
478///     }
479/// }
480///
481/// let conn = Connection::open("vector_store.db").await?;
482/// let openai_client = Client::new("YOUR_API_KEY");
483/// let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);
484///
485/// // Initialize vector store
486/// let vector_store = SqliteVectorStore::new(conn, &model).await?;
487///
488/// // Create documents
489/// let documents = vec![
490///     Document {
491///         id: "doc1".to_string(),
492///         content: "Example document 1".to_string(),
493///     },
494///     Document {
495///         id: "doc2".to_string(),
496///         content: "Example document 2".to_string(),
497///     },
498/// ];
499///
500/// // Generate embeddings
501/// let embeddings = EmbeddingsBuilder::new(model.clone())
502///     .documents(documents)?
503///     .build()
504///     .await?;
505///
506/// // Add to vector store
507/// vector_store.insert_documents(embeddings).await?;
508///
509/// // Create index and search
510/// let index = vector_store.index(model);
511/// let req = VectorSearchRequest::builder()
512///     .query("Example query")
513///     .samples(2)
514///     .build()?;
515/// let results = index.top_n::<Document>(req).await?;
516/// ```
517pub struct SqliteVectorIndex<E, T>
518where
519    E: EmbeddingModel + 'static,
520    T: SqliteVectorStoreTable + 'static,
521{
522    store: SqliteVectorStore<E, T>,
523    embedding_model: E,
524}
525
526impl<E, T> SqliteVectorIndex<E, T>
527where
528    E: EmbeddingModel + 'static,
529    T: SqliteVectorStoreTable,
530{
531    pub fn new(embedding_model: E, store: SqliteVectorStore<E, T>) -> Self {
532        Self {
533            store,
534            embedding_model,
535        }
536    }
537}
538
539fn build_where_clause(
540    req: &VectorSearchRequest<SqliteSearchFilter>,
541    query_vec: Vec<f32>,
542) -> Result<(String, Vec<Value>), FilterError> {
543    let thresh = req.threshold().unwrap_or(0.);
544    let thresh = SqliteSearchFilter::gt("distance", thresh.into());
545
546    let filter = req
547        .filter()
548        .as_ref()
549        .cloned()
550        .map(|filter| thresh.clone().and(filter))
551        .unwrap_or(thresh);
552
553    let where_clause = format!(
554        "WHERE e.embedding MATCH ? AND k = ? AND {}",
555        filter.condition
556    );
557
558    let query_vec = query_vec.into_iter().flat_map(f32::to_le_bytes).collect();
559    let query_vec = Value::Blob(query_vec);
560    let samples = req.samples() as u32;
561
562    let mut params = vec![query_vec.clone(), query_vec, samples.into()];
563    let filter_params = filter.clone().compile_params()?;
564    params.extend(filter_params);
565
566    Ok((where_clause, params))
567}
568
569impl<E: EmbeddingModel + std::marker::Sync, T: SqliteVectorStoreTable> VectorStoreIndex
570    for SqliteVectorIndex<E, T>
571{
572    type Filter = SqliteSearchFilter;
573
574    async fn top_n<D>(
575        &self,
576        req: VectorSearchRequest<SqliteSearchFilter>,
577    ) -> Result<Vec<(f64, String, D)>, VectorStoreError>
578    where
579        D: for<'de> Deserialize<'de>,
580    {
581        tracing::debug!("Finding top {} matches for query", req.samples() as usize);
582        let embedding = self.embedding_model.embed_text(req.query()).await?;
583        let query_vec: Vec<f32> = serialize_embedding(&embedding);
584        let table_name = T::name();
585
586        // Get all column names from SqliteVectorStoreTable
587        let columns = T::schema();
588        let column_names: Vec<&str> = columns.iter().map(|column| column.name).collect();
589
590        // Build SELECT statement with all columns
591        let select_cols = column_names.join(", ");
592
593        let (where_clause, params) = build_where_clause(&req, query_vec)?;
594
595        let rows = self
596            .store
597            .conn
598            .call(move |conn| {
599                let mut stmt = conn.prepare(&format!(
600                    "SELECT d.{select_cols}, (1 - vec_distance_cosine(?, e.embedding)) as distance
601                    FROM {table_name}_embeddings e
602                    JOIN {table_name} d ON e.rowid = d.rowid
603                    {where_clause}
604                    ORDER BY distance"
605                ))?;
606
607                let rows = stmt
608                    .query_map(rusqlite::params_from_iter(params), |row| {
609                        // Create a map of column names to values
610                        let mut map = serde_json::Map::new();
611                        for (i, col_name) in column_names.iter().enumerate() {
612                            let value: String = row.get(i)?;
613                            map.insert(col_name.to_string(), serde_json::Value::String(value));
614                        }
615                        let distance: f64 = row.get(column_names.len())?;
616                        let id: String = row.get(0)?; // Assuming id is always first column
617
618                        Ok((id, serde_json::Value::Object(map), distance))
619                    })?
620                    .collect::<Result<Vec<_>, _>>()?;
621                Ok(rows)
622            })
623            .await
624            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
625
626        debug!("Found {} potential matches", rows.len());
627        let mut top_n = Vec::new();
628        for (id, doc_value, distance) in rows {
629            match serde_json::from_value::<D>(doc_value) {
630                Ok(doc) => {
631                    top_n.push((distance, id, doc));
632                }
633                Err(e) => {
634                    debug!("Failed to deserialize document {}: {}", id, e);
635                    continue;
636                }
637            }
638        }
639
640        debug!("Returning {} matches", top_n.len());
641        Ok(top_n)
642    }
643
644    async fn top_n_ids(
645        &self,
646        req: VectorSearchRequest<SqliteSearchFilter>,
647    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
648        tracing::debug!(
649            "Finding top {} document IDs for query",
650            req.samples() as usize
651        );
652        let embedding = self.embedding_model.embed_text(req.query()).await?;
653        let query_vec = serialize_embedding(&embedding);
654        let table_name = T::name();
655
656        let (where_clause, params) = build_where_clause(&req, query_vec)?;
657
658        let results = self
659            .store
660            .conn
661            .call(move |conn| {
662                let mut stmt = conn.prepare(&format!(
663                    "SELECT d.id, (1 - vec_distance_cosine(?1, e.embedding)) as distance
664                     FROM {table_name}_embeddings e
665                     JOIN {table_name} d ON e.rowid = d.rowid
666                     {where_clause}
667                     ORDER BY distance"
668                ))?;
669
670                let results = stmt
671                    .query_map(rusqlite::params_from_iter(params), |row| {
672                        Ok((row.get::<_, f64>(1)?, row.get::<_, String>(0)?))
673                    })?
674                    .collect::<Result<Vec<_>, _>>()?;
675                Ok(results)
676            })
677            .await
678            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
679
680        debug!("Found {} matching document IDs", results.len());
681        Ok(results)
682    }
683}
684
685fn serialize_embedding(embedding: &Embedding) -> Vec<f32> {
686    embedding.vec.iter().map(|x| *x as f32).collect()
687}
688
689impl ColumnValue for String {
690    fn to_sql_string(&self) -> String {
691        self.clone()
692    }
693
694    fn column_type(&self) -> &'static str {
695        "TEXT"
696    }
697}