mod model;
pub use model::LemurEncoder;
use crate::RetrieveError;
pub struct LemurIndex {
input_dim: usize,
hidden_dim: usize,
encoder: LemurEncoder,
doc_weights: Vec<Vec<f32>>,
doc_tokens: Vec<Vec<Vec<f32>>>,
doc_ids: Vec<u32>,
}
impl LemurIndex {
pub fn new(encoder: LemurEncoder) -> Self {
let input_dim = encoder.input_dim();
let hidden_dim = encoder.hidden_dim();
Self {
input_dim,
hidden_dim,
encoder,
doc_weights: Vec::new(),
doc_tokens: Vec::new(),
doc_ids: Vec::new(),
}
}
pub fn add_document(
&mut self,
doc_id: u32,
tokens: Vec<Vec<f32>>,
) -> Result<(), RetrieveError> {
if tokens.is_empty() {
return Err(RetrieveError::InvalidParameter(
"document must have at least one token".into(),
));
}
if tokens[0].len() != self.input_dim {
return Err(RetrieveError::DimensionMismatch {
query_dim: tokens[0].len(),
doc_dim: self.input_dim,
});
}
let encoded: Vec<Vec<f32>> = tokens.iter().map(|t| self.encoder.forward(t)).collect();
let n = encoded.len() as f32;
let mut weight = vec![0.0f32; self.hidden_dim];
for enc in &encoded {
for (w, e) in weight.iter_mut().zip(enc.iter()) {
*w += e;
}
}
for w in &mut weight {
*w /= n;
}
self.doc_weights.push(weight);
self.doc_tokens.push(tokens);
self.doc_ids.push(doc_id);
Ok(())
}
pub fn search(
&self,
query_tokens: &[Vec<f32>],
k: usize,
rerank_pool: usize,
) -> Result<Vec<(u32, f32)>, RetrieveError> {
if query_tokens.is_empty() {
return Err(RetrieveError::EmptyQuery);
}
let query_vec = self.encode_query(query_tokens);
let pool = rerank_pool.max(k);
let mut scores: Vec<(usize, f32)> = self
.doc_weights
.iter()
.enumerate()
.map(|(i, w)| (i, dot(&query_vec, w)))
.collect();
scores.sort_unstable_by(|a, b| b.1.total_cmp(&a.1)); scores.truncate(pool);
let mut reranked: Vec<(u32, f32)> = scores
.into_iter()
.map(|(idx, _mips_score)| {
let maxsim = self.maxsim(query_tokens, &self.doc_tokens[idx]);
(self.doc_ids[idx], maxsim)
})
.collect();
reranked.sort_unstable_by(|a, b| b.1.total_cmp(&a.1)); reranked.truncate(k);
Ok(reranked)
}
pub fn len(&self) -> usize {
self.doc_ids.len()
}
pub fn is_empty(&self) -> bool {
self.doc_ids.is_empty()
}
fn encode_query(&self, tokens: &[Vec<f32>]) -> Vec<f32> {
let mut pooled = vec![0.0f32; self.hidden_dim];
for token in tokens {
let enc = self.encoder.forward(token);
for (p, e) in pooled.iter_mut().zip(enc.iter()) {
*p += e;
}
}
pooled
}
fn maxsim(&self, query_tokens: &[Vec<f32>], doc_tokens: &[Vec<f32>]) -> f32 {
let mut total = 0.0f32;
for qt in query_tokens {
let mut max_sim = f32::NEG_INFINITY;
for dt in doc_tokens {
let sim = dot(qt, dt);
if sim > max_sim {
max_sim = sim;
}
}
total += max_sim;
}
total
}
}
#[inline]
fn dot(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn add_and_search() {
let enc = LemurEncoder::random(32, 64, 42);
let mut index = LemurIndex::new(enc);
for doc_id in 0..10u32 {
let tokens: Vec<Vec<f32>> = (0..5)
.map(|t| {
let seed = (doc_id as f32 * 0.1 + t as f32 * 0.01) * 0.1;
vec![seed; 32]
})
.collect();
index.add_document(doc_id, tokens).unwrap();
}
assert_eq!(index.len(), 10);
let query_tokens: Vec<Vec<f32>> = (0..5).map(|t| vec![t as f32 * 0.01; 32]).collect();
let results = index.search(&query_tokens, 5, 10).unwrap();
assert_eq!(results.len(), 5);
}
#[test]
fn empty_query_returns_error() {
let enc = LemurEncoder::random(32, 64, 42);
let index = LemurIndex::new(enc);
let empty: Vec<Vec<f32>> = vec![];
assert!(index.search(&empty, 5, 10).is_err());
}
}