use super::calibration::CalibrationStats;
use super::embedding::QuantizedEmbedding;
use super::error::{validate_embedding, QuantizationError};
use super::simd::SimdBackend;
#[derive(Debug, Clone)]
pub struct RescoreRetrieverConfig {
pub rescore_multiplier: usize,
pub top_k: usize,
pub min_calibration_samples: usize,
pub simd_backend: Option<SimdBackend>,
}
impl Default for RescoreRetrieverConfig {
fn default() -> Self {
Self {
rescore_multiplier: 4, top_k: 10,
min_calibration_samples: 1000,
simd_backend: None, }
}
}
#[derive(Debug)]
pub struct RescoreRetriever {
embeddings: Vec<QuantizedEmbedding>,
doc_ids: Vec<String>,
calibration: CalibrationStats,
config: RescoreRetrieverConfig,
backend: SimdBackend,
}
impl RescoreRetriever {
pub fn new(dims: usize, config: RescoreRetrieverConfig) -> Self {
let backend = config.simd_backend.unwrap_or_else(SimdBackend::detect);
Self {
embeddings: Vec::new(),
doc_ids: Vec::new(),
calibration: CalibrationStats::new(dims),
config,
backend,
}
}
pub fn add_calibration_sample(&mut self, embedding: &[f32]) -> Result<(), QuantizationError> {
self.calibration.update(embedding)
}
pub fn index_document(
&mut self,
doc_id: &str,
embedding: &[f32],
) -> Result<(), QuantizationError> {
self.calibration.update(embedding)?;
let quantized = QuantizedEmbedding::from_f32(embedding, &self.calibration)?;
self.embeddings.push(quantized);
self.doc_ids.push(doc_id.to_string());
Ok(())
}
fn stage1_retrieve(&self, query_i8: &[i8]) -> Vec<(usize, i32)> {
let num_candidates = self.config.top_k * self.config.rescore_multiplier;
let mut scores: Vec<(usize, i32)> = self
.embeddings
.iter()
.enumerate()
.map(|(i, emb)| {
let score = self.backend.dot_i8(query_i8, &emb.values);
(i, score)
})
.collect();
scores.sort_by(|a, b| b.1.cmp(&a.1));
scores.truncate(num_candidates);
scores
}
fn stage2_rescore(&self, query: &[f32], candidates: Vec<(usize, i32)>) -> Vec<RescoreResult> {
let mut results: Vec<RescoreResult> = candidates
.into_iter()
.map(|(doc_idx, approx_score)| {
let emb = &self.embeddings[doc_idx];
let precise_score = self.backend.dot_f32_i8(query, &emb.values, emb.params.scale);
RescoreResult {
doc_id: self.doc_ids[doc_idx].clone(),
score: precise_score,
approx_score,
}
})
.collect();
results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
results.truncate(self.config.top_k);
results
}
pub fn retrieve(&self, query: &[f32]) -> Result<Vec<RescoreResult>, QuantizationError> {
validate_embedding(query, self.calibration.dims)?;
let query_quantized = QuantizedEmbedding::from_f32(query, &self.calibration)?;
let candidates = self.stage1_retrieve(&query_quantized.values);
Ok(self.stage2_rescore(query, candidates))
}
pub fn len(&self) -> usize {
self.embeddings.len()
}
pub fn is_empty(&self) -> bool {
self.embeddings.is_empty()
}
pub fn calibration(&self) -> &CalibrationStats {
&self.calibration
}
pub fn memory_usage(&self) -> usize {
self.embeddings.iter().map(|e| e.memory_size()).sum()
}
}
#[derive(Debug, Clone)]
pub struct RescoreResult {
pub doc_id: String,
pub score: f32,
pub approx_score: i32,
}