reflex/vectordb/
rescoring.rs1use half::f16;
4use std::cmp::Ordering;
5use thiserror::Error;
6use tracing::warn;
7
8use crate::storage::CacheEntry;
9
10pub const DEFAULT_TOP_K: usize = 5;
12
13pub const DEFAULT_EMBEDDING_DIM: usize = crate::constants::DEFAULT_EMBEDDING_DIM;
15
16pub const EMBEDDING_BYTES: usize = crate::constants::EMBEDDING_F16_BYTES;
18
19#[derive(Debug, Error)]
20pub enum RescoringError {
22 #[error("invalid query dimension: expected {expected}, got {actual}")]
24 InvalidQueryDimension {
25 expected: usize,
27 actual: usize,
29 },
30
31 #[error("invalid embedding size for candidate {id}: expected {expected} bytes, got {actual}")]
33 InvalidEmbeddingSize {
34 id: u64,
36 expected: usize,
38 actual: usize,
40 },
41
42 #[error("no candidates provided for rescoring")]
43 NoCandidates,
45}
46
47pub type RescoringResult<T> = Result<T, RescoringError>;
49
50#[derive(Debug, Clone)]
51pub struct CandidateEntry {
53 pub id: u64,
55 pub entry: CacheEntry,
57 pub bq_score: Option<f32>,
59}
60
61impl CandidateEntry {
62 pub fn new(id: u64, entry: CacheEntry) -> Self {
64 Self {
65 id,
66 entry,
67 bq_score: None,
68 }
69 }
70
71 pub fn with_bq_score(id: u64, entry: CacheEntry, bq_score: f32) -> Self {
73 Self {
74 id,
75 entry,
76 bq_score: Some(bq_score),
77 }
78 }
79
80 pub fn embedding_as_f16(&self) -> Option<&[f16]> {
82 bytes_to_f16_slice(&self.entry.embedding)
83 }
84}
85
86#[derive(Debug, Clone)]
87pub struct ScoredCandidate {
89 pub id: u64,
91 pub entry: CacheEntry,
93 pub score: f32,
95 pub bq_score: Option<f32>,
97}
98
99impl ScoredCandidate {
100 pub fn score_delta(&self) -> Option<f32> {
102 self.bq_score.map(|bq| self.score - bq)
103 }
104}
105
106#[derive(Debug, Clone)]
107pub struct RescorerConfig {
109 pub top_k: usize,
111 pub validate_dimensions: bool,
113}
114
115impl Default for RescorerConfig {
116 fn default() -> Self {
117 Self {
118 top_k: DEFAULT_TOP_K,
119 validate_dimensions: true,
120 }
121 }
122}
123
124impl RescorerConfig {
125 pub fn with_top_k(top_k: usize) -> Self {
127 Self {
128 top_k,
129 ..Default::default()
130 }
131 }
132}
133
134#[derive(Debug, Clone)]
135pub struct VectorRescorer {
137 config: RescorerConfig,
138}
139
140impl VectorRescorer {
141 pub fn new() -> Self {
143 Self {
144 config: RescorerConfig::default(),
145 }
146 }
147
148 pub fn with_top_k(top_k: usize) -> Self {
150 Self {
151 config: RescorerConfig::with_top_k(top_k),
152 }
153 }
154
155 pub fn with_config(config: RescorerConfig) -> Self {
157 Self { config }
158 }
159
160 pub fn config(&self) -> &RescorerConfig {
162 &self.config
163 }
164
165 pub fn rescore(
167 &self,
168 query: &[f16],
169 candidates: Vec<CandidateEntry>,
170 ) -> RescoringResult<Vec<ScoredCandidate>> {
171 if self.config.validate_dimensions && query.len() != DEFAULT_EMBEDDING_DIM {
172 return Err(RescoringError::InvalidQueryDimension {
173 expected: DEFAULT_EMBEDDING_DIM,
174 actual: query.len(),
175 });
176 }
177
178 if candidates.is_empty() {
179 return Err(RescoringError::NoCandidates);
180 }
181
182 let mut scored: Vec<ScoredCandidate> = candidates
183 .into_iter()
184 .filter_map(|candidate| {
185 let embedding = match candidate.embedding_as_f16() {
186 Some(emb) => emb,
187 None => {
188 warn!(
189 candidate_id = candidate.id,
190 "Dropping candidate: failed to parse embedding as F16"
191 );
192 return None;
193 }
194 };
195
196 if self.config.validate_dimensions && embedding.len() != DEFAULT_EMBEDDING_DIM {
197 warn!(
198 candidate_id = candidate.id,
199 expected_dim = DEFAULT_EMBEDDING_DIM,
200 actual_dim = embedding.len(),
201 "Dropping candidate: embedding dimension mismatch"
202 );
203 return None;
204 }
205
206 let score = cosine_similarity_f16(query, embedding);
207
208 Some(ScoredCandidate {
209 id: candidate.id,
210 entry: candidate.entry,
211 score,
212 bq_score: candidate.bq_score,
213 })
214 })
215 .collect();
216
217 scored.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
218
219 scored.truncate(self.config.top_k);
220
221 Ok(scored)
222 }
223
224 pub fn rescore_from_bytes(
226 &self,
227 query_bytes: &[u8],
228 candidates: Vec<CandidateEntry>,
229 ) -> RescoringResult<Vec<ScoredCandidate>> {
230 let query =
231 bytes_to_f16_slice(query_bytes).ok_or(RescoringError::InvalidQueryDimension {
232 expected: EMBEDDING_BYTES,
233 actual: query_bytes.len(),
234 })?;
235
236 self.rescore(query, candidates)
237 }
238}
239
240impl Default for VectorRescorer {
241 fn default() -> Self {
242 Self::new()
243 }
244}
245
246#[inline]
247pub fn cosine_similarity_f16(a: &[f16], b: &[f16]) -> f32 {
249 if a.len() != b.len() || a.is_empty() {
250 return 0.0;
251 }
252
253 let (dot, norm_a_sq, norm_b_sq) =
254 a.iter()
255 .zip(b.iter())
256 .fold((0.0f32, 0.0f32, 0.0f32), |(dot, na, nb), (av, bv)| {
257 let av = av.to_f32();
258 let bv = bv.to_f32();
259 (dot + av * bv, na + av * av, nb + bv * bv)
260 });
261
262 let norm_a = norm_a_sq.sqrt();
263 let norm_b = norm_b_sq.sqrt();
264
265 if norm_a == 0.0 || norm_b == 0.0 {
266 0.0
267 } else {
268 dot / (norm_a * norm_b)
269 }
270}
271
272#[inline]
273pub fn cosine_similarity_f16_f32(a: &[f16], b: &[f32]) -> f32 {
275 if a.len() != b.len() || a.is_empty() {
276 return 0.0;
277 }
278
279 let mut dot_product = 0.0f32;
280 let mut norm_a_sq = 0.0f32;
281 let mut norm_b_sq = 0.0f32;
282
283 for (av_f16, &bv) in a.iter().zip(b.iter()) {
284 let av = av_f16.to_f32();
285 dot_product += av * bv;
286 norm_a_sq += av * av;
287 norm_b_sq += bv * bv;
288 }
289
290 let norm_a = norm_a_sq.sqrt();
291 let norm_b = norm_b_sq.sqrt();
292
293 if norm_a == 0.0 || norm_b == 0.0 {
294 0.0
295 } else {
296 dot_product / (norm_a * norm_b)
297 }
298}
299
300#[inline]
302pub fn bytes_to_f16_slice(bytes: &[u8]) -> Option<&[f16]> {
303 bytemuck::try_cast_slice(bytes).ok()
304}
305
306#[inline]
307pub fn f16_slice_to_bytes(values: &[f16]) -> &[u8] {
309 bytemuck::cast_slice(values)
310}
311
312pub fn f32_to_f16_vec(values: &[f32]) -> Vec<f16> {
314 values.iter().map(|&v| f16::from_f32(v)).collect()
315}
316
317pub fn f16_to_f32_vec(values: &[f16]) -> Vec<f32> {
319 values.iter().map(|v| v.to_f32()).collect()
320}
321
322#[cfg(test)]
323mod tests;