use crate::error::Result;
use rusqlite::Connection;
fn f32_slice_to_bytes(vec: &[f32]) -> &[u8] {
unsafe { std::slice::from_raw_parts(vec.as_ptr() as *const u8, vec.len() * 4) }
}
#[allow(dead_code)]
pub fn init_vec_extension() {
unsafe {
#[allow(clippy::missing_transmute_annotations)]
let func = std::mem::transmute(sqlite_vec::sqlite3_vec_init as *const ());
rusqlite::ffi::sqlite3_auto_extension(Some(func));
}
}
#[allow(dead_code)]
pub fn create_vec_table(conn: &Connection, dimensions: usize) -> Result<()> {
conn.execute_batch(&format!(
"CREATE VIRTUAL TABLE IF NOT EXISTS vec_episodes USING vec0(\
episode_id INTEGER PRIMARY KEY, \
embedding float[{dimensions}] distance_metric=cosine\
)"
))?;
Ok(())
}
pub fn upsert_vec(conn: &Connection, node_id: i64, embedding: &[f32]) -> Result<()> {
conn.execute(
"INSERT OR REPLACE INTO vec_episodes(episode_id, embedding) VALUES (?, ?)",
rusqlite::params![node_id, f32_slice_to_bytes(embedding)],
)?;
Ok(())
}
pub fn knn_search(conn: &Connection, query: &[f32], limit: usize) -> Result<Vec<(i64, f32)>> {
let mut stmt = conn.prepare(
"SELECT episode_id, distance FROM vec_episodes \
WHERE embedding MATCH ? ORDER BY distance LIMIT ?",
)?;
let rows = stmt.query_map(
rusqlite::params![f32_slice_to_bytes(query), limit as i64],
|row| {
let id: i64 = row.get(0)?;
let distance: f64 = row.get(1)?;
Ok((id, 1.0f32 - distance as f32))
},
)?;
let mut results = Vec::new();
for row in rows {
results.push(row?);
}
Ok(results)
}
#[allow(dead_code)]
pub fn delete_vec(conn: &Connection, node_id: i64) -> Result<()> {
conn.execute(
"DELETE FROM vec_episodes WHERE episode_id = ?",
rusqlite::params![node_id],
)?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
fn setup_vec_db(dimensions: usize) -> Connection {
init_vec_extension();
let conn = Connection::open_in_memory().unwrap();
create_vec_table(&conn, dimensions).unwrap();
conn
}
#[test]
fn test_vec_table_creation() {
let _conn = setup_vec_db(3);
}
#[test]
fn test_vec_upsert_and_search() {
let conn = setup_vec_db(3);
let v1 = [1.0f32, 0.0, 0.0];
let v2 = [0.7f32, 0.7, 0.0];
let v3 = [0.0f32, 0.0, 1.0];
upsert_vec(&conn, 1, &v1).unwrap();
upsert_vec(&conn, 2, &v2).unwrap();
upsert_vec(&conn, 3, &v3).unwrap();
let query = [1.0f32, 0.0, 0.0];
let results = knn_search(&conn, &query, 10).unwrap();
assert_eq!(results.len(), 3, "should return all 3 vectors");
assert_eq!(results[0].0, 1, "nearest should be episode 1");
assert_eq!(results[1].0, 2, "second nearest should be episode 2");
assert!(results[0].1 >= results[1].1);
assert!(results[1].1 >= results[2].1);
}
#[test]
fn test_vec_search_respects_limit() {
let conn = setup_vec_db(3);
for i in 1..=5 {
upsert_vec(&conn, i, &[1.0, 0.0, (i as f32) * 0.01]).unwrap();
}
let results = knn_search(&conn, &[1.0, 0.0, 0.0], 2).unwrap();
assert_eq!(results.len(), 2, "should respect the limit of 2");
}
#[test]
fn test_vec_delete() {
let conn = setup_vec_db(3);
upsert_vec(&conn, 1, &[1.0, 0.0, 0.0]).unwrap();
upsert_vec(&conn, 2, &[0.0, 1.0, 0.0]).unwrap();
delete_vec(&conn, 1).unwrap();
let results = knn_search(&conn, &[1.0, 0.0, 0.0], 10).unwrap();
assert_eq!(results.len(), 1, "deleted vector should not appear");
assert_eq!(results[0].0, 2, "only episode 2 should remain");
}
#[test]
fn test_vec_similarity_score() {
let conn = setup_vec_db(3);
let v = [1.0f32, 0.0, 0.0];
upsert_vec(&conn, 1, &v).unwrap();
let results = knn_search(&conn, &v, 1).unwrap();
assert_eq!(results.len(), 1);
assert!(
(results[0].1 - 1.0f32).abs() < 1e-5,
"identical vector should have similarity ~1.0, got {}",
results[0].1
);
}
}