use arrow::array::AsArray;
use arrow_array::{Array, LargeBinaryArray};
use super::{
CompressedPositionStorage, PostingList, PostingTailCodec,
builder::BLOCK_SIZE,
encoding::{
decode_position_stream_block, decompress_positions, decompress_posting_block,
decompress_posting_remainder,
},
};
pub enum PostingListIterator<'a> {
Plain(PlainPostingListIterator<'a>),
Compressed(Box<CompressedPostingListIterator>),
}
impl<'a> PostingListIterator<'a> {
pub fn new(posting: &'a PostingList) -> Self {
match posting {
PostingList::Plain(posting) => Self::Plain(posting.iter()),
PostingList::Compressed(posting) => {
Self::Compressed(Box::new(CompressedPostingListIterator::new(
posting.length as usize,
posting.blocks.clone(),
posting.posting_tail_codec,
posting.positions.clone(),
)))
}
}
}
}
impl<'a> Iterator for PostingListIterator<'a> {
type Item = (u64, u32, Option<Box<dyn Iterator<Item = u32> + 'a>>);
fn next(&mut self) -> Option<Self::Item> {
match self {
PostingListIterator::Plain(iter) => iter
.next()
.map(|(doc_id, freq, pos)| (doc_id, freq as u32, pos)),
PostingListIterator::Compressed(iter) => iter
.next()
.map(|(doc_id, freq, pos)| (doc_id as u64, freq, pos)),
}
}
}
pub type PlainPostingListIterator<'a> =
Box<dyn Iterator<Item = (u64, f32, Option<Box<dyn Iterator<Item = u32> + 'a>>)> + 'a>;
struct OwnedPositionsIter {
positions: Box<[u32]>,
index: usize,
}
impl OwnedPositionsIter {
fn new(positions: &[u32]) -> Self {
Self {
positions: Box::<[u32]>::from(positions),
index: 0,
}
}
}
impl Iterator for OwnedPositionsIter {
type Item = u32;
fn next(&mut self) -> Option<Self::Item> {
let position = self.positions.get(self.index).copied()?;
self.index += 1;
Some(position)
}
}
pub struct CompressedPostingListIterator {
remainder: usize,
blocks: LargeBinaryArray,
next_block_idx: usize,
posting_tail_codec: PostingTailCodec,
positions: Option<CompressedPositionStorage>,
idx: usize,
doc_ids: Vec<u32>,
frequencies: Vec<u32>,
doc_idx_in_block: usize,
decoded_positions: Vec<u32>,
position_offsets: Vec<usize>,
buffer: [u32; BLOCK_SIZE],
}
impl CompressedPostingListIterator {
pub fn new(
length: usize,
blocks: LargeBinaryArray,
posting_tail_codec: PostingTailCodec,
positions: Option<CompressedPositionStorage>,
) -> Self {
debug_assert!(length > 0, "length: {}", length);
debug_assert_eq!(
length.div_ceil(BLOCK_SIZE),
blocks.len(),
"length: {}, num_blocks: {}",
length,
blocks.len(),
);
Self {
remainder: length % BLOCK_SIZE,
blocks,
next_block_idx: 0,
posting_tail_codec,
positions,
idx: 0,
doc_ids: Vec::new(),
frequencies: Vec::new(),
doc_idx_in_block: 0,
decoded_positions: Vec::new(),
position_offsets: Vec::new(),
buffer: [0; BLOCK_SIZE],
}
}
}
impl Iterator for CompressedPostingListIterator {
type Item = (u32, u32, Option<Box<dyn Iterator<Item = u32>>>);
fn next(&mut self) -> Option<Self::Item> {
if self.doc_idx_in_block < self.doc_ids.len() {
let doc_id = self.doc_ids[self.doc_idx_in_block];
let freq = self.frequencies[self.doc_idx_in_block];
let positions = self.positions.as_ref().map(|storage| match storage {
CompressedPositionStorage::LegacyPerDoc(list) => {
let compressed = list.value(self.idx);
let positions = decompress_positions(compressed.as_binary());
Box::new(positions.into_iter()) as Box<dyn Iterator<Item = u32>>
}
CompressedPositionStorage::SharedStream(_) => {
let start = self.position_offsets[self.doc_idx_in_block];
let end = self.position_offsets[self.doc_idx_in_block + 1];
Box::new(OwnedPositionsIter::new(&self.decoded_positions[start..end]))
as Box<dyn Iterator<Item = u32>>
}
});
self.idx += 1;
self.doc_idx_in_block += 1;
return Some((doc_id, freq, positions));
}
if self.next_block_idx >= self.blocks.len() {
return None;
}
let compressed = self.blocks.value(self.next_block_idx);
self.next_block_idx += 1;
self.doc_ids.clear();
self.frequencies.clear();
if self.next_block_idx == self.blocks.len() && self.remainder > 0 {
decompress_posting_remainder(
compressed,
self.remainder,
self.posting_tail_codec,
&mut self.doc_ids,
&mut self.frequencies,
);
} else {
decompress_posting_block(
compressed,
&mut self.buffer,
&mut self.doc_ids,
&mut self.frequencies,
);
}
self.doc_idx_in_block = 0;
self.decoded_positions.clear();
self.position_offsets.clear();
if let Some(CompressedPositionStorage::SharedStream(stream)) = self.positions.as_ref() {
decode_position_stream_block(
stream.block(self.next_block_idx - 1),
self.frequencies.as_slice(),
stream.codec(),
&mut self.decoded_positions,
)
.expect("shared position stream decoding should succeed");
self.position_offsets.reserve(self.frequencies.len() + 1);
self.position_offsets.push(0);
let mut offset = 0usize;
for &frequency in &self.frequencies {
offset += frequency as usize;
self.position_offsets.push(offset);
}
}
self.next()
}
}
pub fn take_fst_keys<'f, I, S>(s: I, dst: &mut Vec<String>, max_expansion: usize)
where
I: for<'a> fst::IntoStreamer<'a, Into = S, Item = (&'a [u8], u64)>,
S: 'f + for<'a> fst::Streamer<'a, Item = (&'a [u8], u64)>,
{
let mut stream = s.into_stream();
while let Some((token, _)) = stream.next() {
dst.push(String::from_utf8_lossy(token).into_owned());
if dst.len() >= max_expansion {
break;
}
}
}