roboticus-db 0.11.4

SQLite persistence layer with 28 tables, FTS5 search, WAL mode, and migration system
Documentation
//! Tool description embedding cache.
//! Keyed by (tool_name, sha256(description)) — re-embedded only on description change.

use crate::{
    Database, DbResultExt,
    embeddings::{blob_to_embedding, embedding_to_blob},
};
use roboticus_core::Result;
use rusqlite::OptionalExtension;

/// Save a tool's embedding vector to the cache.
pub fn save_tool_embedding(
    db: &Database,
    tool_name: &str,
    desc_hash: &str,
    embedding: &[f32],
) -> Result<()> {
    let conn = db.conn();
    let blob = embedding_to_blob(embedding);
    conn.execute(
        "INSERT OR REPLACE INTO tool_embeddings (tool_name, description_hash, embedding, dimensions)
         VALUES (?1, ?2, ?3, ?4)",
        rusqlite::params![tool_name, desc_hash, blob, embedding.len() as i64],
    )
    .db_err()?;
    Ok(())
}

/// Load a tool's embedding vector from the cache.
/// Returns None if no entry exists for this (tool_name, desc_hash) pair.
pub fn get_tool_embedding(
    db: &Database,
    tool_name: &str,
    desc_hash: &str,
) -> Result<Option<Vec<f32>>> {
    let conn = db.conn();
    let result: Option<Vec<u8>> = conn
        .query_row(
            "SELECT embedding FROM tool_embeddings
             WHERE tool_name = ?1 AND description_hash = ?2",
            rusqlite::params![tool_name, desc_hash],
            |row| row.get(0),
        )
        .optional()
        .db_err()?;
    match result {
        Some(blob) => Ok(Some(blob_to_embedding(&blob))),
        None => Ok(None),
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    fn test_db() -> Database {
        let db = Database::new(":memory:").unwrap();
        db.conn()
            .execute_batch(
                "CREATE TABLE IF NOT EXISTS tool_embeddings (
                    tool_name TEXT NOT NULL,
                    description_hash TEXT NOT NULL,
                    embedding BLOB NOT NULL,
                    dimensions INTEGER NOT NULL,
                    created_at TEXT NOT NULL DEFAULT (datetime('now')),
                    PRIMARY KEY (tool_name, description_hash)
                )",
            )
            .unwrap();
        db
    }

    #[test]
    fn save_and_load_tool_embedding() {
        let db = test_db();
        let embedding = vec![0.1f32, 0.2, 0.3];
        save_tool_embedding(&db, "web_search", "abc123", &embedding).unwrap();
        let loaded = get_tool_embedding(&db, "web_search", "abc123").unwrap();
        assert!(loaded.is_some());
        let loaded = loaded.unwrap();
        assert_eq!(loaded.len(), 3);
        assert!((loaded[0] - 0.1).abs() < 1e-6);
    }

    #[test]
    fn stale_embedding_not_returned() {
        let db = test_db();
        let embedding = vec![0.1f32, 0.2, 0.3];
        save_tool_embedding(&db, "web_search", "old_hash", &embedding).unwrap();
        let loaded = get_tool_embedding(&db, "web_search", "new_hash").unwrap();
        assert!(loaded.is_none());
    }

    #[test]
    fn insert_or_replace_updates_embedding() {
        let db = test_db();
        let emb_v1 = vec![1.0f32, 0.0];
        let emb_v2 = vec![0.0f32, 1.0];
        save_tool_embedding(&db, "bash", "hash1", &emb_v1).unwrap();
        save_tool_embedding(&db, "bash", "hash1", &emb_v2).unwrap();
        let loaded = get_tool_embedding(&db, "bash", "hash1").unwrap().unwrap();
        assert!((loaded[0] - 0.0).abs() < 1e-6);
        assert!((loaded[1] - 1.0).abs() < 1e-6);
    }

    #[test]
    fn different_tools_same_hash_are_independent() {
        let db = test_db();
        let emb_a = vec![1.0f32, 0.0];
        let emb_b = vec![0.0f32, 1.0];
        save_tool_embedding(&db, "tool_a", "shared_hash", &emb_a).unwrap();
        save_tool_embedding(&db, "tool_b", "shared_hash", &emb_b).unwrap();
        let loaded_a = get_tool_embedding(&db, "tool_a", "shared_hash")
            .unwrap()
            .unwrap();
        let loaded_b = get_tool_embedding(&db, "tool_b", "shared_hash")
            .unwrap()
            .unwrap();
        assert!((loaded_a[0] - 1.0).abs() < 1e-6);
        assert!((loaded_b[1] - 1.0).abs() < 1e-6);
    }

    #[test]
    fn missing_tool_returns_none() {
        let db = test_db();
        let loaded = get_tool_embedding(&db, "nonexistent_tool", "any_hash").unwrap();
        assert!(loaded.is_none());
    }
}