use std::path::Path;
use super::{DocumentId, RagError, Result};
pub trait RetrievalBackend: Send + Sync {
fn add(&mut self, id: DocumentId, embedding: &[f32]) -> Result<()>;
fn query(&self, embedding: &[f32], top_k: usize) -> Vec<(DocumentId, f32)>;
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
fn embedding_dim(&self) -> usize;
fn save(&self, path: &Path) -> Result<()>;
fn load(path: &Path, embedding_dim: usize) -> Result<Self>
where
Self: Sized;
fn clear(&mut self);
fn contains(&self, id: DocumentId) -> bool;
fn remove(&mut self, id: DocumentId) -> Result<bool> {
let _ = id;
Err(RagError::IndexError(
"Removal not supported by this backend".to_string(),
))
}
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub enum BackendType {
#[default]
ExactCosine,
#[cfg(feature = "rag-hnsw")]
Hnsw,
}
impl BackendType {
pub fn recommended_for_size(num_documents: usize) -> Self {
if num_documents > 1_000_000 {
#[cfg(feature = "rag-hnsw")]
return BackendType::Hnsw;
}
BackendType::ExactCosine
}
}
pub fn normalize_embedding(embedding: &[f32]) -> Vec<f32> {
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm == 0.0 {
embedding.to_vec()
} else {
embedding.iter().map(|x| x / norm).collect()
}
}
#[inline]
pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
let dot = dot_product(a, b);
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
0.0
} else {
dot / (norm_a * norm_b)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_normalize_embedding() {
let embedding = vec![3.0, 4.0];
let normalized = normalize_embedding(&embedding);
let norm: f32 = normalized.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-6);
}
#[test]
fn test_dot_product() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0, 6.0];
assert!((dot_product(&a, &b) - 32.0).abs() < 1e-6);
}
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0];
let b = vec![1.0, 0.0];
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
let c = vec![0.0, 1.0];
assert!((cosine_similarity(&a, &c) - 0.0).abs() < 1e-6);
}
#[test]
fn test_recommended_backend() {
assert_eq!(
BackendType::recommended_for_size(1000),
BackendType::ExactCosine
);
assert_eq!(
BackendType::recommended_for_size(100_000),
BackendType::ExactCosine
);
}
}