use ahash::AHashMap;
use crate::error::{LaurusError, Result};
use crate::storage::structured::{StructReader, StructWriter};
use crate::storage::{StorageInput, StorageOutput};
#[derive(Debug, Clone, PartialEq)]
pub struct Posting {
pub doc_id: u64,
pub frequency: u32,
pub positions: Option<Vec<u32>>,
pub weight: f32,
}
impl Posting {
pub fn new(doc_id: u64) -> Self {
Posting {
doc_id,
frequency: 1,
positions: None,
weight: 1.0,
}
}
pub fn with_frequency(doc_id: u64, frequency: u32) -> Self {
Posting {
doc_id,
frequency,
positions: None,
weight: 1.0,
}
}
pub fn with_positions(doc_id: u64, positions: Vec<u32>) -> Self {
let frequency = positions.len() as u32;
Posting {
doc_id,
frequency,
positions: Some(positions),
weight: 1.0,
}
}
pub fn with_weight(mut self, weight: f32) -> Self {
self.weight = weight;
self
}
pub fn add_position(&mut self, position: u32) {
match &mut self.positions {
Some(positions) => {
positions.push(position);
self.frequency = positions.len() as u32;
}
None => {
self.positions = Some(vec![position]);
self.frequency = 1;
}
}
}
pub fn frequency(&self) -> u32 {
self.frequency
}
pub fn positions(&self) -> Option<&[u32]> {
self.positions.as_deref()
}
}
#[derive(Debug, Clone)]
pub struct PostingList {
pub term: String,
pub postings: Vec<Posting>,
pub total_frequency: u64,
pub doc_frequency: u64,
}
impl PostingList {
pub fn new(term: String) -> Self {
PostingList {
term,
postings: Vec::new(),
total_frequency: 0,
doc_frequency: 0,
}
}
pub fn add_posting(&mut self, posting: Posting) {
match self
.postings
.binary_search_by_key(&posting.doc_id, |p| p.doc_id)
{
Ok(pos) => {
let existing = &mut self.postings[pos];
existing.frequency += posting.frequency;
self.total_frequency += posting.frequency as u64;
if let Some(new_positions) = posting.positions {
match &mut existing.positions {
Some(positions) => positions.extend(new_positions),
None => existing.positions = Some(new_positions),
}
}
}
Err(pos) => {
self.total_frequency += posting.frequency as u64;
self.doc_frequency += 1;
self.postings.insert(pos, posting);
}
}
}
pub fn len(&self) -> usize {
self.postings.len()
}
pub fn is_empty(&self) -> bool {
self.postings.is_empty()
}
pub fn iter(&'_ self) -> std::slice::Iter<'_, Posting> {
self.postings.iter()
}
pub fn optimize(&mut self) {
self.postings.sort_by_key(|p| p.doc_id);
self.postings.dedup_by_key(|p| p.doc_id);
}
pub fn encode<W: StorageOutput>(&self, writer: &mut StructWriter<W>) -> Result<()> {
writer.write_string(&self.term)?;
writer.write_varint(self.total_frequency)?;
writer.write_varint(self.doc_frequency)?;
writer.write_varint(self.postings.len() as u64)?;
let mut prev_doc_id = 0u64;
for posting in &self.postings {
let delta = posting.doc_id - prev_doc_id;
writer.write_varint(delta)?;
prev_doc_id = posting.doc_id;
writer.write_varint(posting.frequency as u64)?;
writer.write_f32(posting.weight)?;
if let Some(positions) = &posting.positions {
writer.write_u8(1)?; writer.write_varint(positions.len() as u64)?;
let mut prev_pos = 0u32;
for &pos in positions {
let delta = pos.saturating_sub(prev_pos);
writer.write_varint(delta as u64)?;
prev_pos = pos;
}
} else {
writer.write_u8(0)?; }
}
Ok(())
}
pub fn decode<R: StorageInput>(reader: &mut StructReader<R>) -> Result<Self> {
let term = reader.read_string()?;
let total_frequency = reader.read_varint()?;
let doc_frequency = reader.read_varint()?;
let posting_count = reader.read_varint()? as usize;
let mut postings = Vec::with_capacity(posting_count);
let mut prev_doc_id = 0u64;
for _ in 0..posting_count {
let delta = reader.read_varint()?;
let doc_id = prev_doc_id + delta;
prev_doc_id = doc_id;
let frequency = reader.read_varint()? as u32;
let weight = reader.read_f32()?;
let has_positions = reader.read_u8()? != 0;
let positions = if has_positions {
let pos_count = reader.read_varint()? as usize;
let mut positions = Vec::with_capacity(pos_count);
let mut prev_pos = 0u32;
for _ in 0..pos_count {
let delta = reader.read_varint()? as u32;
let pos = prev_pos + delta;
positions.push(pos);
prev_pos = pos;
}
Some(positions)
} else {
None
};
postings.push(Posting {
doc_id,
frequency,
positions,
weight,
});
}
Ok(PostingList {
term,
postings,
total_frequency,
doc_frequency,
})
}
}
pub struct PostingIterator {
postings: Vec<Posting>,
position: usize,
}
impl PostingIterator {
pub fn new(postings: Vec<Posting>) -> Self {
PostingIterator {
postings,
position: 0,
}
}
pub fn empty() -> Self {
PostingIterator {
postings: Vec::new(),
position: 0,
}
}
pub fn current(&self) -> Option<&Posting> {
self.postings.get(self.position)
}
#[allow(clippy::should_implement_trait)]
pub fn next(&mut self) -> Option<&Posting> {
if self.position < self.postings.len() {
let posting = &self.postings[self.position];
self.position += 1;
Some(posting)
} else {
None
}
}
pub fn skip_to(&mut self, target_doc_id: u64) -> bool {
while self.position < self.postings.len() {
if self.postings[self.position].doc_id >= target_doc_id {
return true;
}
self.position += 1;
}
false
}
pub fn is_exhausted(&self) -> bool {
self.position >= self.postings.len()
}
pub fn len(&self) -> usize {
self.postings.len()
}
pub fn is_empty(&self) -> bool {
self.postings.is_empty()
}
}
impl Iterator for PostingIterator {
type Item = Posting;
fn next(&mut self) -> Option<Self::Item> {
if self.position < self.postings.len() {
let posting = self.postings[self.position].clone();
self.position += 1;
Some(posting)
} else {
None
}
}
}
#[derive(Debug)]
pub struct TermPostingIndex {
terms: AHashMap<String, PostingList>,
doc_count: u64,
term_count: u64,
}
impl TermPostingIndex {
pub fn new() -> Self {
TermPostingIndex {
terms: AHashMap::new(),
doc_count: 0,
term_count: 0,
}
}
pub fn add_posting(&mut self, term: String, posting: Posting) {
let posting_list = self.terms.entry(term.clone()).or_insert_with(|| {
self.term_count += 1;
PostingList::new(term)
});
posting_list.add_posting(posting);
}
pub fn add_document(&mut self, doc_id: u64, terms: Vec<(String, u32, Option<Vec<u32>>)>) {
for (term, frequency, positions) in terms {
let posting = if let Some(positions) = positions {
Posting::with_positions(doc_id, positions)
} else {
Posting::with_frequency(doc_id, frequency)
};
self.add_posting(term, posting);
}
self.doc_count = self.doc_count.max(doc_id + 1);
}
pub fn get_posting_list(&self, term: &str) -> Option<&PostingList> {
self.terms.get(term)
}
pub fn get_posting_iterator(&self, term: &str) -> PostingIterator {
match self.terms.get(term) {
Some(posting_list) => PostingIterator::new(posting_list.postings.clone()),
None => PostingIterator::empty(),
}
}
pub fn doc_count(&self) -> u64 {
self.doc_count
}
pub fn term_count(&self) -> u64 {
self.term_count
}
pub fn terms(&self) -> impl Iterator<Item = &String> {
self.terms.keys()
}
pub fn optimize(&mut self) {
for posting_list in self.terms.values_mut() {
posting_list.optimize();
}
}
pub fn write_to_storage<W: StorageOutput>(&self, writer: &mut StructWriter<W>) -> Result<()> {
writer.write_u32(0x494E5658)?; writer.write_u32(1)?; writer.write_varint(self.doc_count)?;
writer.write_varint(self.term_count)?;
writer.write_varint(self.terms.len() as u64)?;
let mut sorted_terms: Vec<_> = self.terms.iter().collect();
sorted_terms.sort_by_key(|(term, _)| *term);
for (_, posting_list) in sorted_terms {
posting_list.encode(writer)?;
}
Ok(())
}
pub fn read_from_storage<R: StorageInput>(reader: &mut StructReader<R>) -> Result<Self> {
let magic = reader.read_u32()?;
if magic != 0x494E5658 {
return Err(LaurusError::index("Invalid inverted index file format"));
}
let version = reader.read_u32()?;
if version != 1 {
return Err(LaurusError::index(format!(
"Unsupported index version: {version}"
)));
}
let doc_count = reader.read_varint()?;
let term_count = reader.read_varint()?;
let posting_list_count = reader.read_varint()? as usize;
let mut terms = AHashMap::with_capacity(posting_list_count);
for _ in 0..posting_list_count {
let posting_list = PostingList::decode(reader)?;
terms.insert(posting_list.term.clone(), posting_list);
}
Ok(TermPostingIndex {
terms,
doc_count,
term_count,
})
}
}
impl Default for TermPostingIndex {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct PostingStats {
pub posting_list_count: usize,
pub total_postings: usize,
pub avg_postings_per_list: f64,
pub max_posting_list_size: usize,
pub compressed_size: usize,
}
impl TermPostingIndex {
pub fn stats(&self) -> PostingStats {
let posting_list_count = self.terms.len();
let total_postings: usize = self.terms.values().map(|pl| pl.postings.len()).sum();
let avg_postings_per_list = if posting_list_count > 0 {
total_postings as f64 / posting_list_count as f64
} else {
0.0
};
let max_posting_list_size = self
.terms
.values()
.map(|pl| pl.postings.len())
.max()
.unwrap_or(0);
PostingStats {
posting_list_count,
total_postings,
avg_postings_per_list,
max_posting_list_size,
compressed_size: 0, }
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::Storage;
use crate::storage::memory::MemoryStorage;
use crate::storage::memory::MemoryStorageConfig;
use std::sync::Arc;
#[test]
fn test_posting_creation() {
let posting = Posting::new(1);
assert_eq!(posting.doc_id, 1);
assert_eq!(posting.frequency, 1);
assert_eq!(posting.positions, None);
assert_eq!(posting.weight, 1.0);
let posting = Posting::with_frequency(2, 5);
assert_eq!(posting.doc_id, 2);
assert_eq!(posting.frequency, 5);
let posting = Posting::with_positions(3, vec![10, 20, 30]);
assert_eq!(posting.doc_id, 3);
assert_eq!(posting.frequency, 3);
assert_eq!(posting.positions, Some(vec![10, 20, 30]));
}
#[test]
fn test_posting_list() {
let mut list = PostingList::new("test".to_string());
assert!(list.is_empty());
list.add_posting(Posting::new(1));
list.add_posting(Posting::new(3));
list.add_posting(Posting::new(2));
assert_eq!(list.len(), 3);
assert_eq!(list.doc_frequency, 3);
let doc_ids: Vec<u64> = list.postings.iter().map(|p| p.doc_id).collect();
assert_eq!(doc_ids, vec![1, 2, 3]);
}
#[test]
fn test_posting_iterator() {
let postings = vec![
Posting::new(1),
Posting::new(3),
Posting::new(5),
Posting::new(7),
];
let mut iter = PostingIterator::new(postings);
assert_eq!(iter.current().unwrap().doc_id, 1);
assert_eq!(iter.next().unwrap().doc_id, 1);
assert_eq!(iter.current().unwrap().doc_id, 3);
assert!(iter.skip_to(5));
assert_eq!(iter.current().map(|p| p.doc_id), Some(5));
assert_eq!(iter.current().unwrap().doc_id, 5);
assert!(!iter.skip_to(10));
assert!(iter.is_exhausted());
}
#[test]
fn test_inverted_index() {
let mut index = TermPostingIndex::new();
index.add_document(
1,
vec![
("hello".to_string(), 1, Some(vec![0])),
("world".to_string(), 1, Some(vec![1])),
],
);
index.add_document(
2,
vec![
("hello".to_string(), 1, Some(vec![0])),
("rust".to_string(), 1, Some(vec![1])),
("world".to_string(), 1, Some(vec![2])),
],
);
assert_eq!(index.doc_count(), 3); assert_eq!(index.term_count(), 3);
let hello_list = index.get_posting_list("hello").unwrap();
assert_eq!(hello_list.postings.len(), 2);
assert_eq!(hello_list.doc_frequency, 2);
let rust_list = index.get_posting_list("rust").unwrap();
assert_eq!(rust_list.postings.len(), 1);
assert_eq!(rust_list.doc_frequency, 1);
assert!(index.get_posting_list("nonexistent").is_none());
}
#[test]
fn test_posting_list_encoding() {
let storage = Arc::new(MemoryStorage::new(MemoryStorageConfig::default()));
let mut original_list = PostingList::new("test".to_string());
original_list.add_posting(Posting::with_positions(1, vec![0, 5, 10]));
original_list.add_posting(Posting::with_frequency(3, 2));
original_list.add_posting(Posting::new(5));
{
let output = storage.create_output("test_posting.bin").unwrap();
let mut writer = StructWriter::new(output);
original_list.encode(&mut writer).unwrap();
writer.close().unwrap();
}
{
let input = storage.open_input("test_posting.bin").unwrap();
let mut reader = StructReader::new(input).unwrap();
let decoded_list = PostingList::decode(&mut reader).unwrap();
assert_eq!(decoded_list.term, original_list.term);
assert_eq!(decoded_list.postings.len(), original_list.postings.len());
assert_eq!(decoded_list.doc_frequency, original_list.doc_frequency);
assert_eq!(decoded_list.total_frequency, original_list.total_frequency);
for (orig, decoded) in original_list
.postings
.iter()
.zip(decoded_list.postings.iter())
{
assert_eq!(orig.doc_id, decoded.doc_id);
assert_eq!(orig.frequency, decoded.frequency);
assert_eq!(orig.positions, decoded.positions);
}
}
}
#[test]
fn test_inverted_index_serialization() {
let storage = Arc::new(MemoryStorage::new(MemoryStorageConfig::default()));
let mut original_index = TermPostingIndex::new();
original_index.add_document(
1,
vec![
("hello".to_string(), 2, Some(vec![0, 5])),
("world".to_string(), 1, Some(vec![1])),
],
);
original_index.add_document(
2,
vec![
("hello".to_string(), 1, Some(vec![2])),
("rust".to_string(), 3, Some(vec![0, 3, 6])),
],
);
{
let output = storage.create_output("test_index.bin").unwrap();
let mut writer = StructWriter::new(output);
original_index.write_to_storage(&mut writer).unwrap();
writer.close().unwrap();
}
{
let input = storage.open_input("test_index.bin").unwrap();
let mut reader = StructReader::new(input).unwrap();
let loaded_index = TermPostingIndex::read_from_storage(&mut reader).unwrap();
assert_eq!(loaded_index.doc_count(), original_index.doc_count());
assert_eq!(loaded_index.term_count(), original_index.term_count());
for term in ["hello", "world", "rust"] {
let orig_list = original_index.get_posting_list(term);
let loaded_list = loaded_index.get_posting_list(term);
match (orig_list, loaded_list) {
(Some(orig), Some(loaded)) => {
assert_eq!(orig.postings.len(), loaded.postings.len());
assert_eq!(orig.doc_frequency, loaded.doc_frequency);
}
(None, None) => {}
_ => panic!("Mismatch in term existence: {term}"),
}
}
}
}
#[test]
fn test_posting_stats() {
let mut index = TermPostingIndex::new();
for doc_id in 0..100 {
index.add_document(
doc_id,
vec![
("common".to_string(), 1, None),
(format!("term_{}", doc_id % 10), 1, None),
],
);
}
let stats = index.stats();
assert!(stats.posting_list_count > 0);
assert!(stats.total_postings > 0);
assert!(stats.avg_postings_per_list > 0.0);
assert!(stats.max_posting_list_size > 0);
}
}