use std::collections::BTreeMap;
use ckg_core::Result;
use cozo::{DataValue, ScriptMutability};
use super::lifecycle::run_idempotent;
use super::map_err;
use super::Storage;
pub const STORAGE_EMBED_DIM: usize = 768;
pub const HNSW_K_MAX: usize = 10_000;
impl Storage {
pub fn put_embeddings(&self, items: Vec<(String, Vec<f32>)>) -> Result<()> {
if items.is_empty() {
return Ok(());
}
for (i, (id, v)) in items.iter().enumerate() {
if v.len() != STORAGE_EMBED_DIM {
return Err(map_err(format!(
"put_embeddings row {i} (id={id}) has dim {} (expected {STORAGE_EMBED_DIM})",
v.len()
)));
}
}
const SCRIPT: &str = "?[id, vec] <- $rows :put Embedding {id => vec}";
for chunk in items.chunks(500) {
let rows: Vec<DataValue> = chunk
.iter()
.map(|(id, v)| {
DataValue::List(vec![
DataValue::from(id.as_str()),
DataValue::Vec(cozo::Vector::F32(ndarray::Array1::from_vec(v.clone()))),
])
})
.collect();
let mut params = BTreeMap::new();
params.insert("rows".into(), DataValue::List(rows));
self.db
.run_script(SCRIPT, params, ScriptMutability::Mutable)
.map_err(map_err)?;
}
run_idempotent(
&self.db,
"::hnsw create Embedding:embed_idx { \
fields: [vec], dim: 768, m: 16, ef_construction: 200, distance: Cosine \
}",
)?;
Ok(())
}
pub fn hnsw_search(&self, query: &[f32], k: usize) -> Result<Vec<(String, f32)>> {
if query.is_empty() {
return Ok(Vec::new());
}
if query.len() != STORAGE_EMBED_DIM {
return Err(map_err(format!(
"hnsw_search query has dim {} (expected {STORAGE_EMBED_DIM})",
query.len()
)));
}
let k = k.min(HNSW_K_MAX);
let q = DataValue::Vec(cozo::Vector::F32(ndarray::Array1::from_vec(query.to_vec())));
let mut params = BTreeMap::new();
params.insert("q".into(), q);
let script = format!(
"?[id, dist] := ~Embedding:embed_idx{{ id, dist | query: $q, k: {k}, ef: 200 }}\n\
:order dist\n\
:limit {k}"
);
let rows = self
.db
.run_script(&script, params, ScriptMutability::Immutable)
.map_err(map_err)?;
let mut out = Vec::with_capacity(rows.rows.len());
for r in rows.rows {
let id = match r.first() {
Some(DataValue::Str(s)) => s.to_string(),
_ => continue,
};
let dist = match r.get(1) {
Some(DataValue::Num(cozo::Num::Float(f))) => *f as f32,
Some(DataValue::Num(cozo::Num::Int(i))) => *i as f32,
_ => 0.0,
};
let sim = (1.0 - dist / 2.0).clamp(0.0, 1.0);
out.push((id, sim));
}
Ok(out)
}
pub fn iter_embeddings_capped(&self, max_rows: usize) -> Result<Vec<(String, Vec<f32>)>> {
if max_rows == 0 {
return Err(map_err(
"iter_embeddings_capped: max_rows must be > 0 (use hnsw_search for k-NN queries)",
));
}
let limit_i64 = max_rows.min(i64::MAX as usize) as i64;
let mut params = BTreeMap::new();
params.insert("limit".into(), DataValue::from(limit_i64));
let rows = self
.db
.run_script(
"?[id, vec] := *Embedding{id, vec} :limit $limit",
params,
ScriptMutability::Immutable,
)
.map_err(map_err)?;
let mut out = Vec::with_capacity(rows.rows.len());
for r in rows.rows {
let mut it = r.into_iter();
let id = match it.next() {
Some(DataValue::Str(s)) => s.to_string(),
_ => continue,
};
let vec = match it.next() {
Some(DataValue::Vec(cozo::Vector::F32(arr))) => arr.to_vec(),
Some(DataValue::Vec(cozo::Vector::F64(arr))) => {
arr.iter().map(|f| *f as f32).collect()
}
_ => continue,
};
out.push((id, vec));
}
Ok(out)
}
}