use bep::embeddings::{Embedding, EmbeddingModel};
use bep::vector_store::{VectorStoreError, VectorStoreIndex};
use bep::OneOrMany;
use serde::Deserialize;
use std::marker::PhantomData;
use tokio_rusqlite::Connection;
use tracing::{debug, info};
use zerocopy::IntoBytes;
#[derive(Debug)]
pub enum SqliteError {
DatabaseError(Box<dyn std::error::Error + Send + Sync>),
SerializationError(Box<dyn std::error::Error + Send + Sync>),
InvalidColumnType(String),
}
pub trait ColumnValue: Send + Sync {
fn to_sql_string(&self) -> String;
fn column_type(&self) -> &'static str;
}
pub struct Column {
name: &'static str,
col_type: &'static str,
indexed: bool,
}
impl Column {
pub fn new(name: &'static str, col_type: &'static str) -> Self {
Self {
name,
col_type,
indexed: false,
}
}
pub fn indexed(mut self) -> Self {
self.indexed = true;
self
}
}
pub trait SqliteVectorStoreTable: Send + Sync + Clone {
fn name() -> &'static str;
fn schema() -> Vec<Column>;
fn id(&self) -> String;
fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)>;
}
#[derive(Clone)]
pub struct SqliteVectorStore<E: EmbeddingModel + 'static, T: SqliteVectorStoreTable + 'static> {
conn: Connection,
_phantom: PhantomData<(E, T)>,
}
impl<E: EmbeddingModel + 'static, T: SqliteVectorStoreTable + 'static> SqliteVectorStore<E, T> {
pub async fn new(conn: Connection, embedding_model: &E) -> Result<Self, VectorStoreError> {
let dims = embedding_model.ndims();
let table_name = T::name();
let schema = T::schema();
let mut create_table = format!("CREATE TABLE IF NOT EXISTS {} (", table_name);
let mut first = true;
for column in &schema {
if !first {
create_table.push(',');
}
create_table.push_str(&format!("\n {} {}", column.name, column.col_type));
first = false;
}
create_table.push_str("\n)");
let mut create_indexes = vec![format!(
"CREATE INDEX IF NOT EXISTS idx_{}_id ON {}(id)",
table_name, table_name
)];
for column in schema {
if column.indexed {
create_indexes.push(format!(
"CREATE INDEX IF NOT EXISTS idx_{}_{} ON {}({})",
table_name, column.name, table_name, column.name
));
}
}
conn.call(move |conn| {
conn.execute_batch("BEGIN")?;
conn.execute_batch(&create_table)?;
for index_stmt in create_indexes {
conn.execute_batch(&index_stmt)?;
}
conn.execute_batch(&format!(
"CREATE VIRTUAL TABLE IF NOT EXISTS {}_embeddings USING vec0(embedding float[{}])",
table_name, dims
))?;
conn.execute_batch("COMMIT")?;
Ok(())
})
.await
.map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
Ok(Self {
conn,
_phantom: PhantomData,
})
}
pub fn index(self, model: E) -> SqliteVectorIndex<E, T> {
SqliteVectorIndex::new(model, self)
}
pub fn add_rows_with_txn(
&self,
txn: &rusqlite::Transaction<'_>,
documents: Vec<(T, OneOrMany<Embedding>)>,
) -> Result<i64, tokio_rusqlite::Error> {
info!("Adding {} documents to store", documents.len());
let table_name = T::name();
let mut last_id = 0;
for (doc, embeddings) in &documents {
debug!("Storing document with id {}", doc.id());
let values = doc.column_values();
let columns = values.iter().map(|(col, _)| *col).collect::<Vec<_>>();
let placeholders = (1..=values.len())
.map(|i| format!("?{}", i))
.collect::<Vec<_>>();
let insert_sql = format!(
"INSERT OR REPLACE INTO {} ({}) VALUES ({})",
table_name,
columns.join(", "),
placeholders.join(", ")
);
txn.execute(
&insert_sql,
rusqlite::params_from_iter(values.iter().map(|(_, val)| val.to_sql_string())),
)?;
last_id = txn.last_insert_rowid();
let embeddings_sql = format!(
"INSERT INTO {}_embeddings (rowid, embedding) VALUES (?1, ?2)",
table_name
);
let mut stmt = txn.prepare(&embeddings_sql)?;
for (i, embedding) in embeddings.iter().enumerate() {
let vec = serialize_embedding(embedding);
debug!(
"Storing embedding {} of {} (size: {} bytes)",
i + 1,
embeddings.len(),
vec.len() * 4
);
let blob = rusqlite::types::Value::Blob(vec.as_bytes().to_vec());
stmt.execute(rusqlite::params![last_id, blob])?;
}
}
Ok(last_id)
}
pub async fn add_rows(
&self,
documents: Vec<(T, OneOrMany<Embedding>)>,
) -> Result<i64, VectorStoreError> {
let documents = documents.clone();
let this = self.clone();
self.conn
.call(move |conn| {
let tx = conn.transaction().map_err(tokio_rusqlite::Error::from)?;
let result = this.add_rows_with_txn(&tx, documents)?;
tx.commit().map_err(tokio_rusqlite::Error::from)?;
Ok(result)
})
.await
.map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))
}
}
pub struct SqliteVectorIndex<E: EmbeddingModel + 'static, T: SqliteVectorStoreTable + 'static> {
store: SqliteVectorStore<E, T>,
embedding_model: E,
}
impl<E: EmbeddingModel + 'static, T: SqliteVectorStoreTable> SqliteVectorIndex<E, T> {
pub fn new(embedding_model: E, store: SqliteVectorStore<E, T>) -> Self {
Self {
store,
embedding_model,
}
}
}
impl<E: EmbeddingModel + std::marker::Sync, T: SqliteVectorStoreTable> VectorStoreIndex
for SqliteVectorIndex<E, T>
{
async fn top_n<D: for<'a> Deserialize<'a>>(
&self,
query: &str,
n: usize,
) -> Result<Vec<(f64, String, D)>, VectorStoreError> {
debug!("Finding top {} matches for query", n);
let embedding = self.embedding_model.embed_text(query).await?;
let query_vec: Vec<f32> = serialize_embedding(&embedding);
let table_name = T::name();
let columns = T::schema();
let column_names: Vec<&str> = columns.iter().map(|column| column.name).collect();
let rows = self
.store
.conn
.call(move |conn| {
let select_cols = column_names.join(", ");
let mut stmt = conn.prepare(&format!(
"SELECT d.{}, e.distance
FROM {}_embeddings e
JOIN {} d ON e.rowid = d.rowid
WHERE e.embedding MATCH ?1 AND k = ?2
ORDER BY e.distance",
select_cols, table_name, table_name
))?;
let rows = stmt
.query_map(rusqlite::params![query_vec.as_bytes().to_vec(), n], |row| {
let mut map = serde_json::Map::new();
for (i, col_name) in column_names.iter().enumerate() {
let value: String = row.get(i)?;
map.insert(col_name.to_string(), serde_json::Value::String(value));
}
let distance: f64 = row.get(column_names.len())?;
let id: String = row.get(0)?;
Ok((id, serde_json::Value::Object(map), distance))
})?
.collect::<Result<Vec<_>, _>>()?;
Ok(rows)
})
.await
.map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
debug!("Found {} potential matches", rows.len());
let mut top_n = Vec::new();
for (id, doc_value, distance) in rows {
match serde_json::from_value::<D>(doc_value) {
Ok(doc) => {
top_n.push((distance, id, doc));
}
Err(e) => {
debug!("Failed to deserialize document {}: {}", id, e);
continue;
}
}
}
debug!("Returning {} matches", top_n.len());
Ok(top_n)
}
async fn top_n_ids(
&self,
query: &str,
n: usize,
) -> Result<Vec<(f64, String)>, VectorStoreError> {
debug!("Finding top {} document IDs for query", n);
let embedding = self.embedding_model.embed_text(query).await?;
let query_vec = serialize_embedding(&embedding);
let table_name = T::name();
let results = self
.store
.conn
.call(move |conn| {
let mut stmt = conn.prepare(&format!(
"SELECT d.id, e.distance
FROM {0}_embeddings e
JOIN {0} d ON e.rowid = d.rowid
WHERE e.embedding MATCH ?1 AND k = ?2
ORDER BY e.distance",
table_name
))?;
let results = stmt
.query_map(
rusqlite::params![
query_vec
.iter()
.flat_map(|x| x.to_le_bytes())
.collect::<Vec<u8>>(),
n
],
|row| Ok((row.get::<_, f64>(1)?, row.get::<_, String>(0)?)),
)?
.collect::<Result<Vec<_>, _>>()?;
Ok(results)
})
.await
.map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
debug!("Found {} matching document IDs", results.len());
Ok(results)
}
}
fn serialize_embedding(embedding: &Embedding) -> Vec<f32> {
embedding.vec.iter().map(|x| *x as f32).collect()
}
impl ColumnValue for String {
fn to_sql_string(&self) -> String {
self.clone()
}
fn column_type(&self) -> &'static str {
"TEXT"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Column, ColumnValue, SqliteVectorStore, SqliteVectorStoreTable};
use bep::{
embeddings::EmbeddingsBuilder,
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
Embed,
};
use rusqlite::ffi::sqlite3_auto_extension;
use sqlite_vec::sqlite3_vec_init;
use tokio_rusqlite::Connection;
#[derive(Embed, Clone, Debug, Deserialize)]
struct TestDocument {
id: String,
#[embed]
content: String,
}
impl SqliteVectorStoreTable for TestDocument {
fn name() -> &'static str {
"test_documents"
}
fn schema() -> Vec<Column> {
vec![
Column::new("id", "TEXT PRIMARY KEY"),
Column::new("content", "TEXT"),
]
}
fn id(&self) -> String {
self.id.clone()
}
fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)> {
vec![
("id", Box::new(self.id.clone())),
("content", Box::new(self.content.clone())),
]
}
}
#[tokio::test]
async fn test_vector_search() -> Result<(), anyhow::Error> {
unsafe {
sqlite3_auto_extension(Some(std::mem::transmute(sqlite3_vec_init as *const ())));
}
let conn = Connection::open(":memory:").await?;
let openai_api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
let openai_client = Client::new(&openai_api_key);
let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);
let documents = vec![
TestDocument {
id: "doc0".to_string(),
content: "The quick brown fox jumps over the lazy dog".to_string(),
},
TestDocument {
id: "doc1".to_string(),
content: "The lazy dog sleeps while the quick brown fox runs".to_string(),
},
];
let embeddings = EmbeddingsBuilder::new(model.clone())
.documents(documents)?
.build()
.await?;
let vector_store = SqliteVectorStore::new(conn, &model).await?;
vector_store.add_rows(embeddings).await?;
let index = vector_store.index(model);
let results = index
.top_n::<TestDocument>("The quick brown fox jumps over the lazy dog", 1)
.await?;
assert_eq!(results.len(), 1);
let id_results = index
.top_n_ids("The quick brown fox jumps over the lazy dog", 1)
.await?;
assert_eq!(id_results.len(), 1);
Ok(())
}
}