use std::sync::atomic::{AtomicU64, Ordering};
use parking_lot::RwLock;
use rustc_hash::{FxHashMap, FxHashSet};
use super::types::{PostingEntry, SparseVector};
pub const FREEZE_THRESHOLD: usize = 10_000;
struct MutableSegment {
postings: FxHashMap<u32, Vec<PostingEntry>>,
max_weights: FxHashMap<u32, f32>,
doc_set: FxHashSet<u64>,
doc_count: usize,
}
impl MutableSegment {
fn new() -> Self {
Self {
postings: FxHashMap::default(),
max_weights: FxHashMap::default(),
doc_set: FxHashSet::default(),
doc_count: 0,
}
}
fn insert(&mut self, point_id: u64, vector: &SparseVector) -> bool {
let is_new = self.doc_set.insert(point_id);
for (&term_id, &weight) in vector.indices.iter().zip(vector.values.iter()) {
let entries = self.postings.entry(term_id).or_default();
let entry = PostingEntry {
doc_id: point_id,
weight,
};
match entries.binary_search_by_key(&point_id, |e| e.doc_id) {
Ok(pos) => entries[pos] = entry,
Err(pos) => entries.insert(pos, entry),
}
let abs_weight = weight.abs();
let max_w = self.max_weights.entry(term_id).or_insert(0.0);
if abs_weight > *max_w {
*max_w = abs_weight;
}
}
if is_new {
self.doc_count += 1;
}
is_new
}
#[allow(dead_code)]
fn merge_batch_postings(entries: &mut Vec<PostingEntry>, mut updates: Vec<PostingEntry>) {
if updates.is_empty() {
return;
}
updates.sort_by_key(|entry| entry.doc_id);
let mut deduped_rev = Vec::with_capacity(updates.len());
for entry in updates.into_iter().rev() {
if deduped_rev
.last()
.is_none_or(|last: &PostingEntry| last.doc_id != entry.doc_id)
{
deduped_rev.push(entry);
}
}
deduped_rev.reverse();
let existing = std::mem::take(entries);
let mut merged = Vec::with_capacity(existing.len() + deduped_rev.len());
let mut existing_iter = existing.into_iter().peekable();
let mut updates_iter = deduped_rev.into_iter().peekable();
while let (Some(existing_entry), Some(update_entry)) =
(existing_iter.peek(), updates_iter.peek())
{
match existing_entry.doc_id.cmp(&update_entry.doc_id) {
std::cmp::Ordering::Less => {
merged.push(*existing_entry);
existing_iter.next();
}
std::cmp::Ordering::Greater => {
merged.push(*update_entry);
updates_iter.next();
}
std::cmp::Ordering::Equal => {
merged.push(*update_entry);
existing_iter.next();
updates_iter.next();
}
}
}
merged.extend(existing_iter);
merged.extend(updates_iter);
*entries = merged;
}
fn delete(&mut self, point_id: u64) -> bool {
self.doc_set.remove(&point_id);
let mut any_removed = false;
let mut recalc_terms: Vec<u32> = Vec::new();
let mut empty_terms: Vec<u32> = Vec::new();
for (&term_id, entries) in &mut self.postings {
let before = entries.len();
entries.retain(|e| e.doc_id != point_id);
if entries.len() < before {
any_removed = true;
if entries.is_empty() {
empty_terms.push(term_id);
} else {
recalc_terms.push(term_id);
}
}
}
for term_id in &empty_terms {
self.postings.remove(term_id);
self.max_weights.remove(term_id);
}
for term_id in recalc_terms {
if let Some(entries) = self.postings.get(&term_id) {
let max_w = entries
.iter()
.map(|e| e.weight.abs())
.fold(0.0_f32, f32::max);
self.max_weights.insert(term_id, max_w);
}
}
any_removed
}
}
#[allow(dead_code)]
pub(crate) struct FrozenSegment {
pub(crate) postings: FxHashMap<u32, (Vec<PostingEntry>, f32)>,
tombstones: FxHashSet<u64>,
pub(crate) doc_count: usize,
}
impl FrozenSegment {
#[allow(dead_code)]
pub(crate) fn new(
postings: FxHashMap<u32, (Vec<PostingEntry>, f32)>,
doc_count: usize,
) -> Self {
Self {
postings,
tombstones: FxHashSet::default(),
doc_count,
}
}
fn contains_live(&self, point_id: u64) -> bool {
if self.tombstones.contains(&point_id) {
return false;
}
self.postings.values().any(|(entries, _)| {
entries
.binary_search_by_key(&point_id, |e| e.doc_id)
.is_ok()
})
}
}
pub struct SparseInvertedIndex {
mutable: RwLock<MutableSegment>,
frozen: RwLock<Vec<FrozenSegment>>,
doc_count: AtomicU64,
}
impl Default for SparseInvertedIndex {
fn default() -> Self {
Self::new()
}
}
impl SparseInvertedIndex {
#[must_use]
pub fn new() -> Self {
Self {
mutable: RwLock::new(MutableSegment::new()),
frozen: RwLock::new(Vec::new()),
doc_count: AtomicU64::new(0),
}
}
pub fn insert(&self, point_id: u64, vector: &SparseVector) {
let mut seg = self.mutable.write();
let is_new = seg.insert(point_id, vector);
if is_new {
self.doc_count.fetch_add(1, Ordering::Relaxed);
}
if seg.doc_count >= FREEZE_THRESHOLD {
self.freeze_inner(&mut seg);
}
}
#[allow(dead_code)]
pub(crate) fn insert_batch_chunk(&self, docs: &[(u64, SparseVector)]) {
if docs.is_empty() {
return;
}
let (batch_postings, batch_max_weights, batch_doc_ids) = Self::build_batch_buffers(docs);
let mut seg = self.mutable.write();
let new_docs = Self::merge_doc_ids(&mut seg, batch_doc_ids);
if new_docs > 0 {
self.doc_count.fetch_add(new_docs, Ordering::Relaxed);
}
Self::merge_batch_into_segment(&mut seg, batch_postings, &batch_max_weights);
if seg.doc_count >= FREEZE_THRESHOLD {
self.freeze_inner(&mut seg);
}
}
#[allow(clippy::type_complexity)]
fn build_batch_buffers(
docs: &[(u64, SparseVector)],
) -> (
FxHashMap<u32, Vec<PostingEntry>>,
FxHashMap<u32, f32>,
FxHashSet<u64>,
) {
let mut batch_postings: FxHashMap<u32, Vec<PostingEntry>> = FxHashMap::default();
let mut batch_max_weights: FxHashMap<u32, f32> = FxHashMap::default();
let mut batch_doc_ids: FxHashSet<u64> = FxHashSet::default();
for (point_id, vector) in docs {
batch_doc_ids.insert(*point_id);
for (&term_id, &weight) in vector.indices.iter().zip(vector.values.iter()) {
batch_postings
.entry(term_id)
.or_default()
.push(PostingEntry {
doc_id: *point_id,
weight,
});
let abs_weight = weight.abs();
let max_weight = batch_max_weights.entry(term_id).or_insert(0.0);
if abs_weight > *max_weight {
*max_weight = abs_weight;
}
}
}
(batch_postings, batch_max_weights, batch_doc_ids)
}
fn merge_doc_ids(seg: &mut MutableSegment, batch_doc_ids: FxHashSet<u64>) -> u64 {
let mut new_docs = 0_u64;
for point_id in batch_doc_ids {
if seg.doc_set.insert(point_id) {
seg.doc_count += 1;
new_docs += 1;
}
}
new_docs
}
fn merge_batch_into_segment(
seg: &mut MutableSegment,
batch_postings: FxHashMap<u32, Vec<PostingEntry>>,
batch_max_weights: &FxHashMap<u32, f32>,
) {
for (term_id, updates) in batch_postings {
let entries = seg.postings.entry(term_id).or_default();
MutableSegment::merge_batch_postings(entries, updates);
if let Some(&abs_weight) = batch_max_weights.get(&term_id) {
let max_weight = seg.max_weights.entry(term_id).or_insert(0.0);
if abs_weight > *max_weight {
*max_weight = abs_weight;
}
}
}
}
fn freeze_inner(&self, seg: &mut MutableSegment) {
let old = std::mem::replace(seg, MutableSegment::new());
let mut frozen_postings = FxHashMap::default();
for (term_id, entries) in old.postings {
let max_w = old.max_weights.get(&term_id).copied().unwrap_or(0.0);
frozen_postings.insert(term_id, (entries, max_w));
}
let frozen_seg = FrozenSegment {
postings: frozen_postings,
tombstones: FxHashSet::default(),
doc_count: old.doc_count,
};
let mut frozen_vec = self.frozen.write();
frozen_vec.push(frozen_seg);
}
pub fn delete(&self, point_id: u64) {
let mut seg = self.mutable.write();
let was_in_mutable = seg.delete(point_id);
let mut frozen_vec = self.frozen.write();
let mut was_in_frozen = false;
for frozen_seg in frozen_vec.iter_mut() {
if frozen_seg.contains_live(point_id) {
frozen_seg.tombstones.insert(point_id);
was_in_frozen = true;
}
}
if was_in_mutable || was_in_frozen {
self.doc_count.fetch_sub(1, Ordering::Relaxed);
}
}
#[must_use]
pub fn doc_count(&self) -> u64 {
self.doc_count.load(Ordering::Relaxed)
}
#[must_use]
pub fn posting_count(&self, term_id: u32) -> usize {
let mut count: usize = 0;
{
let frozen_vec = self.frozen.read();
for frozen_seg in frozen_vec.iter() {
if let Some((entries, _)) = frozen_seg.postings.get(&term_id) {
count += entries
.iter()
.filter(|e| !frozen_seg.tombstones.contains(&e.doc_id))
.count();
}
}
}
{
let seg = self.mutable.read();
if let Some(entries) = seg.postings.get(&term_id) {
count += entries.len();
}
}
count
}
#[must_use]
pub fn get_all_postings(&self, term_id: u32) -> Vec<PostingEntry> {
let mut result = Vec::new();
{
let frozen_vec = self.frozen.read();
for frozen_seg in frozen_vec.iter() {
if let Some((entries, _)) = frozen_seg.postings.get(&term_id) {
for entry in entries {
if !frozen_seg.tombstones.contains(&entry.doc_id) {
result.push(*entry);
}
}
}
}
}
{
let seg = self.mutable.read();
if let Some(entries) = seg.postings.get(&term_id) {
result.extend_from_slice(entries);
}
}
result.sort_by_key(|e| e.doc_id);
result
}
#[must_use]
pub fn get_global_max_weight(&self, term_id: u32) -> f32 {
let mut max_w = 0.0_f32;
{
let frozen_vec = self.frozen.read();
for frozen_seg in frozen_vec.iter() {
if let Some(&(_, w)) = frozen_seg.postings.get(&term_id) {
max_w = max_w.max(w);
}
}
}
{
let seg = self.mutable.read();
if let Some(&w) = seg.max_weights.get(&term_id) {
max_w = max_w.max(w);
}
}
max_w
}
fn collect_term_ids(&self) -> FxHashSet<u32> {
let mut terms: FxHashSet<u32> = FxHashSet::default();
{
let frozen_vec = self.frozen.read();
for frozen_seg in frozen_vec.iter() {
terms.extend(frozen_seg.postings.keys());
}
}
{
let seg = self.mutable.read();
terms.extend(seg.postings.keys());
}
terms
}
#[must_use]
pub fn term_count(&self) -> usize {
self.collect_term_ids().len()
}
#[must_use]
#[allow(dead_code)]
pub(crate) fn from_frozen_segment(segment: FrozenSegment) -> Self {
let doc_count = segment.doc_count as u64;
Self {
mutable: RwLock::new(MutableSegment::new()),
frozen: RwLock::new(vec![segment]),
doc_count: AtomicU64::new(doc_count),
}
}
#[must_use]
pub fn all_term_ids(&self) -> Vec<u32> {
let mut ids: Vec<u32> = self.collect_term_ids().into_iter().collect();
ids.sort_unstable();
ids
}
#[must_use]
pub fn get_merged_postings_for_compaction(&self) -> FxHashMap<u32, (Vec<PostingEntry>, f32)> {
let mut merged: FxHashMap<u32, Vec<PostingEntry>> = FxHashMap::default();
{
let seg = self.mutable.read();
for (&term_id, entries) in &seg.postings {
let dest = merged.entry(term_id).or_default();
dest.extend_from_slice(entries);
}
}
{
let frozen_vec = self.frozen.read();
for frozen_seg in frozen_vec.iter() {
for (&term_id, (entries, _)) in &frozen_seg.postings {
let dest = merged.entry(term_id).or_default();
for entry in entries {
if !frozen_seg.tombstones.contains(&entry.doc_id) {
dest.push(*entry);
}
}
}
}
}
let mut result: FxHashMap<u32, (Vec<PostingEntry>, f32)> = FxHashMap::default();
for (term_id, mut entries) in merged {
entries.sort_by_key(|e| e.doc_id);
entries.dedup_by_key(|e| e.doc_id);
if entries.is_empty() {
continue;
}
let max_w = entries
.iter()
.map(|e| e.weight.abs())
.fold(0.0_f32, f32::max);
result.insert(term_id, (entries, max_w));
}
result
}
#[cfg(test)]
fn frozen_count(&self) -> usize {
self.frozen.read().len()
}
#[cfg(test)]
fn mutable_doc_count(&self) -> usize {
self.mutable.read().doc_count
}
}
#[cfg(test)]
#[path = "inverted_index_tests.rs"]
mod tests;