abu-rag 0.2.0

Retrieval-Augmented Generation system
Documentation
use std::{borrow::Cow, marker::PhantomData, path::Path, sync::Mutex};
use rusqlite::{params, Connection};
use serde::{de::DeserializeOwned, Serialize};
use crate::vectordb::VectorId;
use super::VectorStorage;

#[derive(Debug, thiserror::Error)]
pub enum SqliteStorageError {
    #[error("SQLite database error: {0}")]
    Db(#[from] rusqlite::Error),

    #[error("Serialization error: {0}")]
    Serde(#[from] serde_json::Error),
}

pub struct SqliteStorage<P> {
    conn: Mutex<Connection>,
    _marker: PhantomData<P>,
}

impl<P: Clone + 'static + Send + Sync> SqliteStorage<P> {
    pub fn open(db_path: impl AsRef<Path>) -> Result<Self, SqliteStorageError> {
        let conn = Connection::open(db_path)?;
        conn.execute(
            "CREATE TABLE IF NOT EXISTS vector_store (
                id INTEGER PRIMARY KEY,
                payload TEXT NOT NULL
            )",
            []
        )?;

        Ok(Self { conn: Mutex::new(conn), _marker: PhantomData })
    }
}

impl<P> VectorStorage for SqliteStorage<P> 
where 
    P: Clone + 'static + Send + Sync + Serialize + DeserializeOwned
{
    type Error = SqliteStorageError;
    type Payload = P;

    async fn add(&mut self, id: VectorId, payload: Self::Payload) -> Result<(), Self::Error> {
        let conn = self.conn.lock().unwrap();
        let payload_str = serde_json::to_string(&payload)?;

        conn.execute(
            "INSERT OR REPLACE INTO vector_store (id, payload) VALUES (?1, ?2)",
            params![id, payload_str],
        )?;

        Ok(())
    }

    async fn add_batch<IS, PS>(&mut self, ids: IS, payloads: PS) -> Result<(), Self::Error>
        where 
            IS: IntoIterator<Item = VectorId>,
            PS: IntoIterator<Item = Self::Payload>, 
    {
        let mut conn = self.conn.lock().unwrap();
        let tx = conn.transaction()?;
        {
            let mut stmt = tx.prepare("INSERT OR REPLACE INTO vector_store (id, payload) VALUES (?1, ?2)")?;
            for (id, payload) in ids.into_iter().zip(payloads.into_iter()) {
                let payload_str = serde_json::to_string(&payload)?;
                stmt.execute(params![id, payload_str])?;
            }
        }
        tx.commit()?;

        Ok(())
    }

    async fn delete(&mut self, id: VectorId) -> Result<(), Self::Error> {
        let conn = self.conn.lock().unwrap();
        conn.execute("DELETE FROM vector_store WHERE id = ?1", params![id])?;
        Ok(())
    }

    async fn delete_batch<IS>(&mut self, ids: IS) -> Result<(), Self::Error> 
        where 
            IS: IntoIterator<Item = VectorId>, 
    {
        let mut conn = self.conn.lock().unwrap();
        let tx = conn.transaction()?;
        {
            let mut stmt = tx.prepare("DELETE FROM vector_store WHERE id = ?1")?;
            for id in ids {
                stmt.execute(params![id])?;
            }
        }
        tx.commit()?;

        Ok(())   
    }

    async fn get(&self, id: VectorId) -> Result<Option<Cow<Self::Payload>>, Self::Error> {
        let conn = self.conn.lock().unwrap();
        let mut stmt = conn.prepare("SELECT payload FROM vector_store WHERE id = ?1")?;
        if let Some(row) = stmt.query(params![id])?.next()? {
            let payload_str: String = row.get(0)?;
            let payload: P = serde_json::from_str(&payload_str)?;
            Ok(Some(Cow::Owned(payload)))
        } else {
            Ok(None)
        }
    }

    async fn clear(&mut self) -> Result<(), Self::Error> {
        let conn = self.conn.lock().unwrap();
        conn.execute("DELETE FROM vector_store", [])?;
        Ok(())
    }
}