use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct RetrievalConfig {
pub top_k: usize,
pub min_similarity: f32,
pub use_ann: bool,
pub max_index_size: usize,
}
impl Default for RetrievalConfig {
fn default() -> Self {
Self {
top_k: 5,
min_similarity: 0.5,
use_ann: false,
max_index_size: 1_000_000,
}
}
}
#[derive(Debug, Clone)]
pub struct EquationDocument {
pub id: String,
pub latex: String,
pub embedding: Vec<f32>,
pub source: Option<String>,
pub label: Option<String>,
pub domain: Option<String>,
pub metadata: HashMap<String, String>,
}
impl EquationDocument {
pub fn new(id: String, latex: String, embedding: Vec<f32>) -> Self {
Self {
id,
latex,
embedding,
source: None,
label: None,
domain: None,
metadata: HashMap::new(),
}
}
pub fn with_source(mut self, source: String) -> Self {
self.source = Some(source);
self
}
pub fn with_label(mut self, label: String) -> Self {
self.label = Some(label);
self
}
pub fn with_domain(mut self, domain: String) -> Self {
self.domain = Some(domain);
self
}
pub fn with_metadata(mut self, key: String, value: String) -> Self {
self.metadata.insert(key, value);
self
}
}
#[derive(Debug, Clone)]
pub struct RetrievalResult {
pub document: EquationDocument,
pub similarity: f32,
pub rank: usize,
}
impl RetrievalResult {
pub fn new(document: EquationDocument, similarity: f32, rank: usize) -> Self {
Self {
document,
similarity,
rank,
}
}
}
pub struct EquationRagIndex {
equations: Vec<EquationDocument>,
id_index: HashMap<String, usize>,
domain_index: HashMap<String, Vec<usize>>,
config: RetrievalConfig,
dimension: usize,
}
impl EquationRagIndex {
pub fn new(dimension: usize) -> Self {
Self::with_config(dimension, RetrievalConfig::default())
}
pub fn with_config(dimension: usize, config: RetrievalConfig) -> Self {
Self {
equations: Vec::new(),
id_index: HashMap::new(),
domain_index: HashMap::new(),
config,
dimension,
}
}
pub fn add(&mut self, doc: EquationDocument) -> Result<(), &'static str> {
if doc.embedding.len() != self.dimension {
return Err("Embedding dimension mismatch");
}
if self.equations.len() >= self.config.max_index_size {
return Err("Index size limit reached");
}
let idx = self.equations.len();
self.id_index.insert(doc.id.clone(), idx);
if let Some(ref domain) = doc.domain {
self.domain_index
.entry(domain.clone())
.or_insert_with(Vec::new)
.push(idx);
}
self.equations.push(doc);
Ok(())
}
pub fn add_batch(&mut self, docs: Vec<EquationDocument>) -> Result<usize, &'static str> {
let mut added = 0;
for doc in docs {
if self.add(doc).is_ok() {
added += 1;
}
}
Ok(added)
}
pub fn retrieve(&self, query_embedding: &[f32]) -> Vec<RetrievalResult> {
self.retrieve_with_filter(query_embedding, None)
}
pub fn retrieve_in_domain(
&self,
query_embedding: &[f32],
domain: &str,
) -> Vec<RetrievalResult> {
self.retrieve_with_filter(query_embedding, Some(domain))
}
fn retrieve_with_filter(
&self,
query_embedding: &[f32],
domain: Option<&str>,
) -> Vec<RetrievalResult> {
if query_embedding.len() != self.dimension {
return Vec::new();
}
let candidates: Vec<usize> = if let Some(domain) = domain {
self.domain_index.get(domain).cloned().unwrap_or_default()
} else {
(0..self.equations.len()).collect()
};
let mut results: Vec<(usize, f32)> = candidates
.iter()
.map(|&idx| {
let similarity = cosine_similarity(query_embedding, &self.equations[idx].embedding);
(idx, similarity)
})
.filter(|(_, sim)| *sim >= self.config.min_similarity)
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results.truncate(self.config.top_k);
results
.into_iter()
.enumerate()
.map(|(rank, (idx, similarity))| {
RetrievalResult::new(self.equations[idx].clone(), similarity, rank)
})
.collect()
}
pub fn get(&self, id: &str) -> Option<&EquationDocument> {
self.id_index.get(id).map(|&idx| &self.equations[idx])
}
pub fn len(&self) -> usize {
self.equations.len()
}
pub fn is_empty(&self) -> bool {
self.equations.is_empty()
}
pub fn domains(&self) -> Vec<&str> {
self.domain_index.keys().map(|s| s.as_str()).collect()
}
pub fn domain_count(&self, domain: &str) -> usize {
self.domain_index.get(domain).map(|v| v.len()).unwrap_or(0)
}
pub fn dimension(&self) -> usize {
self.dimension
}
pub fn config(&self) -> &RetrievalConfig {
&self.config
}
pub fn clear(&mut self) {
self.equations.clear();
self.id_index.clear();
self.domain_index.clear();
}
}
pub struct EquationRetriever {
index: EquationRagIndex,
cache: HashMap<Vec<u8>, Vec<RetrievalResult>>,
max_cache_size: usize,
}
impl EquationRetriever {
pub fn new(index: EquationRagIndex) -> Self {
Self {
index,
cache: HashMap::new(),
max_cache_size: 1000,
}
}
pub fn retrieve(&mut self, query_embedding: &[f32]) -> &[RetrievalResult] {
let cache_key = embedding_to_cache_key(query_embedding);
if !self.cache.contains_key(&cache_key) {
let results = self.index.retrieve(query_embedding);
if self.cache.len() >= self.max_cache_size {
self.cache.clear();
}
self.cache.insert(cache_key.clone(), results);
}
self.cache.get(&cache_key).unwrap()
}
pub fn retrieve_in_domain(
&self,
query_embedding: &[f32],
domain: &str,
) -> Vec<RetrievalResult> {
self.index.retrieve_in_domain(query_embedding, domain)
}
pub fn index(&self) -> &EquationRagIndex {
&self.index
}
pub fn index_mut(&mut self) -> &mut EquationRagIndex {
self.cache.clear(); &mut self.index
}
pub fn clear_cache(&mut self) {
self.cache.clear();
}
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let mut dot = 0.0f32;
let mut norm_a = 0.0f32;
let mut norm_b = 0.0f32;
for (x, y) in a.iter().zip(b.iter()) {
dot += x * y;
norm_a += x * x;
norm_b += y * y;
}
let denom = (norm_a * norm_b).sqrt();
if denom > 0.0 {
dot / denom
} else {
0.0
}
}
fn embedding_to_cache_key(embedding: &[f32]) -> Vec<u8> {
embedding
.iter()
.map(|&x| (x * 127.0).clamp(-128.0, 127.0) as i8 as u8)
.collect()
}
pub struct EquationRagIndexBuilder {
dimension: usize,
config: RetrievalConfig,
equations: Vec<EquationDocument>,
}
impl EquationRagIndexBuilder {
pub fn new(dimension: usize) -> Self {
Self {
dimension,
config: RetrievalConfig::default(),
equations: Vec::new(),
}
}
pub fn with_config(mut self, config: RetrievalConfig) -> Self {
self.config = config;
self
}
pub fn top_k(mut self, k: usize) -> Self {
self.config.top_k = k;
self
}
pub fn min_similarity(mut self, threshold: f32) -> Self {
self.config.min_similarity = threshold;
self
}
pub fn add_equation(mut self, doc: EquationDocument) -> Self {
self.equations.push(doc);
self
}
pub fn add_equations(mut self, docs: Vec<EquationDocument>) -> Self {
self.equations.extend(docs);
self
}
pub fn build(self) -> Result<EquationRagIndex, &'static str> {
let mut index = EquationRagIndex::with_config(self.dimension, self.config);
index.add_batch(self.equations)?;
Ok(index)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_embedding(seed: u32, dim: usize) -> Vec<f32> {
let mut embedding = vec![0.0f32; dim];
for (i, v) in embedding.iter_mut().enumerate() {
*v = ((seed as f32 * 0.1 + i as f32 * 0.01) % 1.0) * 2.0 - 1.0;
}
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
for v in &mut embedding {
*v /= norm;
}
embedding
}
#[test]
fn test_add_and_retrieve() {
let mut index = EquationRagIndex::new(8);
let doc1 = EquationDocument::new(
"eq1".to_string(),
"x^2".to_string(),
create_test_embedding(1, 8),
);
let doc2 = EquationDocument::new(
"eq2".to_string(),
"y^2".to_string(),
create_test_embedding(2, 8),
);
index.add(doc1).unwrap();
index.add(doc2).unwrap();
assert_eq!(index.len(), 2);
let query = create_test_embedding(1, 8);
let results = index.retrieve(&query);
assert!(!results.is_empty());
assert_eq!(results[0].document.id, "eq1");
}
#[test]
fn test_domain_filter() {
let mut index = EquationRagIndex::new(8);
let doc1 = EquationDocument::new(
"eq1".to_string(),
"x^2".to_string(),
create_test_embedding(1, 8),
)
.with_domain("algebra".to_string());
let doc2 = EquationDocument::new(
"eq2".to_string(),
"\\int x dx".to_string(),
create_test_embedding(2, 8),
)
.with_domain("calculus".to_string());
index.add(doc1).unwrap();
index.add(doc2).unwrap();
assert_eq!(index.domain_count("algebra"), 1);
assert_eq!(index.domain_count("calculus"), 1);
let query = create_test_embedding(1, 8);
let results = index.retrieve_in_domain(&query, "algebra");
assert_eq!(results.len(), 1);
assert_eq!(results[0].document.domain, Some("algebra".to_string()));
}
#[test]
fn test_similarity_threshold() {
let config = RetrievalConfig {
min_similarity: 0.99, ..Default::default()
};
let mut index = EquationRagIndex::with_config(8, config);
let mut doc_emb = vec![0.0f32; 8];
doc_emb[0] = 1.0;
let doc = EquationDocument::new("eq1".to_string(), "x^2".to_string(), doc_emb);
index.add(doc).unwrap();
let mut query = vec![0.0f32; 8];
query[1] = 1.0;
let results = index.retrieve(&query);
assert!(results.is_empty());
}
#[test]
fn test_builder() {
let index = EquationRagIndexBuilder::new(8)
.top_k(3)
.min_similarity(0.5)
.add_equation(EquationDocument::new(
"eq1".to_string(),
"x^2".to_string(),
create_test_embedding(1, 8),
))
.build()
.unwrap();
assert_eq!(index.len(), 1);
assert_eq!(index.config().top_k, 3);
}
#[test]
fn test_retriever_cache() {
let mut index = EquationRagIndex::new(8);
index
.add(EquationDocument::new(
"eq1".to_string(),
"x^2".to_string(),
create_test_embedding(1, 8),
))
.unwrap();
let mut retriever = EquationRetriever::new(index);
let query = create_test_embedding(1, 8);
let len1 = retriever.retrieve(&query).len();
assert!(len1 > 0);
let len2 = retriever.retrieve(&query).len();
assert_eq!(len1, len2);
retriever.clear_cache();
}
}