rig_sqlite/
lib.rs

1use rig::OneOrMany;
2use rig::embeddings::{Embedding, EmbeddingModel};
3use rig::vector_store::request::{FilterError, SearchFilter, VectorSearchRequest};
4use rig::vector_store::{VectorStoreError, VectorStoreIndex};
5use rusqlite::types::Value;
6use serde::Deserialize;
7use std::marker::PhantomData;
8use std::ops::RangeInclusive;
9use tokio_rusqlite::Connection;
10use tracing::{debug, info};
11use zerocopy::IntoBytes;
12
13#[derive(Debug)]
14pub enum SqliteError {
15    DatabaseError(Box<dyn std::error::Error + Send + Sync>),
16    SerializationError(Box<dyn std::error::Error + Send + Sync>),
17    InvalidColumnType(String),
18}
19
20pub trait ColumnValue: Send + Sync {
21    fn to_sql_string(&self) -> String;
22    fn column_type(&self) -> &'static str;
23}
24
25pub struct Column {
26    name: &'static str,
27    col_type: &'static str,
28    indexed: bool,
29}
30
31impl Column {
32    pub fn new(name: &'static str, col_type: &'static str) -> Self {
33        Self {
34            name,
35            col_type,
36            indexed: false,
37        }
38    }
39
40    pub fn indexed(mut self) -> Self {
41        self.indexed = true;
42        self
43    }
44}
45
46/// Example of a document type that can be used with SqliteVectorStore
47/// ```rust
48/// use rig::Embed;
49/// use serde::Deserialize;
50/// use rig_sqlite::{Column, ColumnValue, SqliteVectorStoreTable};
51///
52/// #[derive(Embed, Clone, Debug, Deserialize)]
53/// struct Document {
54///     id: String,
55///     #[embed]
56///     content: String,
57/// }
58///
59/// impl SqliteVectorStoreTable for Document {
60///     fn name() -> &'static str {
61///         "documents"
62///     }
63///
64///     fn schema() -> Vec<Column> {
65///         vec![
66///             Column::new("id", "TEXT PRIMARY KEY"),
67///             Column::new("content", "TEXT"),
68///         ]
69///     }
70///
71///     fn id(&self) -> String {
72///         self.id.clone()
73///     }
74///
75///     fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)> {
76///         vec![
77///             ("id", Box::new(self.id.clone())),
78///             ("content", Box::new(self.content.clone())),
79///         ]
80///     }
81/// }
82/// ```
83pub trait SqliteVectorStoreTable: Send + Sync + Clone {
84    fn name() -> &'static str;
85    fn schema() -> Vec<Column>;
86    fn id(&self) -> String;
87    fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)>;
88}
89
90#[derive(Clone)]
91pub struct SqliteVectorStore<E, T>
92where
93    E: EmbeddingModel + 'static,
94    T: SqliteVectorStoreTable + 'static,
95{
96    conn: Connection,
97    _phantom: PhantomData<(E, T)>,
98}
99
100impl<E, T> SqliteVectorStore<E, T>
101where
102    E: EmbeddingModel + 'static,
103    T: SqliteVectorStoreTable + 'static,
104{
105    pub async fn new(conn: Connection, embedding_model: &E) -> Result<Self, VectorStoreError> {
106        let dims = embedding_model.ndims();
107        let table_name = T::name();
108        let schema = T::schema();
109
110        // Build the table schema
111        let mut create_table = format!("CREATE TABLE IF NOT EXISTS {table_name} (");
112
113        // Add columns
114        let mut first = true;
115        for column in &schema {
116            if !first {
117                create_table.push(',');
118            }
119            create_table.push_str(&format!("\n    {} {}", column.name, column.col_type));
120            first = false;
121        }
122
123        create_table.push_str("\n)");
124
125        // Build index creation statements
126        let mut create_indexes = vec![format!(
127            "CREATE INDEX IF NOT EXISTS idx_{}_id ON {}(id)",
128            table_name, table_name
129        )];
130
131        // Add indexes for marked columns
132        for column in schema {
133            if column.indexed {
134                create_indexes.push(format!(
135                    "CREATE INDEX IF NOT EXISTS idx_{}_{} ON {}({})",
136                    table_name, column.name, table_name, column.name
137                ));
138            }
139        }
140
141        conn.call(move |conn| {
142            conn.execute_batch("BEGIN")?;
143
144            // Create document table
145            conn.execute_batch(&create_table)?;
146
147            // Create indexes
148            for index_stmt in create_indexes {
149                conn.execute_batch(&index_stmt)?;
150            }
151
152            // Create embeddings table
153            conn.execute_batch(&format!(
154                "CREATE VIRTUAL TABLE IF NOT EXISTS {table_name}_embeddings USING vec0(embedding float[{dims}])"
155            ))?;
156
157            conn.execute_batch("COMMIT")?;
158            Ok(())
159        })
160        .await
161        .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
162
163        Ok(Self {
164            conn,
165            _phantom: PhantomData,
166        })
167    }
168
169    pub fn index(self, model: E) -> SqliteVectorIndex<E, T> {
170        SqliteVectorIndex::new(model, self)
171    }
172
173    pub fn add_rows_with_txn(
174        &self,
175        txn: &rusqlite::Transaction<'_>,
176        documents: Vec<(T, OneOrMany<Embedding>)>,
177    ) -> Result<i64, tokio_rusqlite::Error> {
178        info!("Adding {} documents to store", documents.len());
179        let table_name = T::name();
180        let mut last_id = 0;
181
182        for (doc, embeddings) in &documents {
183            debug!("Storing document with id {}", doc.id());
184
185            let values = doc.column_values();
186            let columns = values.iter().map(|(col, _)| *col).collect::<Vec<_>>();
187
188            let placeholders = (1..=values.len())
189                .map(|i| format!("?{i}"))
190                .collect::<Vec<_>>();
191
192            let insert_sql = format!(
193                "INSERT OR REPLACE INTO {} ({}) VALUES ({})",
194                table_name,
195                columns.join(", "),
196                placeholders.join(", ")
197            );
198
199            txn.execute(
200                &insert_sql,
201                rusqlite::params_from_iter(values.iter().map(|(_, val)| val.to_sql_string())),
202            )?;
203            last_id = txn.last_insert_rowid();
204
205            let embeddings_sql =
206                format!("INSERT INTO {table_name}_embeddings (rowid, embedding) VALUES (?1, ?2)");
207
208            let mut stmt = txn.prepare(&embeddings_sql)?;
209            for (i, embedding) in embeddings.iter().enumerate() {
210                let vec = serialize_embedding(embedding);
211                debug!(
212                    "Storing embedding {} of {} (size: {} bytes)",
213                    i + 1,
214                    embeddings.len(),
215                    vec.len() * 4
216                );
217                let blob = rusqlite::types::Value::Blob(vec.as_bytes().to_vec());
218                stmt.execute(rusqlite::params![last_id, blob])?;
219            }
220        }
221
222        Ok(last_id)
223    }
224
225    pub async fn add_rows(
226        &self,
227        documents: Vec<(T, OneOrMany<Embedding>)>,
228    ) -> Result<i64, VectorStoreError> {
229        let documents = documents.clone();
230        let this = self.clone();
231
232        self.conn
233            .call(move |conn| {
234                let tx = conn.transaction().map_err(tokio_rusqlite::Error::from)?;
235                let result = this.add_rows_with_txn(&tx, documents)?;
236                tx.commit().map_err(tokio_rusqlite::Error::from)?;
237                Ok(result)
238            })
239            .await
240            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))
241    }
242}
243
244#[derive(Clone, Default)]
245pub struct SqliteSearchFilter {
246    condition: String,
247    params: Vec<serde_json::Value>,
248}
249
250impl SearchFilter for SqliteSearchFilter {
251    type Value = serde_json::Value;
252
253    fn eq(key: String, value: Self::Value) -> Self {
254        Self {
255            condition: format!("{key} = ?"),
256            params: vec![value],
257        }
258    }
259
260    fn gt(key: String, value: Self::Value) -> Self {
261        Self {
262            condition: format!("{key} > ?"),
263            params: vec![value],
264        }
265    }
266
267    fn lt(key: String, value: Self::Value) -> Self {
268        Self {
269            condition: format!("{key} < ?"),
270            params: vec![value],
271        }
272    }
273
274    fn and(self, rhs: Self) -> Self {
275        Self {
276            condition: format!("({}) AND ({})", self.condition, rhs.condition),
277            params: self.params.into_iter().chain(rhs.params).collect(),
278        }
279    }
280
281    fn or(self, rhs: Self) -> Self {
282        Self {
283            condition: format!("({}) OR ({})", self.condition, rhs.condition),
284            params: self.params.into_iter().chain(rhs.params).collect(),
285        }
286    }
287}
288
289impl SqliteSearchFilter {
290    #[allow(clippy::should_implement_trait)]
291    pub fn not(self) -> Self {
292        Self {
293            condition: format!("NOT ({})", self.condition),
294            ..self
295        }
296    }
297
298    /// Tests whether the value at `key` is contained in the range
299    pub fn between<N>(key: String, range: RangeInclusive<N>) -> Self
300    where
301        N: Ord + rusqlite::ToSql + std::fmt::Display,
302    {
303        let lo = range.start();
304        let hi = range.end();
305
306        Self {
307            condition: format!("{key} between {lo} and {hi}"),
308            ..Default::default()
309        }
310    }
311
312    // Null checks
313    pub fn is_null(key: String) -> Self {
314        Self {
315            condition: format!("{key} is null"),
316            ..Default::default()
317        }
318    }
319
320    pub fn is_not_null(key: String) -> Self {
321        Self {
322            condition: format!("{key} is not null"),
323            ..Default::default()
324        }
325    }
326
327    // String ops
328    /// Tests whether the value at `key` satisfies the glob pattern
329    /// `pattern` should be a valid SQLite glob pattern
330    pub fn glob<'a, S>(key: String, pattern: S) -> Self
331    where
332        S: AsRef<&'a str>,
333    {
334        Self {
335            condition: format!("{key} glob {}", pattern.as_ref()),
336            ..Default::default()
337        }
338    }
339
340    /// Tests whether the value at `key` satisfies the "like" pattern
341    /// `pattern` should be a valid SQLite like pattern
342    pub fn like<'a, S>(key: String, pattern: S) -> Self
343    where
344        S: AsRef<&'a str>,
345    {
346        Self {
347            condition: format!("{key} like {}", pattern.as_ref()),
348            ..Default::default()
349        }
350    }
351}
352
353impl SqliteSearchFilter {
354    fn compile_params(self) -> Result<Vec<Value>, FilterError> {
355        let mut params = Vec::with_capacity(self.params.len());
356
357        fn convert(value: serde_json::Value) -> Result<Value, FilterError> {
358            use serde_json::Value::*;
359
360            match value {
361                Null => Ok(Value::Null),
362                Bool(b) => Ok(Value::Integer(b as i64)),
363                String(s) => Ok(Value::Text(s)),
364                Number(n) => Ok(if let Some(float) = n.as_f64() {
365                    Value::Real(float)
366                } else if let Some(int) = n.as_i64() {
367                    Value::Integer(int)
368                } else {
369                    unreachable!()
370                }),
371                Array(arr) => {
372                    let blob = serde_json::to_vec(&arr)
373                        .map_err(|e| FilterError::Serialization(e.to_string()))?;
374
375                    Ok(Value::Blob(blob))
376                }
377                Object(obj) => {
378                    let blob = serde_json::to_vec(&obj)
379                        .map_err(|e| FilterError::Serialization(e.to_string()))?;
380
381                    Ok(Value::Blob(blob))
382                }
383            }
384        }
385
386        for param in self.params.into_iter() {
387            params.push(convert(param)?)
388        }
389
390        Ok(params)
391    }
392}
393
394/// SQLite vector store implementation for Rig.
395///
396/// This crate provides a SQLite-based vector store implementation that can be used with Rig.
397/// It uses the `sqlite-vec` extension to enable vector similarity search capabilities.
398///
399/// # Example
400/// ```rust
401/// use rig::{
402///     embeddings::EmbeddingsBuilder,
403///     providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
404///     vector_store::VectorStoreIndex,
405///     Embed,
406/// };
407/// use rig_sqlite::{Column, ColumnValue, SqliteVectorStore, SqliteVectorStoreTable};
408/// use serde::Deserialize;
409/// use tokio_rusqlite::Connection;
410///
411/// #[derive(Embed, Clone, Debug, Deserialize)]
412/// struct Document {
413///     id: String,
414///     #[embed]
415///     content: String,
416/// }
417///
418/// impl SqliteVectorStoreTable for Document {
419///     fn name() -> &'static str {
420///         "documents"
421///     }
422///
423///     fn schema() -> Vec<Column> {
424///         vec![
425///             Column::new("id", "TEXT PRIMARY KEY"),
426///             Column::new("content", "TEXT"),
427///         ]
428///     }
429///
430///     fn id(&self) -> String {
431///         self.id.clone()
432///     }
433///
434///     fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)> {
435///         vec![
436///             ("id", Box::new(self.id.clone())),
437///             ("content", Box::new(self.content.clone())),
438///         ]
439///     }
440/// }
441///
442/// let conn = Connection::open("vector_store.db").await?;
443/// let openai_client = Client::new("YOUR_API_KEY");
444/// let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);
445///
446/// // Initialize vector store
447/// let vector_store = SqliteVectorStore::new(conn, &model).await?;
448///
449/// // Create documents
450/// let documents = vec![
451///     Document {
452///         id: "doc1".to_string(),
453///         content: "Example document 1".to_string(),
454///     },
455///     Document {
456///         id: "doc2".to_string(),
457///         content: "Example document 2".to_string(),
458///     },
459/// ];
460///
461/// // Generate embeddings
462/// let embeddings = EmbeddingsBuilder::new(model.clone())
463///     .documents(documents)?
464///     .build()
465///     .await?;
466///
467/// // Add to vector store
468/// vector_store.add_rows(embeddings).await?;
469///
470/// // Create index and search
471/// let index = vector_store.index(model);
472/// let results = index
473///     .top_n::<Document>("Example query", 2)
474///     .await?;
475/// ```
476pub struct SqliteVectorIndex<E, T>
477where
478    E: EmbeddingModel + 'static,
479    T: SqliteVectorStoreTable + 'static,
480{
481    store: SqliteVectorStore<E, T>,
482    embedding_model: E,
483}
484
485impl<E, T> SqliteVectorIndex<E, T>
486where
487    E: EmbeddingModel + 'static,
488    T: SqliteVectorStoreTable,
489{
490    pub fn new(embedding_model: E, store: SqliteVectorStore<E, T>) -> Self {
491        Self {
492            store,
493            embedding_model,
494        }
495    }
496}
497
498fn build_where_clause(
499    req: &VectorSearchRequest<SqliteSearchFilter>,
500    query_vec: Vec<f32>,
501) -> Result<(String, Vec<Value>), FilterError> {
502    let thresh = req.threshold().unwrap_or(0.);
503    let thresh = SqliteSearchFilter::gt("e.distance".into(), thresh.into());
504
505    let filter = req
506        .filter()
507        .as_ref()
508        .cloned()
509        .map(|filter| thresh.clone().and(filter))
510        .unwrap_or(thresh);
511
512    let where_clause = format!(
513        "WHERE e.embedding MATCH ? AND k = ? AND {}",
514        filter.condition
515    );
516
517    let query_vec = query_vec.into_iter().flat_map(f32::to_le_bytes).collect();
518    let query_vec = Value::Blob(query_vec);
519    let samples = req.samples() as u32;
520
521    let mut params = vec![query_vec, samples.into()];
522    let filter_params = filter.clone().compile_params()?;
523    params.extend(filter_params);
524
525    Ok((where_clause, params))
526}
527
528impl<E: EmbeddingModel + std::marker::Sync, T: SqliteVectorStoreTable> VectorStoreIndex
529    for SqliteVectorIndex<E, T>
530{
531    type Filter = SqliteSearchFilter;
532
533    async fn top_n<D>(
534        &self,
535        req: VectorSearchRequest<SqliteSearchFilter>,
536    ) -> Result<Vec<(f64, String, D)>, VectorStoreError>
537    where
538        D: for<'de> Deserialize<'de>,
539    {
540        tracing::debug!("Finding top {} matches for query", req.samples() as usize);
541        let embedding = self.embedding_model.embed_text(req.query()).await?;
542        let query_vec: Vec<f32> = serialize_embedding(&embedding);
543        let table_name = T::name();
544
545        // Get all column names from SqliteVectorStoreTable
546        let columns = T::schema();
547        let column_names: Vec<&str> = columns.iter().map(|column| column.name).collect();
548
549        // Build SELECT statement with all columns
550        let select_cols = column_names.join(", ");
551
552        let (where_clause, params) = build_where_clause(&req, query_vec)?;
553
554        let rows = self
555            .store
556            .conn
557            .call(move |conn| {
558                let mut stmt = conn.prepare(&format!(
559                    "SELECT d.{select_cols}, e.distance
560                    FROM {table_name}_embeddings e
561                    JOIN {table_name} d ON e.rowid = d.rowid
562                    {where_clause}
563                    ORDER BY e.distance"
564                ))?;
565
566                dbg!(&stmt);
567
568                let rows = stmt
569                    .query_map(rusqlite::params_from_iter(params), |row| {
570                        // Create a map of column names to values
571                        let mut map = serde_json::Map::new();
572                        for (i, col_name) in column_names.iter().enumerate() {
573                            let value: String = row.get(i)?;
574                            map.insert(col_name.to_string(), serde_json::Value::String(value));
575                        }
576                        let distance: f64 = row.get(column_names.len())?;
577                        let id: String = row.get(0)?; // Assuming id is always first column
578
579                        Ok((id, serde_json::Value::Object(map), distance))
580                    })?
581                    .collect::<Result<Vec<_>, _>>()?;
582                Ok(rows)
583            })
584            .await
585            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
586
587        debug!("Found {} potential matches", rows.len());
588        let mut top_n = Vec::new();
589        for (id, doc_value, distance) in rows {
590            match serde_json::from_value::<D>(doc_value) {
591                Ok(doc) => {
592                    top_n.push((distance, id, doc));
593                }
594                Err(e) => {
595                    debug!("Failed to deserialize document {}: {}", id, e);
596                    continue;
597                }
598            }
599        }
600
601        debug!("Returning {} matches", top_n.len());
602        Ok(top_n)
603    }
604
605    async fn top_n_ids(
606        &self,
607        req: VectorSearchRequest<SqliteSearchFilter>,
608    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
609        tracing::debug!(
610            "Finding top {} document IDs for query",
611            req.samples() as usize
612        );
613        let embedding = self.embedding_model.embed_text(req.query()).await?;
614        let query_vec = serialize_embedding(&embedding);
615        let table_name = T::name();
616
617        let (where_clause, params) = build_where_clause(&req, query_vec)?;
618
619        let results = self
620            .store
621            .conn
622            .call(move |conn| {
623                let mut stmt = conn.prepare(&format!(
624                    "SELECT d.id, e.distance
625                     FROM {table_name}_embeddings e
626                     JOIN {table_name} d ON e.rowid = d.rowid
627                     {where_clause}
628                     ORDER BY e.distance"
629                ))?;
630
631                dbg!(&stmt);
632
633                let results = stmt
634                    .query_map(rusqlite::params_from_iter(params), |row| {
635                        Ok((row.get::<_, f64>(1)?, row.get::<_, String>(0)?))
636                    })?
637                    .collect::<Result<Vec<_>, _>>()?;
638                Ok(results)
639            })
640            .await
641            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
642
643        debug!("Found {} matching document IDs", results.len());
644        Ok(results)
645    }
646}
647
648fn serialize_embedding(embedding: &Embedding) -> Vec<f32> {
649    embedding.vec.iter().map(|x| *x as f32).collect()
650}
651
652impl ColumnValue for String {
653    fn to_sql_string(&self) -> String {
654        self.clone()
655    }
656
657    fn column_type(&self) -> &'static str {
658        "TEXT"
659    }
660}