use rusqlite::params;
use crate::embeddings::pooling::{bytes_to_vec, cosine_similarity};
use crate::error::Result;
use crate::search::fts5::FtsResult;
use crate::storage::Database;
pub struct VectorSearch;
impl VectorSearch {
pub fn search(
db: &Database,
query_vector: &[f32],
model_name: &str,
limit: usize,
) -> Result<Vec<FtsResult>> {
db.with_reader(|conn| {
let mut stmt = conn.prepare(
"SELECT memory_id, embedding FROM memory_vectors WHERE model_name = ?1",
)?;
let mut scored: Vec<(i64, f32)> = stmt
.query_map(params![model_name], |row| {
let id: i64 = row.get(0)?;
let blob: Vec<u8> = row.get(1)?;
Ok((id, blob))
})?
.filter_map(|r| r.ok())
.map(|(id, blob)| {
let stored = bytes_to_vec(&blob);
let sim = cosine_similarity(query_vector, &stored);
(id, sim)
})
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(limit);
Ok(scored
.into_iter()
.map(|(id, score)| FtsResult {
memory_id: id,
score,
})
.collect())
})
}
pub fn store_vector(
db: &Database,
memory_id: i64,
embedding: &[f32],
model_name: &str,
content_hash: &str,
) -> Result<()> {
let blob = crate::embeddings::pooling::vec_to_bytes(embedding);
let dims = embedding.len() as i32;
db.with_writer(|conn| {
conn.execute(
"INSERT OR REPLACE INTO memory_vectors (memory_id, embedding, model_name, dimensions, content_hash)
VALUES (?1, ?2, ?3, ?4, ?5)",
params![memory_id, blob, model_name, dims, content_hash],
)?;
Ok(())
})
}
pub fn vector_exists(db: &Database, content_hash: &str) -> Result<bool> {
db.with_reader(|conn| {
let count: i32 = conn.query_row(
"SELECT COUNT(*) FROM memory_vectors WHERE content_hash = ?1",
[content_hash],
|row| row.get(0),
)?;
Ok(count > 0)
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::embeddings::pooling::normalize_l2;
use crate::storage::migrations;
fn setup() -> Database {
let db = Database::open_in_memory().expect("open");
db.with_writer(|conn| {
migrations::migrate(conn)?;
Ok(())
})
.expect("migrate");
db.with_writer(|conn| {
conn.execute(
"INSERT INTO memories (id, searchable_text, memory_type, content_hash, record_json)
VALUES (1, 'auth error', 'semantic', 'h1', '{}')",
[],
)?;
conn.execute(
"INSERT INTO memories (id, searchable_text, memory_type, content_hash, record_json)
VALUES (2, 'db timeout', 'episodic', 'h2', '{}')",
[],
)?;
conn.execute(
"INSERT INTO memories (id, searchable_text, memory_type, content_hash, record_json)
VALUES (3, 'build fix', 'procedural', 'h3', '{}')",
[],
)?;
Ok(())
})
.expect("insert");
db
}
#[test]
fn store_and_search_vectors() {
let db = setup();
let v1 = normalize_l2(&[1.0, 0.0, 0.0]);
let v2 = normalize_l2(&[0.0, 1.0, 0.0]);
let v3 = normalize_l2(&[0.9, 0.1, 0.0]);
VectorSearch::store_vector(&db, 1, &v1, "test-model", "h1").expect("store 1");
VectorSearch::store_vector(&db, 2, &v2, "test-model", "h2").expect("store 2");
VectorSearch::store_vector(&db, 3, &v3, "test-model", "h3").expect("store 3");
let query = normalize_l2(&[1.0, 0.0, 0.0]);
let results = VectorSearch::search(&db, &query, "test-model", 10).expect("search");
assert_eq!(results.len(), 3);
assert_eq!(results[0].memory_id, 1);
assert!((results[0].score - 1.0).abs() < 0.01);
assert_eq!(results[1].memory_id, 3);
}
#[test]
fn model_name_filter() {
let db = setup();
let v = normalize_l2(&[1.0, 0.0, 0.0]);
VectorSearch::store_vector(&db, 1, &v, "model-a", "h1").expect("store");
VectorSearch::store_vector(&db, 2, &v, "model-b", "h2").expect("store");
let results = VectorSearch::search(&db, &v, "model-a", 10).expect("search");
assert_eq!(results.len(), 1);
assert_eq!(results[0].memory_id, 1);
}
#[test]
fn vector_exists_check() {
let db = setup();
let v = normalize_l2(&[1.0, 0.0, 0.0]);
assert!(!VectorSearch::vector_exists(&db, "h1").expect("exists"));
VectorSearch::store_vector(&db, 1, &v, "test", "h1").expect("store");
assert!(VectorSearch::vector_exists(&db, "h1").expect("exists"));
}
#[test]
fn limit_respected() {
let db = setup();
let v = normalize_l2(&[1.0, 0.0, 0.0]);
for i in 1..=3 {
VectorSearch::store_vector(&db, i, &v, "test", &format!("h{i}")).expect("store");
}
let results = VectorSearch::search(&db, &v, "test", 2).expect("search");
assert_eq!(results.len(), 2);
}
}