Skip to main content

rig_sqlite/
lib.rs

1use rig::embeddings::{Embedding, EmbeddingModel};
2use rig::vector_store::request::{FilterError, SearchFilter, VectorSearchRequest};
3use rig::vector_store::{InsertDocuments, VectorStoreError, VectorStoreIndex};
4use rig::wasm_compat::{WasmCompatSend, WasmCompatSync};
5use rig::{Embed, OneOrMany};
6use rusqlite::types::Value;
7use serde::{Deserialize, Serialize};
8use std::marker::PhantomData;
9use std::ops::RangeInclusive;
10use tokio_rusqlite::Connection;
11use tracing::{debug, info};
12use zerocopy::IntoBytes;
13
14#[derive(Debug)]
15pub enum SqliteError {
16    DatabaseError(Box<dyn std::error::Error + Send + Sync>),
17    SerializationError(Box<dyn std::error::Error + Send + Sync>),
18    InvalidColumnType(String),
19}
20
21pub trait ColumnValue: Send + Sync {
22    fn to_sql_string(&self) -> String;
23    fn column_type(&self) -> &'static str;
24}
25
26pub struct Column {
27    name: &'static str,
28    col_type: &'static str,
29    indexed: bool,
30}
31
32impl Column {
33    pub fn new(name: &'static str, col_type: &'static str) -> Self {
34        Self {
35            name,
36            col_type,
37            indexed: false,
38        }
39    }
40
41    pub fn indexed(mut self) -> Self {
42        self.indexed = true;
43        self
44    }
45}
46
47/// Example of a document type that can be used with SqliteVectorStore
48/// ```rust
49/// use rig::Embed;
50/// use serde::{Deserialize, Serialize};
51/// use rig_sqlite::{Column, ColumnValue, SqliteVectorStoreTable};
52///
53/// #[derive(Embed, Clone, Debug, Deserialize, Serialize)]
54/// struct Document {
55///     id: String,
56///     #[embed]
57///     content: String,
58/// }
59///
60/// impl SqliteVectorStoreTable for Document {
61///     fn name() -> &'static str {
62///         "documents"
63///     }
64///
65///     fn schema() -> Vec<Column> {
66///         vec![
67///             Column::new("id", "TEXT PRIMARY KEY"),
68///             Column::new("content", "TEXT"),
69///         ]
70///     }
71///
72///     fn id(&self) -> String {
73///         self.id.clone()
74///     }
75///
76///     fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)> {
77///         vec![
78///             ("id", Box::new(self.id.clone())),
79///             ("content", Box::new(self.content.clone())),
80///         ]
81///     }
82/// }
83/// ```
84pub trait SqliteVectorStoreTable: Send + Sync + Clone {
85    fn name() -> &'static str;
86    fn schema() -> Vec<Column>;
87    fn id(&self) -> String;
88    fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)>;
89}
90
91#[derive(Clone)]
92pub struct SqliteVectorStore<E, T>
93where
94    E: EmbeddingModel + 'static,
95    T: SqliteVectorStoreTable + 'static,
96{
97    conn: Connection,
98    _phantom: PhantomData<(E, T)>,
99}
100
101impl<E, T> SqliteVectorStore<E, T>
102where
103    E: EmbeddingModel + Clone + 'static,
104    T: SqliteVectorStoreTable + 'static,
105{
106    pub async fn new(conn: Connection, embedding_model: &E) -> Result<Self, VectorStoreError> {
107        let dims = embedding_model.ndims();
108        let table_name = T::name();
109        let schema = T::schema();
110
111        // Build the table schema
112        let mut create_table = format!("CREATE TABLE IF NOT EXISTS {table_name} (");
113
114        // Add columns
115        let mut first = true;
116        for column in &schema {
117            if !first {
118                create_table.push(',');
119            }
120            create_table.push_str(&format!("\n    {} {}", column.name, column.col_type));
121            first = false;
122        }
123
124        create_table.push_str("\n)");
125
126        // Build index creation statements
127        let mut create_indexes = vec![format!(
128            "CREATE INDEX IF NOT EXISTS idx_{}_id ON {}(id)",
129            table_name, table_name
130        )];
131
132        // Add indexes for marked columns
133        for column in schema {
134            if column.indexed {
135                create_indexes.push(format!(
136                    "CREATE INDEX IF NOT EXISTS idx_{}_{} ON {}({})",
137                    table_name, column.name, table_name, column.name
138                ));
139            }
140        }
141
142        conn.call(move |conn| {
143            conn.execute_batch("BEGIN")?;
144
145            // Create document table
146            conn.execute_batch(&create_table)?;
147
148            // Create indexes
149            for index_stmt in create_indexes {
150                conn.execute_batch(&index_stmt)?;
151            }
152
153            // Create embeddings table
154            conn.execute_batch(&format!(
155                "CREATE VIRTUAL TABLE IF NOT EXISTS {table_name}_embeddings USING vec0(embedding float[{dims}])"
156            ))?;
157
158            conn.execute_batch("COMMIT")?;
159            Ok(())
160        })
161        .await
162        .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
163
164        Ok(Self {
165            conn,
166            _phantom: PhantomData,
167        })
168    }
169
170    pub fn index(self, model: E) -> SqliteVectorIndex<E, T> {
171        SqliteVectorIndex::new(model, self)
172    }
173
174    pub fn add_rows_with_txn(
175        &self,
176        txn: &rusqlite::Transaction<'_>,
177        documents: Vec<(T, OneOrMany<Embedding>)>,
178    ) -> Result<i64, tokio_rusqlite::Error> {
179        info!("Adding {} documents to store", documents.len());
180        let table_name = T::name();
181        let mut last_id = 0;
182
183        for (doc, embeddings) in &documents {
184            debug!("Storing document with id {}", doc.id());
185
186            let values = doc.column_values();
187            let columns = values.iter().map(|(col, _)| *col).collect::<Vec<_>>();
188
189            let placeholders = (1..=values.len())
190                .map(|i| format!("?{i}"))
191                .collect::<Vec<_>>();
192
193            let insert_sql = format!(
194                "INSERT OR REPLACE INTO {} ({}) VALUES ({})",
195                table_name,
196                columns.join(", "),
197                placeholders.join(", ")
198            );
199
200            txn.execute(
201                &insert_sql,
202                rusqlite::params_from_iter(values.iter().map(|(_, val)| val.to_sql_string())),
203            )?;
204            last_id = txn.last_insert_rowid();
205
206            let embeddings_sql =
207                format!("INSERT INTO {table_name}_embeddings (rowid, embedding) VALUES (?1, ?2)");
208
209            let mut stmt = txn.prepare(&embeddings_sql)?;
210            for (i, embedding) in embeddings.iter().enumerate() {
211                let vec = serialize_embedding(embedding);
212                debug!(
213                    "Storing embedding {} of {} (size: {} bytes)",
214                    i + 1,
215                    embeddings.len(),
216                    vec.len() * 4
217                );
218                let blob = rusqlite::types::Value::Blob(vec.as_bytes().to_vec());
219                stmt.execute(rusqlite::params![last_id, blob])?;
220            }
221        }
222
223        Ok(last_id)
224    }
225
226    pub async fn add_rows(
227        &self,
228        documents: Vec<(T, OneOrMany<Embedding>)>,
229    ) -> Result<i64, VectorStoreError>
230    where
231        T: 'static,
232        Self: 'static,
233    {
234        let cloned = self.clone();
235
236        self.conn
237            .call(move |conn| {
238                let tx = conn.transaction()?;
239                let result = cloned.add_rows_with_txn(&tx, documents)?;
240                tx.commit()?;
241
242                Ok(result)
243            })
244            .await
245            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))
246    }
247}
248
249impl<E, T> InsertDocuments for SqliteVectorStore<E, T>
250where
251    E: EmbeddingModel + Clone + WasmCompatSend + WasmCompatSync + 'static,
252    T: SqliteVectorStoreTable
253        + for<'de> Deserialize<'de>
254        + WasmCompatSend
255        + WasmCompatSync
256        + 'static,
257{
258    async fn insert_documents<Doc: Serialize + Embed + WasmCompatSend>(
259        &self,
260        documents: Vec<(Doc, OneOrMany<Embedding>)>,
261    ) -> Result<(), VectorStoreError> {
262        if documents.is_empty() {
263            return Ok(());
264        }
265
266        let rows = documents
267            .into_iter()
268            .map(|(document, embeddings)| {
269                let document = serde_json::to_value(document)?;
270                let row = serde_json::from_value::<T>(document)?;
271
272                Ok((row, embeddings))
273            })
274            .collect::<Result<Vec<_>, VectorStoreError>>()?;
275
276        self.add_rows(rows).await?;
277
278        Ok(())
279    }
280}
281
282#[derive(Clone, Default, Deserialize, Serialize, Debug)]
283pub struct SqliteSearchFilter {
284    condition: String,
285    params: Vec<serde_json::Value>,
286}
287
288impl SearchFilter for SqliteSearchFilter {
289    type Value = serde_json::Value;
290
291    fn eq(key: impl AsRef<str>, value: Self::Value) -> Self {
292        Self {
293            condition: format!("{} = ?", key.as_ref()),
294            params: vec![value],
295        }
296    }
297
298    fn gt(key: impl AsRef<str>, value: Self::Value) -> Self {
299        Self {
300            condition: format!("{} > ?", key.as_ref()),
301            params: vec![value],
302        }
303    }
304
305    fn lt(key: impl AsRef<str>, value: Self::Value) -> Self {
306        Self {
307            condition: format!("{} < ?", key.as_ref()),
308            params: vec![value],
309        }
310    }
311
312    fn and(self, rhs: Self) -> Self {
313        Self {
314            condition: format!("({}) AND ({})", self.condition, rhs.condition),
315            params: self.params.into_iter().chain(rhs.params).collect(),
316        }
317    }
318
319    fn or(self, rhs: Self) -> Self {
320        Self {
321            condition: format!("({}) OR ({})", self.condition, rhs.condition),
322            params: self.params.into_iter().chain(rhs.params).collect(),
323        }
324    }
325}
326
327impl SqliteSearchFilter {
328    #[allow(clippy::should_implement_trait)]
329    pub fn not(self) -> Self {
330        Self {
331            condition: format!("NOT ({})", self.condition),
332            ..self
333        }
334    }
335
336    /// Tests whether the value at `key` is contained in the range
337    pub fn between<N>(key: String, range: RangeInclusive<N>) -> Self
338    where
339        N: Ord + rusqlite::ToSql + std::fmt::Display,
340    {
341        let lo = range.start();
342        let hi = range.end();
343
344        Self {
345            condition: format!("{key} between {lo} and {hi}"),
346            ..Default::default()
347        }
348    }
349
350    // Null checks
351    pub fn is_null(key: String) -> Self {
352        Self {
353            condition: format!("{key} is null"),
354            ..Default::default()
355        }
356    }
357
358    pub fn is_not_null(key: String) -> Self {
359        Self {
360            condition: format!("{key} is not null"),
361            ..Default::default()
362        }
363    }
364
365    // String ops
366    /// Tests whether the value at `key` satisfies the glob pattern
367    /// `pattern` should be a valid SQLite glob pattern
368    pub fn glob<'a, S>(key: String, pattern: S) -> Self
369    where
370        S: AsRef<&'a str>,
371    {
372        Self {
373            condition: format!("{key} glob {}", pattern.as_ref()),
374            ..Default::default()
375        }
376    }
377
378    /// Tests whether the value at `key` satisfies the "like" pattern
379    /// `pattern` should be a valid SQLite like pattern
380    pub fn like<'a, S>(key: String, pattern: S) -> Self
381    where
382        S: AsRef<&'a str>,
383    {
384        Self {
385            condition: format!("{key} like {}", pattern.as_ref()),
386            ..Default::default()
387        }
388    }
389}
390
391impl SqliteSearchFilter {
392    fn compile_params(self) -> Result<Vec<Value>, FilterError> {
393        let mut params = Vec::with_capacity(self.params.len());
394
395        fn convert(value: serde_json::Value) -> Result<Value, FilterError> {
396            use serde_json::Value::*;
397
398            match value {
399                Null => Ok(Value::Null),
400                Bool(b) => Ok(Value::Integer(b as i64)),
401                String(s) => Ok(Value::Text(s)),
402                Number(n) => Ok(if let Some(float) = n.as_f64() {
403                    Value::Real(float)
404                } else if let Some(int) = n.as_i64() {
405                    Value::Integer(int)
406                } else if let Some(int) = n.as_u64() {
407                    Value::Integer(int as i64)
408                } else {
409                    Value::Text(n.to_string())
410                }),
411                Array(arr) => {
412                    let blob = serde_json::to_vec(&arr)
413                        .map_err(|e| FilterError::Serialization(e.to_string()))?;
414
415                    Ok(Value::Blob(blob))
416                }
417                Object(obj) => {
418                    let blob = serde_json::to_vec(&obj)
419                        .map_err(|e| FilterError::Serialization(e.to_string()))?;
420
421                    Ok(Value::Blob(blob))
422                }
423            }
424        }
425
426        for param in self.params.into_iter() {
427            params.push(convert(param)?)
428        }
429
430        Ok(params)
431    }
432}
433
434/// SQLite vector store implementation for Rig.
435///
436/// This crate provides a SQLite-based vector store implementation that can be used with Rig.
437/// It uses the `sqlite-vec` extension to enable vector similarity search capabilities.
438///
439/// # Example
440/// ```no_run
441/// use rig::{
442///     client::EmbeddingsClient,
443///     embeddings::EmbeddingsBuilder,
444///     providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
445///     vector_store::{InsertDocuments, VectorStoreIndex},
446///     Embed,
447/// };
448/// use rig_sqlite::{Column, ColumnValue, SqliteVectorStore, SqliteVectorStoreTable};
449/// use rig::vector_store::request::VectorSearchRequest;
450/// use serde::{Deserialize, Serialize};
451/// use tokio_rusqlite::Connection;
452///
453/// # async fn example() -> anyhow::Result<()> {
454/// #[derive(Embed, Clone, Debug, Deserialize, Serialize)]
455/// struct Document {
456///     id: String,
457///     #[embed]
458///     content: String,
459/// }
460///
461/// impl SqliteVectorStoreTable for Document {
462///     fn name() -> &'static str {
463///         "documents"
464///     }
465///
466///     fn schema() -> Vec<Column> {
467///         vec![
468///             Column::new("id", "TEXT PRIMARY KEY"),
469///             Column::new("content", "TEXT"),
470///         ]
471///     }
472///
473///     fn id(&self) -> String {
474///         self.id.clone()
475///     }
476///
477///     fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)> {
478///         vec![
479///             ("id", Box::new(self.id.clone())),
480///             ("content", Box::new(self.content.clone())),
481///         ]
482///     }
483/// }
484///
485/// let conn = Connection::open("vector_store.db").await?;
486/// let openai_client = Client::new("YOUR_API_KEY")?;
487/// let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);
488///
489/// // Initialize vector store
490/// let vector_store: SqliteVectorStore<_, Document> = SqliteVectorStore::new(conn, &model).await?;
491///
492/// // Create documents
493/// let documents = vec![
494///     Document {
495///         id: "doc1".to_string(),
496///         content: "Example document 1".to_string(),
497///     },
498///     Document {
499///         id: "doc2".to_string(),
500///         content: "Example document 2".to_string(),
501///     },
502/// ];
503///
504/// // Generate embeddings
505/// let embeddings = EmbeddingsBuilder::new(model.clone())
506///     .documents(documents)?
507///     .build()
508///     .await?;
509///
510/// // Add to vector store
511/// vector_store.insert_documents(embeddings).await?;
512///
513/// // Create index and search
514/// let index = vector_store.index(model);
515/// let req = VectorSearchRequest::builder()
516///     .query("Example query")
517///     .samples(2)
518///     .build();
519/// let results = index.top_n::<Document>(req).await?;
520/// # let _ = results;
521/// # Ok(())
522/// # }
523/// # let _ = example();
524/// ```
525pub struct SqliteVectorIndex<E, T>
526where
527    E: EmbeddingModel + 'static,
528    T: SqliteVectorStoreTable + 'static,
529{
530    store: SqliteVectorStore<E, T>,
531    embedding_model: E,
532}
533
534impl<E, T> SqliteVectorIndex<E, T>
535where
536    E: EmbeddingModel + 'static,
537    T: SqliteVectorStoreTable,
538{
539    pub fn new(embedding_model: E, store: SqliteVectorStore<E, T>) -> Self {
540        Self {
541            store,
542            embedding_model,
543        }
544    }
545}
546
547fn build_where_clause(
548    req: &VectorSearchRequest<SqliteSearchFilter>,
549    query_vec: Vec<f32>,
550) -> Result<(String, Vec<Value>), FilterError> {
551    let thresh = req.threshold().unwrap_or(0.);
552    let thresh = SqliteSearchFilter::gt("distance", thresh.into());
553
554    let filter = req
555        .filter()
556        .as_ref()
557        .cloned()
558        .map(|filter| thresh.clone().and(filter))
559        .unwrap_or(thresh);
560
561    let where_clause = format!(
562        "WHERE e.embedding MATCH ? AND k = ? AND {}",
563        filter.condition
564    );
565
566    let query_vec = query_vec.into_iter().flat_map(f32::to_le_bytes).collect();
567    let query_vec = Value::Blob(query_vec);
568    let samples = req.samples() as u32;
569
570    let mut params = vec![query_vec.clone(), query_vec, samples.into()];
571    let filter_params = filter.clone().compile_params()?;
572    params.extend(filter_params);
573
574    Ok((where_clause, params))
575}
576
577impl<E: EmbeddingModel + std::marker::Sync, T: SqliteVectorStoreTable> VectorStoreIndex
578    for SqliteVectorIndex<E, T>
579{
580    type Filter = SqliteSearchFilter;
581
582    async fn top_n<D>(
583        &self,
584        req: VectorSearchRequest<SqliteSearchFilter>,
585    ) -> Result<Vec<(f64, String, D)>, VectorStoreError>
586    where
587        D: for<'de> Deserialize<'de>,
588    {
589        tracing::debug!("Finding top {} matches for query", req.samples() as usize);
590        let embedding = self.embedding_model.embed_text(req.query()).await?;
591        let query_vec: Vec<f32> = serialize_embedding(&embedding);
592        let table_name = T::name();
593
594        // Get all column names from SqliteVectorStoreTable
595        let columns = T::schema();
596        let column_names: Vec<&str> = columns.iter().map(|column| column.name).collect();
597
598        // Build SELECT statement with all columns
599        let select_cols = column_names.join(", ");
600
601        let (where_clause, params) = build_where_clause(&req, query_vec)?;
602
603        let rows = self
604            .store
605            .conn
606            .call(move |conn| {
607                let mut stmt = conn.prepare(&format!(
608                    "SELECT d.{select_cols}, (1 - vec_distance_cosine(?, e.embedding)) as distance
609                    FROM {table_name}_embeddings e
610                    JOIN {table_name} d ON e.rowid = d.rowid
611                    {where_clause}
612                    ORDER BY distance"
613                ))?;
614
615                let rows = stmt
616                    .query_map(rusqlite::params_from_iter(params), |row| {
617                        // Create a map of column names to values
618                        let mut map = serde_json::Map::new();
619                        for (i, col_name) in column_names.iter().enumerate() {
620                            let value: String = row.get(i)?;
621                            map.insert(col_name.to_string(), serde_json::Value::String(value));
622                        }
623                        let distance: f64 = row.get(column_names.len())?;
624                        let id: String = row.get(0)?; // Assuming id is always first column
625
626                        Ok((id, serde_json::Value::Object(map), distance))
627                    })?
628                    .collect::<Result<Vec<_>, _>>()?;
629                Ok(rows)
630            })
631            .await
632            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
633
634        debug!("Found {} potential matches", rows.len());
635        let mut top_n = Vec::new();
636        for (id, doc_value, distance) in rows {
637            match serde_json::from_value::<D>(doc_value) {
638                Ok(doc) => {
639                    top_n.push((distance, id, doc));
640                }
641                Err(e) => {
642                    debug!("Failed to deserialize document {}: {}", id, e);
643                    continue;
644                }
645            }
646        }
647
648        debug!("Returning {} matches", top_n.len());
649        Ok(top_n)
650    }
651
652    async fn top_n_ids(
653        &self,
654        req: VectorSearchRequest<SqliteSearchFilter>,
655    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
656        tracing::debug!(
657            "Finding top {} document IDs for query",
658            req.samples() as usize
659        );
660        let embedding = self.embedding_model.embed_text(req.query()).await?;
661        let query_vec = serialize_embedding(&embedding);
662        let table_name = T::name();
663
664        let (where_clause, params) = build_where_clause(&req, query_vec)?;
665
666        let results = self
667            .store
668            .conn
669            .call(move |conn| {
670                let mut stmt = conn.prepare(&format!(
671                    "SELECT d.id, (1 - vec_distance_cosine(?1, e.embedding)) as distance
672                     FROM {table_name}_embeddings e
673                     JOIN {table_name} d ON e.rowid = d.rowid
674                     {where_clause}
675                     ORDER BY distance"
676                ))?;
677
678                let results = stmt
679                    .query_map(rusqlite::params_from_iter(params), |row| {
680                        Ok((row.get::<_, f64>(1)?, row.get::<_, String>(0)?))
681                    })?
682                    .collect::<Result<Vec<_>, _>>()?;
683                Ok(results)
684            })
685            .await
686            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
687
688        debug!("Found {} matching document IDs", results.len());
689        Ok(results)
690    }
691}
692
693fn serialize_embedding(embedding: &Embedding) -> Vec<f32> {
694    embedding.vec.iter().map(|x| *x as f32).collect()
695}
696
697impl ColumnValue for String {
698    fn to_sql_string(&self) -> String {
699        self.clone()
700    }
701
702    fn column_type(&self) -> &'static str {
703        "TEXT"
704    }
705}