use std::io::Write;
use std::mem::size_of;
use hashbrown::HashMap;
use rustc_hash::FxHashMap;
#[cfg(not(feature = "native"))]
use super::simple_interner::{Rodeo, Spur};
#[cfg(feature = "native")]
use lasso::{Rodeo, Spur};
#[cfg(feature = "native")]
use rayon::prelude::*;
use crate::structures::{PostingList, SSTableWriter, TermInfo};
use crate::{DocId, Result};
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
pub(super) struct TermKey {
pub field: u32,
pub term: Spur,
}
#[derive(Clone, Copy)]
pub(super) struct CompactPosting {
pub doc_id: DocId,
pub term_freq: u16,
}
#[cfg(feature = "native")]
const SPILL_THRESHOLD: usize = 16384;
pub(super) struct PostingListBuilder {
pub postings: Vec<CompactPosting>,
#[cfg(feature = "native")]
pub spilled_count: u32,
}
impl PostingListBuilder {
pub fn new() -> Self {
Self {
postings: Vec::with_capacity(4),
#[cfg(feature = "native")]
spilled_count: 0,
}
}
#[inline]
pub fn add(&mut self, doc_id: DocId, term_freq: u32) {
if let Some(last) = self.postings.last_mut()
&& last.doc_id == doc_id
{
last.term_freq = last.term_freq.saturating_add(term_freq as u16);
return;
}
self.postings.push(CompactPosting {
doc_id,
term_freq: term_freq.min(u16::MAX as u32) as u16,
});
}
pub fn len(&self) -> usize {
#[cfg(feature = "native")]
{
self.spilled_count as usize + self.postings.len()
}
#[cfg(not(feature = "native"))]
{
self.postings.len()
}
}
#[cfg(feature = "native")]
#[inline]
pub fn should_spill(&self) -> bool {
self.postings.len() >= SPILL_THRESHOLD
}
}
pub(super) struct PositionPostingListBuilder {
pub postings: Vec<(DocId, Vec<u32>)>,
}
impl PositionPostingListBuilder {
pub fn new() -> Self {
Self {
postings: Vec::new(),
}
}
#[inline]
pub fn add_position(&mut self, doc_id: DocId, position: u32) {
if let Some((last_doc, positions)) = self.postings.last_mut()
&& *last_doc == doc_id
{
positions.push(position);
return;
}
let mut positions = Vec::with_capacity(4);
positions.push(position);
self.postings.push((doc_id, positions));
}
}
pub(super) enum SerializedPosting {
Inline(TermInfo),
External { bytes: Vec<u8>, doc_count: u32 },
}
#[cfg(feature = "native")]
pub(super) type SpillIndex = HashMap<TermKey, Vec<(u64, u32)>>;
pub(super) fn build_postings_streaming(
inverted_index: HashMap<TermKey, PostingListBuilder>,
term_interner: Rodeo,
position_offsets: &FxHashMap<Vec<u8>, (u64, u64)>,
term_dict_writer: &mut dyn Write,
postings_writer: &mut dyn Write,
#[cfg(feature = "native")] spill_reader: Option<(
&mut std::io::BufReader<std::fs::File>,
&SpillIndex,
)>,
) -> Result<()> {
#[cfg(feature = "native")]
let inverted_index = {
let mut index = inverted_index;
if let Some((reader, spill_index)) = spill_reader {
use byteorder::{LittleEndian, ReadBytesExt};
use std::io::{Seek, SeekFrom};
let mut all_ranges: Vec<(TermKey, u64, u32)> = Vec::new();
for (term_key, ranges) in spill_index {
for &(offset, count) in ranges {
all_ranges.push((*term_key, offset, count));
}
}
all_ranges.sort_unstable_by_key(|&(_, offset, _)| offset);
let mut per_term: HashMap<TermKey, Vec<CompactPosting>> = HashMap::new();
for (term_key, offset, count) in all_ranges {
reader.seek(SeekFrom::Start(offset))?;
let buf = per_term.entry(term_key).or_default();
for _ in 0..count {
let doc_id = reader.read_u32::<LittleEndian>()?;
let tf = reader.read_u16::<LittleEndian>()?;
buf.push(CompactPosting {
doc_id,
term_freq: tf,
});
}
}
for (term_key, mut spilled) in per_term {
if let Some(builder) = index.get_mut(&term_key) {
spilled.append(&mut builder.postings);
builder.postings = spilled;
builder.spilled_count = 0;
}
}
}
index
};
let mut term_entries: Vec<(Vec<u8>, PostingListBuilder)> = inverted_index
.into_iter()
.map(|(term_key, posting_list)| {
let term_str = term_interner.resolve(&term_key.term);
let mut key = Vec::with_capacity(4 + term_str.len());
key.extend_from_slice(&term_key.field.to_le_bytes());
key.extend_from_slice(term_str.as_bytes());
(key, posting_list)
})
.collect();
drop(term_interner);
#[cfg(feature = "native")]
term_entries.par_sort_unstable_by(|a, b| a.0.cmp(&b.0));
#[cfg(not(feature = "native"))]
term_entries.sort_unstable_by(|a, b| a.0.cmp(&b.0));
let serialize_fn = |(key, posting_builder): (Vec<u8>, PostingListBuilder)| -> Result<(Vec<u8>, SerializedPosting)> {
let has_positions = position_offsets.contains_key(&key);
if !has_positions
&& let Some(inline) = TermInfo::try_inline_iter(
posting_builder.postings.len(),
posting_builder
.postings
.iter()
.map(|p| (p.doc_id, p.term_freq as u32)),
)
{
return Ok((key, SerializedPosting::Inline(inline)));
}
let mut full_postings = PostingList::with_capacity(posting_builder.len());
for p in &posting_builder.postings {
full_postings.push(p.doc_id, p.term_freq as u32);
}
let mut posting_bytes = Vec::new();
let block_list =
crate::structures::BlockPostingList::from_posting_list(&full_postings)?;
block_list.serialize(&mut posting_bytes)?;
let result = SerializedPosting::External {
bytes: posting_bytes,
doc_count: full_postings.doc_count(),
};
Ok((key, result))
};
#[cfg(feature = "native")]
let serialized: Vec<(Vec<u8>, SerializedPosting)> = term_entries
.into_par_iter()
.map(serialize_fn)
.collect::<Result<Vec<_>>>()?;
#[cfg(not(feature = "native"))]
let serialized: Vec<(Vec<u8>, SerializedPosting)> = term_entries
.into_iter()
.map(serialize_fn)
.collect::<Result<Vec<_>>>()?;
let mut postings_offset = 0u64;
let mut writer = SSTableWriter::<_, TermInfo>::new(term_dict_writer);
for (key, serialized_posting) in serialized {
let term_info = match serialized_posting {
SerializedPosting::Inline(info) => info,
SerializedPosting::External { bytes, doc_count } => {
let posting_len = bytes.len() as u64;
postings_writer.write_all(&bytes)?;
let info = if let Some(&(pos_offset, pos_len)) = position_offsets.get(&key) {
TermInfo::external_with_positions(
postings_offset,
posting_len,
doc_count,
pos_offset,
pos_len,
)
} else {
TermInfo::external(postings_offset, posting_len, doc_count)
};
postings_offset += posting_len;
info
}
};
writer.insert(&key, &term_info)?;
}
let _ = writer.finish()?;
Ok(())
}
pub(super) fn build_positions_streaming(
position_index: HashMap<TermKey, PositionPostingListBuilder>,
term_interner: &Rodeo,
writer: &mut dyn Write,
) -> Result<FxHashMap<Vec<u8>, (u64, u64)>> {
use crate::structures::PositionPostingList;
let mut position_offsets: FxHashMap<Vec<u8>, (u64, u64)> = FxHashMap::default();
let mut entries: Vec<(Vec<u8>, PositionPostingListBuilder)> = position_index
.into_iter()
.map(|(term_key, pos_builder)| {
let term_str = term_interner.resolve(&term_key.term);
let mut key = Vec::with_capacity(size_of::<u32>() + term_str.len());
key.extend_from_slice(&term_key.field.to_le_bytes());
key.extend_from_slice(term_str.as_bytes());
(key, pos_builder)
})
.collect();
entries.sort_by(|a, b| a.0.cmp(&b.0));
let mut current_offset = 0u64;
let mut buf = Vec::new();
for (key, pos_builder) in entries {
let mut pos_list = PositionPostingList::with_capacity(pos_builder.postings.len());
for (doc_id, positions) in pos_builder.postings {
pos_list.push(doc_id, positions);
}
buf.clear();
pos_list.serialize(&mut buf).map_err(crate::Error::Io)?;
writer.write_all(&buf)?;
position_offsets.insert(key, (current_offset, buf.len() as u64));
current_offset += buf.len() as u64;
}
Ok(position_offsets)
}