rig_sqlite/
lib.rs

1use rig::OneOrMany;
2use rig::embeddings::{Embedding, EmbeddingModel};
3use rig::vector_store::request::{FilterError, SearchFilter, VectorSearchRequest};
4use rig::vector_store::{VectorStoreError, VectorStoreIndex};
5use rusqlite::types::Value;
6use serde::Deserialize;
7use std::marker::PhantomData;
8use std::ops::RangeInclusive;
9use tokio_rusqlite::Connection;
10use tracing::{debug, info};
11use zerocopy::IntoBytes;
12
13#[derive(Debug)]
14pub enum SqliteError {
15    DatabaseError(Box<dyn std::error::Error + Send + Sync>),
16    SerializationError(Box<dyn std::error::Error + Send + Sync>),
17    InvalidColumnType(String),
18}
19
20pub trait ColumnValue: Send + Sync {
21    fn to_sql_string(&self) -> String;
22    fn column_type(&self) -> &'static str;
23}
24
25pub struct Column {
26    name: &'static str,
27    col_type: &'static str,
28    indexed: bool,
29}
30
31impl Column {
32    pub fn new(name: &'static str, col_type: &'static str) -> Self {
33        Self {
34            name,
35            col_type,
36            indexed: false,
37        }
38    }
39
40    pub fn indexed(mut self) -> Self {
41        self.indexed = true;
42        self
43    }
44}
45
46/// Example of a document type that can be used with SqliteVectorStore
47/// ```rust
48/// use rig::Embed;
49/// use serde::Deserialize;
50/// use rig_sqlite::{Column, ColumnValue, SqliteVectorStoreTable};
51///
52/// #[derive(Embed, Clone, Debug, Deserialize)]
53/// struct Document {
54///     id: String,
55///     #[embed]
56///     content: String,
57/// }
58///
59/// impl SqliteVectorStoreTable for Document {
60///     fn name() -> &'static str {
61///         "documents"
62///     }
63///
64///     fn schema() -> Vec<Column> {
65///         vec![
66///             Column::new("id", "TEXT PRIMARY KEY"),
67///             Column::new("content", "TEXT"),
68///         ]
69///     }
70///
71///     fn id(&self) -> String {
72///         self.id.clone()
73///     }
74///
75///     fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)> {
76///         vec![
77///             ("id", Box::new(self.id.clone())),
78///             ("content", Box::new(self.content.clone())),
79///         ]
80///     }
81/// }
82/// ```
83pub trait SqliteVectorStoreTable: Send + Sync + Clone {
84    fn name() -> &'static str;
85    fn schema() -> Vec<Column>;
86    fn id(&self) -> String;
87    fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)>;
88}
89
90#[derive(Clone)]
91pub struct SqliteVectorStore<E, T>
92where
93    E: EmbeddingModel + 'static,
94    T: SqliteVectorStoreTable + 'static,
95{
96    conn: Connection,
97    _phantom: PhantomData<(E, T)>,
98}
99
100impl<E, T> SqliteVectorStore<E, T>
101where
102    E: EmbeddingModel + Clone + 'static,
103    T: SqliteVectorStoreTable + 'static,
104{
105    pub async fn new(conn: Connection, embedding_model: &E) -> Result<Self, VectorStoreError> {
106        let dims = embedding_model.ndims();
107        let table_name = T::name();
108        let schema = T::schema();
109
110        // Build the table schema
111        let mut create_table = format!("CREATE TABLE IF NOT EXISTS {table_name} (");
112
113        // Add columns
114        let mut first = true;
115        for column in &schema {
116            if !first {
117                create_table.push(',');
118            }
119            create_table.push_str(&format!("\n    {} {}", column.name, column.col_type));
120            first = false;
121        }
122
123        create_table.push_str("\n)");
124
125        // Build index creation statements
126        let mut create_indexes = vec![format!(
127            "CREATE INDEX IF NOT EXISTS idx_{}_id ON {}(id)",
128            table_name, table_name
129        )];
130
131        // Add indexes for marked columns
132        for column in schema {
133            if column.indexed {
134                create_indexes.push(format!(
135                    "CREATE INDEX IF NOT EXISTS idx_{}_{} ON {}({})",
136                    table_name, column.name, table_name, column.name
137                ));
138            }
139        }
140
141        conn.call(move |conn| {
142            conn.execute_batch("BEGIN")?;
143
144            // Create document table
145            conn.execute_batch(&create_table)?;
146
147            // Create indexes
148            for index_stmt in create_indexes {
149                conn.execute_batch(&index_stmt)?;
150            }
151
152            // Create embeddings table
153            conn.execute_batch(&format!(
154                "CREATE VIRTUAL TABLE IF NOT EXISTS {table_name}_embeddings USING vec0(embedding float[{dims}])"
155            ))?;
156
157            conn.execute_batch("COMMIT")?;
158            Ok(())
159        })
160        .await
161        .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
162
163        Ok(Self {
164            conn,
165            _phantom: PhantomData,
166        })
167    }
168
169    pub fn index(self, model: E) -> SqliteVectorIndex<E, T> {
170        SqliteVectorIndex::new(model, self)
171    }
172
173    pub fn add_rows_with_txn(
174        &self,
175        txn: &rusqlite::Transaction<'_>,
176        documents: Vec<(T, OneOrMany<Embedding>)>,
177    ) -> Result<i64, tokio_rusqlite::Error> {
178        info!("Adding {} documents to store", documents.len());
179        let table_name = T::name();
180        let mut last_id = 0;
181
182        for (doc, embeddings) in &documents {
183            debug!("Storing document with id {}", doc.id());
184
185            let values = doc.column_values();
186            let columns = values.iter().map(|(col, _)| *col).collect::<Vec<_>>();
187
188            let placeholders = (1..=values.len())
189                .map(|i| format!("?{i}"))
190                .collect::<Vec<_>>();
191
192            let insert_sql = format!(
193                "INSERT OR REPLACE INTO {} ({}) VALUES ({})",
194                table_name,
195                columns.join(", "),
196                placeholders.join(", ")
197            );
198
199            txn.execute(
200                &insert_sql,
201                rusqlite::params_from_iter(values.iter().map(|(_, val)| val.to_sql_string())),
202            )?;
203            last_id = txn.last_insert_rowid();
204
205            let embeddings_sql =
206                format!("INSERT INTO {table_name}_embeddings (rowid, embedding) VALUES (?1, ?2)");
207
208            let mut stmt = txn.prepare(&embeddings_sql)?;
209            for (i, embedding) in embeddings.iter().enumerate() {
210                let vec = serialize_embedding(embedding);
211                debug!(
212                    "Storing embedding {} of {} (size: {} bytes)",
213                    i + 1,
214                    embeddings.len(),
215                    vec.len() * 4
216                );
217                let blob = rusqlite::types::Value::Blob(vec.as_bytes().to_vec());
218                stmt.execute(rusqlite::params![last_id, blob])?;
219            }
220        }
221
222        Ok(last_id)
223    }
224
225    pub async fn add_rows(
226        &self,
227        documents: Vec<(T, OneOrMany<Embedding>)>,
228    ) -> Result<i64, VectorStoreError>
229    where
230        T: 'static,
231        Self: 'static,
232    {
233        let cloned = self.clone();
234
235        self.conn
236            .call(move |conn| {
237                let tx = conn.transaction()?;
238                let result = cloned.add_rows_with_txn(&tx, documents)?;
239                tx.commit()?;
240
241                Ok(result)
242            })
243            .await
244            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))
245    }
246}
247
248#[derive(Clone, Default)]
249pub struct SqliteSearchFilter {
250    condition: String,
251    params: Vec<serde_json::Value>,
252}
253
254impl SearchFilter for SqliteSearchFilter {
255    type Value = serde_json::Value;
256
257    fn eq(key: String, value: Self::Value) -> Self {
258        Self {
259            condition: format!("{key} = ?"),
260            params: vec![value],
261        }
262    }
263
264    fn gt(key: String, value: Self::Value) -> Self {
265        Self {
266            condition: format!("{key} > ?"),
267            params: vec![value],
268        }
269    }
270
271    fn lt(key: String, value: Self::Value) -> Self {
272        Self {
273            condition: format!("{key} < ?"),
274            params: vec![value],
275        }
276    }
277
278    fn and(self, rhs: Self) -> Self {
279        Self {
280            condition: format!("({}) AND ({})", self.condition, rhs.condition),
281            params: self.params.into_iter().chain(rhs.params).collect(),
282        }
283    }
284
285    fn or(self, rhs: Self) -> Self {
286        Self {
287            condition: format!("({}) OR ({})", self.condition, rhs.condition),
288            params: self.params.into_iter().chain(rhs.params).collect(),
289        }
290    }
291}
292
293impl SqliteSearchFilter {
294    #[allow(clippy::should_implement_trait)]
295    pub fn not(self) -> Self {
296        Self {
297            condition: format!("NOT ({})", self.condition),
298            ..self
299        }
300    }
301
302    /// Tests whether the value at `key` is contained in the range
303    pub fn between<N>(key: String, range: RangeInclusive<N>) -> Self
304    where
305        N: Ord + rusqlite::ToSql + std::fmt::Display,
306    {
307        let lo = range.start();
308        let hi = range.end();
309
310        Self {
311            condition: format!("{key} between {lo} and {hi}"),
312            ..Default::default()
313        }
314    }
315
316    // Null checks
317    pub fn is_null(key: String) -> Self {
318        Self {
319            condition: format!("{key} is null"),
320            ..Default::default()
321        }
322    }
323
324    pub fn is_not_null(key: String) -> Self {
325        Self {
326            condition: format!("{key} is not null"),
327            ..Default::default()
328        }
329    }
330
331    // String ops
332    /// Tests whether the value at `key` satisfies the glob pattern
333    /// `pattern` should be a valid SQLite glob pattern
334    pub fn glob<'a, S>(key: String, pattern: S) -> Self
335    where
336        S: AsRef<&'a str>,
337    {
338        Self {
339            condition: format!("{key} glob {}", pattern.as_ref()),
340            ..Default::default()
341        }
342    }
343
344    /// Tests whether the value at `key` satisfies the "like" pattern
345    /// `pattern` should be a valid SQLite like pattern
346    pub fn like<'a, S>(key: String, pattern: S) -> Self
347    where
348        S: AsRef<&'a str>,
349    {
350        Self {
351            condition: format!("{key} like {}", pattern.as_ref()),
352            ..Default::default()
353        }
354    }
355}
356
357impl SqliteSearchFilter {
358    fn compile_params(self) -> Result<Vec<Value>, FilterError> {
359        let mut params = Vec::with_capacity(self.params.len());
360
361        fn convert(value: serde_json::Value) -> Result<Value, FilterError> {
362            use serde_json::Value::*;
363
364            match value {
365                Null => Ok(Value::Null),
366                Bool(b) => Ok(Value::Integer(b as i64)),
367                String(s) => Ok(Value::Text(s)),
368                Number(n) => Ok(if let Some(float) = n.as_f64() {
369                    Value::Real(float)
370                } else if let Some(int) = n.as_i64() {
371                    Value::Integer(int)
372                } else {
373                    unreachable!()
374                }),
375                Array(arr) => {
376                    let blob = serde_json::to_vec(&arr)
377                        .map_err(|e| FilterError::Serialization(e.to_string()))?;
378
379                    Ok(Value::Blob(blob))
380                }
381                Object(obj) => {
382                    let blob = serde_json::to_vec(&obj)
383                        .map_err(|e| FilterError::Serialization(e.to_string()))?;
384
385                    Ok(Value::Blob(blob))
386                }
387            }
388        }
389
390        for param in self.params.into_iter() {
391            params.push(convert(param)?)
392        }
393
394        Ok(params)
395    }
396}
397
398/// SQLite vector store implementation for Rig.
399///
400/// This crate provides a SQLite-based vector store implementation that can be used with Rig.
401/// It uses the `sqlite-vec` extension to enable vector similarity search capabilities.
402///
403/// # Example
404/// ```rust
405/// use rig::{
406///     embeddings::EmbeddingsBuilder,
407///     providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
408///     vector_store::VectorStoreIndex,
409///     Embed,
410/// };
411/// use rig_sqlite::{Column, ColumnValue, SqliteVectorStore, SqliteVectorStoreTable};
412/// use serde::Deserialize;
413/// use tokio_rusqlite::Connection;
414///
415/// #[derive(Embed, Clone, Debug, Deserialize)]
416/// struct Document {
417///     id: String,
418///     #[embed]
419///     content: String,
420/// }
421///
422/// impl SqliteVectorStoreTable for Document {
423///     fn name() -> &'static str {
424///         "documents"
425///     }
426///
427///     fn schema() -> Vec<Column> {
428///         vec![
429///             Column::new("id", "TEXT PRIMARY KEY"),
430///             Column::new("content", "TEXT"),
431///         ]
432///     }
433///
434///     fn id(&self) -> String {
435///         self.id.clone()
436///     }
437///
438///     fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)> {
439///         vec![
440///             ("id", Box::new(self.id.clone())),
441///             ("content", Box::new(self.content.clone())),
442///         ]
443///     }
444/// }
445///
446/// let conn = Connection::open("vector_store.db").await?;
447/// let openai_client = Client::new("YOUR_API_KEY");
448/// let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);
449///
450/// // Initialize vector store
451/// let vector_store = SqliteVectorStore::new(conn, &model).await?;
452///
453/// // Create documents
454/// let documents = vec![
455///     Document {
456///         id: "doc1".to_string(),
457///         content: "Example document 1".to_string(),
458///     },
459///     Document {
460///         id: "doc2".to_string(),
461///         content: "Example document 2".to_string(),
462///     },
463/// ];
464///
465/// // Generate embeddings
466/// let embeddings = EmbeddingsBuilder::new(model.clone())
467///     .documents(documents)?
468///     .build()
469///     .await?;
470///
471/// // Add to vector store
472/// vector_store.add_rows(embeddings).await?;
473///
474/// // Create index and search
475/// let index = vector_store.index(model);
476/// let results = index
477///     .top_n::<Document>("Example query", 2)
478///     .await?;
479/// ```
480pub struct SqliteVectorIndex<E, T>
481where
482    E: EmbeddingModel + 'static,
483    T: SqliteVectorStoreTable + 'static,
484{
485    store: SqliteVectorStore<E, T>,
486    embedding_model: E,
487}
488
489impl<E, T> SqliteVectorIndex<E, T>
490where
491    E: EmbeddingModel + 'static,
492    T: SqliteVectorStoreTable,
493{
494    pub fn new(embedding_model: E, store: SqliteVectorStore<E, T>) -> Self {
495        Self {
496            store,
497            embedding_model,
498        }
499    }
500}
501
502fn build_where_clause(
503    req: &VectorSearchRequest<SqliteSearchFilter>,
504    query_vec: Vec<f32>,
505) -> Result<(String, Vec<Value>), FilterError> {
506    let thresh = req.threshold().unwrap_or(0.);
507    let thresh = SqliteSearchFilter::gt("e.distance".into(), thresh.into());
508
509    let filter = req
510        .filter()
511        .as_ref()
512        .cloned()
513        .map(|filter| thresh.clone().and(filter))
514        .unwrap_or(thresh);
515
516    let where_clause = format!(
517        "WHERE e.embedding MATCH ? AND k = ? AND {}",
518        filter.condition
519    );
520
521    let query_vec = query_vec.into_iter().flat_map(f32::to_le_bytes).collect();
522    let query_vec = Value::Blob(query_vec);
523    let samples = req.samples() as u32;
524
525    let mut params = vec![query_vec, samples.into()];
526    let filter_params = filter.clone().compile_params()?;
527    params.extend(filter_params);
528
529    Ok((where_clause, params))
530}
531
532impl<E: EmbeddingModel + std::marker::Sync, T: SqliteVectorStoreTable> VectorStoreIndex
533    for SqliteVectorIndex<E, T>
534{
535    type Filter = SqliteSearchFilter;
536
537    async fn top_n<D>(
538        &self,
539        req: VectorSearchRequest<SqliteSearchFilter>,
540    ) -> Result<Vec<(f64, String, D)>, VectorStoreError>
541    where
542        D: for<'de> Deserialize<'de>,
543    {
544        tracing::debug!("Finding top {} matches for query", req.samples() as usize);
545        let embedding = self.embedding_model.embed_text(req.query()).await?;
546        let query_vec: Vec<f32> = serialize_embedding(&embedding);
547        let table_name = T::name();
548
549        // Get all column names from SqliteVectorStoreTable
550        let columns = T::schema();
551        let column_names: Vec<&str> = columns.iter().map(|column| column.name).collect();
552
553        // Build SELECT statement with all columns
554        let select_cols = column_names.join(", ");
555
556        let (where_clause, params) = build_where_clause(&req, query_vec)?;
557
558        let rows = self
559            .store
560            .conn
561            .call(move |conn| {
562                let mut stmt = conn.prepare(&format!(
563                    "SELECT d.{select_cols}, e.distance
564                    FROM {table_name}_embeddings e
565                    JOIN {table_name} d ON e.rowid = d.rowid
566                    {where_clause}
567                    ORDER BY e.distance"
568                ))?;
569
570                let rows = stmt
571                    .query_map(rusqlite::params_from_iter(params), |row| {
572                        // Create a map of column names to values
573                        let mut map = serde_json::Map::new();
574                        for (i, col_name) in column_names.iter().enumerate() {
575                            let value: String = row.get(i)?;
576                            map.insert(col_name.to_string(), serde_json::Value::String(value));
577                        }
578                        let distance: f64 = row.get(column_names.len())?;
579                        let id: String = row.get(0)?; // Assuming id is always first column
580
581                        Ok((id, serde_json::Value::Object(map), distance))
582                    })?
583                    .collect::<Result<Vec<_>, _>>()?;
584                Ok(rows)
585            })
586            .await
587            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
588
589        debug!("Found {} potential matches", rows.len());
590        let mut top_n = Vec::new();
591        for (id, doc_value, distance) in rows {
592            match serde_json::from_value::<D>(doc_value) {
593                Ok(doc) => {
594                    top_n.push((distance, id, doc));
595                }
596                Err(e) => {
597                    debug!("Failed to deserialize document {}: {}", id, e);
598                    continue;
599                }
600            }
601        }
602
603        debug!("Returning {} matches", top_n.len());
604        Ok(top_n)
605    }
606
607    async fn top_n_ids(
608        &self,
609        req: VectorSearchRequest<SqliteSearchFilter>,
610    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
611        tracing::debug!(
612            "Finding top {} document IDs for query",
613            req.samples() as usize
614        );
615        let embedding = self.embedding_model.embed_text(req.query()).await?;
616        let query_vec = serialize_embedding(&embedding);
617        let table_name = T::name();
618
619        let (where_clause, params) = build_where_clause(&req, query_vec)?;
620
621        let results = self
622            .store
623            .conn
624            .call(move |conn| {
625                let mut stmt = conn.prepare(&format!(
626                    "SELECT d.id, e.distance
627                     FROM {table_name}_embeddings e
628                     JOIN {table_name} d ON e.rowid = d.rowid
629                     {where_clause}
630                     ORDER BY e.distance"
631                ))?;
632
633                let results = stmt
634                    .query_map(rusqlite::params_from_iter(params), |row| {
635                        Ok((row.get::<_, f64>(1)?, row.get::<_, String>(0)?))
636                    })?
637                    .collect::<Result<Vec<_>, _>>()?;
638                Ok(results)
639            })
640            .await
641            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
642
643        debug!("Found {} matching document IDs", results.len());
644        Ok(results)
645    }
646}
647
648fn serialize_embedding(embedding: &Embedding) -> Vec<f32> {
649    embedding.vec.iter().map(|x| *x as f32).collect()
650}
651
652impl ColumnValue for String {
653    fn to_sql_string(&self) -> String {
654        self.clone()
655    }
656
657    fn column_type(&self) -> &'static str {
658        "TEXT"
659    }
660}