use crate::{
Database, DbResultExt,
embeddings::{blob_to_embedding, embedding_to_blob},
};
use roboticus_core::Result;
use rusqlite::OptionalExtension;
pub fn save_tool_embedding(
db: &Database,
tool_name: &str,
desc_hash: &str,
embedding: &[f32],
) -> Result<()> {
let conn = db.conn();
let blob = embedding_to_blob(embedding);
conn.execute(
"INSERT OR REPLACE INTO tool_embeddings (tool_name, description_hash, embedding, dimensions)
VALUES (?1, ?2, ?3, ?4)",
rusqlite::params![tool_name, desc_hash, blob, embedding.len() as i64],
)
.db_err()?;
Ok(())
}
pub fn get_tool_embedding(
db: &Database,
tool_name: &str,
desc_hash: &str,
) -> Result<Option<Vec<f32>>> {
let conn = db.conn();
let result: Option<Vec<u8>> = conn
.query_row(
"SELECT embedding FROM tool_embeddings
WHERE tool_name = ?1 AND description_hash = ?2",
rusqlite::params![tool_name, desc_hash],
|row| row.get(0),
)
.optional()
.db_err()?;
match result {
Some(blob) => Ok(Some(blob_to_embedding(&blob))),
None => Ok(None),
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_db() -> Database {
let db = Database::new(":memory:").unwrap();
db.conn()
.execute_batch(
"CREATE TABLE IF NOT EXISTS tool_embeddings (
tool_name TEXT NOT NULL,
description_hash TEXT NOT NULL,
embedding BLOB NOT NULL,
dimensions INTEGER NOT NULL,
created_at TEXT NOT NULL DEFAULT (datetime('now')),
PRIMARY KEY (tool_name, description_hash)
)",
)
.unwrap();
db
}
#[test]
fn save_and_load_tool_embedding() {
let db = test_db();
let embedding = vec![0.1f32, 0.2, 0.3];
save_tool_embedding(&db, "web_search", "abc123", &embedding).unwrap();
let loaded = get_tool_embedding(&db, "web_search", "abc123").unwrap();
assert!(loaded.is_some());
let loaded = loaded.unwrap();
assert_eq!(loaded.len(), 3);
assert!((loaded[0] - 0.1).abs() < 1e-6);
}
#[test]
fn stale_embedding_not_returned() {
let db = test_db();
let embedding = vec![0.1f32, 0.2, 0.3];
save_tool_embedding(&db, "web_search", "old_hash", &embedding).unwrap();
let loaded = get_tool_embedding(&db, "web_search", "new_hash").unwrap();
assert!(loaded.is_none());
}
#[test]
fn insert_or_replace_updates_embedding() {
let db = test_db();
let emb_v1 = vec![1.0f32, 0.0];
let emb_v2 = vec![0.0f32, 1.0];
save_tool_embedding(&db, "bash", "hash1", &emb_v1).unwrap();
save_tool_embedding(&db, "bash", "hash1", &emb_v2).unwrap();
let loaded = get_tool_embedding(&db, "bash", "hash1").unwrap().unwrap();
assert!((loaded[0] - 0.0).abs() < 1e-6);
assert!((loaded[1] - 1.0).abs() < 1e-6);
}
#[test]
fn different_tools_same_hash_are_independent() {
let db = test_db();
let emb_a = vec![1.0f32, 0.0];
let emb_b = vec![0.0f32, 1.0];
save_tool_embedding(&db, "tool_a", "shared_hash", &emb_a).unwrap();
save_tool_embedding(&db, "tool_b", "shared_hash", &emb_b).unwrap();
let loaded_a = get_tool_embedding(&db, "tool_a", "shared_hash")
.unwrap()
.unwrap();
let loaded_b = get_tool_embedding(&db, "tool_b", "shared_hash")
.unwrap()
.unwrap();
assert!((loaded_a[0] - 1.0).abs() < 1e-6);
assert!((loaded_b[1] - 1.0).abs() < 1e-6);
}
#[test]
fn missing_tool_returns_none() {
let db = test_db();
let loaded = get_tool_embedding(&db, "nonexistent_tool", "any_hash").unwrap();
assert!(loaded.is_none());
}
}