rig_sqlite/
lib.rs

1use rig::OneOrMany;
2use rig::embeddings::{Embedding, EmbeddingModel};
3use rig::vector_store::{VectorStoreError, VectorStoreIndex};
4use serde::Deserialize;
5use std::marker::PhantomData;
6use tokio_rusqlite::Connection;
7use tracing::{debug, info};
8use zerocopy::IntoBytes;
9
10#[derive(Debug)]
11pub enum SqliteError {
12    DatabaseError(Box<dyn std::error::Error + Send + Sync>),
13    SerializationError(Box<dyn std::error::Error + Send + Sync>),
14    InvalidColumnType(String),
15}
16
17pub trait ColumnValue: Send + Sync {
18    fn to_sql_string(&self) -> String;
19    fn column_type(&self) -> &'static str;
20}
21
22pub struct Column {
23    name: &'static str,
24    col_type: &'static str,
25    indexed: bool,
26}
27
28impl Column {
29    pub fn new(name: &'static str, col_type: &'static str) -> Self {
30        Self {
31            name,
32            col_type,
33            indexed: false,
34        }
35    }
36
37    pub fn indexed(mut self) -> Self {
38        self.indexed = true;
39        self
40    }
41}
42
43/// Example of a document type that can be used with SqliteVectorStore
44/// ```rust
45/// use rig::Embed;
46/// use serde::Deserialize;
47/// use rig_sqlite::{Column, ColumnValue, SqliteVectorStoreTable};
48///
49/// #[derive(Embed, Clone, Debug, Deserialize)]
50/// struct Document {
51///     id: String,
52///     #[embed]
53///     content: String,
54/// }
55///
56/// impl SqliteVectorStoreTable for Document {
57///     fn name() -> &'static str {
58///         "documents"
59///     }
60///
61///     fn schema() -> Vec<Column> {
62///         vec![
63///             Column::new("id", "TEXT PRIMARY KEY"),
64///             Column::new("content", "TEXT"),
65///         ]
66///     }
67///
68///     fn id(&self) -> String {
69///         self.id.clone()
70///     }
71///
72///     fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)> {
73///         vec![
74///             ("id", Box::new(self.id.clone())),
75///             ("content", Box::new(self.content.clone())),
76///         ]
77///     }
78/// }
79/// ```
80pub trait SqliteVectorStoreTable: Send + Sync + Clone {
81    fn name() -> &'static str;
82    fn schema() -> Vec<Column>;
83    fn id(&self) -> String;
84    fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)>;
85}
86
87#[derive(Clone)]
88pub struct SqliteVectorStore<E: EmbeddingModel + 'static, T: SqliteVectorStoreTable + 'static> {
89    conn: Connection,
90    _phantom: PhantomData<(E, T)>,
91}
92
93impl<E: EmbeddingModel + 'static, T: SqliteVectorStoreTable + 'static> SqliteVectorStore<E, T> {
94    pub async fn new(conn: Connection, embedding_model: &E) -> Result<Self, VectorStoreError> {
95        let dims = embedding_model.ndims();
96        let table_name = T::name();
97        let schema = T::schema();
98
99        // Build the table schema
100        let mut create_table = format!("CREATE TABLE IF NOT EXISTS {table_name} (");
101
102        // Add columns
103        let mut first = true;
104        for column in &schema {
105            if !first {
106                create_table.push(',');
107            }
108            create_table.push_str(&format!("\n    {} {}", column.name, column.col_type));
109            first = false;
110        }
111
112        create_table.push_str("\n)");
113
114        // Build index creation statements
115        let mut create_indexes = vec![format!(
116            "CREATE INDEX IF NOT EXISTS idx_{}_id ON {}(id)",
117            table_name, table_name
118        )];
119
120        // Add indexes for marked columns
121        for column in schema {
122            if column.indexed {
123                create_indexes.push(format!(
124                    "CREATE INDEX IF NOT EXISTS idx_{}_{} ON {}({})",
125                    table_name, column.name, table_name, column.name
126                ));
127            }
128        }
129
130        conn.call(move |conn| {
131            conn.execute_batch("BEGIN")?;
132
133            // Create document table
134            conn.execute_batch(&create_table)?;
135
136            // Create indexes
137            for index_stmt in create_indexes {
138                conn.execute_batch(&index_stmt)?;
139            }
140
141            // Create embeddings table
142            conn.execute_batch(&format!(
143                "CREATE VIRTUAL TABLE IF NOT EXISTS {table_name}_embeddings USING vec0(embedding float[{dims}])"
144            ))?;
145
146            conn.execute_batch("COMMIT")?;
147            Ok(())
148        })
149        .await
150        .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
151
152        Ok(Self {
153            conn,
154            _phantom: PhantomData,
155        })
156    }
157
158    pub fn index(self, model: E) -> SqliteVectorIndex<E, T> {
159        SqliteVectorIndex::new(model, self)
160    }
161
162    pub fn add_rows_with_txn(
163        &self,
164        txn: &rusqlite::Transaction<'_>,
165        documents: Vec<(T, OneOrMany<Embedding>)>,
166    ) -> Result<i64, tokio_rusqlite::Error> {
167        info!("Adding {} documents to store", documents.len());
168        let table_name = T::name();
169        let mut last_id = 0;
170
171        for (doc, embeddings) in &documents {
172            debug!("Storing document with id {}", doc.id());
173
174            let values = doc.column_values();
175            let columns = values.iter().map(|(col, _)| *col).collect::<Vec<_>>();
176
177            let placeholders = (1..=values.len())
178                .map(|i| format!("?{i}"))
179                .collect::<Vec<_>>();
180
181            let insert_sql = format!(
182                "INSERT OR REPLACE INTO {} ({}) VALUES ({})",
183                table_name,
184                columns.join(", "),
185                placeholders.join(", ")
186            );
187
188            txn.execute(
189                &insert_sql,
190                rusqlite::params_from_iter(values.iter().map(|(_, val)| val.to_sql_string())),
191            )?;
192            last_id = txn.last_insert_rowid();
193
194            let embeddings_sql =
195                format!("INSERT INTO {table_name}_embeddings (rowid, embedding) VALUES (?1, ?2)");
196
197            let mut stmt = txn.prepare(&embeddings_sql)?;
198            for (i, embedding) in embeddings.iter().enumerate() {
199                let vec = serialize_embedding(embedding);
200                debug!(
201                    "Storing embedding {} of {} (size: {} bytes)",
202                    i + 1,
203                    embeddings.len(),
204                    vec.len() * 4
205                );
206                let blob = rusqlite::types::Value::Blob(vec.as_bytes().to_vec());
207                stmt.execute(rusqlite::params![last_id, blob])?;
208            }
209        }
210
211        Ok(last_id)
212    }
213
214    pub async fn add_rows(
215        &self,
216        documents: Vec<(T, OneOrMany<Embedding>)>,
217    ) -> Result<i64, VectorStoreError> {
218        let documents = documents.clone();
219        let this = self.clone();
220
221        self.conn
222            .call(move |conn| {
223                let tx = conn.transaction().map_err(tokio_rusqlite::Error::from)?;
224                let result = this.add_rows_with_txn(&tx, documents)?;
225                tx.commit().map_err(tokio_rusqlite::Error::from)?;
226                Ok(result)
227            })
228            .await
229            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))
230    }
231}
232
233/// SQLite vector store implementation for Rig.
234///
235/// This crate provides a SQLite-based vector store implementation that can be used with Rig.
236/// It uses the `sqlite-vec` extension to enable vector similarity search capabilities.
237///
238/// # Example
239/// ```rust
240/// use rig::{
241///     embeddings::EmbeddingsBuilder,
242///     providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
243///     vector_store::VectorStoreIndex,
244///     Embed,
245/// };
246/// use rig_sqlite::{Column, ColumnValue, SqliteVectorStore, SqliteVectorStoreTable};
247/// use serde::Deserialize;
248/// use tokio_rusqlite::Connection;
249///
250/// #[derive(Embed, Clone, Debug, Deserialize)]
251/// struct Document {
252///     id: String,
253///     #[embed]
254///     content: String,
255/// }
256///
257/// impl SqliteVectorStoreTable for Document {
258///     fn name() -> &'static str {
259///         "documents"
260///     }
261///
262///     fn schema() -> Vec<Column> {
263///         vec![
264///             Column::new("id", "TEXT PRIMARY KEY"),
265///             Column::new("content", "TEXT"),
266///         ]
267///     }
268///
269///     fn id(&self) -> String {
270///         self.id.clone()
271///     }
272///
273///     fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)> {
274///         vec![
275///             ("id", Box::new(self.id.clone())),
276///             ("content", Box::new(self.content.clone())),
277///         ]
278///     }
279/// }
280///
281/// let conn = Connection::open("vector_store.db").await?;
282/// let openai_client = Client::new("YOUR_API_KEY");
283/// let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);
284///
285/// // Initialize vector store
286/// let vector_store = SqliteVectorStore::new(conn, &model).await?;
287///
288/// // Create documents
289/// let documents = vec![
290///     Document {
291///         id: "doc1".to_string(),
292///         content: "Example document 1".to_string(),
293///     },
294///     Document {
295///         id: "doc2".to_string(),
296///         content: "Example document 2".to_string(),
297///     },
298/// ];
299///
300/// // Generate embeddings
301/// let embeddings = EmbeddingsBuilder::new(model.clone())
302///     .documents(documents)?
303///     .build()
304///     .await?;
305///
306/// // Add to vector store
307/// vector_store.add_rows(embeddings).await?;
308///
309/// // Create index and search
310/// let index = vector_store.index(model);
311/// let results = index
312///     .top_n::<Document>("Example query", 2)
313///     .await?;
314/// ```
315pub struct SqliteVectorIndex<E: EmbeddingModel + 'static, T: SqliteVectorStoreTable + 'static> {
316    store: SqliteVectorStore<E, T>,
317    embedding_model: E,
318}
319
320impl<E: EmbeddingModel + 'static, T: SqliteVectorStoreTable> SqliteVectorIndex<E, T> {
321    pub fn new(embedding_model: E, store: SqliteVectorStore<E, T>) -> Self {
322        Self {
323            store,
324            embedding_model,
325        }
326    }
327}
328
329impl<E: EmbeddingModel + std::marker::Sync, T: SqliteVectorStoreTable> VectorStoreIndex
330    for SqliteVectorIndex<E, T>
331{
332    async fn top_n<D: for<'a> Deserialize<'a>>(
333        &self,
334        query: &str,
335        n: usize,
336    ) -> Result<Vec<(f64, String, D)>, VectorStoreError> {
337        debug!("Finding top {} matches for query", n);
338        let embedding = self.embedding_model.embed_text(query).await?;
339        let query_vec: Vec<f32> = serialize_embedding(&embedding);
340        let table_name = T::name();
341
342        // Get all column names from SqliteVectorStoreTable
343        let columns = T::schema();
344        let column_names: Vec<&str> = columns.iter().map(|column| column.name).collect();
345
346        let rows = self
347            .store
348            .conn
349            .call(move |conn| {
350                // Build SELECT statement with all columns
351                let select_cols = column_names.join(", ");
352                let mut stmt = conn.prepare(&format!(
353                    "SELECT d.{select_cols}, e.distance 
354                    FROM {table_name}_embeddings e
355                    JOIN {table_name} d ON e.rowid = d.rowid
356                    WHERE e.embedding MATCH ?1 AND k = ?2
357                    ORDER BY e.distance"
358                ))?;
359
360                let rows = stmt
361                    .query_map(rusqlite::params![query_vec.as_bytes().to_vec(), n], |row| {
362                        // Create a map of column names to values
363                        let mut map = serde_json::Map::new();
364                        for (i, col_name) in column_names.iter().enumerate() {
365                            let value: String = row.get(i)?;
366                            map.insert(col_name.to_string(), serde_json::Value::String(value));
367                        }
368                        let distance: f64 = row.get(column_names.len())?;
369                        let id: String = row.get(0)?; // Assuming id is always first column
370
371                        Ok((id, serde_json::Value::Object(map), distance))
372                    })?
373                    .collect::<Result<Vec<_>, _>>()?;
374                Ok(rows)
375            })
376            .await
377            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
378
379        debug!("Found {} potential matches", rows.len());
380        let mut top_n = Vec::new();
381        for (id, doc_value, distance) in rows {
382            match serde_json::from_value::<D>(doc_value) {
383                Ok(doc) => {
384                    top_n.push((distance, id, doc));
385                }
386                Err(e) => {
387                    debug!("Failed to deserialize document {}: {}", id, e);
388                    continue;
389                }
390            }
391        }
392
393        debug!("Returning {} matches", top_n.len());
394        Ok(top_n)
395    }
396
397    async fn top_n_ids(
398        &self,
399        query: &str,
400        n: usize,
401    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
402        debug!("Finding top {} document IDs for query", n);
403        let embedding = self.embedding_model.embed_text(query).await?;
404        let query_vec = serialize_embedding(&embedding);
405        let table_name = T::name();
406
407        let results = self
408            .store
409            .conn
410            .call(move |conn| {
411                let mut stmt = conn.prepare(&format!(
412                    "SELECT d.id, e.distance 
413                     FROM {table_name}_embeddings e
414                     JOIN {table_name} d ON e.rowid = d.rowid
415                     WHERE e.embedding MATCH ?1 AND k = ?2
416                     ORDER BY e.distance"
417                ))?;
418
419                let results = stmt
420                    .query_map(
421                        rusqlite::params![
422                            query_vec
423                                .iter()
424                                .flat_map(|x| x.to_le_bytes())
425                                .collect::<Vec<u8>>(),
426                            n
427                        ],
428                        |row| Ok((row.get::<_, f64>(1)?, row.get::<_, String>(0)?)),
429                    )?
430                    .collect::<Result<Vec<_>, _>>()?;
431                Ok(results)
432            })
433            .await
434            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
435
436        debug!("Found {} matching document IDs", results.len());
437        Ok(results)
438    }
439}
440
441fn serialize_embedding(embedding: &Embedding) -> Vec<f32> {
442    embedding.vec.iter().map(|x| *x as f32).collect()
443}
444
445impl ColumnValue for String {
446    fn to_sql_string(&self) -> String {
447        self.clone()
448    }
449
450    fn column_type(&self) -> &'static str {
451        "TEXT"
452    }
453}