funes-memory 1.0.0

Your machine's memory, queryable. Local AI memory for the terminal.
use anyhow::Result;
use rusqlite::{params, Connection};
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Chunk {
    pub id: String,
    pub source_path: String,
    pub content: String,
    pub embedding: Vec<f32>,
    pub timestamp: i64,
    pub chunk_type: String,
}

pub struct Store {
    conn: Connection,
    pub path: PathBuf,
}

impl Store {
    pub fn new(db_path: &Path) -> Result<Self> {
        if let Some(parent) = db_path.parent() {
            std::fs::create_dir_all(parent)?;
        }

        let conn = Connection::open(db_path)?;

        conn.execute_batch("
            CREATE TABLE IF NOT EXISTS chunks (
                id TEXT PRIMARY KEY,
                source_path TEXT NOT NULL,
                content TEXT NOT NULL,
                embedding TEXT NOT NULL,
                timestamp INTEGER NOT NULL,
                chunk_type TEXT NOT NULL
            );
            CREATE INDEX IF NOT EXISTS idx_source_path ON chunks(source_path);
            CREATE INDEX IF NOT EXISTS idx_timestamp ON chunks(timestamp);
        ")?;

        Ok(Store { conn, path: db_path.to_path_buf() })
    }

    pub fn insert(&self, chunk: &Chunk) -> Result<()> {
        let embedding_json = serde_json::to_string(&chunk.embedding)?;

        self.conn.execute(
            "INSERT OR REPLACE INTO chunks
             (id, source_path, content, embedding, timestamp, chunk_type)
             VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
            params![
                chunk.id,
                chunk.source_path,
                chunk.content,
                embedding_json,
                chunk.timestamp,
                chunk.chunk_type,
            ],
        )?;

        Ok(())
    }

    pub fn get_all(&self) -> Result<Vec<Chunk>> {
        let mut stmt = self.conn.prepare(
            "SELECT id, source_path, content, embedding, timestamp, chunk_type
             FROM chunks"
        )?;

        let chunks = stmt.query_map([], |row| {
            Ok((
                row.get::<_, String>(0)?,
                row.get::<_, String>(1)?,
                row.get::<_, String>(2)?,
                row.get::<_, String>(3)?,
                row.get::<_, i64>(4)?,
                row.get::<_, String>(5)?,
            ))
        })?
        .filter_map(|r| r.ok())
        .filter_map(|(id, source_path, content, embedding_json, timestamp, chunk_type)| {
            let embedding: Vec<f32> = serde_json::from_str(&embedding_json).ok()?;
            Some(Chunk { id, source_path, content, embedding, timestamp, chunk_type })
        })
        .collect();

        Ok(chunks)
    }

    pub fn delete_by_path(&self, path: &str) -> Result<()> {
        self.conn.execute(
            "DELETE FROM chunks WHERE source_path = ?1",
            params![path],
        )?;
        Ok(())
    }

    pub fn count(&self) -> Result<i64> {
        let count: i64 = self.conn.query_row(
            "SELECT COUNT(*) FROM chunks",
            [],
            |row| row.get(0),
        )?;
        Ok(count)
    }

    pub fn clear(&self) -> Result<()> {
        self.conn.execute_batch("DELETE FROM chunks;")?;
        Ok(())
    }
}   

#[cfg(test)]
mod tests {
    use super::*;
    use std::path::PathBuf;

    fn temp_store() -> Store {
        let path = PathBuf::from(format!(
            "test_{}.db",
            uuid::Uuid::new_v4()
        ));
        Store::new(&path).unwrap()
    }

    fn cleanup(store: &Store) {
        let path = store.path.clone();
        drop(store);
        let _ = std::fs::remove_file(path);
    }

    #[test]
    fn test_insert_and_count() {
        let store = temp_store();
        let chunk = Chunk {
            id: "test-1".to_string(),
            source_path: "/test/file.rs".to_string(),
            content: "fn main() {}".to_string(),
            embedding: vec![0.1, 0.2, 0.3],
            timestamp: 0,
            chunk_type: "code".to_string(),
        };
        store.insert(&chunk).unwrap();
        assert_eq!(store.count().unwrap(), 1);
    }

    #[test]
    fn test_get_all_returns_inserted() {
        let store = temp_store();
        let chunk = Chunk {
            id: "test-2".to_string(),
            source_path: "/test/file.md".to_string(),
            content: "hello world".to_string(),
            embedding: vec![0.5, 0.5],
            timestamp: 0,
            chunk_type: "markdown".to_string(),
        };
        store.insert(&chunk).unwrap();
        let all = store.get_all().unwrap();
        assert_eq!(all.len(), 1);
        assert_eq!(all[0].content, "hello world");
    }

    #[test]
    fn test_delete_by_path() {
        let store = temp_store();
        let chunk = Chunk {
            id: "test-3".to_string(),
            source_path: "/test/delete.rs".to_string(),
            content: "to be deleted".to_string(),
            embedding: vec![0.1],
            timestamp: 0,
            chunk_type: "code".to_string(),
        };
        store.insert(&chunk).unwrap();
        assert_eq!(store.count().unwrap(), 1);
        store.delete_by_path("/test/delete.rs").unwrap();
        assert_eq!(store.count().unwrap(), 0);
    }

    #[test]
    fn test_clear() {
        let store = temp_store();
        for i in 0..5 {
            let chunk = Chunk {
                id: format!("test-{}", i),
                source_path: "/test/file.rs".to_string(),
                content: format!("content {}", i),
                embedding: vec![0.1],
                timestamp: 0,
                chunk_type: "code".to_string(),
            };
            store.insert(&chunk).unwrap();
        }
        assert_eq!(store.count().unwrap(), 5);
        store.clear().unwrap();
        assert_eq!(store.count().unwrap(), 0);
    }

    #[test]
    fn test_insert_or_replace() {
        let store = temp_store();
        let chunk = Chunk {
            id: "same-id".to_string(),
            source_path: "/test/file.rs".to_string(),
            content: "original".to_string(),
            embedding: vec![0.1],
            timestamp: 0,
            chunk_type: "code".to_string(),
        };
        store.insert(&chunk).unwrap();
        let updated = Chunk {
            id: "same-id".to_string(),
            content: "updated".to_string(),
            ..chunk
        };
        store.insert(&updated).unwrap();
        assert_eq!(store.count().unwrap(), 1);
        assert_eq!(store.get_all().unwrap()[0].content, "updated");
    }
}