use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use khive_runtime::{KhiveRuntime, Namespace, NamespaceToken, RuntimeError};
use khive_storage::types::{SqlStatement, SqlValue};
use khive_vamana::{CorpusFingerprint, VamanaConfig, VamanaIndex, VamanaSnapshot};
use tokio::sync::RwLock;
use uuid::Uuid;
pub(crate) struct AnnBridge {
index: VamanaIndex,
id_map: Vec<Uuid>,
}
pub(crate) struct AnnState {
pub(crate) index: RwLock<Option<AnnBridge>>,
warming: AtomicBool,
}
pub(crate) type SharedAnn = Arc<AnnState>;
pub(crate) fn new_shared() -> SharedAnn {
Arc::new(AnnState {
index: RwLock::new(None),
warming: AtomicBool::new(false),
})
}
impl AnnBridge {
pub fn build(mut vectors: Vec<f32>, dim: usize, id_map: Vec<Uuid>) -> Result<Self, String> {
if dim == 0 {
return Err("dimension must be > 0".into());
}
if vectors.is_empty() || id_map.is_empty() {
return Err("no vectors to build ANN index from".into());
}
let n = vectors.len() / dim;
if n != id_map.len() {
return Err(format!(
"id_map length {} != vector count {}",
id_map.len(),
n
));
}
for row in vectors.chunks_exact_mut(dim) {
l2_normalize(row);
}
let cfg = VamanaConfig::with_dimensions(dim);
let index = VamanaIndex::build(&vectors, cfg).map_err(|e| format!("{e}"))?;
Ok(Self { index, id_map })
}
pub fn search(&self, query: &[f32], k: usize) -> Vec<(Uuid, f32)> {
let mut q = query.to_vec();
l2_normalize(&mut q);
match self.index.search(&q, k) {
Ok(results) => results
.into_iter()
.filter_map(|(idx, dist)| {
self.id_map.get(idx as usize).map(|uuid| {
let cosine = 1.0 - dist / 2.0;
(*uuid, cosine.max(0.0))
})
})
.collect(),
Err(e) => {
tracing::warn!(error = %e, "vamana ANN search failed");
Vec::new()
}
}
}
pub fn num_vectors(&self) -> usize {
self.index.num_vectors()
}
pub fn to_vamana_snapshot(
&self,
namespace: &str,
model: &str,
fingerprint: CorpusFingerprint,
) -> Result<VamanaSnapshot, khive_vamana::VamanaError> {
let external_ids: Vec<String> = self.id_map.iter().map(|id| id.to_string()).collect();
self.index
.to_snapshot(namespace, model, fingerprint, external_ids)
}
pub fn from_vamana_snapshot(snapshot: VamanaSnapshot) -> Result<Self, String> {
let id_map: Vec<Uuid> = snapshot
.external_ids
.iter()
.map(|s| Uuid::parse_str(s).map_err(|e| format!("bad UUID {s}: {e}")))
.collect::<Result<_, _>>()?;
let index =
VamanaIndex::from_snapshot(&snapshot).map_err(|e| format!("snapshot restore: {e}"))?;
Ok(Self { index, id_map })
}
}
fn l2_normalize(v: &mut [f32]) {
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-8 {
for x in v.iter_mut() {
*x /= norm;
}
}
}
pub(crate) fn snapshot_key(namespace: &str, model: &str) -> String {
format!("{namespace}::vamana::{model}")
}
pub(crate) fn sanitize_model_key(s: &str) -> String {
s.chars()
.map(|c| if c.is_ascii_alphanumeric() { c } else { '_' })
.collect()
}
async fn ensure_snapshot_schema(rt: &KhiveRuntime) -> Result<(), RuntimeError> {
let sql = rt.sql();
let mut w = sql
.writer()
.await
.map_err(|e| RuntimeError::Internal(e.to_string()))?;
w.execute_script(
r#"
CREATE TABLE IF NOT EXISTS retrieval_snapshots (
namespace TEXT NOT NULL,
index_type TEXT NOT NULL,
snapshot BLOB NOT NULL,
created_at INTEGER NOT NULL,
PRIMARY KEY (namespace, index_type)
);
CREATE INDEX IF NOT EXISTS idx_retrieval_snapshots_namespace
ON retrieval_snapshots(namespace);
"#
.into(),
)
.await
.map_err(|e| RuntimeError::Internal(e.to_string()))
}
pub(crate) async fn persist_snapshot(
rt: &KhiveRuntime,
namespace: &str,
model: &str,
bridge: &AnnBridge,
fingerprint: CorpusFingerprint,
) -> Result<(), RuntimeError> {
if let Err(e) = ensure_snapshot_schema(rt).await {
tracing::warn!(error = %e, "failed to create retrieval_snapshots schema");
return Err(e);
}
let snapshot = bridge
.to_vamana_snapshot(namespace, model, fingerprint)
.map_err(|e| RuntimeError::Internal(format!("to_snapshot: {e}")))?;
let blob = serde_json::to_vec(&snapshot)
.map_err(|e| RuntimeError::Internal(format!("snapshot serialize: {e}")))?;
let key = snapshot_key(namespace, model);
let sql = rt.sql();
let mut w = sql
.writer()
.await
.map_err(|e| RuntimeError::Internal(e.to_string()))?;
w.execute(SqlStatement {
sql: "INSERT OR REPLACE INTO retrieval_snapshots \
(namespace, index_type, snapshot, created_at) VALUES (?1, ?2, ?3, ?4)"
.into(),
params: vec![
SqlValue::Text(key),
SqlValue::Text("vamana".into()),
SqlValue::Blob(blob),
SqlValue::Integer(
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_micros() as i64,
),
],
label: Some("persist_vamana_snapshot".into()),
})
.await
.map_err(|e| RuntimeError::Internal(e.to_string()))?;
Ok(())
}
async fn try_load_snapshot(
rt: &KhiveRuntime,
namespace: &str,
model: &str,
) -> Option<VamanaSnapshot> {
let key = snapshot_key(namespace, model);
let sql = rt.sql();
let mut reader = sql.reader().await.ok()?;
let rows = reader
.query_all(SqlStatement {
sql: "SELECT snapshot FROM retrieval_snapshots \
WHERE namespace = ?1 AND index_type = ?2"
.into(),
params: vec![SqlValue::Text(key), SqlValue::Text("vamana".into())],
label: None,
})
.await
.ok()?;
let row = rows.into_iter().next()?;
let blob = match row.get("snapshot")? {
SqlValue::Blob(b) => b.clone(),
_ => return None,
};
serde_json::from_slice::<VamanaSnapshot>(&blob).ok()
}
pub(crate) async fn compute_fingerprint(
rt: &KhiveRuntime,
token: &NamespaceToken,
model: &str,
) -> Option<CorpusFingerprint> {
let store = rt.vectors_for_model(token, model).ok()?;
let info = store.info().await.ok()?;
Some(CorpusFingerprint {
vector_count: info.entry_count,
dimensions: info.dimensions as u32,
})
}
pub(crate) async fn load_and_build_from_vector_store(
rt: &KhiveRuntime,
token: &NamespaceToken,
model: &str,
) -> Result<Option<AnnBridge>, RuntimeError> {
let store = match rt.vectors_for_model(token, model) {
Ok(s) => s,
Err(_) => return Ok(None),
};
let info = store
.info()
.await
.map_err(|e| RuntimeError::Internal(e.to_string()))?;
let count = info.entry_count;
let dims = info.dimensions;
if count == 0 || dims == 0 {
return Ok(None);
}
let ns = token.namespace().as_str().to_owned();
let model_key = sanitize_model_key(model);
let table_name = format!("vec_{model_key}");
let model_str = model.to_owned();
let sql = rt.sql();
let mut reader = sql
.reader()
.await
.map_err(|e| RuntimeError::Internal(e.to_string()))?;
let rows = reader
.query_all(SqlStatement {
sql: format!(
"SELECT subject_id, embedding FROM {table_name} \
WHERE namespace = ?1 AND embedding_model = ?2 \
AND field = 'knowledge.atom' \
ORDER BY subject_id"
),
params: vec![SqlValue::Text(ns), SqlValue::Text(model_str)],
label: Some("vamana_corpus_scan".into()),
})
.await
.map_err(|e| RuntimeError::Internal(e.to_string()))?;
if rows.is_empty() {
return Ok(None);
}
let mut id_map: Vec<Uuid> = Vec::with_capacity(rows.len());
let mut flat: Vec<f32> = Vec::with_capacity(rows.len() * dims);
for row in &rows {
let id_str = match row.get("subject_id") {
Some(SqlValue::Text(s)) => s.as_str(),
_ => continue,
};
let uuid = match Uuid::parse_str(id_str) {
Ok(id) => id,
Err(_) => continue,
};
let bytes = match row.get("embedding") {
Some(SqlValue::Blob(b)) => b.as_slice(),
_ => continue,
};
if bytes.len() != dims * 4 {
continue;
}
let vec: Vec<f32> = bytes
.chunks_exact(4)
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect();
id_map.push(uuid);
flat.extend_from_slice(&vec);
}
if id_map.is_empty() {
return Ok(None);
}
AnnBridge::build(flat, dims, id_map)
.map(Some)
.map_err(RuntimeError::Internal)
}
pub(crate) async fn invalidate_snapshot(rt: &KhiveRuntime, namespace: &str) {
let pattern = format!("{namespace}::vamana::%");
let sql = rt.sql();
let mut w = match sql.writer().await {
Ok(w) => w,
Err(e) => {
tracing::warn!(error = %e, "failed to open writer for Vamana snapshot invalidation");
return;
}
};
match w
.execute(SqlStatement {
sql: "DELETE FROM retrieval_snapshots WHERE namespace LIKE ?1".into(),
params: vec![SqlValue::Text(pattern)],
label: Some("invalidate_vamana_snapshot".into()),
})
.await
{
Ok(_) => {}
Err(e) if e.to_string().contains("no such table") => {}
Err(e) => {
tracing::warn!(error = %e, "failed to invalidate Vamana snapshot");
}
}
}
pub(crate) async fn warm_known_snapshots(rt: &KhiveRuntime, ann: &SharedAnn) {
if ann.index.read().await.is_some() {
return;
}
let sql = rt.sql();
let mut reader = match sql.reader().await {
Ok(r) => r,
Err(_) => return,
};
let rows = match reader
.query_all(SqlStatement {
sql: "SELECT DISTINCT namespace FROM retrieval_snapshots WHERE namespace LIKE ?1"
.into(),
params: vec![SqlValue::Text("%::vamana::%".into())],
label: None,
})
.await
{
Ok(r) => r,
Err(_) => return,
};
for row in &rows {
let ns_key = match row.get("namespace") {
Some(SqlValue::Text(s)) => s.as_str(),
_ => continue,
};
let ns_str = match ns_key.split("::vamana::").next() {
Some(s) if !s.is_empty() => s,
_ => continue,
};
let ns = match Namespace::parse(ns_str) {
Ok(n) => n,
Err(_) => continue,
};
let token = match rt.authorize(ns) {
Ok(t) => t,
Err(_) => continue,
};
ensure_ann(rt, &token, ann).await;
if ann.index.read().await.is_some() {
break;
}
}
}
pub(crate) fn ensure_ann_background(rt: &KhiveRuntime, token: &NamespaceToken, ann: &SharedAnn) {
if ann.warming.swap(true, Ordering::AcqRel) {
return; }
let rt = rt.clone();
let ann = ann.clone();
let ns = token.namespace().clone();
tokio::spawn(async move {
if let Ok(token) = rt.authorize(ns) {
ensure_ann(&rt, &token, &ann).await;
}
if ann.index.read().await.is_none() {
ann.warming.store(false, Ordering::Release); }
});
}
pub(crate) async fn ensure_ann(rt: &KhiveRuntime, token: &NamespaceToken, ann: &SharedAnn) {
if ann.index.read().await.is_some() {
return;
}
let model = rt.default_embedder_name().to_string();
if model.is_empty() {
return;
}
let ns = token.namespace().as_str().to_owned();
if let Some(snapshot) = try_load_snapshot(rt, &ns, &model).await {
let current_fp = compute_fingerprint(rt, token, &model).await;
if let Some(fp) = current_fp {
if snapshot.fingerprint == fp {
match AnnBridge::from_vamana_snapshot(snapshot) {
Ok(bridge) => {
let mut guard = ann.index.write().await;
if guard.is_none() {
*guard = Some(bridge);
}
return;
}
Err(e) => {
tracing::warn!(error = %e, "corrupt Vamana snapshot; rebuilding");
}
}
} else {
tracing::info!(
namespace = %ns,
model = %model,
"stale Vamana snapshot rejected (fingerprint mismatch); rebuilding"
);
}
}
}
match load_and_build_from_vector_store(rt, token, &model).await {
Ok(Some(bridge)) => {
let fp = compute_fingerprint(rt, token, &model).await;
if let Some(fingerprint) = fp {
if let Err(e) = persist_snapshot(rt, &ns, &model, &bridge, fingerprint).await {
tracing::error!(error = %e, "failed to persist Vamana snapshot after rebuild");
}
}
let mut guard = ann.index.write().await;
if guard.is_none() {
*guard = Some(bridge);
}
}
Ok(None) => {}
Err(e) => {
tracing::warn!(error = %e, "failed to rebuild Vamana ANN index");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use khive_runtime::KhiveRuntime;
use khive_storage::types::{SqlStatement, SqlValue};
#[tokio::test]
async fn test_invalidate_snapshot_removes_vamana_rows() {
let rt = KhiveRuntime::memory().expect("in-memory runtime");
let sql = rt.sql();
let mut w = sql.writer().await.expect("writer");
w.execute_script(
"CREATE TABLE IF NOT EXISTS retrieval_snapshots (\
namespace TEXT NOT NULL, index_type TEXT NOT NULL, \
snapshot BLOB NOT NULL, created_at INTEGER NOT NULL, \
PRIMARY KEY (namespace, index_type));"
.into(),
)
.await
.expect("create table");
for (ns, idx_type) in &[
("local::vamana::model-a", "vamana"),
("local::vamana::model-b", "vamana"),
("local::hnsw::model-a", "hnsw"),
] {
w.execute(SqlStatement {
sql: "INSERT INTO retrieval_snapshots (namespace, index_type, snapshot, created_at) VALUES (?1, ?2, ?3, 0)".into(),
params: vec![
SqlValue::Text(ns.to_string()),
SqlValue::Text(idx_type.to_string()),
SqlValue::Blob(b"{}".to_vec()),
],
label: None,
})
.await
.expect("insert");
}
drop(w);
invalidate_snapshot(&rt, "local").await;
let mut r = sql.reader().await.expect("reader");
let rows = r
.query_all(SqlStatement {
sql: "SELECT namespace FROM retrieval_snapshots ORDER BY namespace".into(),
params: vec![],
label: None,
})
.await
.expect("query");
let remaining: Vec<String> = rows
.iter()
.filter_map(|row| match row.get("namespace") {
Some(SqlValue::Text(s)) => Some(s.clone()),
_ => None,
})
.collect();
assert!(
remaining.contains(&"local::hnsw::model-a".to_string()),
"HNSW rows must survive: {remaining:?}"
);
assert!(
!remaining.contains(&"local::vamana::model-a".to_string()),
"vamana model-a must be deleted: {remaining:?}"
);
assert!(
!remaining.contains(&"local::vamana::model-b".to_string()),
"vamana model-b must be deleted: {remaining:?}"
);
}
#[tokio::test]
async fn test_invalidate_snapshot_tolerates_missing_table() {
let rt = KhiveRuntime::memory().expect("in-memory runtime");
invalidate_snapshot(&rt, "local").await;
}
#[tokio::test]
async fn test_invalidate_clears_in_memory_ann() {
let ann = new_shared();
let dim = 4;
let vectors = vec![1.0f32, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0];
let ids = vec![Uuid::new_v4(), Uuid::new_v4()];
let bridge = AnnBridge::build(vectors, dim, ids).expect("build");
*ann.index.write().await = Some(bridge);
assert!(
ann.index.read().await.is_some(),
"pre-condition: ANN loaded"
);
*ann.index.write().await = None;
assert!(
ann.index.read().await.is_none(),
"clearing SharedAnn must remove the bridge"
);
}
}