rig_sqlite/
lib.rs

1use rig::embeddings::{Embedding, EmbeddingModel};
2use rig::vector_store::{VectorStoreError, VectorStoreIndex};
3use rig::OneOrMany;
4use serde::Deserialize;
5use std::marker::PhantomData;
6use tokio_rusqlite::Connection;
7use tracing::{debug, info};
8use zerocopy::IntoBytes;
9
10#[derive(Debug)]
11pub enum SqliteError {
12    DatabaseError(Box<dyn std::error::Error + Send + Sync>),
13    SerializationError(Box<dyn std::error::Error + Send + Sync>),
14    InvalidColumnType(String),
15}
16
17pub trait ColumnValue: Send + Sync {
18    fn to_sql_string(&self) -> String;
19    fn column_type(&self) -> &'static str;
20}
21
22pub struct Column {
23    name: &'static str,
24    col_type: &'static str,
25    indexed: bool,
26}
27
28impl Column {
29    pub fn new(name: &'static str, col_type: &'static str) -> Self {
30        Self {
31            name,
32            col_type,
33            indexed: false,
34        }
35    }
36
37    pub fn indexed(mut self) -> Self {
38        self.indexed = true;
39        self
40    }
41}
42
43/// Example of a document type that can be used with SqliteVectorStore
44/// ```rust
45/// use rig::Embed;
46/// use serde::Deserialize;
47/// use rig_sqlite::{Column, ColumnValue, SqliteVectorStoreTable};
48///
49/// #[derive(Embed, Clone, Debug, Deserialize)]
50/// struct Document {
51///     id: String,
52///     #[embed]
53///     content: String,
54/// }
55///
56/// impl SqliteVectorStoreTable for Document {
57///     fn name() -> &'static str {
58///         "documents"
59///     }
60///
61///     fn schema() -> Vec<Column> {
62///         vec![
63///             Column::new("id", "TEXT PRIMARY KEY"),
64///             Column::new("content", "TEXT"),
65///         ]
66///     }
67///
68///     fn id(&self) -> String {
69///         self.id.clone()
70///     }
71///
72///     fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)> {
73///         vec![
74///             ("id", Box::new(self.id.clone())),
75///             ("content", Box::new(self.content.clone())),
76///         ]
77///     }
78/// }
79/// ```
80pub trait SqliteVectorStoreTable: Send + Sync + Clone {
81    fn name() -> &'static str;
82    fn schema() -> Vec<Column>;
83    fn id(&self) -> String;
84    fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)>;
85}
86
87#[derive(Clone)]
88pub struct SqliteVectorStore<E: EmbeddingModel + 'static, T: SqliteVectorStoreTable + 'static> {
89    conn: Connection,
90    _phantom: PhantomData<(E, T)>,
91}
92
93impl<E: EmbeddingModel + 'static, T: SqliteVectorStoreTable + 'static> SqliteVectorStore<E, T> {
94    pub async fn new(conn: Connection, embedding_model: &E) -> Result<Self, VectorStoreError> {
95        let dims = embedding_model.ndims();
96        let table_name = T::name();
97        let schema = T::schema();
98
99        // Build the table schema
100        let mut create_table = format!("CREATE TABLE IF NOT EXISTS {} (", table_name);
101
102        // Add columns
103        let mut first = true;
104        for column in &schema {
105            if !first {
106                create_table.push(',');
107            }
108            create_table.push_str(&format!("\n    {} {}", column.name, column.col_type));
109            first = false;
110        }
111
112        create_table.push_str("\n)");
113
114        // Build index creation statements
115        let mut create_indexes = vec![format!(
116            "CREATE INDEX IF NOT EXISTS idx_{}_id ON {}(id)",
117            table_name, table_name
118        )];
119
120        // Add indexes for marked columns
121        for column in schema {
122            if column.indexed {
123                create_indexes.push(format!(
124                    "CREATE INDEX IF NOT EXISTS idx_{}_{} ON {}({})",
125                    table_name, column.name, table_name, column.name
126                ));
127            }
128        }
129
130        conn.call(move |conn| {
131            conn.execute_batch("BEGIN")?;
132
133            // Create document table
134            conn.execute_batch(&create_table)?;
135
136            // Create indexes
137            for index_stmt in create_indexes {
138                conn.execute_batch(&index_stmt)?;
139            }
140
141            // Create embeddings table
142            conn.execute_batch(&format!(
143                "CREATE VIRTUAL TABLE IF NOT EXISTS {}_embeddings USING vec0(embedding float[{}])",
144                table_name, dims
145            ))?;
146
147            conn.execute_batch("COMMIT")?;
148            Ok(())
149        })
150        .await
151        .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
152
153        Ok(Self {
154            conn,
155            _phantom: PhantomData,
156        })
157    }
158
159    pub fn index(self, model: E) -> SqliteVectorIndex<E, T> {
160        SqliteVectorIndex::new(model, self)
161    }
162
163    pub fn add_rows_with_txn(
164        &self,
165        txn: &rusqlite::Transaction<'_>,
166        documents: Vec<(T, OneOrMany<Embedding>)>,
167    ) -> Result<i64, tokio_rusqlite::Error> {
168        info!("Adding {} documents to store", documents.len());
169        let table_name = T::name();
170        let mut last_id = 0;
171
172        for (doc, embeddings) in &documents {
173            debug!("Storing document with id {}", doc.id());
174
175            let values = doc.column_values();
176            let columns = values.iter().map(|(col, _)| *col).collect::<Vec<_>>();
177
178            let placeholders = (1..=values.len())
179                .map(|i| format!("?{}", i))
180                .collect::<Vec<_>>();
181
182            let insert_sql = format!(
183                "INSERT OR REPLACE INTO {} ({}) VALUES ({})",
184                table_name,
185                columns.join(", "),
186                placeholders.join(", ")
187            );
188
189            txn.execute(
190                &insert_sql,
191                rusqlite::params_from_iter(values.iter().map(|(_, val)| val.to_sql_string())),
192            )?;
193            last_id = txn.last_insert_rowid();
194
195            let embeddings_sql = format!(
196                "INSERT INTO {}_embeddings (rowid, embedding) VALUES (?1, ?2)",
197                table_name
198            );
199
200            let mut stmt = txn.prepare(&embeddings_sql)?;
201            for (i, embedding) in embeddings.iter().enumerate() {
202                let vec = serialize_embedding(embedding);
203                debug!(
204                    "Storing embedding {} of {} (size: {} bytes)",
205                    i + 1,
206                    embeddings.len(),
207                    vec.len() * 4
208                );
209                let blob = rusqlite::types::Value::Blob(vec.as_bytes().to_vec());
210                stmt.execute(rusqlite::params![last_id, blob])?;
211            }
212        }
213
214        Ok(last_id)
215    }
216
217    pub async fn add_rows(
218        &self,
219        documents: Vec<(T, OneOrMany<Embedding>)>,
220    ) -> Result<i64, VectorStoreError> {
221        let documents = documents.clone();
222        let this = self.clone();
223
224        self.conn
225            .call(move |conn| {
226                let tx = conn.transaction().map_err(tokio_rusqlite::Error::from)?;
227                let result = this.add_rows_with_txn(&tx, documents)?;
228                tx.commit().map_err(tokio_rusqlite::Error::from)?;
229                Ok(result)
230            })
231            .await
232            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))
233    }
234}
235
236/// SQLite vector store implementation for Rig.
237///
238/// This crate provides a SQLite-based vector store implementation that can be used with Rig.
239/// It uses the `sqlite-vec` extension to enable vector similarity search capabilities.
240///
241/// # Example
242/// ```rust
243/// use rig::{
244///     embeddings::EmbeddingsBuilder,
245///     providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
246///     vector_store::VectorStoreIndex,
247///     Embed,
248/// };
249/// use rig_sqlite::{Column, ColumnValue, SqliteVectorStore, SqliteVectorStoreTable};
250/// use serde::Deserialize;
251/// use tokio_rusqlite::Connection;
252///
253/// #[derive(Embed, Clone, Debug, Deserialize)]
254/// struct Document {
255///     id: String,
256///     #[embed]
257///     content: String,
258/// }
259///
260/// impl SqliteVectorStoreTable for Document {
261///     fn name() -> &'static str {
262///         "documents"
263///     }
264///
265///     fn schema() -> Vec<Column> {
266///         vec![
267///             Column::new("id", "TEXT PRIMARY KEY"),
268///             Column::new("content", "TEXT"),
269///         ]
270///     }
271///
272///     fn id(&self) -> String {
273///         self.id.clone()
274///     }
275///
276///     fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)> {
277///         vec![
278///             ("id", Box::new(self.id.clone())),
279///             ("content", Box::new(self.content.clone())),
280///         ]
281///     }
282/// }
283///
284/// let conn = Connection::open("vector_store.db").await?;
285/// let openai_client = Client::new("YOUR_API_KEY");
286/// let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);
287///
288/// // Initialize vector store
289/// let vector_store = SqliteVectorStore::new(conn, &model).await?;
290///
291/// // Create documents
292/// let documents = vec![
293///     Document {
294///         id: "doc1".to_string(),
295///         content: "Example document 1".to_string(),
296///     },
297///     Document {
298///         id: "doc2".to_string(),
299///         content: "Example document 2".to_string(),
300///     },
301/// ];
302///
303/// // Generate embeddings
304/// let embeddings = EmbeddingsBuilder::new(model.clone())
305///     .documents(documents)?
306///     .build()
307///     .await?;
308///
309/// // Add to vector store
310/// vector_store.add_rows(embeddings).await?;
311///
312/// // Create index and search
313/// let index = vector_store.index(model);
314/// let results = index
315///     .top_n::<Document>("Example query", 2)
316///     .await?;
317/// ```
318pub struct SqliteVectorIndex<E: EmbeddingModel + 'static, T: SqliteVectorStoreTable + 'static> {
319    store: SqliteVectorStore<E, T>,
320    embedding_model: E,
321}
322
323impl<E: EmbeddingModel + 'static, T: SqliteVectorStoreTable> SqliteVectorIndex<E, T> {
324    pub fn new(embedding_model: E, store: SqliteVectorStore<E, T>) -> Self {
325        Self {
326            store,
327            embedding_model,
328        }
329    }
330}
331
332impl<E: EmbeddingModel + std::marker::Sync, T: SqliteVectorStoreTable> VectorStoreIndex
333    for SqliteVectorIndex<E, T>
334{
335    async fn top_n<D: for<'a> Deserialize<'a>>(
336        &self,
337        query: &str,
338        n: usize,
339    ) -> Result<Vec<(f64, String, D)>, VectorStoreError> {
340        debug!("Finding top {} matches for query", n);
341        let embedding = self.embedding_model.embed_text(query).await?;
342        let query_vec: Vec<f32> = serialize_embedding(&embedding);
343        let table_name = T::name();
344
345        // Get all column names from SqliteVectorStoreTable
346        let columns = T::schema();
347        let column_names: Vec<&str> = columns.iter().map(|column| column.name).collect();
348
349        let rows = self
350            .store
351            .conn
352            .call(move |conn| {
353                // Build SELECT statement with all columns
354                let select_cols = column_names.join(", ");
355                let mut stmt = conn.prepare(&format!(
356                    "SELECT d.{}, e.distance 
357                    FROM {}_embeddings e
358                    JOIN {} d ON e.rowid = d.rowid
359                    WHERE e.embedding MATCH ?1 AND k = ?2
360                    ORDER BY e.distance",
361                    select_cols, table_name, table_name
362                ))?;
363
364                let rows = stmt
365                    .query_map(rusqlite::params![query_vec.as_bytes().to_vec(), n], |row| {
366                        // Create a map of column names to values
367                        let mut map = serde_json::Map::new();
368                        for (i, col_name) in column_names.iter().enumerate() {
369                            let value: String = row.get(i)?;
370                            map.insert(col_name.to_string(), serde_json::Value::String(value));
371                        }
372                        let distance: f64 = row.get(column_names.len())?;
373                        let id: String = row.get(0)?; // Assuming id is always first column
374
375                        Ok((id, serde_json::Value::Object(map), distance))
376                    })?
377                    .collect::<Result<Vec<_>, _>>()?;
378                Ok(rows)
379            })
380            .await
381            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
382
383        debug!("Found {} potential matches", rows.len());
384        let mut top_n = Vec::new();
385        for (id, doc_value, distance) in rows {
386            match serde_json::from_value::<D>(doc_value) {
387                Ok(doc) => {
388                    top_n.push((distance, id, doc));
389                }
390                Err(e) => {
391                    debug!("Failed to deserialize document {}: {}", id, e);
392                    continue;
393                }
394            }
395        }
396
397        debug!("Returning {} matches", top_n.len());
398        Ok(top_n)
399    }
400
401    async fn top_n_ids(
402        &self,
403        query: &str,
404        n: usize,
405    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
406        debug!("Finding top {} document IDs for query", n);
407        let embedding = self.embedding_model.embed_text(query).await?;
408        let query_vec = serialize_embedding(&embedding);
409        let table_name = T::name();
410
411        let results = self
412            .store
413            .conn
414            .call(move |conn| {
415                let mut stmt = conn.prepare(&format!(
416                    "SELECT d.id, e.distance 
417                     FROM {0}_embeddings e
418                     JOIN {0} d ON e.rowid = d.rowid
419                     WHERE e.embedding MATCH ?1 AND k = ?2
420                     ORDER BY e.distance",
421                    table_name
422                ))?;
423
424                let results = stmt
425                    .query_map(
426                        rusqlite::params![
427                            query_vec
428                                .iter()
429                                .flat_map(|x| x.to_le_bytes())
430                                .collect::<Vec<u8>>(),
431                            n
432                        ],
433                        |row| Ok((row.get::<_, f64>(1)?, row.get::<_, String>(0)?)),
434                    )?
435                    .collect::<Result<Vec<_>, _>>()?;
436                Ok(results)
437            })
438            .await
439            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
440
441        debug!("Found {} matching document IDs", results.len());
442        Ok(results)
443    }
444}
445
446fn serialize_embedding(embedding: &Embedding) -> Vec<f32> {
447    embedding.vec.iter().map(|x| *x as f32).collect()
448}
449
450impl ColumnValue for String {
451    fn to_sql_string(&self) -> String {
452        self.clone()
453    }
454
455    fn column_type(&self) -> &'static str {
456        "TEXT"
457    }
458}