use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard};
use instant_distance::{Builder, HnswMap, Search};
use tracing::warn;
use crate::Database;
use crate::embeddings::{blob_to_embedding, cosine_similarity};
fn read_or_recover<T>(lock: &RwLock<T>) -> RwLockReadGuard<'_, T> {
lock.read().unwrap_or_else(|e| {
warn!("ANN index RwLock poisoned on read, recovering with potentially stale data");
e.into_inner()
})
}
fn write_or_poison_err<T>(lock: &RwLock<T>) -> roboticus_core::Result<RwLockWriteGuard<'_, T>> {
lock.write().map_err(|_| {
roboticus_core::RoboticusError::Database(
"ANN index RwLock poisoned, cannot write; rebuild required".into(),
)
})
}
#[derive(Clone)]
struct EmbeddingPoint(Vec<f32>);
impl instant_distance::Point for EmbeddingPoint {
fn distance(&self, other: &Self) -> f32 {
1.0 - cosine_sim_f32(&self.0, &other.0)
}
}
fn cosine_sim_f32(a: &[f32], b: &[f32]) -> f32 {
cosine_similarity(a, b) as f32
}
#[derive(Clone)]
struct IndexEntry {
source_table: String,
source_id: String,
content_preview: String,
}
pub struct AnnIndex {
inner: Arc<RwLock<Option<IndexState>>>,
enabled: bool,
pub min_entries_for_index: usize,
}
struct IndexState {
hnsw: HnswMap<EmbeddingPoint, usize>,
entries: Vec<IndexEntry>,
}
pub struct AnnSearchResult {
pub source_table: String,
pub source_id: String,
pub content_preview: String,
pub similarity: f64,
}
const DEFAULT_MIN_ENTRIES: usize = 100;
impl AnnIndex {
pub fn new(enabled: bool) -> Self {
Self {
inner: Arc::new(RwLock::new(None)),
enabled,
min_entries_for_index: DEFAULT_MIN_ENTRIES,
}
}
pub fn build_from_db(&self, db: &Database) -> roboticus_core::Result<usize> {
if !self.enabled {
return Ok(0);
}
let conn = db.conn();
let mut stmt = conn
.prepare(
"SELECT source_table, source_id, content_preview, embedding_blob \
FROM embeddings",
)
.map_err(|e| roboticus_core::RoboticusError::Database(e.to_string()))?;
let mut points = Vec::new();
let mut values = Vec::new();
let mut entries = Vec::new();
let rows = stmt
.query_map([], |row| {
Ok((
row.get::<_, String>(0)?,
row.get::<_, String>(1)?,
row.get::<_, String>(2)?,
row.get::<_, Option<Vec<u8>>>(3)?,
))
})
.map_err(|e| roboticus_core::RoboticusError::Database(e.to_string()))?;
for row in rows {
let (source_table, source_id, content_preview, blob) =
row.map_err(|e| roboticus_core::RoboticusError::Database(e.to_string()))?;
let embedding = if let Some(b) = blob
&& !b.is_empty()
{
blob_to_embedding(&b)
} else {
continue;
};
if embedding.is_empty() {
continue;
}
let idx = entries.len();
points.push(EmbeddingPoint(embedding));
values.push(idx);
entries.push(IndexEntry {
source_table,
source_id,
content_preview,
});
}
let count = points.len();
if count < self.min_entries_for_index {
*write_or_poison_err(&self.inner)? = None;
return Ok(count);
}
let hnsw = Builder::default().build(points, values);
*write_or_poison_err(&self.inner)? = Some(IndexState { hnsw, entries });
Ok(count)
}
pub fn search(&self, query_embedding: &[f32], k: usize) -> Option<Vec<AnnSearchResult>> {
let guard = read_or_recover(&self.inner);
let state = guard.as_ref()?;
let query = EmbeddingPoint(query_embedding.to_vec());
let mut search = Search::default();
let results: Vec<AnnSearchResult> = state
.hnsw
.search(&query, &mut search)
.take(k)
.map(|item| {
let idx = *item.value;
let entry = &state.entries[idx];
let similarity = 1.0 - item.distance as f64;
AnnSearchResult {
source_table: entry.source_table.clone(),
source_id: entry.source_id.clone(),
content_preview: entry.content_preview.clone(),
similarity,
}
})
.collect();
Some(results)
}
pub fn is_built(&self) -> bool {
read_or_recover(&self.inner).is_some()
}
pub fn entry_count(&self) -> usize {
read_or_recover(&self.inner)
.as_ref()
.map(|s| s.entries.len())
.unwrap_or(0)
}
pub fn rebuild(&self, db: &Database) -> roboticus_core::Result<usize> {
self.build_from_db(db)
}
}
impl Clone for AnnIndex {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
enabled: self.enabled,
min_entries_for_index: self.min_entries_for_index,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::embeddings::store_embedding;
fn test_db() -> Database {
Database::new(":memory:").unwrap()
}
#[test]
fn disabled_index_returns_none() {
let index = AnnIndex::new(false);
let db = test_db();
let count = index.build_from_db(&db).unwrap();
assert_eq!(count, 0);
assert!(!index.is_built());
assert!(index.search(&[1.0, 0.0], 5).is_none());
}
#[test]
fn empty_db_no_index() {
let index = AnnIndex::new(true);
let db = test_db();
let count = index.build_from_db(&db).unwrap();
assert_eq!(count, 0);
assert!(!index.is_built());
}
#[test]
fn below_min_entries_no_index() {
let db = test_db();
for i in 0..10 {
store_embedding(
&db,
&format!("e{i}"),
"episodic_memory",
&format!("{i}"),
"preview",
&[1.0, 0.0],
)
.unwrap();
}
let index = AnnIndex::new(true);
let count = index.build_from_db(&db).unwrap();
assert_eq!(count, 10);
assert!(!index.is_built());
}
#[test]
fn builds_index_above_threshold() {
let db = test_db();
let mut index = AnnIndex::new(true);
index.min_entries_for_index = 5;
for i in 0..10 {
let emb = vec![i as f32 / 10.0, 1.0 - i as f32 / 10.0];
store_embedding(
&db,
&format!("e{i}"),
"episodic_memory",
&format!("t{i}"),
&format!("entry {i}"),
&emb,
)
.unwrap();
}
let count = index.build_from_db(&db).unwrap();
assert_eq!(count, 10);
assert!(index.is_built());
assert_eq!(index.entry_count(), 10);
}
#[test]
fn search_returns_nearest() {
let db = test_db();
let mut index = AnnIndex::new(true);
index.min_entries_for_index = 3;
store_embedding(&db, "e1", "episodic_memory", "t1", "near", &[1.0, 0.0, 0.0]).unwrap();
store_embedding(&db, "e2", "episodic_memory", "t2", "far", &[0.0, 1.0, 0.0]).unwrap();
store_embedding(
&db,
"e3",
"episodic_memory",
"t3",
"medium",
&[0.7, 0.3, 0.0],
)
.unwrap();
index.build_from_db(&db).unwrap();
assert!(index.is_built());
let results = index.search(&[1.0, 0.0, 0.0], 2).unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].content_preview, "near");
assert!(results[0].similarity > results[1].similarity);
}
#[test]
fn clone_shares_state() {
let index = AnnIndex::new(true);
let clone = index.clone();
assert_eq!(index.is_built(), clone.is_built());
}
#[test]
fn build_from_blob_embeddings() {
let db = test_db();
let mut index = AnnIndex::new(true);
index.min_entries_for_index = 3;
for i in 0..5 {
let emb = vec![i as f32 / 5.0, 1.0 - i as f32 / 5.0, 0.5];
store_embedding(
&db,
&format!("b{i}"),
"episodic_memory",
&format!("s{i}"),
&format!("entry {i}"),
&emb,
)
.unwrap();
}
let count = index.build_from_db(&db).unwrap();
assert_eq!(count, 5);
assert!(index.is_built());
let results = index.search(&[1.0, 0.0, 0.5], 2).unwrap();
assert!(!results.is_empty());
}
#[test]
fn build_skips_empty_embeddings() {
let db = test_db();
let mut index = AnnIndex::new(true);
index.min_entries_for_index = 3;
{
let conn = db.conn();
conn.execute(
"INSERT INTO embeddings (id, source_table, source_id, content_preview, dimensions) \
VALUES ('empty1', 'episodic_memory', 's1', 'empty', 0)",
[],
)
.unwrap();
conn.execute(
"INSERT INTO embeddings (id, source_table, source_id, content_preview, dimensions) \
VALUES ('empty2', 'episodic_memory', 's2', 'empty2', 0)",
[],
)
.unwrap();
}
store_embedding(
&db,
"valid1",
"episodic_memory",
"v1",
"ok1",
&[1.0, 0.0, 0.0],
)
.unwrap();
store_embedding(
&db,
"valid2",
"episodic_memory",
"v2",
"ok2",
&[0.0, 1.0, 0.0],
)
.unwrap();
store_embedding(
&db,
"valid3",
"episodic_memory",
"v3",
"ok3",
&[0.0, 0.0, 1.0],
)
.unwrap();
let count = index.build_from_db(&db).unwrap();
assert_eq!(count, 3);
assert!(index.is_built());
}
#[test]
fn build_skips_empty_blobs() {
let db = test_db();
let mut index = AnnIndex::new(true);
index.min_entries_for_index = 3;
{
let conn = db.conn();
for i in 0..2 {
conn.execute(
"INSERT INTO embeddings (id, source_table, source_id, content_preview, embedding_blob, dimensions) \
VALUES (?1, 'episodic_memory', ?2, ?3, X'', 3)",
rusqlite::params![format!("empty{i}"), format!("e{i}"), format!("empty {i}")],
)
.unwrap();
}
}
store_embedding(
&db,
"valid1",
"episodic_memory",
"v1",
"ok1",
&[1.0, 0.0, 0.0],
)
.unwrap();
store_embedding(
&db,
"valid2",
"episodic_memory",
"v2",
"ok2",
&[0.0, 1.0, 0.0],
)
.unwrap();
store_embedding(
&db,
"valid3",
"episodic_memory",
"v3",
"ok3",
&[0.0, 0.0, 1.0],
)
.unwrap();
let count = index.build_from_db(&db).unwrap();
assert_eq!(count, 3);
assert!(index.is_built());
}
}