use std::sync::atomic::{AtomicU64, Ordering};
use parking_lot::RwLock;
use rustc_hash::{FxHashMap, FxHashSet};
use super::merge::merge_sorted_runs;
use super::mutable_segment::MutableSegment;
use super::types::{PostingEntry, SparseVector};
pub const FREEZE_THRESHOLD: usize = 10_000;
#[allow(dead_code)] pub(crate) struct FrozenSegment {
pub(crate) postings: FxHashMap<u32, (Vec<PostingEntry>, f32)>,
tombstones: FxHashSet<u64>,
pub(crate) doc_count: usize,
}
impl FrozenSegment {
#[allow(dead_code)] pub(crate) fn new(
postings: FxHashMap<u32, (Vec<PostingEntry>, f32)>,
doc_count: usize,
) -> Self {
debug_assert!(
postings
.iter()
.all(|(_, (entries, _))| entries.windows(2).all(|w| w[0].doc_id < w[1].doc_id)),
"FrozenSegment posting lists must be sorted ascending and deduplicated by doc_id"
);
Self {
postings,
tombstones: FxHashSet::default(),
doc_count,
}
}
fn contains_live(&self, point_id: u64) -> bool {
if self.tombstones.contains(&point_id) {
return false;
}
self.postings.values().any(|(entries, _)| {
entries
.binary_search_by_key(&point_id, |e| e.doc_id)
.is_ok()
})
}
}
pub struct SparseInvertedIndex {
mutable: RwLock<MutableSegment>,
frozen: RwLock<Vec<FrozenSegment>>,
doc_count: AtomicU64,
}
impl Default for SparseInvertedIndex {
fn default() -> Self {
Self::new()
}
}
impl SparseInvertedIndex {
#[must_use]
pub fn new() -> Self {
Self {
mutable: RwLock::new(MutableSegment::new()),
frozen: RwLock::new(Vec::new()),
doc_count: AtomicU64::new(0),
}
}
pub fn insert(&self, point_id: u64, vector: &SparseVector) {
let mut seg = self.mutable.write();
let is_new = seg.insert(point_id, vector);
if is_new {
self.doc_count.fetch_add(1, Ordering::Relaxed);
}
if seg.doc_count >= FREEZE_THRESHOLD {
self.freeze_inner(&mut seg);
}
}
#[allow(dead_code)] pub(crate) fn insert_batch_chunk(&self, docs: &[(u64, SparseVector)]) {
if docs.is_empty() {
return;
}
let (batch_postings, batch_max_weights, batch_doc_ids) = Self::build_batch_buffers(docs);
let mut seg = self.mutable.write();
let new_docs = Self::merge_doc_ids(&mut seg, batch_doc_ids);
if new_docs > 0 {
self.doc_count.fetch_add(new_docs, Ordering::Relaxed);
}
Self::merge_batch_into_segment(&mut seg, batch_postings, &batch_max_weights);
if seg.doc_count >= FREEZE_THRESHOLD {
self.freeze_inner(&mut seg);
}
}
#[allow(clippy::type_complexity)]
fn build_batch_buffers(
docs: &[(u64, SparseVector)],
) -> (
FxHashMap<u32, Vec<PostingEntry>>,
FxHashMap<u32, f32>,
FxHashSet<u64>,
) {
let mut batch_postings: FxHashMap<u32, Vec<PostingEntry>> = FxHashMap::default();
let mut batch_max_weights: FxHashMap<u32, f32> = FxHashMap::default();
let mut batch_doc_ids: FxHashSet<u64> = FxHashSet::default();
for (point_id, vector) in docs {
batch_doc_ids.insert(*point_id);
for (&term_id, &weight) in vector.indices.iter().zip(vector.values.iter()) {
batch_postings
.entry(term_id)
.or_default()
.push(PostingEntry {
doc_id: *point_id,
weight,
});
let abs_weight = weight.abs();
let max_weight = batch_max_weights.entry(term_id).or_insert(0.0);
if abs_weight > *max_weight {
*max_weight = abs_weight;
}
}
}
(batch_postings, batch_max_weights, batch_doc_ids)
}
fn merge_doc_ids(seg: &mut MutableSegment, batch_doc_ids: FxHashSet<u64>) -> u64 {
let mut new_docs = 0_u64;
for point_id in batch_doc_ids {
if seg.doc_set.insert(point_id) {
seg.doc_count += 1;
new_docs += 1;
}
}
new_docs
}
fn merge_batch_into_segment(
seg: &mut MutableSegment,
batch_postings: FxHashMap<u32, Vec<PostingEntry>>,
batch_max_weights: &FxHashMap<u32, f32>,
) {
for (term_id, updates) in batch_postings {
let entries = seg.postings.entry(term_id).or_default();
MutableSegment::merge_batch_postings(entries, updates);
if let Some(&abs_weight) = batch_max_weights.get(&term_id) {
let max_weight = seg.max_weights.entry(term_id).or_insert(0.0);
if abs_weight > *max_weight {
*max_weight = abs_weight;
}
}
}
}
fn freeze_inner(&self, seg: &mut MutableSegment) {
let old = std::mem::replace(seg, MutableSegment::new());
let mut frozen_postings = FxHashMap::default();
for (term_id, entries) in old.postings {
let max_w = old.max_weights.get(&term_id).copied().unwrap_or(0.0);
frozen_postings.insert(term_id, (entries, max_w));
}
let frozen_seg = FrozenSegment::new(frozen_postings, old.doc_count);
let mut frozen_vec = self.frozen.write();
frozen_vec.push(frozen_seg);
}
pub fn delete(&self, point_id: u64) {
let mut seg = self.mutable.write();
let was_in_mutable = seg.delete(point_id);
let mut frozen_vec = self.frozen.write();
let mut was_in_frozen = false;
for frozen_seg in frozen_vec.iter_mut() {
if frozen_seg.contains_live(point_id) {
frozen_seg.tombstones.insert(point_id);
was_in_frozen = true;
}
}
if was_in_mutable || was_in_frozen {
self.doc_count.fetch_sub(1, Ordering::Relaxed);
}
}
#[must_use]
pub fn doc_count(&self) -> u64 {
self.doc_count.load(Ordering::Relaxed)
}
#[must_use]
pub fn posting_count(&self, term_id: u32) -> usize {
let mut count: usize = 0;
{
let frozen_vec = self.frozen.read();
for frozen_seg in frozen_vec.iter() {
if let Some((entries, _)) = frozen_seg.postings.get(&term_id) {
count += entries
.iter()
.filter(|e| !frozen_seg.tombstones.contains(&e.doc_id))
.count();
}
}
}
{
let seg = self.mutable.read();
if let Some(entries) = seg.postings.get(&term_id) {
count += entries.len();
}
}
count
}
#[must_use]
pub fn get_all_postings(&self, term_id: u32) -> Vec<PostingEntry> {
let frozen_runs = self.collect_frozen_runs(term_id);
let mutable_run = self.collect_mutable_run(term_id);
merge_sorted_runs(frozen_runs, mutable_run)
}
#[inline]
fn collect_frozen_runs(&self, term_id: u32) -> Vec<Vec<PostingEntry>> {
let frozen_vec = self.frozen.read();
frozen_vec
.iter()
.filter_map(|seg| {
let (entries, _) = seg.postings.get(&term_id)?;
if entries.is_empty() {
return None;
}
let owned: Vec<PostingEntry> = if seg.tombstones.is_empty() {
entries.clone()
} else {
entries
.iter()
.copied()
.filter(|e| !seg.tombstones.contains(&e.doc_id))
.collect()
};
(!owned.is_empty()).then_some(owned)
})
.collect()
}
#[inline]
fn collect_mutable_run(&self, term_id: u32) -> Vec<PostingEntry> {
let seg = self.mutable.read();
seg.postings.get(&term_id).cloned().unwrap_or_default()
}
#[must_use]
pub fn get_global_max_weight(&self, term_id: u32) -> f32 {
let mut max_w = 0.0_f32;
{
let frozen_vec = self.frozen.read();
for frozen_seg in frozen_vec.iter() {
if let Some(&(_, w)) = frozen_seg.postings.get(&term_id) {
max_w = max_w.max(w);
}
}
}
{
let seg = self.mutable.read();
if let Some(&w) = seg.max_weights.get(&term_id) {
max_w = max_w.max(w);
}
}
max_w
}
fn collect_term_ids(&self) -> FxHashSet<u32> {
let mut terms: FxHashSet<u32> = FxHashSet::default();
{
let frozen_vec = self.frozen.read();
for frozen_seg in frozen_vec.iter() {
terms.extend(frozen_seg.postings.keys());
}
}
{
let seg = self.mutable.read();
terms.extend(seg.postings.keys());
}
terms
}
#[must_use]
pub fn term_count(&self) -> usize {
self.collect_term_ids().len()
}
#[must_use]
#[allow(dead_code)] pub(crate) fn from_frozen_segment(segment: FrozenSegment) -> Self {
let doc_count = segment.doc_count as u64;
Self {
mutable: RwLock::new(MutableSegment::new()),
frozen: RwLock::new(vec![segment]),
doc_count: AtomicU64::new(doc_count),
}
}
#[must_use]
pub fn all_term_ids(&self) -> Vec<u32> {
let mut ids: Vec<u32> = self.collect_term_ids().into_iter().collect();
ids.sort_unstable();
ids
}
#[must_use]
pub fn get_merged_postings_for_compaction(&self) -> FxHashMap<u32, (Vec<PostingEntry>, f32)> {
let mut merged: FxHashMap<u32, Vec<PostingEntry>> = FxHashMap::default();
{
let seg = self.mutable.read();
for (&term_id, entries) in &seg.postings {
let dest = merged.entry(term_id).or_default();
dest.extend_from_slice(entries);
}
}
{
let frozen_vec = self.frozen.read();
for frozen_seg in frozen_vec.iter().rev() {
for (&term_id, (entries, _)) in &frozen_seg.postings {
let dest = merged.entry(term_id).or_default();
for entry in entries {
if !frozen_seg.tombstones.contains(&entry.doc_id) {
dest.push(*entry);
}
}
}
}
}
let mut result: FxHashMap<u32, (Vec<PostingEntry>, f32)> = FxHashMap::default();
for (term_id, mut entries) in merged {
entries.sort_by_key(|e| e.doc_id);
entries.dedup_by_key(|e| e.doc_id);
if entries.is_empty() {
continue;
}
let max_w = entries
.iter()
.map(|e| e.weight.abs())
.fold(0.0_f32, f32::max);
result.insert(term_id, (entries, max_w));
}
result
}
#[cfg(test)]
fn frozen_count(&self) -> usize {
self.frozen.read().len()
}
#[cfg(test)]
fn mutable_doc_count(&self) -> usize {
self.mutable.read().doc_count
}
}
#[cfg(test)]
#[path = "inverted_index_tests.rs"]
mod tests;