use serde::{Deserialize, Serialize};
use std::collections::BinaryHeap;
use crate::core::types::RetrievalResult;
#[derive(Default, Serialize, Deserialize)]
pub struct SparseInvertedIndex {
pub postings: Vec<Vec<u32>>,
pub m: usize,
pub dimensions: usize,
pub doc_count: usize,
pub df: Vec<u32>,
}
impl SparseInvertedIndex {
pub fn new(dimensions: usize, m: usize) -> Self {
Self {
postings: vec![Vec::new(); dimensions],
m,
dimensions,
doc_count: 0,
df: vec![0; dimensions],
}
}
pub fn add_doc(&mut self, doc_id: u32, dims: &[u32]) {
for &dim in dims {
if (dim as usize) < self.dimensions {
self.postings[dim as usize].push(doc_id);
self.df[dim as usize] += 1;
}
}
self.doc_count += 1;
}
pub fn finalize(&mut self) {
for list in &mut self.postings {
list.sort_unstable();
}
}
pub fn query(
&self,
query_dims: &[u32],
k: usize,
accumulator: &mut Accumulator,
) -> Vec<RetrievalResult> {
debug_assert!(
self.postings
.iter()
.all(|list| list.windows(2).all(|w| w[0] <= w[1])),
"posting lists must be sorted before query"
);
if query_dims.is_empty() || self.doc_count == 0 {
return Vec::new();
}
let mut sorted_qdims: Vec<u32> = query_dims.to_vec();
sorted_qdims
.sort_unstable_by_key(|&d| self.df.get(d as usize).cloned().unwrap_or(u32::MAX));
accumulator.next_epoch();
for &dim in &sorted_qdims {
if let Some(list) = self.postings.get(dim as usize) {
for &doc_id in list {
accumulator.increment(doc_id);
}
}
}
let m_f = self.m as f64;
let mut heap = BinaryHeap::with_capacity(k + 1);
for &doc_id in &accumulator.touched {
let inter = accumulator.get_count(doc_id) as f64;
let denom = 2.0 * m_f - inter;
let similarity = if denom > f64::EPSILON {
inter / denom
} else {
1.0
};
heap.push(RetrievalResult {
id: doc_id.to_string(), similarity,
});
if heap.len() > k {
heap.pop();
}
}
heap.into_sorted_vec()
}
}
pub struct Accumulator {
counts: Vec<u16>,
seen: Vec<u32>,
epoch: u32,
touched: Vec<u32>,
}
impl Accumulator {
pub fn new(max_docs: usize) -> Self {
Self {
counts: vec![0; max_docs],
seen: vec![0; max_docs],
epoch: 0,
touched: Vec::with_capacity(1024),
}
}
pub fn next_epoch(&mut self) {
self.epoch += 1;
self.touched.clear();
if self.epoch == 0 {
self.seen.fill(0);
self.epoch = 1;
}
}
#[inline(always)]
pub fn increment(&mut self, doc_id: u32) {
let idx = doc_id as usize;
if idx >= self.seen.len() {
let new_size = (idx + 1).max(self.seen.len() * 2);
self.seen.resize(new_size, 0);
self.counts.resize(new_size, 0);
}
if self.seen[idx] != self.epoch {
self.seen[idx] = self.epoch;
self.counts[idx] = 1;
self.touched.push(doc_id);
} else {
self.counts[idx] += 1;
}
}
#[inline(always)]
pub fn get_count(&self, doc_id: u32) -> u16 {
let idx = doc_id as usize;
if idx < self.seen.len() && self.seen[idx] == self.epoch {
self.counts[idx]
} else {
0
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_inverted_index_basic() {
let mut index = SparseInvertedIndex::new(1000, 4);
index.add_doc(0, &[10, 20, 30, 40]);
index.add_doc(1, &[10, 25, 35, 45]);
index.add_doc(2, &[10, 20, 50, 60]);
index.finalize();
let mut acc = Accumulator::new(10);
let results = index.query(&[10, 20, 30], 5, &mut acc);
assert_eq!(results.len(), 3);
assert_eq!(results[0].id, "0");
assert!((results[0].similarity - 0.6).abs() < 1e-6);
assert_eq!(results[1].id, "2");
assert!((results[1].similarity - 0.333333).abs() < 1e-5);
}
}