engram-storage 0.3.0

SQLite storage with FTS5
Documentation
use rusqlite::params;
use serde::{Deserialize, Serialize};

use crate::database::Database;
use crate::error::StorageError;

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Memory {
    pub id: String,
    pub memory_type: String,
    pub context: String,
    pub action: String,
    pub result: String,
    pub score: f32,
    pub embedding_context: Option<Vec<u8>>,
    pub embedding_action: Option<Vec<u8>>,
    pub embedding_result: Option<Vec<u8>>,
    pub indexed: bool,
    pub tags: Option<String>,
    pub project: Option<String>,
    pub parent_id: Option<String>,
    pub source_ids: Option<String>,
    pub insight_type: Option<String>,
    pub created_at: String,
    pub updated_at: String,
    pub used_count: i64,
    pub last_used_at: Option<String>,
    pub superseded_by: Option<String>,
}

pub fn row_to_memory(row: &rusqlite::Row) -> rusqlite::Result<Memory> {
    Ok(Memory {
        id: row.get("id")?,
        memory_type: row.get("memory_type")?,
        context: row.get("context")?,
        action: row.get("action")?,
        result: row.get("result")?,
        score: row.get("score")?,
        embedding_context: row.get("embedding_context")?,
        embedding_action: row.get("embedding_action")?,
        embedding_result: row.get("embedding_result")?,
        indexed: row.get("indexed")?,
        tags: row.get("tags")?,
        project: row.get("project")?,
        parent_id: row.get("parent_id")?,
        source_ids: row.get("source_ids")?,
        insight_type: row.get("insight_type")?,
        created_at: row.get("created_at")?,
        updated_at: row.get("updated_at")?,
        used_count: row.get("used_count")?,
        last_used_at: row.get("last_used_at")?,
        superseded_by: row.get("superseded_by")?,
    })
}

const INSERT_SQL: &str = r#"
    INSERT INTO memories (
        id, memory_type, context, action, result, score,
        embedding_context, embedding_action, embedding_result,
        indexed, tags, project, parent_id, source_ids, insight_type,
        created_at, updated_at, used_count, last_used_at, superseded_by
    ) VALUES (
        ?1, ?2, ?3, ?4, ?5, ?6,
        ?7, ?8, ?9,
        ?10, ?11, ?12, ?13, ?14, ?15,
        ?16, ?17, ?18, ?19, ?20
    )
"#;

impl Database {
    pub fn insert_memory(&self, memory: &Memory) -> Result<(), StorageError> {
        self.connection()
            .execute(
                INSERT_SQL,
                params![
                    memory.id,
                    memory.memory_type,
                    memory.context,
                    memory.action,
                    memory.result,
                    memory.score,
                    memory.embedding_context,
                    memory.embedding_action,
                    memory.embedding_result,
                    memory.indexed,
                    memory.tags,
                    memory.project,
                    memory.parent_id,
                    memory.source_ids,
                    memory.insight_type,
                    memory.created_at,
                    memory.updated_at,
                    memory.used_count,
                    memory.last_used_at,
                    memory.superseded_by,
                ],
            )
            .map_err(|error| match error {
                rusqlite::Error::SqliteFailure(sql_error, _)
                    if sql_error.extended_code == rusqlite::ffi::SQLITE_CONSTRAINT_PRIMARYKEY
                        || sql_error.extended_code == rusqlite::ffi::SQLITE_CONSTRAINT_UNIQUE =>
                {
                    StorageError::DuplicateKey(format!("memory id={}", memory.id))
                }
                other => StorageError::Sqlite(other),
            })?;
        Ok(())
    }

    pub fn get_memory(&self, id: &str) -> Result<Memory, StorageError> {
        self.connection()
            .query_row(
                "SELECT * FROM memories WHERE id = ?1",
                params![id],
                row_to_memory,
            )
            .map_err(|error| match error {
                rusqlite::Error::QueryReturnedNoRows => {
                    StorageError::NotFound(format!("memory id={id}"))
                }
                other => StorageError::Sqlite(other),
            })
    }

    pub fn set_memory_indexed(&self, id: &str, indexed: bool) -> Result<(), StorageError> {
        let affected = self.connection().execute(
            "UPDATE memories SET indexed = ?1 WHERE id = ?2",
            params![indexed, id],
        )?;
        if affected == 0 {
            return Err(StorageError::NotFound(format!("memory id={id}")));
        }
        Ok(())
    }

