rig_sqlite/
lib.rs

1use rig::OneOrMany;
2use rig::embeddings::{Embedding, EmbeddingModel};
3use rig::vector_store::request::VectorSearchRequest;
4use rig::vector_store::{VectorStoreError, VectorStoreIndex};
5use serde::Deserialize;
6use std::marker::PhantomData;
7use tokio_rusqlite::Connection;
8use tracing::{debug, info};
9use zerocopy::IntoBytes;
10
11#[derive(Debug)]
12pub enum SqliteError {
13    DatabaseError(Box<dyn std::error::Error + Send + Sync>),
14    SerializationError(Box<dyn std::error::Error + Send + Sync>),
15    InvalidColumnType(String),
16}
17
18pub trait ColumnValue: Send + Sync {
19    fn to_sql_string(&self) -> String;
20    fn column_type(&self) -> &'static str;
21}
22
23pub struct Column {
24    name: &'static str,
25    col_type: &'static str,
26    indexed: bool,
27}
28
29impl Column {
30    pub fn new(name: &'static str, col_type: &'static str) -> Self {
31        Self {
32            name,
33            col_type,
34            indexed: false,
35        }
36    }
37
38    pub fn indexed(mut self) -> Self {
39        self.indexed = true;
40        self
41    }
42}
43
44/// Example of a document type that can be used with SqliteVectorStore
45/// ```rust
46/// use rig::Embed;
47/// use serde::Deserialize;
48/// use rig_sqlite::{Column, ColumnValue, SqliteVectorStoreTable};
49///
50/// #[derive(Embed, Clone, Debug, Deserialize)]
51/// struct Document {
52///     id: String,
53///     #[embed]
54///     content: String,
55/// }
56///
57/// impl SqliteVectorStoreTable for Document {
58///     fn name() -> &'static str {
59///         "documents"
60///     }
61///
62///     fn schema() -> Vec<Column> {
63///         vec![
64///             Column::new("id", "TEXT PRIMARY KEY"),
65///             Column::new("content", "TEXT"),
66///         ]
67///     }
68///
69///     fn id(&self) -> String {
70///         self.id.clone()
71///     }
72///
73///     fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)> {
74///         vec![
75///             ("id", Box::new(self.id.clone())),
76///             ("content", Box::new(self.content.clone())),
77///         ]
78///     }
79/// }
80/// ```
81pub trait SqliteVectorStoreTable: Send + Sync + Clone {
82    fn name() -> &'static str;
83    fn schema() -> Vec<Column>;
84    fn id(&self) -> String;
85    fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)>;
86}
87
88#[derive(Clone)]
89pub struct SqliteVectorStore<E: EmbeddingModel + 'static, T: SqliteVectorStoreTable + 'static> {
90    conn: Connection,
91    _phantom: PhantomData<(E, T)>,
92}
93
94impl<E: EmbeddingModel + 'static, T: SqliteVectorStoreTable + 'static> SqliteVectorStore<E, T> {
95    pub async fn new(conn: Connection, embedding_model: &E) -> Result<Self, VectorStoreError> {
96        let dims = embedding_model.ndims();
97        let table_name = T::name();
98        let schema = T::schema();
99
100        // Build the table schema
101        let mut create_table = format!("CREATE TABLE IF NOT EXISTS {table_name} (");
102
103        // Add columns
104        let mut first = true;
105        for column in &schema {
106            if !first {
107                create_table.push(',');
108            }
109            create_table.push_str(&format!("\n    {} {}", column.name, column.col_type));
110            first = false;
111        }
112
113        create_table.push_str("\n)");
114
115        // Build index creation statements
116        let mut create_indexes = vec![format!(
117            "CREATE INDEX IF NOT EXISTS idx_{}_id ON {}(id)",
118            table_name, table_name
119        )];
120
121        // Add indexes for marked columns
122        for column in schema {
123            if column.indexed {
124                create_indexes.push(format!(
125                    "CREATE INDEX IF NOT EXISTS idx_{}_{} ON {}({})",
126                    table_name, column.name, table_name, column.name
127                ));
128            }
129        }
130
131        conn.call(move |conn| {
132            conn.execute_batch("BEGIN")?;
133
134            // Create document table
135            conn.execute_batch(&create_table)?;
136
137            // Create indexes
138            for index_stmt in create_indexes {
139                conn.execute_batch(&index_stmt)?;
140            }
141
142            // Create embeddings table
143            conn.execute_batch(&format!(
144                "CREATE VIRTUAL TABLE IF NOT EXISTS {table_name}_embeddings USING vec0(embedding float[{dims}])"
145            ))?;
146
147            conn.execute_batch("COMMIT")?;
148            Ok(())
149        })
150        .await
151        .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
152
153        Ok(Self {
154            conn,
155            _phantom: PhantomData,
156        })
157    }
158
159    pub fn index(self, model: E) -> SqliteVectorIndex<E, T> {
160        SqliteVectorIndex::new(model, self)
161    }
162
163    pub fn add_rows_with_txn(
164        &self,
165        txn: &rusqlite::Transaction<'_>,
166        documents: Vec<(T, OneOrMany<Embedding>)>,
167    ) -> Result<i64, tokio_rusqlite::Error> {
168        info!("Adding {} documents to store", documents.len());
169        let table_name = T::name();
170        let mut last_id = 0;
171
172        for (doc, embeddings) in &documents {
173            debug!("Storing document with id {}", doc.id());
174
175            let values = doc.column_values();
176            let columns = values.iter().map(|(col, _)| *col).collect::<Vec<_>>();
177
178            let placeholders = (1..=values.len())
179                .map(|i| format!("?{i}"))
180                .collect::<Vec<_>>();
181
182            let insert_sql = format!(
183                "INSERT OR REPLACE INTO {} ({}) VALUES ({})",
184                table_name,
185                columns.join(", "),
186                placeholders.join(", ")
187            );
188
189            txn.execute(
190                &insert_sql,
191                rusqlite::params_from_iter(values.iter().map(|(_, val)| val.to_sql_string())),
192            )?;
193            last_id = txn.last_insert_rowid();
194
195            let embeddings_sql =
196                format!("INSERT INTO {table_name}_embeddings (rowid, embedding) VALUES (?1, ?2)");
197
198            let mut stmt = txn.prepare(&embeddings_sql)?;
199            for (i, embedding) in embeddings.iter().enumerate() {
200                let vec = serialize_embedding(embedding);
201                debug!(
202                    "Storing embedding {} of {} (size: {} bytes)",
203                    i + 1,
204                    embeddings.len(),
205                    vec.len() * 4
206                );
207                let blob = rusqlite::types::Value::Blob(vec.as_bytes().to_vec());
208                stmt.execute(rusqlite::params![last_id, blob])?;
209            }
210        }
211
212        Ok(last_id)
213    }
214
215    pub async fn add_rows(
216        &self,
217        documents: Vec<(T, OneOrMany<Embedding>)>,
218    ) -> Result<i64, VectorStoreError> {
219        let documents = documents.clone();
220        let this = self.clone();
221
222        self.conn
223            .call(move |conn| {
224                let tx = conn.transaction().map_err(tokio_rusqlite::Error::from)?;
225                let result = this.add_rows_with_txn(&tx, documents)?;
226                tx.commit().map_err(tokio_rusqlite::Error::from)?;
227                Ok(result)
228            })
229            .await
230            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))
231    }
232}
233
234/// SQLite vector store implementation for Rig.
235///
236/// This crate provides a SQLite-based vector store implementation that can be used with Rig.
237/// It uses the `sqlite-vec` extension to enable vector similarity search capabilities.
238///
239/// # Example
240/// ```rust
241/// use rig::{
242///     embeddings::EmbeddingsBuilder,
243///     providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
244///     vector_store::VectorStoreIndex,
245///     Embed,
246/// };
247/// use rig_sqlite::{Column, ColumnValue, SqliteVectorStore, SqliteVectorStoreTable};
248/// use serde::Deserialize;
249/// use tokio_rusqlite::Connection;
250///
251/// #[derive(Embed, Clone, Debug, Deserialize)]
252/// struct Document {
253///     id: String,
254///     #[embed]
255///     content: String,
256/// }
257///
258/// impl SqliteVectorStoreTable for Document {
259///     fn name() -> &'static str {
260///         "documents"
261///     }
262///
263///     fn schema() -> Vec<Column> {
264///         vec![
265///             Column::new("id", "TEXT PRIMARY KEY"),
266///             Column::new("content", "TEXT"),
267///         ]
268///     }
269///
270///     fn id(&self) -> String {
271///         self.id.clone()
272///     }
273///
274///     fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)> {
275///         vec![
276///             ("id", Box::new(self.id.clone())),
277///             ("content", Box::new(self.content.clone())),
278///         ]
279///     }
280/// }
281///
282/// let conn = Connection::open("vector_store.db").await?;
283/// let openai_client = Client::new("YOUR_API_KEY");
284/// let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);
285///
286/// // Initialize vector store
287/// let vector_store = SqliteVectorStore::new(conn, &model).await?;
288///
289/// // Create documents
290/// let documents = vec![
291///     Document {
292///         id: "doc1".to_string(),
293///         content: "Example document 1".to_string(),
294///     },
295///     Document {
296///         id: "doc2".to_string(),
297///         content: "Example document 2".to_string(),
298///     },
299/// ];
300///
301/// // Generate embeddings
302/// let embeddings = EmbeddingsBuilder::new(model.clone())
303///     .documents(documents)?
304///     .build()
305///     .await?;
306///
307/// // Add to vector store
308/// vector_store.add_rows(embeddings).await?;
309///
310/// // Create index and search
311/// let index = vector_store.index(model);
312/// let results = index
313///     .top_n::<Document>("Example query", 2)
314///     .await?;
315/// ```
316pub struct SqliteVectorIndex<E: EmbeddingModel + 'static, T: SqliteVectorStoreTable + 'static> {
317    store: SqliteVectorStore<E, T>,
318    embedding_model: E,
319}
320
321impl<E: EmbeddingModel + 'static, T: SqliteVectorStoreTable> SqliteVectorIndex<E, T> {
322    pub fn new(embedding_model: E, store: SqliteVectorStore<E, T>) -> Self {
323        Self {
324            store,
325            embedding_model,
326        }
327    }
328}
329
330impl<E: EmbeddingModel + std::marker::Sync, T: SqliteVectorStoreTable> VectorStoreIndex
331    for SqliteVectorIndex<E, T>
332{
333    async fn top_n<D: for<'a> Deserialize<'a>>(
334        &self,
335        req: VectorSearchRequest,
336    ) -> Result<Vec<(f64, String, D)>, VectorStoreError> {
337        tracing::debug!("Finding top {} matches for query", req.samples() as usize);
338        let embedding = self.embedding_model.embed_text(req.query()).await?;
339        let query_vec: Vec<f32> = serialize_embedding(&embedding);
340        let table_name = T::name();
341
342        // Get all column names from SqliteVectorStoreTable
343        let columns = T::schema();
344        let column_names: Vec<&str> = columns.iter().map(|column| column.name).collect();
345
346        let rows = self
347            .store
348            .conn
349            .call(move |conn| {
350                // Build SELECT statement with all columns
351                let select_cols = column_names.join(", ");
352                let mut stmt = conn.prepare(&format!(
353                    "SELECT d.{select_cols}, e.distance
354                    FROM {table_name}_embeddings e
355                    JOIN {table_name} d ON e.rowid = d.rowid
356                    WHERE e.embedding MATCH ?1 AND k = ?2
357                    ORDER BY e.distance"
358                ))?;
359
360                let rows = stmt
361                    .query_map(
362                        rusqlite::params![query_vec.as_bytes().to_vec(), req.samples() as usize],
363                        |row| {
364                            // Create a map of column names to values
365                            let mut map = serde_json::Map::new();
366                            for (i, col_name) in column_names.iter().enumerate() {
367                                let value: String = row.get(i)?;
368                                map.insert(col_name.to_string(), serde_json::Value::String(value));
369                            }
370                            let distance: f64 = row.get(column_names.len())?;
371                            let id: String = row.get(0)?; // Assuming id is always first column
372
373                            Ok((id, serde_json::Value::Object(map), distance))
374                        },
375                    )?
376                    .collect::<Result<Vec<_>, _>>()?;
377                Ok(rows)
378            })
379            .await
380            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
381
382        debug!("Found {} potential matches", rows.len());
383        let mut top_n = Vec::new();
384        for (id, doc_value, distance) in rows {
385            match serde_json::from_value::<D>(doc_value) {
386                Ok(doc) => {
387                    top_n.push((distance, id, doc));
388                }
389                Err(e) => {
390                    debug!("Failed to deserialize document {}: {}", id, e);
391                    continue;
392                }
393            }
394        }
395
396        debug!("Returning {} matches", top_n.len());
397        Ok(top_n)
398    }
399
400    async fn top_n_ids(
401        &self,
402        req: VectorSearchRequest,
403    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
404        tracing::debug!(
405            "Finding top {} document IDs for query",
406            req.samples() as usize
407        );
408        let embedding = self.embedding_model.embed_text(req.query()).await?;
409        let query_vec = serialize_embedding(&embedding);
410        let table_name = T::name();
411
412        let results = self
413            .store
414            .conn
415            .call(move |conn| {
416                let mut stmt = conn.prepare(&format!(
417                    "SELECT d.id, e.distance
418                     FROM {table_name}_embeddings e
419                     JOIN {table_name} d ON e.rowid = d.rowid
420                     WHERE e.embedding MATCH ?1 AND k = ?2
421                     ORDER BY e.distance"
422                ))?;
423
424                let results = stmt
425                    .query_map(
426                        rusqlite::params![
427                            query_vec
428                                .iter()
429                                .flat_map(|x| x.to_le_bytes())
430                                .collect::<Vec<u8>>(),
431                            req.samples() as usize
432                        ],
433                        |row| Ok((row.get::<_, f64>(1)?, row.get::<_, String>(0)?)),
434                    )?
435                    .collect::<Result<Vec<_>, _>>()?;
436                Ok(results)
437            })
438            .await
439            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
440
441        debug!("Found {} matching document IDs", results.len());
442        Ok(results)
443    }
444}
445
446fn serialize_embedding(embedding: &Embedding) -> Vec<f32> {
447    embedding.vec.iter().map(|x| *x as f32).collect()
448}
449
450impl ColumnValue for String {
451    fn to_sql_string(&self) -> String {
452        self.clone()
453    }
454
455    fn column_type(&self) -> &'static str {
456        "TEXT"
457    }
458}