hermes_core/query/vector/
dense.rs1use crate::dsl::Field;
4use crate::segment::{SegmentReader, VectorSearchResult};
5use crate::{DocId, Score, TERMINATED};
6
7use super::combiner::MultiValueCombiner;
8use crate::query::ScoredPosition;
9use crate::query::traits::{CountFuture, MatchedPositions, Query, Scorer, ScorerFuture};
10
11#[derive(Debug, Clone)]
13pub struct DenseVectorQuery {
14 pub field: Field,
16 pub vector: Vec<f32>,
18 pub nprobe: usize,
20 pub rerank_factor: f32,
22 pub combiner: MultiValueCombiner,
24}
25
26impl std::fmt::Display for DenseVectorQuery {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 write!(
29 f,
30 "Dense({}, dim={}, nprobe={}, rerank={})",
31 self.field.0,
32 self.vector.len(),
33 self.nprobe,
34 self.rerank_factor
35 )
36 }
37}
38
39impl DenseVectorQuery {
40 pub fn new(field: Field, vector: Vec<f32>) -> Self {
42 Self {
43 field,
44 vector,
45 nprobe: 32,
46 rerank_factor: 3.0,
47 combiner: MultiValueCombiner::Max,
48 }
49 }
50
51 pub fn with_nprobe(mut self, nprobe: usize) -> Self {
53 self.nprobe = nprobe;
54 self
55 }
56
57 pub fn with_rerank_factor(mut self, factor: f32) -> Self {
59 self.rerank_factor = factor;
60 self
61 }
62
63 pub fn with_combiner(mut self, combiner: MultiValueCombiner) -> Self {
65 self.combiner = combiner;
66 self
67 }
68}
69
70impl Query for DenseVectorQuery {
71 fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
72 let field = self.field;
73 let vector = self.vector.clone();
74 let nprobe = self.nprobe;
75 let rerank_factor = self.rerank_factor;
76 let combiner = self.combiner;
77 Box::pin(async move {
78 let results = reader
79 .search_dense_vector(field, &vector, limit, nprobe, rerank_factor, combiner)
80 .await?;
81
82 Ok(Box::new(DenseVectorScorer::new(results, field.0)) as Box<dyn Scorer>)
83 })
84 }
85
86 #[cfg(feature = "sync")]
87 fn scorer_sync<'a>(
88 &self,
89 reader: &'a SegmentReader,
90 limit: usize,
91 ) -> crate::Result<Box<dyn Scorer + 'a>> {
92 let results = reader.search_dense_vector_sync(
93 self.field,
94 &self.vector,
95 limit,
96 self.nprobe,
97 self.rerank_factor,
98 self.combiner,
99 )?;
100 Ok(Box::new(DenseVectorScorer::new(results, self.field.0)) as Box<dyn Scorer>)
101 }
102
103 fn count_estimate<'a>(&self, _reader: &'a SegmentReader) -> CountFuture<'a> {
104 Box::pin(async move { Ok(u32::MAX) })
105 }
106}
107
108struct DenseVectorScorer {
110 results: Vec<VectorSearchResult>,
111 position: usize,
112 field_id: u32,
113}
114
115impl DenseVectorScorer {
116 fn new(mut results: Vec<VectorSearchResult>, field_id: u32) -> Self {
117 results.sort_unstable_by_key(|r| r.doc_id);
119 Self {
120 results,
121 position: 0,
122 field_id,
123 }
124 }
125}
126
127impl crate::query::docset::DocSet for DenseVectorScorer {
128 fn doc(&self) -> DocId {
129 if self.position < self.results.len() {
130 self.results[self.position].doc_id
131 } else {
132 TERMINATED
133 }
134 }
135
136 fn advance(&mut self) -> DocId {
137 self.position += 1;
138 self.doc()
139 }
140
141 fn seek(&mut self, target: DocId) -> DocId {
142 let remaining = &self.results[self.position..];
144 let offset = remaining.partition_point(|r| r.doc_id < target);
145 self.position += offset;
146 self.doc()
147 }
148
149 fn size_hint(&self) -> u32 {
150 (self.results.len() - self.position) as u32
151 }
152}
153
154impl Scorer for DenseVectorScorer {
155 fn score(&self) -> Score {
156 if self.position < self.results.len() {
157 self.results[self.position].score
158 } else {
159 0.0
160 }
161 }
162
163 fn matched_positions(&self) -> Option<MatchedPositions> {
164 if self.position >= self.results.len() {
165 return None;
166 }
167 let result = &self.results[self.position];
168 let scored_positions: Vec<ScoredPosition> = result
169 .ordinals
170 .iter()
171 .map(|(ordinal, score)| ScoredPosition::new(*ordinal, *score))
172 .collect();
173 Some(vec![(self.field_id, scored_positions)])
174 }
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180
181 #[test]
182 fn test_dense_vector_query_builder() {
183 let query = DenseVectorQuery::new(Field(0), vec![1.0, 2.0, 3.0])
184 .with_nprobe(64)
185 .with_rerank_factor(5.0);
186
187 assert_eq!(query.field, Field(0));
188 assert_eq!(query.vector.len(), 3);
189 assert_eq!(query.nprobe, 64);
190 assert_eq!(query.rerank_factor, 5.0);
191 }
192}