Skip to main content

batuta/oracle/rag/quantization/
retriever.rs

1//! Two-stage rescoring retriever
2//!
3//! Stage 1: Fast approximate retrieval with int8 embeddings
4//! Stage 2: Precise rescoring with f32 query x i8 docs
5//!
6//! Achieves 99% accuracy retention with 3.66x speedup.
7
8// Library code - usage from examples and integration tests
9use super::calibration::CalibrationStats;
10use super::embedding::QuantizedEmbedding;
11use super::error::{validate_embedding, QuantizationError};
12use super::simd::SimdBackend;
13
14/// Two-stage rescoring retriever configuration
15///
16/// Following the scalar int8 rescoring specification.
17#[derive(Debug, Clone)]
18pub struct RescoreRetrieverConfig {
19    /// Number of candidates to retrieve in stage 1 (multiplier x top_k)
20    /// Optimal: 4-5x for 99% accuracy retention
21    pub rescore_multiplier: usize,
22    /// Final number of results to return
23    pub top_k: usize,
24    /// Minimum calibration samples required
25    pub min_calibration_samples: usize,
26    /// SIMD backend (auto-detected if None)
27    pub simd_backend: Option<SimdBackend>,
28}
29
30impl Default for RescoreRetrieverConfig {
31    fn default() -> Self {
32        Self {
33            rescore_multiplier: 4, // Optimal per specification
34            top_k: 10,
35            min_calibration_samples: 1000,
36            simd_backend: None, // Auto-detect
37        }
38    }
39}
40
41/// Two-stage rescoring retriever
42///
43/// Stage 1: Fast approximate retrieval with int8 embeddings
44/// Stage 2: Precise rescoring with f32 query x i8 docs
45///
46/// Achieves 99% accuracy retention with 3.66x speedup.
47#[derive(Debug)]
48pub struct RescoreRetriever {
49    /// Int8 quantized document embeddings
50    embeddings: Vec<QuantizedEmbedding>,
51    /// Document IDs corresponding to embeddings
52    doc_ids: Vec<String>,
53    /// Calibration statistics
54    calibration: CalibrationStats,
55    /// Configuration
56    config: RescoreRetrieverConfig,
57    /// SIMD backend
58    backend: SimdBackend,
59}
60
61impl RescoreRetriever {
62    /// Create new rescoring retriever
63    pub fn new(dims: usize, config: RescoreRetrieverConfig) -> Self {
64        let backend = config.simd_backend.unwrap_or_else(SimdBackend::detect);
65        Self {
66            embeddings: Vec::new(),
67            doc_ids: Vec::new(),
68            calibration: CalibrationStats::new(dims),
69            config,
70            backend,
71        }
72    }
73
74    /// Add embedding to calibration set (Kaizen)
75    pub fn add_calibration_sample(&mut self, embedding: &[f32]) -> Result<(), QuantizationError> {
76        self.calibration.update(embedding)
77    }
78
79    /// Index a document with its embedding
80    pub fn index_document(
81        &mut self,
82        doc_id: &str,
83        embedding: &[f32],
84    ) -> Result<(), QuantizationError> {
85        // Update calibration
86        self.calibration.update(embedding)?;
87
88        // Quantize embedding
89        let quantized = QuantizedEmbedding::from_f32(embedding, &self.calibration)?;
90
91        self.embeddings.push(quantized);
92        self.doc_ids.push(doc_id.to_string());
93
94        Ok(())
95    }
96
97    /// Stage 1: Fast int8 retrieval
98    ///
99    /// Returns (doc_index, approximate_score) pairs
100    fn stage1_retrieve(&self, query_i8: &[i8]) -> Vec<(usize, i32)> {
101        let num_candidates = self.config.top_k * self.config.rescore_multiplier;
102
103        let mut scores: Vec<(usize, i32)> = self
104            .embeddings
105            .iter()
106            .enumerate()
107            .map(|(i, emb)| {
108                let score = self.backend.dot_i8(query_i8, &emb.values);
109                (i, score)
110            })
111            .collect();
112
113        // Sort descending by score
114        scores.sort_by(|a, b| b.1.cmp(&a.1));
115        scores.truncate(num_candidates);
116
117        scores
118    }
119
120    /// Stage 2: Precise rescoring with f32 query
121    fn stage2_rescore(&self, query: &[f32], candidates: Vec<(usize, i32)>) -> Vec<RescoreResult> {
122        let mut results: Vec<RescoreResult> = candidates
123            .into_iter()
124            .map(|(doc_idx, approx_score)| {
125                let emb = &self.embeddings[doc_idx];
126                let precise_score = self.backend.dot_f32_i8(query, &emb.values, emb.params.scale);
127
128                RescoreResult {
129                    doc_id: self.doc_ids[doc_idx].clone(),
130                    score: precise_score,
131                    approx_score,
132                }
133            })
134            .collect();
135
136        // Sort by precise score descending
137        results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
138        results.truncate(self.config.top_k);
139
140        results
141    }
142
143    /// Full two-stage retrieval
144    ///
145    /// # Arguments
146    /// * `query` - f32 query embedding
147    ///
148    /// # Returns
149    /// Top-k results with precise scores
150    pub fn retrieve(&self, query: &[f32]) -> Result<Vec<RescoreResult>, QuantizationError> {
151        // Validate query embedding
152        validate_embedding(query, self.calibration.dims)?;
153
154        // Stage 1: Quantize query and retrieve candidates
155        let query_quantized = QuantizedEmbedding::from_f32(query, &self.calibration)?;
156        let candidates = self.stage1_retrieve(&query_quantized.values);
157
158        // Stage 2: Rescore with f32 precision
159        Ok(self.stage2_rescore(query, candidates))
160    }
161
162    /// Get number of indexed documents
163    pub fn len(&self) -> usize {
164        self.embeddings.len()
165    }
166
167    /// Check if index is empty
168    pub fn is_empty(&self) -> bool {
169        self.embeddings.is_empty()
170    }
171
172    /// Get calibration statistics
173    pub fn calibration(&self) -> &CalibrationStats {
174        &self.calibration
175    }
176
177    /// Get total memory usage in bytes
178    pub fn memory_usage(&self) -> usize {
179        self.embeddings.iter().map(|e| e.memory_size()).sum()
180    }
181}
182
183/// Result from rescoring retrieval
184#[derive(Debug, Clone)]
185pub struct RescoreResult {
186    /// Document ID
187    pub doc_id: String,
188    /// Precise score from stage 2 (f32 x i8)
189    pub score: f32,
190    /// Approximate score from stage 1 (i8 x i8)
191    pub approx_score: i32,
192}