batuta/oracle/rag/quantization/
retriever.rs1use super::calibration::CalibrationStats;
10use super::embedding::QuantizedEmbedding;
11use super::error::{validate_embedding, QuantizationError};
12use super::simd::SimdBackend;
13
14#[derive(Debug, Clone)]
18pub struct RescoreRetrieverConfig {
19 pub rescore_multiplier: usize,
22 pub top_k: usize,
24 pub min_calibration_samples: usize,
26 pub simd_backend: Option<SimdBackend>,
28}
29
30impl Default for RescoreRetrieverConfig {
31 fn default() -> Self {
32 Self {
33 rescore_multiplier: 4, top_k: 10,
35 min_calibration_samples: 1000,
36 simd_backend: None, }
38 }
39}
40
41#[derive(Debug)]
48pub struct RescoreRetriever {
49 embeddings: Vec<QuantizedEmbedding>,
51 doc_ids: Vec<String>,
53 calibration: CalibrationStats,
55 config: RescoreRetrieverConfig,
57 backend: SimdBackend,
59}
60
61impl RescoreRetriever {
62 pub fn new(dims: usize, config: RescoreRetrieverConfig) -> Self {
64 let backend = config.simd_backend.unwrap_or_else(SimdBackend::detect);
65 Self {
66 embeddings: Vec::new(),
67 doc_ids: Vec::new(),
68 calibration: CalibrationStats::new(dims),
69 config,
70 backend,
71 }
72 }
73
74 pub fn add_calibration_sample(&mut self, embedding: &[f32]) -> Result<(), QuantizationError> {
76 self.calibration.update(embedding)
77 }
78
79 pub fn index_document(
81 &mut self,
82 doc_id: &str,
83 embedding: &[f32],
84 ) -> Result<(), QuantizationError> {
85 self.calibration.update(embedding)?;
87
88 let quantized = QuantizedEmbedding::from_f32(embedding, &self.calibration)?;
90
91 self.embeddings.push(quantized);
92 self.doc_ids.push(doc_id.to_string());
93
94 Ok(())
95 }
96
97 fn stage1_retrieve(&self, query_i8: &[i8]) -> Vec<(usize, i32)> {
101 let num_candidates = self.config.top_k * self.config.rescore_multiplier;
102
103 let mut scores: Vec<(usize, i32)> = self
104 .embeddings
105 .iter()
106 .enumerate()
107 .map(|(i, emb)| {
108 let score = self.backend.dot_i8(query_i8, &emb.values);
109 (i, score)
110 })
111 .collect();
112
113 scores.sort_by(|a, b| b.1.cmp(&a.1));
115 scores.truncate(num_candidates);
116
117 scores
118 }
119
120 fn stage2_rescore(&self, query: &[f32], candidates: Vec<(usize, i32)>) -> Vec<RescoreResult> {
122 let mut results: Vec<RescoreResult> = candidates
123 .into_iter()
124 .map(|(doc_idx, approx_score)| {
125 let emb = &self.embeddings[doc_idx];
126 let precise_score = self.backend.dot_f32_i8(query, &emb.values, emb.params.scale);
127
128 RescoreResult {
129 doc_id: self.doc_ids[doc_idx].clone(),
130 score: precise_score,
131 approx_score,
132 }
133 })
134 .collect();
135
136 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
138 results.truncate(self.config.top_k);
139
140 results
141 }
142
143 pub fn retrieve(&self, query: &[f32]) -> Result<Vec<RescoreResult>, QuantizationError> {
151 validate_embedding(query, self.calibration.dims)?;
153
154 let query_quantized = QuantizedEmbedding::from_f32(query, &self.calibration)?;
156 let candidates = self.stage1_retrieve(&query_quantized.values);
157
158 Ok(self.stage2_rescore(query, candidates))
160 }
161
162 pub fn len(&self) -> usize {
164 self.embeddings.len()
165 }
166
167 pub fn is_empty(&self) -> bool {
169 self.embeddings.is_empty()
170 }
171
172 pub fn calibration(&self) -> &CalibrationStats {
174 &self.calibration
175 }
176
177 pub fn memory_usage(&self) -> usize {
179 self.embeddings.iter().map(|e| e.memory_size()).sum()
180 }
181}
182
183#[derive(Debug, Clone)]
185pub struct RescoreResult {
186 pub doc_id: String,
188 pub score: f32,
190 pub approx_score: i32,
192}