use mnem_core::codec::to_canonical_bytes;
use mnem_core::error::Error;
use mnem_core::id::{CODEC_RAW, Cid, Multihash, NodeId};
use serde::{Deserialize, Serialize};
#[cfg(feature = "hnsw")]
use crate::hnsw::HnswVectorIndex;
#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)]
#[repr(u8)]
pub enum DistanceMetric {
Cosine = 1,
L2 = 2,
Dot = 3,
}
impl DistanceMetric {
#[must_use]
pub const fn tag(self) -> u8 {
self as u8
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct KnnEdge {
pub src: NodeId,
pub dst: NodeId,
pub weight: f32,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct KnnEdgeIndex {
pub root_cid: Cid,
pub k: u32,
pub metric: DistanceMetric,
pub edges: Vec<KnnEdge>,
}
impl KnnEdgeIndex {
pub fn compute_cid(&self) -> Result<Cid, Error> {
let body = to_canonical_bytes(self)?;
let mut buf: Vec<u8> = Vec::with_capacity(body.len() + 64);
buf.extend_from_slice(b"mnem/knn-edge/v1");
buf.extend_from_slice(&self.root_cid.to_bytes());
buf.extend_from_slice(&self.k.to_be_bytes());
buf.push(self.metric.tag());
buf.extend_from_slice(&body);
let hash = Multihash::sha2_256(&buf);
Ok(Cid::new(CODEC_RAW, hash))
}
#[must_use]
pub fn empty(root_cid: Cid, k: u32, metric: DistanceMetric) -> Self {
Self {
root_cid,
k,
metric,
edges: Vec::new(),
}
}
#[must_use]
pub fn len(&self) -> usize {
self.edges.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.edges.is_empty()
}
}
#[must_use]
pub fn derive_knn_edges_from_vectors(
ids: &[NodeId],
vecs: &[Vec<f32>],
k: u32,
metric: DistanceMetric,
) -> Vec<KnnEdge> {
assert_eq!(ids.len(), vecs.len(), "ids/vecs length mismatch");
let n = ids.len();
if n == 0 || k == 0 {
return Vec::new();
}
let k_usize = (k as usize).min(n.saturating_sub(1));
if k_usize == 0 {
return Vec::new();
}
let mut edges: Vec<KnnEdge> = Vec::with_capacity(n * k_usize);
let mut scored: Vec<(f32, NodeId)> = Vec::with_capacity(n);
for i in 0..n {
scored.clear();
for j in 0..n {
if i == j {
continue;
}
let sim = similarity(&vecs[i], &vecs[j], metric);
scored.push((sim, ids[j]));
}
scored.sort_by(|a, b| {
b.0.partial_cmp(&a.0)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.1.cmp(&b.1))
});
for (sim, dst) in scored.iter().take(k_usize) {
edges.push(KnnEdge {
src: ids[i],
dst: *dst,
weight: *sim,
});
}
}
edges.sort_by(|a, b| a.src.cmp(&b.src).then_with(|| a.dst.cmp(&b.dst)));
edges
}
fn similarity(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
debug_assert_eq!(a.len(), b.len());
match metric {
DistanceMetric::Cosine | DistanceMetric::Dot => {
let mut s = 0.0_f32;
for (x, y) in a.iter().zip(b.iter()) {
s += x * y;
}
s
}
DistanceMetric::L2 => {
let mut acc = 0.0_f32;
for (x, y) in a.iter().zip(b.iter()) {
let d = x - y;
acc += d * d;
}
1.0 / (1.0 + acc.sqrt())
}
}
}
#[cfg(feature = "hnsw")]
#[must_use]
pub fn derive_knn_edges(hnsw: &HnswVectorIndex, k: u32, root_cid: Cid) -> KnnEdgeIndex {
let mut ids: Vec<NodeId> = Vec::with_capacity(hnsw.points_len());
let mut vecs: Vec<Vec<f32>> = Vec::with_capacity(hnsw.points_len());
for (id, v) in hnsw.points_iter() {
ids.push(id);
vecs.push(v.to_vec());
}
let edges = derive_knn_edges_from_vectors(&ids, &vecs, k, DistanceMetric::Cosine);
KnnEdgeIndex {
root_cid,
k,
metric: DistanceMetric::Cosine,
edges,
}
}
#[cfg(test)]
mod tests {
use super::*;
fn demo_cid() -> Cid {
Cid::new(
mnem_core::id::CODEC_DAG_CBOR,
Multihash::sha2_256(b"demo-hnsw-root"),
)
}
#[test]
fn empty_input_yields_empty_index() {
let edges = derive_knn_edges_from_vectors(&[], &[], 5, DistanceMetric::Cosine);
assert!(edges.is_empty());
}
#[test]
fn compute_cid_is_stable_across_two_calls() {
let idx = KnnEdgeIndex::empty(demo_cid(), 3, DistanceMetric::Cosine);
let c1 = idx.compute_cid().unwrap();
let c2 = idx.compute_cid().unwrap();
assert_eq!(c1, c2);
}
#[test]
fn distance_metric_tag_stable() {
assert_eq!(DistanceMetric::Cosine.tag(), 1);
assert_eq!(DistanceMetric::L2.tag(), 2);
assert_eq!(DistanceMetric::Dot.tag(), 3);
}
}