Skip to main content

hermes_core/query/vector/
dense.rs

1//! Dense vector query for similarity search (ANN)
2
3use crate::dsl::Field;
4use crate::segment::{SegmentReader, VectorSearchResult};
5use crate::{DocId, Score, TERMINATED};
6
7use super::combiner::MultiValueCombiner;
8use crate::query::ScoredPosition;
9use crate::query::traits::{CountFuture, MatchedPositions, Query, Scorer, ScorerFuture};
10
11/// Dense vector query for similarity search
12#[derive(Debug, Clone)]
13pub struct DenseVectorQuery {
14    /// Field containing the dense vectors
15    pub field: Field,
16    /// Query vector
17    pub vector: Vec<f32>,
18    /// Number of clusters to probe (for IVF indexes)
19    pub nprobe: usize,
20    /// Re-ranking factor (multiplied by k for candidate selection, e.g. 3.0)
21    pub rerank_factor: f32,
22    /// How to combine scores for multi-valued documents
23    pub combiner: MultiValueCombiner,
24}
25
26impl std::fmt::Display for DenseVectorQuery {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        write!(
29            f,
30            "Dense({}, dim={}, nprobe={}, rerank={})",
31            self.field.0,
32            self.vector.len(),
33            self.nprobe,
34            self.rerank_factor
35        )
36    }
37}
38
39impl DenseVectorQuery {
40    /// Create a new dense vector query
41    pub fn new(field: Field, vector: Vec<f32>) -> Self {
42        Self {
43            field,
44            vector,
45            nprobe: 32,
46            rerank_factor: 3.0,
47            combiner: MultiValueCombiner::Max,
48        }
49    }
50
51    /// Set the number of clusters to probe (for IVF indexes)
52    pub fn with_nprobe(mut self, nprobe: usize) -> Self {
53        self.nprobe = nprobe;
54        self
55    }
56
57    /// Set the re-ranking factor (e.g. 3.0 = fetch 3x candidates for reranking)
58    pub fn with_rerank_factor(mut self, factor: f32) -> Self {
59        self.rerank_factor = factor;
60        self
61    }
62
63    /// Set the multi-value score combiner
64    pub fn with_combiner(mut self, combiner: MultiValueCombiner) -> Self {
65        self.combiner = combiner;
66        self
67    }
68}
69
70impl Query for DenseVectorQuery {
71    fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
72        let field = self.field;
73        let vector = self.vector.clone();
74        let nprobe = self.nprobe;
75        let rerank_factor = self.rerank_factor;
76        let combiner = self.combiner;
77        Box::pin(async move {
78            let results = reader
79                .search_dense_vector(field, &vector, limit, nprobe, rerank_factor, combiner)
80                .await?;
81
82            Ok(Box::new(DenseVectorScorer::new(results, field.0)) as Box<dyn Scorer>)
83        })
84    }
85
86    #[cfg(feature = "sync")]
87    fn scorer_sync<'a>(
88        &self,
89        reader: &'a SegmentReader,
90        limit: usize,
91    ) -> crate::Result<Box<dyn Scorer + 'a>> {
92        let results = reader.search_dense_vector_sync(
93            self.field,
94            &self.vector,
95            limit,
96            self.nprobe,
97            self.rerank_factor,
98            self.combiner,
99        )?;
100        Ok(Box::new(DenseVectorScorer::new(results, self.field.0)) as Box<dyn Scorer>)
101    }
102
103    fn count_estimate<'a>(&self, _reader: &'a SegmentReader) -> CountFuture<'a> {
104        Box::pin(async move { Ok(u32::MAX) })
105    }
106}
107
108/// Scorer for dense vector search results with ordinal tracking
109struct DenseVectorScorer {
110    results: Vec<VectorSearchResult>,
111    position: usize,
112    field_id: u32,
113}
114
115impl DenseVectorScorer {
116    fn new(mut results: Vec<VectorSearchResult>, field_id: u32) -> Self {
117        // Sort by doc_id ascending — DocSet contract requires monotonic doc IDs
118        results.sort_unstable_by_key(|r| r.doc_id);
119        Self {
120            results,
121            position: 0,
122            field_id,
123        }
124    }
125}
126
127impl crate::query::docset::DocSet for DenseVectorScorer {
128    fn doc(&self) -> DocId {
129        if self.position < self.results.len() {
130            self.results[self.position].doc_id
131        } else {
132            TERMINATED
133        }
134    }
135
136    fn advance(&mut self) -> DocId {
137        self.position += 1;
138        self.doc()
139    }
140
141    fn seek(&mut self, target: DocId) -> DocId {
142        // Binary search within remaining results for O(log k) seek
143        let remaining = &self.results[self.position..];
144        let offset = remaining.partition_point(|r| r.doc_id < target);
145        self.position += offset;
146        self.doc()
147    }
148
149    fn size_hint(&self) -> u32 {
150        (self.results.len() - self.position) as u32
151    }
152}
153
154impl Scorer for DenseVectorScorer {
155    fn score(&self) -> Score {
156        if self.position < self.results.len() {
157            self.results[self.position].score
158        } else {
159            0.0
160        }
161    }
162
163    fn matched_positions(&self) -> Option<MatchedPositions> {
164        if self.position >= self.results.len() {
165            return None;
166        }
167        let result = &self.results[self.position];
168        let scored_positions: Vec<ScoredPosition> = result
169            .ordinals
170            .iter()
171            .map(|(ordinal, score)| ScoredPosition::new(*ordinal, *score))
172            .collect();
173        Some(vec![(self.field_id, scored_positions)])
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180
181    #[test]
182    fn test_dense_vector_query_builder() {
183        let query = DenseVectorQuery::new(Field(0), vec![1.0, 2.0, 3.0])
184            .with_nprobe(64)
185            .with_rerank_factor(5.0);
186
187        assert_eq!(query.field, Field(0));
188        assert_eq!(query.vector.len(), 3);
189        assert_eq!(query.nprobe, 64);
190        assert_eq!(query.rerank_factor, 5.0);
191    }
192}