reflex/vectordb/
rescoring.rs

1//! Full-precision rescoring of candidates.
2
3use half::f16;
4use std::cmp::Ordering;
5use thiserror::Error;
6use tracing::warn;
7
8use crate::storage::CacheEntry;
9
10/// Default number of rescored candidates returned.
11pub const DEFAULT_TOP_K: usize = 5;
12
13/// Default embedding dimension used for validation.
14pub const DEFAULT_EMBEDDING_DIM: usize = crate::constants::DEFAULT_EMBEDDING_DIM;
15
16/// Expected embedding byte length (f16).
17pub const EMBEDDING_BYTES: usize = crate::constants::EMBEDDING_F16_BYTES;
18
19#[derive(Debug, Error)]
20/// Errors returned by rescoring.
21pub enum RescoringError {
22    /// Query dimension mismatch.
23    #[error("invalid query dimension: expected {expected}, got {actual}")]
24    InvalidQueryDimension {
25        /// Expected dimension.
26        expected: usize,
27        /// Actual dimension.
28        actual: usize,
29    },
30
31    /// Candidate embedding byte length mismatch.
32    #[error("invalid embedding size for candidate {id}: expected {expected} bytes, got {actual}")]
33    InvalidEmbeddingSize {
34        /// Candidate id.
35        id: u64,
36        /// Expected bytes.
37        expected: usize,
38        /// Actual bytes.
39        actual: usize,
40    },
41
42    #[error("no candidates provided for rescoring")]
43    /// No candidates were provided.
44    NoCandidates,
45}
46
47/// Convenience result type for rescoring.
48pub type RescoringResult<T> = Result<T, RescoringError>;
49
50#[derive(Debug, Clone)]
51/// Candidate with optional BQ score, ready for full-precision rescoring.
52pub struct CandidateEntry {
53    /// Point id.
54    pub id: u64,
55    /// Full cache entry.
56    pub entry: CacheEntry,
57    /// Optional BQ score from the first-stage search.
58    pub bq_score: Option<f32>,
59}
60
61impl CandidateEntry {
62    /// Creates a candidate entry without a BQ score.
63    pub fn new(id: u64, entry: CacheEntry) -> Self {
64        Self {
65            id,
66            entry,
67            bq_score: None,
68        }
69    }
70
71    /// Creates a candidate entry with a BQ score.
72    pub fn with_bq_score(id: u64, entry: CacheEntry, bq_score: f32) -> Self {
73        Self {
74            id,
75            entry,
76            bq_score: Some(bq_score),
77        }
78    }
79
80    /// Views the candidate embedding as an f16 slice (if the bytes are valid).
81    pub fn embedding_as_f16(&self) -> Option<&[f16]> {
82        bytes_to_f16_slice(&self.entry.embedding)
83    }
84}
85
86#[derive(Debug, Clone)]
87/// Rescored candidate (full-precision cosine similarity).
88pub struct ScoredCandidate {
89    /// Point id.
90    pub id: u64,
91    /// Full cache entry.
92    pub entry: CacheEntry,
93    /// Full-precision score.
94    pub score: f32,
95    /// Optional BQ score for debugging/analysis.
96    pub bq_score: Option<f32>,
97}
98
99impl ScoredCandidate {
100    /// Returns `score - bq_score` if a BQ score is present.
101    pub fn score_delta(&self) -> Option<f32> {
102        self.bq_score.map(|bq| self.score - bq)
103    }
104}
105
106#[derive(Debug, Clone)]
107/// Configuration for [`VectorRescorer`].
108pub struct RescorerConfig {
109    /// Number of candidates returned after rescoring.
110    pub top_k: usize,
111    /// If true, validate dimensions before scoring.
112    pub validate_dimensions: bool,
113}
114
115impl Default for RescorerConfig {
116    fn default() -> Self {
117        Self {
118            top_k: DEFAULT_TOP_K,
119            validate_dimensions: true,
120        }
121    }
122}
123
124impl RescorerConfig {
125    /// Creates a config overriding `top_k`.
126    pub fn with_top_k(top_k: usize) -> Self {
127        Self {
128            top_k,
129            ..Default::default()
130        }
131    }
132}
133
134#[derive(Debug, Clone)]
135/// Full-precision rescoring of candidates (cosine similarity).
136pub struct VectorRescorer {
137    config: RescorerConfig,
138}
139
140impl VectorRescorer {
141    /// Creates a rescorer with default config.
142    pub fn new() -> Self {
143        Self {
144            config: RescorerConfig::default(),
145        }
146    }
147
148    /// Creates a rescorer overriding `top_k`.
149    pub fn with_top_k(top_k: usize) -> Self {
150        Self {
151            config: RescorerConfig::with_top_k(top_k),
152        }
153    }
154
155    /// Creates a rescorer with an explicit config.
156    pub fn with_config(config: RescorerConfig) -> Self {
157        Self { config }
158    }
159
160    /// Returns the active config.
161    pub fn config(&self) -> &RescorerConfig {
162        &self.config
163    }
164
165    /// Rescores candidates and returns the top-k results.
166    pub fn rescore(
167        &self,
168        query: &[f16],
169        candidates: Vec<CandidateEntry>,
170    ) -> RescoringResult<Vec<ScoredCandidate>> {
171        if self.config.validate_dimensions && query.len() != DEFAULT_EMBEDDING_DIM {
172            return Err(RescoringError::InvalidQueryDimension {
173                expected: DEFAULT_EMBEDDING_DIM,
174                actual: query.len(),
175            });
176        }
177
178        if candidates.is_empty() {
179            return Err(RescoringError::NoCandidates);
180        }
181
182        let mut scored: Vec<ScoredCandidate> = candidates
183            .into_iter()
184            .filter_map(|candidate| {
185                let embedding = match candidate.embedding_as_f16() {
186                    Some(emb) => emb,
187                    None => {
188                        warn!(
189                            candidate_id = candidate.id,
190                            "Dropping candidate: failed to parse embedding as F16"
191                        );
192                        return None;
193                    }
194                };
195
196                if self.config.validate_dimensions && embedding.len() != DEFAULT_EMBEDDING_DIM {
197                    warn!(
198                        candidate_id = candidate.id,
199                        expected_dim = DEFAULT_EMBEDDING_DIM,
200                        actual_dim = embedding.len(),
201                        "Dropping candidate: embedding dimension mismatch"
202                    );
203                    return None;
204                }
205
206                let score = cosine_similarity_f16(query, embedding);
207
208                Some(ScoredCandidate {
209                    id: candidate.id,
210                    entry: candidate.entry,
211                    score,
212                    bq_score: candidate.bq_score,
213                })
214            })
215            .collect();
216
217        scored.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
218
219        scored.truncate(self.config.top_k);
220
221        Ok(scored)
222    }
223
224    /// Like [`Self::rescore`], but takes query bytes (little-endian f16).
225    pub fn rescore_from_bytes(
226        &self,
227        query_bytes: &[u8],
228        candidates: Vec<CandidateEntry>,
229    ) -> RescoringResult<Vec<ScoredCandidate>> {
230        let query =
231            bytes_to_f16_slice(query_bytes).ok_or(RescoringError::InvalidQueryDimension {
232                expected: EMBEDDING_BYTES,
233                actual: query_bytes.len(),
234            })?;
235
236        self.rescore(query, candidates)
237    }
238}
239
240impl Default for VectorRescorer {
241    fn default() -> Self {
242        Self::new()
243    }
244}
245
246#[inline]
247/// Computes cosine similarity between two f16 vectors.
248pub fn cosine_similarity_f16(a: &[f16], b: &[f16]) -> f32 {
249    if a.len() != b.len() || a.is_empty() {
250        return 0.0;
251    }
252
253    let (dot, norm_a_sq, norm_b_sq) =
254        a.iter()
255            .zip(b.iter())
256            .fold((0.0f32, 0.0f32, 0.0f32), |(dot, na, nb), (av, bv)| {
257                let av = av.to_f32();
258                let bv = bv.to_f32();
259                (dot + av * bv, na + av * av, nb + bv * bv)
260            });
261
262    let norm_a = norm_a_sq.sqrt();
263    let norm_b = norm_b_sq.sqrt();
264
265    if norm_a == 0.0 || norm_b == 0.0 {
266        0.0
267    } else {
268        dot / (norm_a * norm_b)
269    }
270}
271
272#[inline]
273/// Computes cosine similarity between an f16 query and an f32 candidate.
274pub fn cosine_similarity_f16_f32(a: &[f16], b: &[f32]) -> f32 {
275    if a.len() != b.len() || a.is_empty() {
276        return 0.0;
277    }
278
279    let mut dot_product = 0.0f32;
280    let mut norm_a_sq = 0.0f32;
281    let mut norm_b_sq = 0.0f32;
282
283    for (av_f16, &bv) in a.iter().zip(b.iter()) {
284        let av = av_f16.to_f32();
285        dot_product += av * bv;
286        norm_a_sq += av * av;
287        norm_b_sq += bv * bv;
288    }
289
290    let norm_a = norm_a_sq.sqrt();
291    let norm_b = norm_b_sq.sqrt();
292
293    if norm_a == 0.0 || norm_b == 0.0 {
294        0.0
295    } else {
296        dot_product / (norm_a * norm_b)
297    }
298}
299
300/// Reinterprets little-endian f16 bytes as an `&[f16]` (no copy).
301#[inline]
302pub fn bytes_to_f16_slice(bytes: &[u8]) -> Option<&[f16]> {
303    bytemuck::try_cast_slice(bytes).ok()
304}
305
306#[inline]
307/// Views an `&[f16]` as bytes (no copy).
308pub fn f16_slice_to_bytes(values: &[f16]) -> &[u8] {
309    bytemuck::cast_slice(values)
310}
311
312/// Converts `&[f32]` to `Vec<f16>`.
313pub fn f32_to_f16_vec(values: &[f32]) -> Vec<f16> {
314    values.iter().map(|&v| f16::from_f32(v)).collect()
315}
316
317/// Converts `&[f16]` to `Vec<f32>`.
318pub fn f16_to_f32_vec(values: &[f16]) -> Vec<f32> {
319    values.iter().map(|v| v.to_f32()).collect()
320}
321
322#[cfg(test)]
323mod tests;