use crate::hnsw::distance_metric::{DistanceMetric, compute_distance};
use crate::hnsw::errors::{HnswError, HnswIndexError};
use crate::hnsw::layer::HnswLayer;
use std::collections::HashSet;
#[derive(Debug, Clone, PartialEq)]
struct SearchCandidate {
node_id: u64,
distance: f32,
level: u8,
}
impl SearchCandidate {
fn new(node_id: u64, distance: f32, level: u8) -> Self {
Self {
node_id,
distance,
level,
}
}
}
impl PartialOrd for SearchCandidate {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
match self.distance.partial_cmp(&other.distance) {
Some(std::cmp::Ordering::Equal) => Some(self.node_id.cmp(&other.node_id)),
Some(ord) => Some(ord), None => None,
}
}
}
#[derive(Debug, Clone)]
pub struct SearchResult {
neighbors: Vec<u64>,
distances: Vec<f32>,
candidates_examined: usize,
search_metrics: SearchMetrics,
}
impl SearchResult {
fn new(
neighbors: Vec<u64>,
distances: Vec<f32>,
candidates_examined: usize,
search_metrics: SearchMetrics,
) -> Self {
Self {
neighbors,
distances,
candidates_examined,
search_metrics,
}
}
pub fn neighbors(&self) -> &[u64] {
&self.neighbors
}
pub fn distances(&self) -> &[f32] {
&self.distances
}
pub fn len(&self) -> usize {
self.neighbors.len()
}
pub fn is_empty(&self) -> bool {
self.neighbors.is_empty()
}
pub fn candidates_examined(&self) -> usize {
self.candidates_examined
}
pub fn metrics(&self) -> &SearchMetrics {
&self.search_metrics
}
}
#[derive(Debug, Clone)]
pub struct SearchMetrics {
layers_visited: u8,
entry_points_considered: usize,
average_degree: f32,
search_depth: usize,
}
impl SearchMetrics {
fn new() -> Self {
Self {
layers_visited: 0,
entry_points_considered: 0,
average_degree: 0.0,
search_depth: 0,
}
}
pub fn layers_visited(&self) -> u8 {
self.layers_visited
}
pub fn entry_points_considered(&self) -> usize {
self.entry_points_considered
}
pub fn average_degree(&self) -> f32 {
self.average_degree
}
pub fn search_depth(&self) -> usize {
self.search_depth
}
}
pub struct NeighborhoodSearch {
distance_metric: DistanceMetric,
}
impl NeighborhoodSearch {
pub fn new(distance_metric: DistanceMetric) -> Self {
Self { distance_metric }
}
pub fn search_layer(
&self,
layer: &HnswLayer,
query_vector: &[f32],
vectors: &[Vec<f32>],
entry_points: &[u64],
k: usize,
) -> Result<SearchResult, HnswError> {
if query_vector.is_empty() {
return Err(HnswError::Index(HnswIndexError::InvalidSearchParameters));
}
if k == 0 {
return Ok(SearchResult::new(vec![], vec![], 0, SearchMetrics::new()));
}
if entry_points.is_empty() {
return Err(HnswError::Index(HnswIndexError::InvalidSearchParameters));
}
if layer.node_count() == 0 {
return Err(HnswError::Index(HnswIndexError::IndexNotInitialized));
}
let mut metrics = SearchMetrics::new();
metrics.entry_points_considered = entry_points.len();
let mut candidates = Vec::new();
let mut visited = HashSet::new();
for &entry_point in entry_points {
if layer.contains_node(entry_point) {
let distance =
self.compute_distance(query_vector, &vectors[entry_point as usize])?;
candidates.push(SearchCandidate::new(entry_point, distance, layer.level()));
visited.insert(entry_point);
}
}
let mut candidates_examined = 0;
let mut result_candidates = Vec::new();
while !candidates.is_empty() && result_candidates.len() < k + layer.max_connections() {
candidates.sort_by(|a, b| {
a.distance
.partial_cmp(&b.distance)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.node_id.cmp(&b.node_id))
});
let candidate = candidates.remove(0);
candidates_examined += 1;
result_candidates.push(candidate.clone());
if let Ok(connections) = layer.get_connections(candidate.node_id) {
for &neighbor_id in connections {
if !visited.contains(&neighbor_id) {
visited.insert(neighbor_id);
let distance =
self.compute_distance(query_vector, &vectors[neighbor_id as usize])?;
candidates.push(SearchCandidate::new(neighbor_id, distance, layer.level()));
}
}
}
}
result_candidates.sort_by(|a, b| {
a.distance
.partial_cmp(&b.distance)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.node_id.cmp(&b.node_id))
});
result_candidates.truncate(k);
let neighbors: Vec<u64> = result_candidates.iter().map(|c| c.node_id).collect();
let distances: Vec<f32> = result_candidates.iter().map(|c| c.distance).collect();
metrics.search_depth = visited.len();
Ok(SearchResult::new(
neighbors,
distances,
candidates_examined,
metrics,
))
}
fn compute_distance(
&self,
query_vector: &[f32],
target_vector: &[f32],
) -> Result<f32, HnswError> {
if query_vector.len() != target_vector.len() {
return Err(HnswError::Index(HnswIndexError::VectorDimensionMismatch {
expected: query_vector.len(),
actual: target_vector.len(),
}));
}
Ok(compute_distance(
self.distance_metric,
query_vector,
target_vector,
))
}
}
impl Default for NeighborhoodSearch {
fn default() -> Self {
Self::new(DistanceMetric::Cosine)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::hnsw::distance_metric::DistanceMetric;
fn create_test_vectors() -> Vec<Vec<f32>> {
vec![
vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0], vec![0.0, 0.0, 1.0], vec![1.0, 1.0, 0.0], vec![0.5, 0.5, 0.5], ]
}
fn create_test_layer() -> HnswLayer {
let mut layer = HnswLayer::new(0, 4);
for i in 0..5 {
layer.add_node(i).unwrap();
}
layer.add_connection(0, 1).unwrap();
layer.add_connection(1, 2).unwrap();
layer.add_connection(2, 3).unwrap();
layer.add_connection(3, 4).unwrap();
layer.add_connection(0, 3).unwrap();
layer
}
#[test]
fn test_search_candidate_creation() {
let candidate = SearchCandidate::new(42, 0.75, 2);
assert_eq!(candidate.node_id, 42);
assert_eq!(candidate.distance, 0.75);
assert_eq!(candidate.level, 2);
}
#[test]
fn test_search_candidate_ordering() {
let c1 = SearchCandidate::new(1, 0.9, 0);
let c2 = SearchCandidate::new(2, 0.5, 0);
let c3 = SearchCandidate::new(3, 0.7, 0);
println!("c1 distance: {}, c2 distance: {}", c1.distance, c2.distance);
println!("c1 < c2: {}", c1 < c2);
println!("c2 < c1: {}", c2 < c1);
assert!(c2 < c1); assert!(c2 < c3); assert!(c3 < c1); }
#[test]
fn test_neighborhood_search_creation() {
let search = NeighborhoodSearch::new(DistanceMetric::Euclidean);
assert_eq!(search.distance_metric, DistanceMetric::Euclidean);
}
#[test]
fn test_neighborhood_search_default() {
let search = NeighborhoodSearch::default();
assert_eq!(search.distance_metric, DistanceMetric::Cosine);
}
#[test]
fn test_search_layer_basic() {
let search = NeighborhoodSearch::new(DistanceMetric::Cosine);
let vectors = create_test_vectors();
let layer = create_test_layer();
let query_vector = vec![1.0, 0.0, 0.0];
let result = search
.search_layer(&layer, &query_vector, &vectors, &[0], 3)
.unwrap();
assert_eq!(result.len(), 3);
assert_eq!(result.neighbors().len(), 3);
assert_eq!(result.distances().len(), 3);
assert_eq!(result.neighbors()[0], 0);
assert_eq!(result.distances()[0], 0.0);
}
#[test]
fn test_search_layer_k_zero() {
let search = NeighborhoodSearch::new(DistanceMetric::Cosine);
let vectors = create_test_vectors();
let layer = create_test_layer();
let query_vector = vec![1.0, 0.0, 0.0];
let result = search
.search_layer(&layer, &query_vector, &vectors, &[0], 0)
.unwrap();
assert!(result.is_empty());
assert_eq!(result.len(), 0);
assert_eq!(result.candidates_examined(), 0);
}
#[test]
fn test_search_layer_no_entry_points() {
let search = NeighborhoodSearch::new(DistanceMetric::Cosine);
let vectors = create_test_vectors();
let layer = create_test_layer();
let query_vector = vec![1.0, 0.0, 0.0];
let result = search.search_layer(&layer, &query_vector, &vectors, &[], 3);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
HnswError::Index(HnswIndexError::InvalidSearchParameters)
));
}
#[test]
fn test_search_layer_empty_layer() {
let search = NeighborhoodSearch::new(DistanceMetric::Cosine);
let vectors = vec![];
let layer = HnswLayer::new(0, 4); let query_vector = vec![1.0, 0.0, 0.0];
let result = search.search_layer(&layer, &query_vector, &vectors, &[0], 3);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
HnswError::Index(HnswIndexError::IndexNotInitialized)
));
}
#[test]
fn test_search_layer_empty_query_vector() {
let search = NeighborhoodSearch::new(DistanceMetric::Cosine);
let vectors = create_test_vectors();
let layer = create_test_layer();
let result = search.search_layer(&layer, &[], &vectors, &[0], 3);
assert!(
result.is_err(),
"search_layer should reject empty query_vector, got {:?}",
result
);
assert!(matches!(
result.unwrap_err(),
HnswError::Index(HnswIndexError::InvalidSearchParameters)
));
}
#[test]
fn test_search_result_accessors() {
let result =
SearchResult::new(vec![1, 2, 3], vec![0.1, 0.2, 0.3], 10, SearchMetrics::new());
assert_eq!(result.len(), 3);
assert!(!result.is_empty());
assert_eq!(result.neighbors(), &[1, 2, 3]);
assert_eq!(result.distances(), &[0.1, 0.2, 0.3]);
assert_eq!(result.candidates_examined(), 10);
assert_eq!(result.metrics().layers_visited(), 0);
}
#[test]
fn test_search_result_empty() {
let result = SearchResult::new(vec![], vec![], 0, SearchMetrics::new());
assert_eq!(result.len(), 0);
assert!(result.is_empty());
assert!(result.neighbors().is_empty());
assert!(result.distances().is_empty());
}
#[test]
fn test_search_metrics() {
let metrics = SearchMetrics::new();
assert_eq!(metrics.layers_visited(), 0);
assert_eq!(metrics.entry_points_considered(), 0);
assert_eq!(metrics.average_degree(), 0.0);
assert_eq!(metrics.search_depth(), 0);
}
#[test]
fn test_compute_distance_success() {
let search = NeighborhoodSearch::new(DistanceMetric::Euclidean);
let query = vec![1.0, 0.0];
let target = vec![0.0, 1.0];
let distance = search.compute_distance(&query, &target).unwrap();
assert!((distance - std::f32::consts::SQRT_2).abs() < f32::EPSILON);
}
#[test]
fn test_compute_distance_dimension_mismatch() {
let search = NeighborhoodSearch::new(DistanceMetric::Euclidean);
let query = vec![1.0, 0.0, 0.0]; let target = vec![1.0, 0.0];
let result = search.compute_distance(&query, &target);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
HnswError::Index(HnswIndexError::VectorDimensionMismatch {
expected: 3,
actual: 2
})
));
}
#[test]
fn test_different_distance_metrics() {
let vectors = create_test_vectors();
let layer = create_test_layer();
let query_vector = vec![0.3, 0.7, 0.2];
let euclidean_search = NeighborhoodSearch::new(DistanceMetric::Euclidean);
let euclidean_result = euclidean_search
.search_layer(&layer, &query_vector, &vectors, &[0], 1)
.unwrap();
let manhattan_search = NeighborhoodSearch::new(DistanceMetric::Manhattan);
let manhattan_result = manhattan_search
.search_layer(&layer, &query_vector, &vectors, &[0], 1)
.unwrap();
assert_eq!(euclidean_result.len(), 1);
assert_eq!(manhattan_result.len(), 1);
assert_eq!(
euclidean_result.neighbors()[0],
manhattan_result.neighbors()[0]
);
assert_ne!(
euclidean_result.distances()[0],
manhattan_result.distances()[0]
);
}
}