lucisearch 0.8.1

Embeddable, in-process search engine — the SQLite/DuckDB of search
Documentation
//! KNN query: approximate nearest neighbor search via the index-wide
//! global HNSW per [[global-vector-indices]] Alternative B.
//!
//! The global graph is consulted once at bind time, producing a single
//! `Vec<(SegmentId, DocId, distance)>`. Hits are bucketed by segment
//! and exposed to the per-segment scorer dispatch the rest of the
//! engine uses. Per-segment kNN fan-out and proportional-k oversampling
//! are gone — there is one graph, one search, one ranking.

use std::collections::HashMap;

use crate::core::{
    DocId, LuciError, NO_MORE_DOCS, Result, ScoreMode, Scorer, SegmentId, TwoPhaseIterator,
};

use crate::query::{BoundQuery, Query, ScorerSupplier};
use crate::search::searcher::Searcher;
use crate::segment::reader::SegmentReader;
use crate::vector::DistanceMetric;

/// Approximate k-nearest-neighbor query on a dense_vector field.
///
/// See [[feature-knn-query-type]].
pub struct KnnQuery {
    pub field: String,
    pub query_vector: Vec<f32>,
    pub k: usize,
    pub num_candidates: usize,
    /// Minimum score threshold. Results with score < threshold are excluded.
    /// Score is metric-specific (see [`crate::vector::distance_to_score`]).
    pub threshold: Option<f32>,
}

impl Query for KnnQuery {
    fn bind(&self, searcher: &Searcher, _score_mode: ScoreMode) -> Result<Box<dyn BoundQuery>> {
        // Resolve the field against the mapping. A knn query names a
        // dense_vector field; if there is no mapping, the field is
        // unknown, or the field is not a dense_vector, the query cannot
        // run — return an explicit error rather than silently producing
        // zero hits ([[code-must-not-lie]]). A *valid* dense_vector field
        // with no vectors indexed yet is the only honest empty case, and
        // is handled by the no-graph branches below.
        let Some(mapping) = searcher.mapping() else {
            return Err(LuciError::InvalidQuery(format!(
                "knn query targets field '{}', but this index has no mapping",
                self.field
            )));
        };
        let Some(field_id) = mapping.field_id(&self.field) else {
            return Err(LuciError::InvalidQuery(format!(
                "knn query targets unknown field '{}'",
                self.field
            )));
        };
        let Some(expected_dims) = mapping.field(field_id).field_type.vector_dims() else {
            return Err(LuciError::InvalidQuery(format!(
                "knn query targets field '{}', which is not a dense_vector field",
                self.field
            )));
        };
        if self.query_vector.len() != expected_dims {
            return Err(LuciError::InvalidQuery(format!(
                "knn query_vector has {} dimensions, field '{}' expects {}",
                self.query_vector.len(),
                self.field,
                expected_dims
            )));
        }

        // Pull the index-wide global graph from the search-side
        // snapshot. If the field has no vectors yet (or no graph is
        // attached), short-circuit to an empty bound query — the rest
        // of the search pipeline handles zero-hit branches uniformly.
        let Some(global) = searcher.global_hnsw() else {
            return Ok(Box::new(BoundKnnQuery {
                results_by_segment: HashMap::new(),
                metric: DistanceMetric::Cosine,
            }));
        };

        let (hits, metric) =
            match global.search(field_id, &self.query_vector, self.k, self.num_candidates)? {
                Some(out) => out,
                None => {
                    return Ok(Box::new(BoundKnnQuery {
                        results_by_segment: HashMap::new(),
                        metric: DistanceMetric::Cosine,
                    }));
                }
            };

        // Apply the threshold against the metric-specific score
        // mapping. Threshold semantics match the per-segment path
        // (kept after this refactor).
        let mut filtered: Vec<_> = hits
            .into_iter()
            .filter(|hit| match self.threshold {
                Some(min_score) => {
                    crate::vector::distance_to_score(hit.distance, metric) >= min_score
                }
                None => true,
            })
            .collect();

        // Bucket by segment. Each bucket holds `(local_doc_id, distance)`
        // and is sorted by ascending `local_doc_id` so the per-segment
        // scorer can iterate in advance() order.
        let mut results_by_segment: HashMap<SegmentId, Vec<(u32, f32)>> = HashMap::new();
        filtered.sort_by(|a, b| {
            a.distance
                .partial_cmp(&b.distance)
                .unwrap_or(std::cmp::Ordering::Equal)
        });
        for hit in filtered {
            results_by_segment
                .entry(hit.segment_id)
                .or_default()
                .push((hit.doc_id.as_u32(), hit.distance));
        }
        for bucket in results_by_segment.values_mut() {
            bucket.sort_by_key(|(doc_id, _)| *doc_id);
        }

        Ok(Box::new(BoundKnnQuery {
            results_by_segment,
            metric,
        }))
    }
}

struct BoundKnnQuery {
    /// Pre-bucketed hits, keyed by the owning segment. Each bucket is
    /// sorted by ascending local doc_id so the scorer can advance() in
    /// order against bool conjunctions.
    results_by_segment: HashMap<SegmentId, Vec<(u32, f32)>>,
    metric: DistanceMetric,
}

impl BoundQuery for BoundKnnQuery {
    fn scorer_supplier(&self, reader: &SegmentReader) -> Result<Option<Box<dyn ScorerSupplier>>> {
        let Some(bucket) = self.results_by_segment.get(&reader.segment_id()) else {
            return Ok(None);
        };
        if bucket.is_empty() {
            return Ok(None);
        }
        Ok(Some(Box::new(KnnScorerSupplier {
            results: bucket.clone(),
            metric: self.metric,
        })))
    }
}

struct KnnScorerSupplier {
    results: Vec<(u32, f32)>,
    metric: DistanceMetric,
}

impl ScorerSupplier for KnnScorerSupplier {
    fn cost(&self) -> u64 {
        self.results.len() as u64
    }

    fn scorer(self: Box<Self>) -> Result<Box<dyn Scorer>> {
        Ok(Box::new(KnnScorer {
            results: self.results,
            metric: self.metric,
            pos: 0,
        }))
    }
}

/// Scorer that iterates over pre-materialized kNN results.
///
/// Results are sorted by doc_id ascending (re-sorted from HNSW's distance
/// order). Scores are computed from distances using metric-specific formulas.
struct KnnScorer {
    results: Vec<(u32, f32)>, // (doc_id, distance) sorted by doc_id
    metric: DistanceMetric,
    pos: usize,
}

impl Scorer for KnnScorer {
    fn doc_id(&self) -> DocId {
        if self.pos < self.results.len() {
            DocId::new(self.results[self.pos].0)
        } else {
            NO_MORE_DOCS
        }
    }

    fn next(&mut self) -> DocId {
        if self.pos < self.results.len() {
            self.pos += 1;
        }
        self.doc_id()
    }

    fn advance(&mut self, target: DocId) -> DocId {
        while self.pos < self.results.len() && self.results[self.pos].0 < target.as_u32() {
            self.pos += 1;
        }
        self.doc_id()
    }

    fn score(&mut self) -> f32 {
        if self.pos < self.results.len() {
            crate::vector::distance_to_score(self.results[self.pos].1, self.metric)
        } else {
            0.0
        }
    }

    fn two_phase(&mut self) -> Option<&mut dyn TwoPhaseIterator> {
        None
    }
}