    pub fn set_memory_embeddings(
        &self,
        id: &str,
        embedding_context: &[u8],
        embedding_action: &[u8],
        embedding_result: &[u8],
    ) -> Result<(), StorageError> {
        let affected = self.connection().execute(
            "UPDATE memories
             SET embedding_context = ?1,
                 embedding_action = ?2,
                 embedding_result = ?3
             WHERE id = ?4",
            params![embedding_context, embedding_action, embedding_result, id],
        )?;
        if affected == 0 {
            return Err(StorageError::NotFound(format!("memory id={id}")));
        }
        Ok(())
    }

    pub fn set_memory_score(&self, id: &str, score: f32) -> Result<(), StorageError> {
        let affected = self.connection().execute(
            "UPDATE memories SET score = ?1 WHERE id = ?2",
            params![score, id],
        )?;
        if affected == 0 {
            return Err(StorageError::NotFound(format!("memory id={id}")));
        }
        Ok(())
    }

    pub fn touch_memory(&self, id: &str, timestamp: &str) -> Result<(), StorageError> {
        let affected = self.connection().execute(
            "UPDATE memories SET used_count = used_count + 1, last_used_at = ?1 WHERE id = ?2",
            params![timestamp, id],
        )?;
        if affected == 0 {
            return Err(StorageError::NotFound(format!("memory id={id}")));
        }
        Ok(())
    }

    pub fn set_superseded_by(&self, id: &str, superseded_by: &str) -> Result<(), StorageError> {
        let affected = self.connection().execute(
            "UPDATE memories SET superseded_by = ?1 WHERE id = ?2",
            params![superseded_by, id],
        )?;
        if affected == 0 {
            return Err(StorageError::NotFound(format!("memory id={id}")));
        }
        Ok(())
    }

    pub fn delete_memory(&self, id: &str) -> Result<(), StorageError> {
        let affected = self
            .connection()
            .execute("DELETE FROM memories WHERE id = ?1", params![id])?;
        if affected == 0 {
            return Err(StorageError::NotFound(format!("memory id={id}")));
        }
        Ok(())
    }

    pub fn bulk_insert_memories(&self, memories: &[Memory]) -> Result<usize, StorageError> {
        let transaction = self.connection().unchecked_transaction()?;
        let mut statement = transaction.prepare(INSERT_SQL)?;
        let mut count = 0;
        for memory in memories {
            statement.execute(params![
                memory.id,
                memory.memory_type,
                memory.context,
                memory.action,
                memory.result,
                memory.score,
                memory.embedding_context,
                memory.embedding_action,
                memory.embedding_result,
                memory.indexed,
                memory.tags,
                memory.project,
                memory.parent_id,
                memory.source_ids,
                memory.insight_type,
                memory.created_at,
                memory.updated_at,
                memory.used_count,
                memory.last_used_at,
                memory.superseded_by,
            ])?;
            count += 1;
        }
        drop(statement);
        transaction.commit()?;
        Ok(count)
    }

    pub fn list_all_memories(&self) -> Result<Vec<Memory>, StorageError> {
        let mut statement = self
            .connection()
            .prepare("SELECT * FROM memories WHERE superseded_by IS NULL")?;
        let rows = statement.query_map([], row_to_memory)?;
        let mut results = Vec::new();
        for row in rows {
            results.push(row?);
        }
        Ok(results)
    }

    pub fn get_unindexed_memories(&self, limit: usize) -> Result<Vec<Memory>, StorageError> {
        let mut statement = self
            .connection()
            .prepare("SELECT * FROM memories WHERE indexed = FALSE LIMIT ?1")?;
        let rows = statement.query_map(params![limit as i64], row_to_memory)?;
        let mut results = Vec::new();
        for row in rows {
            results.push(row?);
        }
        Ok(results)
    }

    pub fn get_indexed_memory_ids(&self) -> Result<Vec<String>, StorageError> {
        let mut statement = self
            .connection()
            .prepare("SELECT id FROM memories WHERE indexed = TRUE")?;
        let rows = statement.query_map([], |row| row.get::<_, String>(0))?;
        let mut results = Vec::new();
        for row in rows {
            results.push(row?);
        }
        Ok(results)
    }
}