use crate::{Error, Result};
use petgraph::graph::UnGraph;
use vicinity::hnsw::{HNSWIndex, HNSWParams};
#[derive(Debug, Clone)]
pub struct KnnGraphConfig {
pub k: usize,
pub hnsw_m: usize,
pub hnsw_ef_construction: usize,
pub symmetric: bool,
pub weight_fn: WeightFunction,
}
impl Default for KnnGraphConfig {
fn default() -> Self {
Self {
k: 10,
hnsw_m: 16,
hnsw_ef_construction: 100,
symmetric: true,
weight_fn: WeightFunction::Similarity,
}
}
}
#[derive(Debug, Clone, Copy)]
pub enum WeightFunction {
Similarity,
InverseDistance,
GaussianKernel,
Uniform,
}
pub fn knn_graph_from_embeddings(embeddings: &[Vec<f32>], k: usize) -> Result<UnGraph<(), f32>> {
knn_graph_with_config(
embeddings,
&KnnGraphConfig {
k,
..Default::default()
},
)
}
pub fn knn_graph_with_config(
embeddings: &[Vec<f32>],
config: &KnnGraphConfig,
) -> Result<UnGraph<(), f32>> {
if embeddings.is_empty() {
return Err(Error::EmptyInput);
}
let dim = embeddings[0].len();
if let Some((_, e)) = embeddings.iter().enumerate().find(|(_, e)| e.len() != dim) {
return Err(Error::DimensionMismatch {
expected: dim,
found: e.len(),
});
}
let n = embeddings.len();
let k = config.k.min(n - 1);
let params = HNSWParams {
m: config.hnsw_m,
ef_construction: config.hnsw_ef_construction,
..Default::default()
};
let mut hnsw = HNSWIndex::with_params(dim, params).map_err(|e| Error::Other(e.to_string()))?;
for (i, embedding) in embeddings.iter().enumerate() {
hnsw.add(i as u32, embedding.clone())
.map_err(|e| Error::Other(e.to_string()))?;
}
hnsw.build().map_err(|e| Error::Other(e.to_string()))?;
let mut graph = UnGraph::<(), f32>::new_undirected();
let nodes: Vec<_> = (0..n).map(|_| graph.add_node(())).collect();
let ef_search = (config.k * 2).max(config.hnsw_ef_construction / 2);
let sigma = if matches!(config.weight_fn, WeightFunction::GaussianKernel) {
let sample_size = (n / 10).clamp(10, 100);
let mut distances = Vec::with_capacity(sample_size * k);
for i in (0..n).step_by((n / sample_size).max(1)).take(sample_size) {
if let Ok(neighbors) = hnsw.search(&embeddings[i], k + 1, ef_search) {
for (_, dist) in neighbors.iter().skip(1) {
distances.push(*dist);
}
}
}
distances.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
distances.get(distances.len() / 2).copied().unwrap_or(1.0)
} else {
1.0
};
for (i, embedding) in embeddings.iter().enumerate() {
let neighbors = hnsw
.search(embedding, k + 1, ef_search)
.map_err(|e| Error::Other(e.to_string()))?;
for (neighbor_idx, distance) in neighbors {
let neighbor_idx = neighbor_idx as usize;
if neighbor_idx == i {
continue; }
let weight = match config.weight_fn {
WeightFunction::Similarity => (1.0 - distance).max(0.001),
WeightFunction::InverseDistance => 1.0 / (1.0 + distance),
WeightFunction::GaussianKernel => (-distance * distance / (sigma * sigma)).exp(),
WeightFunction::Uniform => 1.0,
};
if config.symmetric || i < neighbor_idx {
let _ = graph.add_edge(nodes[i], nodes[neighbor_idx], weight);
}
}
}
Ok(graph)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_knn_graph_basic() {
let embeddings: Vec<Vec<f32>> = vec![
vec![1.0, 0.0, 0.0],
vec![0.9, 0.1, 0.0],
vec![0.95, 0.05, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.1, 0.9, 0.0],
vec![0.05, 0.95, 0.0],
];
let graph = knn_graph_from_embeddings(&embeddings, 2).unwrap();
assert_eq!(graph.node_count(), 6);
assert!(graph.edge_count() > 0);
}
#[test]
fn test_knn_graph_empty() {
let embeddings: Vec<Vec<f32>> = vec![];
let result = knn_graph_from_embeddings(&embeddings, 5);
assert!(result.is_err());
}
#[test]
fn test_knn_graph_single_point() {
let embeddings = vec![vec![1.0, 0.0]];
let graph = knn_graph_from_embeddings(&embeddings, 5).unwrap();
assert_eq!(graph.node_count(), 1);
assert_eq!(graph.edge_count(), 0);
}
#[test]
fn test_knn_graph_weight_functions() {
let embeddings = vec![vec![1.0, 0.0], vec![0.9, 0.1], vec![0.0, 1.0]];
for weight_fn in [
WeightFunction::Similarity,
WeightFunction::InverseDistance,
WeightFunction::GaussianKernel,
WeightFunction::Uniform,
] {
let config = KnnGraphConfig {
k: 2,
weight_fn,
..Default::default()
};
let graph = knn_graph_with_config(&embeddings, &config).unwrap();
for edge in graph.edge_references() {
assert!(
*edge.weight() > 0.0,
"Weight should be positive for {:?}",
weight_fn
);
}
}
}
}