rig_sqlite/
lib.rs

1use rig::OneOrMany;
2use rig::embeddings::{Embedding, EmbeddingModel};
3use rig::vector_store::request::VectorSearchRequest;
4use rig::vector_store::{VectorStoreError, VectorStoreIndex};
5use serde::Deserialize;
6use std::marker::PhantomData;
7use tokio_rusqlite::Connection;
8use tracing::{debug, info};
9use zerocopy::IntoBytes;
10
11#[derive(Debug)]
12pub enum SqliteError {
13    DatabaseError(Box<dyn std::error::Error + Send + Sync>),
14    SerializationError(Box<dyn std::error::Error + Send + Sync>),
15    InvalidColumnType(String),
16}
17
18pub trait ColumnValue: Send + Sync {
19    fn to_sql_string(&self) -> String;
20    fn column_type(&self) -> &'static str;
21}
22
23pub struct Column {
24    name: &'static str,
25    col_type: &'static str,
26    indexed: bool,
27}
28
29impl Column {
30    pub fn new(name: &'static str, col_type: &'static str) -> Self {
31        Self {
32            name,
33            col_type,
34            indexed: false,
35        }
36    }
37
38    pub fn indexed(mut self) -> Self {
39        self.indexed = true;
40        self
41    }
42}
43
44/// Example of a document type that can be used with SqliteVectorStore
45/// ```rust
46/// use rig::Embed;
47/// use serde::Deserialize;
48/// use rig_sqlite::{Column, ColumnValue, SqliteVectorStoreTable};
49///
50/// #[derive(Embed, Clone, Debug, Deserialize)]
51/// struct Document {
52///     id: String,
53///     #[embed]
54///     content: String,
55/// }
56///
57/// impl SqliteVectorStoreTable for Document {
58///     fn name() -> &'static str {
59///         "documents"
60///     }
61///
62///     fn schema() -> Vec<Column> {
63///         vec![
64///             Column::new("id", "TEXT PRIMARY KEY"),
65///             Column::new("content", "TEXT"),
66///         ]
67///     }
68///
69///     fn id(&self) -> String {
70///         self.id.clone()
71///     }
72///
73///     fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)> {
74///         vec![
75///             ("id", Box::new(self.id.clone())),
76///             ("content", Box::new(self.content.clone())),
77///         ]
78///     }
79/// }
80/// ```
81pub trait SqliteVectorStoreTable: Send + Sync + Clone {
82    fn name() -> &'static str;
83    fn schema() -> Vec<Column>;
84    fn id(&self) -> String;
85    fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)>;
86}
87
88#[derive(Clone)]
89pub struct SqliteVectorStore<E, T>
90where
91    E: EmbeddingModel + 'static,
92    T: SqliteVectorStoreTable + 'static,
93{
94    conn: Connection,
95    _phantom: PhantomData<(E, T)>,
96}
97
98impl<E, T> SqliteVectorStore<E, T>
99where
100    E: EmbeddingModel + 'static,
101    T: SqliteVectorStoreTable + 'static,
102{
103    pub async fn new(conn: Connection, embedding_model: &E) -> Result<Self, VectorStoreError> {
104        let dims = embedding_model.ndims();
105        let table_name = T::name();
106        let schema = T::schema();
107
108        // Build the table schema
109        let mut create_table = format!("CREATE TABLE IF NOT EXISTS {table_name} (");
110
111        // Add columns
112        let mut first = true;
113        for column in &schema {
114            if !first {
115                create_table.push(',');
116            }
117            create_table.push_str(&format!("\n    {} {}", column.name, column.col_type));
118            first = false;
119        }
120
121        create_table.push_str("\n)");
122
123        // Build index creation statements
124        let mut create_indexes = vec![format!(
125            "CREATE INDEX IF NOT EXISTS idx_{}_id ON {}(id)",
126            table_name, table_name
127        )];
128
129        // Add indexes for marked columns
130        for column in schema {
131            if column.indexed {
132                create_indexes.push(format!(
133                    "CREATE INDEX IF NOT EXISTS idx_{}_{} ON {}({})",
134                    table_name, column.name, table_name, column.name
135                ));
136            }
137        }
138
139        conn.call(move |conn| {
140            conn.execute_batch("BEGIN")?;
141
142            // Create document table
143            conn.execute_batch(&create_table)?;
144
145            // Create indexes
146            for index_stmt in create_indexes {
147                conn.execute_batch(&index_stmt)?;
148            }
149
150            // Create embeddings table
151            conn.execute_batch(&format!(
152                "CREATE VIRTUAL TABLE IF NOT EXISTS {table_name}_embeddings USING vec0(embedding float[{dims}])"
153            ))?;
154
155            conn.execute_batch("COMMIT")?;
156            Ok(())
157        })
158        .await
159        .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
160
161        Ok(Self {
162            conn,
163            _phantom: PhantomData,
164        })
165    }
166
167    pub fn index(self, model: E) -> SqliteVectorIndex<E, T> {
168        SqliteVectorIndex::new(model, self)
169    }
170
171    pub fn add_rows_with_txn(
172        &self,
173        txn: &rusqlite::Transaction<'_>,
174        documents: Vec<(T, OneOrMany<Embedding>)>,
175    ) -> Result<i64, tokio_rusqlite::Error> {
176        info!("Adding {} documents to store", documents.len());
177        let table_name = T::name();
178        let mut last_id = 0;
179
180        for (doc, embeddings) in &documents {
181            debug!("Storing document with id {}", doc.id());
182
183            let values = doc.column_values();
184            let columns = values.iter().map(|(col, _)| *col).collect::<Vec<_>>();
185
186            let placeholders = (1..=values.len())
187                .map(|i| format!("?{i}"))
188                .collect::<Vec<_>>();
189
190            let insert_sql = format!(
191                "INSERT OR REPLACE INTO {} ({}) VALUES ({})",
192                table_name,
193                columns.join(", "),
194                placeholders.join(", ")
195            );
196
197            txn.execute(
198                &insert_sql,
199                rusqlite::params_from_iter(values.iter().map(|(_, val)| val.to_sql_string())),
200            )?;
201            last_id = txn.last_insert_rowid();
202
203            let embeddings_sql =
204                format!("INSERT INTO {table_name}_embeddings (rowid, embedding) VALUES (?1, ?2)");
205
206            let mut stmt = txn.prepare(&embeddings_sql)?;
207            for (i, embedding) in embeddings.iter().enumerate() {
208                let vec = serialize_embedding(embedding);
209                debug!(
210                    "Storing embedding {} of {} (size: {} bytes)",
211                    i + 1,
212                    embeddings.len(),
213                    vec.len() * 4
214                );
215                let blob = rusqlite::types::Value::Blob(vec.as_bytes().to_vec());
216                stmt.execute(rusqlite::params![last_id, blob])?;
217            }
218        }
219
220        Ok(last_id)
221    }
222
223    pub async fn add_rows(
224        &self,
225        documents: Vec<(T, OneOrMany<Embedding>)>,
226    ) -> Result<i64, VectorStoreError> {
227        let documents = documents.clone();
228        let this = self.clone();
229
230        self.conn
231            .call(move |conn| {
232                let tx = conn.transaction().map_err(tokio_rusqlite::Error::from)?;
233                let result = this.add_rows_with_txn(&tx, documents)?;
234                tx.commit().map_err(tokio_rusqlite::Error::from)?;
235                Ok(result)
236            })
237            .await
238            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))
239    }
240}
241
242/// SQLite vector store implementation for Rig.
243///
244/// This crate provides a SQLite-based vector store implementation that can be used with Rig.
245/// It uses the `sqlite-vec` extension to enable vector similarity search capabilities.
246///
247/// # Example
248/// ```rust
249/// use rig::{
250///     embeddings::EmbeddingsBuilder,
251///     providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
252///     vector_store::VectorStoreIndex,
253///     Embed,
254/// };
255/// use rig_sqlite::{Column, ColumnValue, SqliteVectorStore, SqliteVectorStoreTable};
256/// use serde::Deserialize;
257/// use tokio_rusqlite::Connection;
258///
259/// #[derive(Embed, Clone, Debug, Deserialize)]
260/// struct Document {
261///     id: String,
262///     #[embed]
263///     content: String,
264/// }
265///
266/// impl SqliteVectorStoreTable for Document {
267///     fn name() -> &'static str {
268///         "documents"
269///     }
270///
271///     fn schema() -> Vec<Column> {
272///         vec![
273///             Column::new("id", "TEXT PRIMARY KEY"),
274///             Column::new("content", "TEXT"),
275///         ]
276///     }
277///
278///     fn id(&self) -> String {
279///         self.id.clone()
280///     }
281///
282///     fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)> {
283///         vec![
284///             ("id", Box::new(self.id.clone())),
285///             ("content", Box::new(self.content.clone())),
286///         ]
287///     }
288/// }
289///
290/// let conn = Connection::open("vector_store.db").await?;
291/// let openai_client = Client::new("YOUR_API_KEY");
292/// let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);
293///
294/// // Initialize vector store
295/// let vector_store = SqliteVectorStore::new(conn, &model).await?;
296///
297/// // Create documents
298/// let documents = vec![
299///     Document {
300///         id: "doc1".to_string(),
301///         content: "Example document 1".to_string(),
302///     },
303///     Document {
304///         id: "doc2".to_string(),
305///         content: "Example document 2".to_string(),
306///     },
307/// ];
308///
309/// // Generate embeddings
310/// let embeddings = EmbeddingsBuilder::new(model.clone())
311///     .documents(documents)?
312///     .build()
313///     .await?;
314///
315/// // Add to vector store
316/// vector_store.add_rows(embeddings).await?;
317///
318/// // Create index and search
319/// let index = vector_store.index(model);
320/// let results = index
321///     .top_n::<Document>("Example query", 2)
322///     .await?;
323/// ```
324pub struct SqliteVectorIndex<E, T>
325where
326    E: EmbeddingModel + 'static,
327    T: SqliteVectorStoreTable + 'static,
328{
329    store: SqliteVectorStore<E, T>,
330    embedding_model: E,
331}
332
333impl<E, T> SqliteVectorIndex<E, T>
334where
335    E: EmbeddingModel + 'static,
336    T: SqliteVectorStoreTable,
337{
338    pub fn new(embedding_model: E, store: SqliteVectorStore<E, T>) -> Self {
339        Self {
340            store,
341            embedding_model,
342        }
343    }
344}
345
346impl<E: EmbeddingModel + std::marker::Sync, T: SqliteVectorStoreTable> VectorStoreIndex
347    for SqliteVectorIndex<E, T>
348{
349    async fn top_n<D: for<'a> Deserialize<'a>>(
350        &self,
351        req: VectorSearchRequest,
352    ) -> Result<Vec<(f64, String, D)>, VectorStoreError> {
353        tracing::debug!("Finding top {} matches for query", req.samples() as usize);
354        let embedding = self.embedding_model.embed_text(req.query()).await?;
355        let query_vec: Vec<f32> = serialize_embedding(&embedding);
356        let table_name = T::name();
357
358        // Get all column names from SqliteVectorStoreTable
359        let columns = T::schema();
360        let column_names: Vec<&str> = columns.iter().map(|column| column.name).collect();
361
362        let rows = self
363            .store
364            .conn
365            .call(move |conn| {
366                // Build SELECT statement with all columns
367                let select_cols = column_names.join(", ");
368                let mut stmt = conn.prepare(&format!(
369                    "SELECT d.{select_cols}, e.distance
370                    FROM {table_name}_embeddings e
371                    JOIN {table_name} d ON e.rowid = d.rowid
372                    WHERE e.embedding MATCH ?1 AND k = ?2 AND e.distance >= ?3
373                    ORDER BY e.distance"
374                ))?;
375
376                let rows = stmt
377                    .query_map(
378                        rusqlite::params![
379                            query_vec.as_bytes().to_vec(),
380                            req.samples() as usize,
381                            req.threshold().unwrap_or(0.)
382                        ],
383                        |row| {
384                            // Create a map of column names to values
385                            let mut map = serde_json::Map::new();
386                            for (i, col_name) in column_names.iter().enumerate() {
387                                let value: String = row.get(i)?;
388                                map.insert(col_name.to_string(), serde_json::Value::String(value));
389                            }
390                            let distance: f64 = row.get(column_names.len())?;
391                            let id: String = row.get(0)?; // Assuming id is always first column
392
393                            Ok((id, serde_json::Value::Object(map), distance))
394                        },
395                    )?
396                    .collect::<Result<Vec<_>, _>>()?;
397                Ok(rows)
398            })
399            .await
400            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
401
402        debug!("Found {} potential matches", rows.len());
403        let mut top_n = Vec::new();
404        for (id, doc_value, distance) in rows {
405            match serde_json::from_value::<D>(doc_value) {
406                Ok(doc) => {
407                    top_n.push((distance, id, doc));
408                }
409                Err(e) => {
410                    debug!("Failed to deserialize document {}: {}", id, e);
411                    continue;
412                }
413            }
414        }
415
416        debug!("Returning {} matches", top_n.len());
417        Ok(top_n)
418    }
419
420    async fn top_n_ids(
421        &self,
422        req: VectorSearchRequest,
423    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
424        tracing::debug!(
425            "Finding top {} document IDs for query",
426            req.samples() as usize
427        );
428        let embedding = self.embedding_model.embed_text(req.query()).await?;
429        let query_vec = serialize_embedding(&embedding);
430        let table_name = T::name();
431
432        let results = self
433            .store
434            .conn
435            .call(move |conn| {
436                let mut stmt = conn.prepare(&format!(
437                    "SELECT d.id, e.distance
438                     FROM {table_name}_embeddings e
439                     JOIN {table_name} d ON e.rowid = d.rowid
440                     WHERE e.embedding MATCH ?1 AND k = ?2 AND e.distance >= ?3
441                     ORDER BY e.distance"
442                ))?;
443
444                let results = stmt
445                    .query_map(
446                        rusqlite::params![
447                            query_vec
448                                .iter()
449                                .flat_map(|x| x.to_le_bytes())
450                                .collect::<Vec<u8>>(),
451                            req.samples() as usize,
452                            req.threshold().unwrap_or(0.)
453                        ],
454                        |row| Ok((row.get::<_, f64>(1)?, row.get::<_, String>(0)?)),
455                    )?
456                    .collect::<Result<Vec<_>, _>>()?;
457                Ok(results)
458            })
459            .await
460            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
461
462        debug!("Found {} matching document IDs", results.len());
463        Ok(results)
464    }
465}
466
467fn serialize_embedding(embedding: &Embedding) -> Vec<f32> {
468    embedding.vec.iter().map(|x| *x as f32).collect()
469}
470
471impl ColumnValue for String {
472    fn to_sql_string(&self) -> String {
473        self.clone()
474    }
475
476    fn column_type(&self) -> &'static str {
477        "TEXT"
478    }
479}