jin 0.1.0

Approximate Nearest Neighbor Search: HNSW, DiskANN, IVF-PQ, ScaNN, quantization
Documentation
//! HNSW search algorithm with early termination optimizations.

use std::collections::{BinaryHeap, HashSet};

#[cfg(feature = "hnsw")]
mod distance_impl {
    use crate::hnsw::distance;
    pub use distance::cosine_distance;
}

#[cfg(feature = "hnsw")]
use distance_impl::cosine_distance;

/// Candidate node during search.
#[derive(Clone, PartialEq)]
pub(crate) struct Candidate {
    pub(crate) id: u32,
    pub(crate) distance: f32,
}

impl Eq for Candidate {}

impl Ord for Candidate {
    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
        // Min-heap: smaller distance = higher priority
        // Use total_cmp for IEEE 754 total ordering (NaN-safe)
        self.distance.total_cmp(&other.distance).reverse()
    }
}

impl PartialOrd for Candidate {
    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
        Some(self.cmp(other))
    }
}

/// Search state for HNSW search algorithm.
pub(crate) struct SearchState {
    /// Candidate queue (min-heap by distance)
    candidates: BinaryHeap<Candidate>,

    /// Visited nodes (to avoid revisiting)
    visited: HashSet<u32>,

    /// Best distance found so far
    best_distance: f32,

    /// Number of iterations without improvement (for early termination)
    no_improvement_count: usize,
}

impl SearchState {
    #[allow(dead_code)]
    fn new() -> Self {
        Self {
            candidates: BinaryHeap::new(),
            visited: HashSet::new(),
            best_distance: f32::INFINITY,
            no_improvement_count: 0,
        }
    }

    /// Create with pre-allocated capacity for better performance.
    pub(crate) fn with_capacity(ef: usize) -> Self {
        Self {
            candidates: BinaryHeap::with_capacity(ef * 2), // Pre-allocate for ef candidates
            visited: HashSet::with_capacity(ef * 2),       // Pre-allocate visited set
            best_distance: f32::INFINITY,
            no_improvement_count: 0,
        }
    }

    pub(crate) fn add_candidate(&mut self, id: u32, distance: f32) {
        if !self.visited.contains(&id) {
            self.candidates.push(Candidate { id, distance });
        }
    }

    pub(crate) fn pop_candidate(&mut self) -> Option<Candidate> {
        while let Some(candidate) = self.candidates.pop() {
            if !self.visited.contains(&candidate.id) {
                self.visited.insert(candidate.id);
                if candidate.distance < self.best_distance {
                    self.best_distance = candidate.distance;
                    self.no_improvement_count = 0;
                } else {
                    self.no_improvement_count += 1;
                }
                return Some(candidate);
            }
        }
        None
    }
}

/// Greedy search in a single layer using standard HNSW beam search.
///
/// Implements the correct HNSW search from Malkov & Yashunin (2016):
/// - Uses min-heap for candidates (explore closest first)
/// - Uses max-heap for results (track worst result for pruning)
/// - Continues until best unexplored candidate is worse than worst result
///
/// This is critical for achieving high recall (~98% on standard benchmarks).
#[cfg(feature = "hnsw")]
pub fn greedy_search_layer(
    query: &[f32],
    entry_point: u32,
    layer: &crate::hnsw::graph::Layer,
    vectors: &[f32],
    dimension: usize,
    ef: usize,
) -> Vec<(u32, f32)> {
    use std::collections::BinaryHeap;

    // Candidate for min-heap (explore closest first)
    #[derive(PartialEq)]
    struct MinCandidate {
        id: u32,
        distance: f32,
    }
    impl Eq for MinCandidate {}
    impl Ord for MinCandidate {
        fn cmp(&self, other: &Self) -> std::cmp::Ordering {
            // Min-heap: smaller distance = higher priority
            other.distance.total_cmp(&self.distance)
        }
    }
    impl PartialOrd for MinCandidate {
        fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
            Some(self.cmp(other))
        }
    }

    // Result for max-heap (track worst result for pruning)
    #[derive(PartialEq)]
    struct MaxResult {
        id: u32,
        distance: f32,
    }
    impl Eq for MaxResult {}
    impl Ord for MaxResult {
        fn cmp(&self, other: &Self) -> std::cmp::Ordering {
            // Max-heap: larger distance = higher priority
            self.distance.total_cmp(&other.distance)
        }
    }
    impl PartialOrd for MaxResult {
        fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
            Some(self.cmp(other))
        }
    }

    let mut candidates: BinaryHeap<MinCandidate> = BinaryHeap::with_capacity(ef * 2);
    let mut results: BinaryHeap<MaxResult> = BinaryHeap::with_capacity(ef + 1);
    let mut visited = std::collections::HashSet::with_capacity(ef * 2);

    // Start from entry point
    let entry_vector = get_vector(vectors, dimension, entry_point as usize);
    let entry_distance = cosine_distance(query, entry_vector);
    candidates.push(MinCandidate {
        id: entry_point,
        distance: entry_distance,
    });
    results.push(MaxResult {
        id: entry_point,
        distance: entry_distance,
    });
    visited.insert(entry_point);

    // Standard HNSW beam search:
    // Continue while we have candidates that might improve results
    while let Some(candidate) = candidates.pop() {
        // Stopping condition: if best candidate is worse than worst result
        // and we have enough results, we're done
        let worst_dist = results.peek().map(|r| r.distance).unwrap_or(f32::INFINITY);
        if candidate.distance > worst_dist && results.len() >= ef {
            break;
        }

        // Explore neighbors
        let neighbors = layer.get_neighbors(candidate.id);
        for &neighbor_id in neighbors.iter() {
            if visited.insert(neighbor_id) {
                let neighbor_vector = get_vector(vectors, dimension, neighbor_id as usize);
                let neighbor_distance = cosine_distance(query, neighbor_vector);

                // Only add if potentially useful
                let worst_dist = results.peek().map(|r| r.distance).unwrap_or(f32::INFINITY);
                if results.len() < ef || neighbor_distance < worst_dist {
                    candidates.push(MinCandidate {
                        id: neighbor_id,
                        distance: neighbor_distance,
                    });
                    results.push(MaxResult {
                        id: neighbor_id,
                        distance: neighbor_distance,
                    });

                    // Prune results if over capacity
                    if results.len() > ef {
                        results.pop();
                    }
                }
            }
        }
    }

    // Convert to sorted output
    let mut output: Vec<(u32, f32)> = results.into_iter().map(|r| (r.id, r.distance)).collect();
    output.sort_by(|a, b| a.1.total_cmp(&b.1));
    output
}

/// Get vector from SoA storage.
fn get_vector(vectors: &[f32], dimension: usize, idx: usize) -> &[f32] {
    let start = idx * dimension;
    let end = start + dimension;
    &vectors[start..end]
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_candidate_ordering() {
        let mut heap = BinaryHeap::new();
        heap.push(Candidate {
            id: 0,
            distance: 0.5,
        });
        heap.push(Candidate {
            id: 1,
            distance: 0.1,
        });
        heap.push(Candidate {
            id: 2,
            distance: 0.3,
        });

        // Should pop in order: 0.1, 0.3, 0.5 (min-heap)
        assert_eq!(heap.pop().unwrap().distance, 0.1);
        assert_eq!(heap.pop().unwrap().distance, 0.3);
        assert_eq!(heap.pop().unwrap().distance, 0.5);
    }
}