use std::collections::HashMap;
use crate::index::IndexResult;
use super::SparseVector;
pub struct SpladeIndex {
postings: HashMap<u32, Vec<(usize, f32)>>,
id_map: Vec<String>,
}
impl SpladeIndex {
pub fn build(chunks: Vec<(String, SparseVector)>) -> Self {
let _span = tracing::info_span!("splade_index_build", chunks = chunks.len()).entered();
let mut postings: HashMap<u32, Vec<(usize, f32)>> = HashMap::new();
let mut id_map = Vec::with_capacity(chunks.len());
for (idx, (chunk_id, sparse)) in chunks.into_iter().enumerate() {
for &(token_id, weight) in &sparse {
postings.entry(token_id).or_default().push((idx, weight));
}
id_map.push(chunk_id);
}
tracing::info!(
unique_tokens = postings.len(),
chunks = id_map.len(),
"SPLADE index built"
);
Self { postings, id_map }
}
pub fn search(&self, query: &SparseVector, k: usize) -> Vec<IndexResult> {
self.search_with_filter(query, k, &|_: &str| true)
}
pub fn search_with_filter(
&self,
query: &SparseVector,
k: usize,
filter: &dyn Fn(&str) -> bool,
) -> Vec<IndexResult> {
let _span = tracing::debug_span!(
"splade_index_search",
k,
query_terms = query.len(),
index_size = self.id_map.len()
)
.entered();
if query.is_empty() || self.id_map.is_empty() {
return Vec::new();
}
let mut scores: HashMap<usize, f32> = HashMap::new();
for &(token_id, query_weight) in query {
if let Some(posting_list) = self.postings.get(&token_id) {
for &(chunk_idx, doc_weight) in posting_list {
if chunk_idx >= self.id_map.len() || !filter(&self.id_map[chunk_idx]) {
continue;
}
*scores.entry(chunk_idx).or_insert(0.0) += query_weight * doc_weight;
}
}
}
let mut results: Vec<_> = scores
.into_iter()
.filter_map(|(idx, score)| {
self.id_map.get(idx).map(|id| IndexResult {
id: id.clone(),
score,
})
})
.collect();
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(k);
tracing::debug!(results = results.len(), "SPLADE search complete");
results
}
pub fn len(&self) -> usize {
self.id_map.len()
}
pub fn is_empty(&self) -> bool {
self.id_map.is_empty()
}
pub fn unique_tokens(&self) -> usize {
self.postings.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_test_index() -> SpladeIndex {
SpladeIndex::build(vec![
("chunk_a".to_string(), vec![(1, 0.5), (2, 0.3), (3, 0.8)]),
("chunk_b".to_string(), vec![(1, 0.7), (4, 0.6)]),
("chunk_c".to_string(), vec![(2, 0.9), (3, 0.1), (5, 0.4)]),
])
}
#[test]
fn test_build_empty() {
let index = SpladeIndex::build(vec![]);
assert!(index.is_empty());
assert_eq!(index.unique_tokens(), 0);
}
#[test]
fn test_build_and_search() {
let index = make_test_index();
assert_eq!(index.len(), 3);
let results = index.search(&vec![(1, 1.0)], 10);
assert!(!results.is_empty());
assert_eq!(results[0].id, "chunk_b");
assert_eq!(results[1].id, "chunk_a");
}
#[test]
fn test_dot_product_correct() {
let index = make_test_index();
let results = index.search(&vec![(1, 1.0), (2, 1.0)], 10);
assert_eq!(results[0].id, "chunk_c"); assert!((results[0].score - 0.9).abs() < 1e-5);
assert_eq!(results[1].id, "chunk_a"); assert!((results[1].score - 0.8).abs() < 1e-5);
assert_eq!(results[2].id, "chunk_b"); assert!((results[2].score - 0.7).abs() < 1e-5);
}
#[test]
fn test_search_filter() {
let index = make_test_index();
let results = index.search_with_filter(&vec![(1, 1.0)], 10, &|id: &str| id == "chunk_a");
assert_eq!(results.len(), 1);
assert_eq!(results[0].id, "chunk_a");
}
#[test]
fn test_search_no_match() {
let index = make_test_index();
let results = index.search(&vec![(999, 1.0)], 10);
assert!(results.is_empty());
}
#[test]
fn test_search_empty_query() {
let index = make_test_index();
let results = index.search(&vec![], 10);
assert!(results.is_empty());
}
#[test]
fn test_search_respects_k() {
let index = make_test_index();
let results = index.search(&vec![(1, 1.0), (2, 1.0), (3, 1.0)], 2);
assert_eq!(results.len(), 2);
}
}