1use crate::dsl::Field;
4use crate::segment::SegmentReader;
5use crate::{DocId, Score, TERMINATED};
6
7use super::traits::{CountFuture, Query, Scorer, ScorerFuture};
8
9#[derive(Debug, Clone)]
11pub struct DenseVectorQuery {
12 pub field: Field,
14 pub vector: Vec<f32>,
16 pub k: usize,
18 pub nprobe: usize,
20 pub rerank_factor: usize,
22}
23
24impl DenseVectorQuery {
25 pub fn new(field: Field, vector: Vec<f32>, k: usize) -> Self {
27 Self {
28 field,
29 vector,
30 k,
31 nprobe: 32,
32 rerank_factor: 3,
33 }
34 }
35
36 pub fn with_nprobe(mut self, nprobe: usize) -> Self {
38 self.nprobe = nprobe;
39 self
40 }
41
42 pub fn with_rerank_factor(mut self, factor: usize) -> Self {
44 self.rerank_factor = factor;
45 self
46 }
47}
48
49impl Query for DenseVectorQuery {
50 fn scorer<'a>(&'a self, reader: &'a SegmentReader) -> ScorerFuture<'a> {
51 Box::pin(async move {
52 let results =
53 reader.search_dense_vector(self.field, &self.vector, self.k, self.rerank_factor)?;
54
55 Ok(Box::new(DenseVectorScorer::new(results)) as Box<dyn Scorer>)
56 })
57 }
58
59 fn count_estimate<'a>(&'a self, _reader: &'a SegmentReader) -> CountFuture<'a> {
60 let k = self.k as u32;
61 Box::pin(async move { Ok(k) })
62 }
63}
64
65struct DenseVectorScorer {
67 results: Vec<(u32, f32)>,
68 position: usize,
69}
70
71impl DenseVectorScorer {
72 fn new(results: Vec<(u32, f32)>) -> Self {
73 Self {
74 results,
75 position: 0,
76 }
77 }
78}
79
80impl Scorer for DenseVectorScorer {
81 fn doc(&self) -> DocId {
82 if self.position < self.results.len() {
83 self.results[self.position].0
84 } else {
85 TERMINATED
86 }
87 }
88
89 fn score(&self) -> Score {
90 if self.position < self.results.len() {
91 let distance = self.results[self.position].1;
93 1.0 / (1.0 + distance)
94 } else {
95 0.0
96 }
97 }
98
99 fn advance(&mut self) -> DocId {
100 self.position += 1;
101 self.doc()
102 }
103
104 fn seek(&mut self, target: DocId) -> DocId {
105 while self.doc() < target && self.doc() != TERMINATED {
106 self.advance();
107 }
108 self.doc()
109 }
110
111 fn size_hint(&self) -> u32 {
112 (self.results.len() - self.position) as u32
113 }
114}
115
116#[derive(Debug, Clone)]
118pub struct SparseVectorQuery {
119 pub field: Field,
121 pub indices: Vec<u32>,
123 pub weights: Vec<f32>,
124 pub k: usize,
126}
127
128impl SparseVectorQuery {
129 pub fn new(field: Field, indices: Vec<u32>, weights: Vec<f32>, k: usize) -> Self {
131 Self {
132 field,
133 indices,
134 weights,
135 k,
136 }
137 }
138
139 pub fn from_map(field: Field, sparse_vec: &[(u32, f32)], k: usize) -> Self {
141 let (indices, weights): (Vec<u32>, Vec<f32>) = sparse_vec.iter().copied().unzip();
142 Self::new(field, indices, weights, k)
143 }
144}
145
146impl Query for SparseVectorQuery {
147 fn scorer<'a>(&'a self, reader: &'a SegmentReader) -> ScorerFuture<'a> {
148 Box::pin(async move {
149 let results = reader
150 .search_sparse_vector(self.field, &self.indices, &self.weights, self.k)
151 .await?;
152
153 Ok(Box::new(SparseVectorScorer::new(results)) as Box<dyn Scorer>)
154 })
155 }
156
157 fn count_estimate<'a>(&'a self, _reader: &'a SegmentReader) -> CountFuture<'a> {
158 let k = self.k as u32;
159 Box::pin(async move { Ok(k) })
160 }
161}
162
163struct SparseVectorScorer {
165 results: Vec<(u32, f32)>,
166 position: usize,
167}
168
169impl SparseVectorScorer {
170 fn new(results: Vec<(u32, f32)>) -> Self {
171 Self {
172 results,
173 position: 0,
174 }
175 }
176}
177
178impl Scorer for SparseVectorScorer {
179 fn doc(&self) -> DocId {
180 if self.position < self.results.len() {
181 self.results[self.position].0
182 } else {
183 TERMINATED
184 }
185 }
186
187 fn score(&self) -> Score {
188 if self.position < self.results.len() {
189 self.results[self.position].1
190 } else {
191 0.0
192 }
193 }
194
195 fn advance(&mut self) -> DocId {
196 self.position += 1;
197 self.doc()
198 }
199
200 fn seek(&mut self, target: DocId) -> DocId {
201 while self.doc() < target && self.doc() != TERMINATED {
202 self.advance();
203 }
204 self.doc()
205 }
206
207 fn size_hint(&self) -> u32 {
208 (self.results.len() - self.position) as u32
209 }
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215 use crate::dsl::Field;
216
217 #[test]
218 fn test_dense_vector_query_builder() {
219 let query = DenseVectorQuery::new(Field(0), vec![1.0, 2.0, 3.0], 10)
220 .with_nprobe(64)
221 .with_rerank_factor(5);
222
223 assert_eq!(query.field, Field(0));
224 assert_eq!(query.vector.len(), 3);
225 assert_eq!(query.k, 10);
226 assert_eq!(query.nprobe, 64);
227 assert_eq!(query.rerank_factor, 5);
228 }
229
230 #[test]
231 fn test_sparse_vector_query_from_map() {
232 let sparse = vec![(1, 0.5), (5, 0.3), (10, 0.2)];
233 let query = SparseVectorQuery::from_map(Field(0), &sparse, 10);
234
235 assert_eq!(query.indices, vec![1, 5, 10]);
236 assert_eq!(query.weights, vec![0.5, 0.3, 0.2]);
237 assert_eq!(query.k, 10);
238 }
239}