use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use std::io::{self, Read, Write};
use crate::DocId;
pub const POSITION_BLOCK_SIZE: usize = 128;
pub const MAX_TOKEN_POSITION: u32 = (1 << 20) - 1;
pub const MAX_ELEMENT_ORDINAL: u32 = (1 << 12) - 1;
#[inline]
pub fn encode_position(element_ordinal: u32, token_position: u32) -> u32 {
debug_assert!(
element_ordinal <= MAX_ELEMENT_ORDINAL,
"Element ordinal {} exceeds maximum {}",
element_ordinal,
MAX_ELEMENT_ORDINAL
);
debug_assert!(
token_position <= MAX_TOKEN_POSITION,
"Token position {} exceeds maximum {}",
token_position,
MAX_TOKEN_POSITION
);
(element_ordinal << 20) | (token_position & MAX_TOKEN_POSITION)
}
#[inline]
pub fn decode_element_ordinal(position: u32) -> u32 {
position >> 20
}
#[inline]
pub fn decode_token_position(position: u32) -> u32 {
position & MAX_TOKEN_POSITION
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PostingWithPositions {
pub doc_id: DocId,
pub term_freq: u32,
pub positions: Vec<u32>,
}
#[derive(Debug, Clone)]
pub struct PositionPostingList {
skip_list: Vec<(DocId, DocId, u32)>,
data: Vec<u8>,
doc_count: u32,
}
impl Default for PositionPostingList {
fn default() -> Self {
Self::new()
}
}
impl PositionPostingList {
pub fn new() -> Self {
Self {
skip_list: Vec::new(),
data: Vec::new(),
doc_count: 0,
}
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
skip_list: Vec::with_capacity(capacity / POSITION_BLOCK_SIZE + 1),
data: Vec::with_capacity(capacity * 8), doc_count: 0,
}
}
pub fn from_postings(postings: &[PostingWithPositions]) -> io::Result<Self> {
if postings.is_empty() {
return Ok(Self::new());
}
let mut skip_list = Vec::new();
let mut data = Vec::new();
let mut i = 0;
while i < postings.len() {
let block_start = data.len() as u32;
let block_end = (i + POSITION_BLOCK_SIZE).min(postings.len());
let block = &postings[i..block_end];
let base_doc_id = block.first().unwrap().doc_id;
let last_doc_id = block.last().unwrap().doc_id;
skip_list.push((base_doc_id, last_doc_id, block_start));
write_vint(&mut data, block.len() as u64)?;
let mut prev_doc_id = base_doc_id;
for (j, posting) in block.iter().enumerate() {
if j == 0 {
write_vint(&mut data, posting.doc_id as u64)?;
} else {
let delta = posting.doc_id - prev_doc_id;
write_vint(&mut data, delta as u64)?;
}
prev_doc_id = posting.doc_id;
write_vint(&mut data, posting.positions.len() as u64)?;
for &pos in &posting.positions {
write_vint(&mut data, pos as u64)?;
}
}
i = block_end;
}
Ok(Self {
skip_list,
data,
doc_count: postings.len() as u32,
})
}
pub fn push(&mut self, doc_id: DocId, positions: Vec<u32>) {
let posting = PostingWithPositions {
doc_id,
term_freq: positions.len() as u32,
positions,
};
let block_start = self.data.len() as u32;
let need_new_block =
self.skip_list.is_empty() || self.doc_count.is_multiple_of(POSITION_BLOCK_SIZE as u32);
if need_new_block {
self.skip_list.push((doc_id, doc_id, block_start));
write_vint(&mut self.data, 1u64).unwrap();
write_vint(&mut self.data, doc_id as u64).unwrap();
} else {
let last_block = self.skip_list.last_mut().unwrap();
let prev_doc_id = last_block.1;
last_block.1 = doc_id;
let delta = doc_id - prev_doc_id;
write_vint(&mut self.data, delta as u64).unwrap();
}
write_vint(&mut self.data, posting.positions.len() as u64).unwrap();
for &pos in &posting.positions {
write_vint(&mut self.data, pos as u64).unwrap();
}
self.doc_count += 1;
}
pub fn doc_count(&self) -> u32 {
self.doc_count
}
pub fn len(&self) -> usize {
self.doc_count as usize
}
pub fn is_empty(&self) -> bool {
self.doc_count == 0
}
pub fn get_positions(&self, target_doc_id: DocId) -> Option<Vec<u32>> {
if self.skip_list.is_empty() {
return None;
}
let block_idx = match self.skip_list.binary_search_by(|&(base, last, _)| {
if target_doc_id < base {
std::cmp::Ordering::Greater
} else if target_doc_id > last {
std::cmp::Ordering::Less
} else {
std::cmp::Ordering::Equal
}
}) {
Ok(idx) => idx,
Err(_) => return None, };
let offset = self.skip_list[block_idx].2 as usize;
let mut reader = &self.data[offset..];
let count = read_vint(&mut reader).ok()? as usize;
let mut prev_doc_id = 0u32;
for i in 0..count {
let doc_id = if i == 0 {
read_vint(&mut reader).ok()? as u32
} else {
let delta = read_vint(&mut reader).ok()? as u32;
prev_doc_id + delta
};
prev_doc_id = doc_id;
let num_positions = read_vint(&mut reader).ok()? as usize;
if doc_id == target_doc_id {
let mut positions = Vec::with_capacity(num_positions);
for _ in 0..num_positions {
let pos = read_vint(&mut reader).ok()? as u32;
positions.push(pos);
}
return Some(positions);
} else {
for _ in 0..num_positions {
let _ = read_vint(&mut reader);
}
}
}
None
}
pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
writer.write_u32::<LittleEndian>(self.doc_count)?;
writer.write_u32::<LittleEndian>(self.skip_list.len() as u32)?;
for (base_doc_id, last_doc_id, offset) in &self.skip_list {
writer.write_u32::<LittleEndian>(*base_doc_id)?;
writer.write_u32::<LittleEndian>(*last_doc_id)?;
writer.write_u32::<LittleEndian>(*offset)?;
}
writer.write_u32::<LittleEndian>(self.data.len() as u32)?;
writer.write_all(&self.data)?;
Ok(())
}
pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
let doc_count = reader.read_u32::<LittleEndian>()?;
let skip_count = reader.read_u32::<LittleEndian>()? as usize;
let mut skip_list = Vec::with_capacity(skip_count);
for _ in 0..skip_count {
let base_doc_id = reader.read_u32::<LittleEndian>()?;
let last_doc_id = reader.read_u32::<LittleEndian>()?;
let offset = reader.read_u32::<LittleEndian>()?;
skip_list.push((base_doc_id, last_doc_id, offset));
}
let data_len = reader.read_u32::<LittleEndian>()? as usize;
let mut data = vec![0u8; data_len];
reader.read_exact(&mut data)?;
Ok(Self {
skip_list,
data,
doc_count,
})
}
pub fn concatenate_blocks(sources: &[(PositionPostingList, u32)]) -> io::Result<Self> {
let mut skip_list = Vec::new();
let mut data = Vec::new();
let mut total_docs = 0u32;
for (source, doc_offset) in sources {
for block_idx in 0..source.skip_list.len() {
let (base, last, src_offset) = source.skip_list[block_idx];
let next_offset = if block_idx + 1 < source.skip_list.len() {
source.skip_list[block_idx + 1].2 as usize
} else {
source.data.len()
};
let new_base = base + doc_offset;
let new_last = last + doc_offset;
let new_offset = data.len() as u32;
let block_bytes = &source.data[src_offset as usize..next_offset];
let mut reader = block_bytes;
let count = read_vint(&mut reader)? as usize;
write_vint(&mut data, count as u64)?;
let first_doc = read_vint(&mut reader)? as u32;
write_vint(&mut data, (first_doc + doc_offset) as u64)?;
data.extend_from_slice(reader);
skip_list.push((new_base, new_last, new_offset));
total_docs += count as u32;
}
}
Ok(Self {
skip_list,
data,
doc_count: total_docs,
})
}
pub fn iter(&self) -> PositionPostingIterator<'_> {
PositionPostingIterator::new(self)
}
}
fn write_vint<W: Write>(writer: &mut W, mut value: u64) -> io::Result<()> {
loop {
let byte = (value & 0x7F) as u8;
value >>= 7;
if value == 0 {
writer.write_u8(byte)?;
break;
} else {
writer.write_u8(byte | 0x80)?;
}
}
Ok(())
}
fn read_vint<R: Read>(reader: &mut R) -> io::Result<u64> {
let mut result = 0u64;
let mut shift = 0;
loop {
let byte = reader.read_u8()?;
result |= ((byte & 0x7F) as u64) << shift;
if byte & 0x80 == 0 {
break;
}
shift += 7;
}
Ok(result)
}
pub struct PositionPostingIterator<'a> {
list: &'a PositionPostingList,
current_block: usize,
position_in_block: usize,
block_postings: Vec<PostingWithPositions>,
exhausted: bool,
}
impl<'a> PositionPostingIterator<'a> {
pub fn new(list: &'a PositionPostingList) -> Self {
let exhausted = list.skip_list.is_empty();
let mut iter = Self {
list,
current_block: 0,
position_in_block: 0,
block_postings: Vec::new(),
exhausted,
};
if !iter.exhausted {
iter.load_block(0);
}
iter
}
fn load_block(&mut self, block_idx: usize) {
if block_idx >= self.list.skip_list.len() {
self.exhausted = true;
return;
}
self.current_block = block_idx;
self.position_in_block = 0;
let offset = self.list.skip_list[block_idx].2 as usize;
let mut reader = &self.list.data[offset..];
let count = read_vint(&mut reader).unwrap_or(0) as usize;
self.block_postings.clear();
self.block_postings.reserve(count);
let mut prev_doc_id = 0u32;
for i in 0..count {
let doc_id = if i == 0 {
read_vint(&mut reader).unwrap_or(0) as u32
} else {
let delta = read_vint(&mut reader).unwrap_or(0) as u32;
prev_doc_id + delta
};
prev_doc_id = doc_id;
let num_positions = read_vint(&mut reader).unwrap_or(0) as usize;
let mut positions = Vec::with_capacity(num_positions);
for _ in 0..num_positions {
let pos = read_vint(&mut reader).unwrap_or(0) as u32;
positions.push(pos);
}
self.block_postings.push(PostingWithPositions {
doc_id,
term_freq: num_positions as u32,
positions,
});
}
}
pub fn doc(&self) -> DocId {
if self.exhausted || self.position_in_block >= self.block_postings.len() {
u32::MAX
} else {
self.block_postings[self.position_in_block].doc_id
}
}
pub fn term_freq(&self) -> u32 {
if self.exhausted || self.position_in_block >= self.block_postings.len() {
0
} else {
self.block_postings[self.position_in_block].term_freq
}
}
pub fn positions(&self) -> &[u32] {
if self.exhausted || self.position_in_block >= self.block_postings.len() {
&[]
} else {
&self.block_postings[self.position_in_block].positions
}
}
pub fn advance(&mut self) {
if self.exhausted {
return;
}
self.position_in_block += 1;
if self.position_in_block >= self.block_postings.len() {
self.load_block(self.current_block + 1);
}
}
pub fn seek(&mut self, target: DocId) {
if self.exhausted {
return;
}
if let Some((_, last, _)) = self.list.skip_list.get(self.current_block)
&& target <= *last
{
while self.position_in_block < self.block_postings.len()
&& self.block_postings[self.position_in_block].doc_id < target
{
self.position_in_block += 1;
}
if self.position_in_block >= self.block_postings.len() {
self.load_block(self.current_block + 1);
self.seek(target); }
return;
}
let block_idx = match self.list.skip_list.binary_search_by(|&(base, last, _)| {
if target < base {
std::cmp::Ordering::Greater
} else if target > last {
std::cmp::Ordering::Less
} else {
std::cmp::Ordering::Equal
}
}) {
Ok(idx) => idx,
Err(idx) => idx, };
if block_idx >= self.list.skip_list.len() {
self.exhausted = true;
return;
}
self.load_block(block_idx);
while self.position_in_block < self.block_postings.len()
&& self.block_postings[self.position_in_block].doc_id < target
{
self.position_in_block += 1;
}
if self.position_in_block >= self.block_postings.len() {
self.load_block(self.current_block + 1);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_position_encoding() {
let pos = encode_position(0, 5);
assert_eq!(decode_element_ordinal(pos), 0);
assert_eq!(decode_token_position(pos), 5);
let pos = encode_position(3, 100);
assert_eq!(decode_element_ordinal(pos), 3);
assert_eq!(decode_token_position(pos), 100);
let pos = encode_position(MAX_ELEMENT_ORDINAL, MAX_TOKEN_POSITION);
assert_eq!(decode_element_ordinal(pos), MAX_ELEMENT_ORDINAL);
assert_eq!(decode_token_position(pos), MAX_TOKEN_POSITION);
}
#[test]
fn test_position_posting_list_build() {
let postings = vec![
PostingWithPositions {
doc_id: 1,
term_freq: 2,
positions: vec![encode_position(0, 0), encode_position(0, 2)],
},
PostingWithPositions {
doc_id: 3,
term_freq: 1,
positions: vec![encode_position(1, 0)],
},
];
let list = PositionPostingList::from_postings(&postings).unwrap();
assert_eq!(list.doc_count(), 2);
let pos = list.get_positions(1).unwrap();
assert_eq!(pos.len(), 2);
let pos = list.get_positions(3).unwrap();
assert_eq!(pos.len(), 1);
assert!(list.get_positions(2).is_none());
assert!(list.get_positions(99).is_none());
}
#[test]
fn test_serialization_roundtrip() {
let postings = vec![
PostingWithPositions {
doc_id: 1,
term_freq: 2,
positions: vec![encode_position(0, 0), encode_position(0, 5)],
},
PostingWithPositions {
doc_id: 3,
term_freq: 1,
positions: vec![encode_position(1, 0)],
},
PostingWithPositions {
doc_id: 5,
term_freq: 1,
positions: vec![encode_position(0, 10)],
},
];
let list = PositionPostingList::from_postings(&postings).unwrap();
let mut bytes = Vec::new();
list.serialize(&mut bytes).unwrap();
let mut cursor = std::io::Cursor::new(&bytes);
let deserialized = PositionPostingList::deserialize(&mut cursor).unwrap();
assert_eq!(list.doc_count(), deserialized.doc_count());
let pos = deserialized.get_positions(1).unwrap();
assert_eq!(pos, vec![encode_position(0, 0), encode_position(0, 5)]);
let pos = deserialized.get_positions(3).unwrap();
assert_eq!(pos, vec![encode_position(1, 0)]);
}
#[test]
fn test_binary_search_many_blocks() {
let mut postings = Vec::new();
for i in 0..300 {
postings.push(PostingWithPositions {
doc_id: i * 2, term_freq: 1,
positions: vec![encode_position(0, i)],
});
}
let list = PositionPostingList::from_postings(&postings).unwrap();
assert_eq!(list.doc_count(), 300);
assert_eq!(list.skip_list.len(), 3);
let pos = list.get_positions(0).unwrap();
assert_eq!(pos, vec![encode_position(0, 0)]);
let pos = list.get_positions(256).unwrap(); assert_eq!(pos, vec![encode_position(0, 128)]);
let pos = list.get_positions(598).unwrap(); assert_eq!(pos, vec![encode_position(0, 299)]);
assert!(list.get_positions(1).is_none());
assert!(list.get_positions(257).is_none());
}
#[test]
fn test_concatenate_blocks_merge() {
let postings1 = vec![
PostingWithPositions {
doc_id: 0,
term_freq: 1,
positions: vec![0],
},
PostingWithPositions {
doc_id: 1,
term_freq: 1,
positions: vec![5],
},
PostingWithPositions {
doc_id: 2,
term_freq: 1,
positions: vec![10],
},
];
let list1 = PositionPostingList::from_postings(&postings1).unwrap();
let postings2 = vec![
PostingWithPositions {
doc_id: 0,
term_freq: 1,
positions: vec![100],
},
PostingWithPositions {
doc_id: 1,
term_freq: 1,
positions: vec![105],
},
];
let list2 = PositionPostingList::from_postings(&postings2).unwrap();
let combined = PositionPostingList::concatenate_blocks(&[
(list1, 0), (list2, 3), ])
.unwrap();
assert_eq!(combined.doc_count(), 5);
assert!(combined.get_positions(0).is_some());
assert!(combined.get_positions(1).is_some());
assert!(combined.get_positions(2).is_some());
assert!(combined.get_positions(3).is_some()); assert!(combined.get_positions(4).is_some()); }
#[test]
fn test_iterator() {
let postings = vec![
PostingWithPositions {
doc_id: 1,
term_freq: 2,
positions: vec![0, 5],
},
PostingWithPositions {
doc_id: 3,
term_freq: 1,
positions: vec![10],
},
PostingWithPositions {
doc_id: 5,
term_freq: 1,
positions: vec![15],
},
];
let list = PositionPostingList::from_postings(&postings).unwrap();
let mut iter = list.iter();
assert_eq!(iter.doc(), 1);
assert_eq!(iter.positions(), &[0, 5]);
iter.advance();
assert_eq!(iter.doc(), 3);
iter.seek(5);
assert_eq!(iter.doc(), 5);
assert_eq!(iter.positions(), &[15]);
iter.advance();
assert_eq!(iter.doc(), u32::MAX); }
}