use std::collections::HashMap;
use crate::core::{
DocId, LuciError, NO_MORE_DOCS, Result, ScoreMode, Scorer, SegmentId, TwoPhaseIterator,
};
use crate::query::{BoundQuery, Query, ScorerSupplier};
use crate::search::searcher::Searcher;
use crate::segment::reader::SegmentReader;
use crate::vector::DistanceMetric;
pub struct KnnQuery {
pub field: String,
pub query_vector: Vec<f32>,
pub k: usize,
pub num_candidates: usize,
pub threshold: Option<f32>,
}
impl Query for KnnQuery {
fn bind(&self, searcher: &Searcher, _score_mode: ScoreMode) -> Result<Box<dyn BoundQuery>> {
let Some(mapping) = searcher.mapping() else {
return Err(LuciError::InvalidQuery(format!(
"knn query targets field '{}', but this index has no mapping",
self.field
)));
};
let Some(field_id) = mapping.field_id(&self.field) else {
return Err(LuciError::InvalidQuery(format!(
"knn query targets unknown field '{}'",
self.field
)));
};
let Some(expected_dims) = mapping.field(field_id).field_type.vector_dims() else {
return Err(LuciError::InvalidQuery(format!(
"knn query targets field '{}', which is not a dense_vector field",
self.field
)));
};
if self.query_vector.len() != expected_dims {
return Err(LuciError::InvalidQuery(format!(
"knn query_vector has {} dimensions, field '{}' expects {}",
self.query_vector.len(),
self.field,
expected_dims
)));
}
let Some(global) = searcher.global_hnsw() else {
return Ok(Box::new(BoundKnnQuery {
results_by_segment: HashMap::new(),
metric: DistanceMetric::Cosine,
}));
};
let (hits, metric) =
match global.search(field_id, &self.query_vector, self.k, self.num_candidates)? {
Some(out) => out,
None => {
return Ok(Box::new(BoundKnnQuery {
results_by_segment: HashMap::new(),
metric: DistanceMetric::Cosine,
}));
}
};
let mut filtered: Vec<_> = hits
.into_iter()
.filter(|hit| match self.threshold {
Some(min_score) => {
crate::vector::distance_to_score(hit.distance, metric) >= min_score
}
None => true,
})
.collect();
let mut results_by_segment: HashMap<SegmentId, Vec<(u32, f32)>> = HashMap::new();
filtered.sort_by(|a, b| {
a.distance
.partial_cmp(&b.distance)
.unwrap_or(std::cmp::Ordering::Equal)
});
for hit in filtered {
results_by_segment
.entry(hit.segment_id)
.or_default()
.push((hit.doc_id.as_u32(), hit.distance));
}
for bucket in results_by_segment.values_mut() {
bucket.sort_by_key(|(doc_id, _)| *doc_id);
}
Ok(Box::new(BoundKnnQuery {
results_by_segment,
metric,
}))
}
}
struct BoundKnnQuery {
results_by_segment: HashMap<SegmentId, Vec<(u32, f32)>>,
metric: DistanceMetric,
}
impl BoundQuery for BoundKnnQuery {
fn scorer_supplier(&self, reader: &SegmentReader) -> Result<Option<Box<dyn ScorerSupplier>>> {
let Some(bucket) = self.results_by_segment.get(&reader.segment_id()) else {
return Ok(None);
};
if bucket.is_empty() {
return Ok(None);
}
Ok(Some(Box::new(KnnScorerSupplier {
results: bucket.clone(),
metric: self.metric,
})))
}
}
struct KnnScorerSupplier {
results: Vec<(u32, f32)>,
metric: DistanceMetric,
}
impl ScorerSupplier for KnnScorerSupplier {
fn cost(&self) -> u64 {
self.results.len() as u64
}
fn scorer(self: Box<Self>) -> Result<Box<dyn Scorer>> {
Ok(Box::new(KnnScorer {
results: self.results,
metric: self.metric,
pos: 0,
}))
}
}
struct KnnScorer {
results: Vec<(u32, f32)>, metric: DistanceMetric,
pos: usize,
}
impl Scorer for KnnScorer {
fn doc_id(&self) -> DocId {
if self.pos < self.results.len() {
DocId::new(self.results[self.pos].0)
} else {
NO_MORE_DOCS
}
}
fn next(&mut self) -> DocId {
if self.pos < self.results.len() {
self.pos += 1;
}
self.doc_id()
}
fn advance(&mut self, target: DocId) -> DocId {
while self.pos < self.results.len() && self.results[self.pos].0 < target.as_u32() {
self.pos += 1;
}
self.doc_id()
}
fn score(&mut self) -> f32 {
if self.pos < self.results.len() {
crate::vector::distance_to_score(self.results[self.pos].1, self.metric)
} else {
0.0
}
}
fn two_phase(&mut self) -> Option<&mut dyn TwoPhaseIterator> {
None
}
}