hermes_core/query/
vector.rs

1//! Vector query types for dense and sparse vector search
2
3use crate::dsl::Field;
4use crate::segment::SegmentReader;
5use crate::{DocId, Score, TERMINATED};
6
7use super::traits::{CountFuture, Query, Scorer, ScorerFuture};
8
9/// Dense vector query for similarity search
10#[derive(Debug, Clone)]
11pub struct DenseVectorQuery {
12    /// Field containing the dense vectors
13    pub field: Field,
14    /// Query vector
15    pub vector: Vec<f32>,
16    /// Number of results to return
17    pub k: usize,
18    /// Number of clusters to probe (for IVF indexes)
19    pub nprobe: usize,
20    /// Re-ranking factor (multiplied by k for candidate selection)
21    pub rerank_factor: usize,
22}
23
24impl DenseVectorQuery {
25    /// Create a new dense vector query
26    pub fn new(field: Field, vector: Vec<f32>, k: usize) -> Self {
27        Self {
28            field,
29            vector,
30            k,
31            nprobe: 32,
32            rerank_factor: 3,
33        }
34    }
35
36    /// Set the number of clusters to probe (for IVF indexes)
37    pub fn with_nprobe(mut self, nprobe: usize) -> Self {
38        self.nprobe = nprobe;
39        self
40    }
41
42    /// Set the re-ranking factor
43    pub fn with_rerank_factor(mut self, factor: usize) -> Self {
44        self.rerank_factor = factor;
45        self
46    }
47}
48
49impl Query for DenseVectorQuery {
50    fn scorer<'a>(&'a self, reader: &'a SegmentReader) -> ScorerFuture<'a> {
51        Box::pin(async move {
52            let results =
53                reader.search_dense_vector(self.field, &self.vector, self.k, self.rerank_factor)?;
54
55            Ok(Box::new(DenseVectorScorer::new(results)) as Box<dyn Scorer>)
56        })
57    }
58
59    fn count_estimate<'a>(&'a self, _reader: &'a SegmentReader) -> CountFuture<'a> {
60        let k = self.k as u32;
61        Box::pin(async move { Ok(k) })
62    }
63}
64
65/// Scorer for dense vector search results
66struct DenseVectorScorer {
67    results: Vec<(u32, f32)>,
68    position: usize,
69}
70
71impl DenseVectorScorer {
72    fn new(results: Vec<(u32, f32)>) -> Self {
73        Self {
74            results,
75            position: 0,
76        }
77    }
78}
79
80impl Scorer for DenseVectorScorer {
81    fn doc(&self) -> DocId {
82        if self.position < self.results.len() {
83            self.results[self.position].0
84        } else {
85            TERMINATED
86        }
87    }
88
89    fn score(&self) -> Score {
90        if self.position < self.results.len() {
91            // Convert distance to score (smaller distance = higher score)
92            let distance = self.results[self.position].1;
93            1.0 / (1.0 + distance)
94        } else {
95            0.0
96        }
97    }
98
99    fn advance(&mut self) -> DocId {
100        self.position += 1;
101        self.doc()
102    }
103
104    fn seek(&mut self, target: DocId) -> DocId {
105        while self.doc() < target && self.doc() != TERMINATED {
106            self.advance();
107        }
108        self.doc()
109    }
110
111    fn size_hint(&self) -> u32 {
112        (self.results.len() - self.position) as u32
113    }
114}
115
116/// Sparse vector query for similarity search
117#[derive(Debug, Clone)]
118pub struct SparseVectorQuery {
119    /// Field containing the sparse vectors
120    pub field: Field,
121    /// Query vector as (index, weight) pairs
122    pub indices: Vec<u32>,
123    pub weights: Vec<f32>,
124    /// Number of results to return
125    pub k: usize,
126}
127
128impl SparseVectorQuery {
129    /// Create a new sparse vector query
130    pub fn new(field: Field, indices: Vec<u32>, weights: Vec<f32>, k: usize) -> Self {
131        Self {
132            field,
133            indices,
134            weights,
135            k,
136        }
137    }
138
139    /// Create from a sparse vector map
140    pub fn from_map(field: Field, sparse_vec: &[(u32, f32)], k: usize) -> Self {
141        let (indices, weights): (Vec<u32>, Vec<f32>) = sparse_vec.iter().copied().unzip();
142        Self::new(field, indices, weights, k)
143    }
144}
145
146impl Query for SparseVectorQuery {
147    fn scorer<'a>(&'a self, reader: &'a SegmentReader) -> ScorerFuture<'a> {
148        Box::pin(async move {
149            let results = reader
150                .search_sparse_vector(self.field, &self.indices, &self.weights, self.k)
151                .await?;
152
153            Ok(Box::new(SparseVectorScorer::new(results)) as Box<dyn Scorer>)
154        })
155    }
156
157    fn count_estimate<'a>(&'a self, _reader: &'a SegmentReader) -> CountFuture<'a> {
158        let k = self.k as u32;
159        Box::pin(async move { Ok(k) })
160    }
161}
162
163/// Scorer for sparse vector search results
164struct SparseVectorScorer {
165    results: Vec<(u32, f32)>,
166    position: usize,
167}
168
169impl SparseVectorScorer {
170    fn new(results: Vec<(u32, f32)>) -> Self {
171        Self {
172            results,
173            position: 0,
174        }
175    }
176}
177
178impl Scorer for SparseVectorScorer {
179    fn doc(&self) -> DocId {
180        if self.position < self.results.len() {
181            self.results[self.position].0
182        } else {
183            TERMINATED
184        }
185    }
186
187    fn score(&self) -> Score {
188        if self.position < self.results.len() {
189            self.results[self.position].1
190        } else {
191            0.0
192        }
193    }
194
195    fn advance(&mut self) -> DocId {
196        self.position += 1;
197        self.doc()
198    }
199
200    fn seek(&mut self, target: DocId) -> DocId {
201        while self.doc() < target && self.doc() != TERMINATED {
202            self.advance();
203        }
204        self.doc()
205    }
206
207    fn size_hint(&self) -> u32 {
208        (self.results.len() - self.position) as u32
209    }
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215    use crate::dsl::Field;
216
217    #[test]
218    fn test_dense_vector_query_builder() {
219        let query = DenseVectorQuery::new(Field(0), vec![1.0, 2.0, 3.0], 10)
220            .with_nprobe(64)
221            .with_rerank_factor(5);
222
223        assert_eq!(query.field, Field(0));
224        assert_eq!(query.vector.len(), 3);
225        assert_eq!(query.k, 10);
226        assert_eq!(query.nprobe, 64);
227        assert_eq!(query.rerank_factor, 5);
228    }
229
230    #[test]
231    fn test_sparse_vector_query_from_map() {
232        let sparse = vec![(1, 0.5), (5, 0.3), (10, 0.2)];
233        let query = SparseVectorQuery::from_map(Field(0), &sparse, 10);
234
235        assert_eq!(query.indices, vec![1, 5, 10]);
236        assert_eq!(query.weights, vec![0.5, 0.3, 0.2]);
237        assert_eq!(query.k, 10);
238    }
239}