use crate::structures::simd;
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use std::io::{self, Read, Write};
pub const HORIZONTAL_BP128_BLOCK_SIZE: usize = 128;
pub const SMALL_BLOCK_SIZE: usize = 32;
pub const SMALL_BLOCK_THRESHOLD: usize = 256;
pub fn pack_block(
values: &[u32; HORIZONTAL_BP128_BLOCK_SIZE],
bit_width: u8,
output: &mut Vec<u8>,
) {
if bit_width == 0 {
return;
}
let bytes_needed = (HORIZONTAL_BP128_BLOCK_SIZE * bit_width as usize).div_ceil(8);
let start = output.len();
output.resize(start + bytes_needed, 0);
let mut bit_pos = 0usize;
for &value in values {
let byte_idx = start + bit_pos / 8;
let bit_offset = bit_pos % 8;
let mut remaining_bits = bit_width as usize;
let mut val = value;
let mut current_byte_idx = byte_idx;
let mut current_bit_offset = bit_offset;
while remaining_bits > 0 {
let bits_in_byte = (8 - current_bit_offset).min(remaining_bits);
let mask = ((1u32 << bits_in_byte) - 1) as u8;
output[current_byte_idx] |= ((val as u8) & mask) << current_bit_offset;
val >>= bits_in_byte;
remaining_bits -= bits_in_byte;
current_byte_idx += 1;
current_bit_offset = 0;
}
bit_pos += bit_width as usize;
}
}
pub fn unpack_block(input: &[u8], bit_width: u8, output: &mut [u32; HORIZONTAL_BP128_BLOCK_SIZE]) {
if bit_width == 0 {
output.fill(0);
return;
}
match bit_width {
8 => simd::unpack_8bit(input, output, HORIZONTAL_BP128_BLOCK_SIZE),
16 => simd::unpack_16bit(input, output, HORIZONTAL_BP128_BLOCK_SIZE),
32 => simd::unpack_32bit(input, output, HORIZONTAL_BP128_BLOCK_SIZE),
_ => unpack_block_generic(input, bit_width, output),
}
}
#[inline]
fn unpack_block_generic(
input: &[u8],
bit_width: u8,
output: &mut [u32; HORIZONTAL_BP128_BLOCK_SIZE],
) {
let mask = (1u64 << bit_width) - 1;
let bit_width_usize = bit_width as usize;
let mut bit_pos = 0usize;
let input_ptr = input.as_ptr();
for out in output.iter_mut() {
let byte_idx = bit_pos >> 3; let bit_offset = bit_pos & 7;
let word = unsafe { (input_ptr.add(byte_idx) as *const u64).read_unaligned() };
*out = ((word >> bit_offset) & mask) as u32;
bit_pos += bit_width_usize;
}
}
#[inline]
pub fn unpack_block_n(input: &[u8], bit_width: u8, output: &mut [u32], n: usize) {
if bit_width == 0 {
output[..n].fill(0);
return;
}
let mask = (1u64 << bit_width) - 1;
let bit_width_usize = bit_width as usize;
let mut bit_pos = 0usize;
let input_ptr = input.as_ptr();
for out in output[..n].iter_mut() {
let byte_idx = bit_pos >> 3;
let bit_offset = bit_pos & 7;
let word = unsafe { (input_ptr.add(byte_idx) as *const u64).read_unaligned() };
*out = ((word >> bit_offset) & mask) as u32;
bit_pos += bit_width_usize;
}
}
#[inline]
pub fn binary_search_block(block: &[u32], target: u32) -> usize {
match block.binary_search(&target) {
Ok(idx) => idx,
Err(idx) => idx,
}
}
#[allow(dead_code)]
#[inline]
fn prefix_sum_8(deltas: &mut [u32; 8]) {
for i in (1..8).rev() {
deltas[i] = deltas[i].wrapping_add(deltas[i - 1]);
}
for i in (2..8).rev() {
deltas[i] = deltas[i].wrapping_add(deltas[i - 2]);
}
for i in (4..8).rev() {
deltas[i] = deltas[i].wrapping_add(deltas[i - 4]);
}
}
#[derive(Debug, Clone)]
pub struct HorizontalBP128Block {
pub doc_deltas: Vec<u8>,
pub doc_bit_width: u8,
pub term_freqs: Vec<u8>,
pub tf_bit_width: u8,
pub first_doc_id: u32,
pub last_doc_id: u32,
pub num_docs: u16,
pub max_tf: u32,
pub max_block_score: f32,
}
impl HorizontalBP128Block {
pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
writer.write_u32::<LittleEndian>(self.first_doc_id)?;
writer.write_u32::<LittleEndian>(self.last_doc_id)?;
writer.write_u16::<LittleEndian>(self.num_docs)?;
writer.write_u8(self.doc_bit_width)?;
writer.write_u8(self.tf_bit_width)?;
writer.write_u32::<LittleEndian>(self.max_tf)?;
writer.write_f32::<LittleEndian>(self.max_block_score)?;
writer.write_u16::<LittleEndian>(self.doc_deltas.len() as u16)?;
writer.write_all(&self.doc_deltas)?;
writer.write_u16::<LittleEndian>(self.term_freqs.len() as u16)?;
writer.write_all(&self.term_freqs)?;
Ok(())
}
pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
let first_doc_id = reader.read_u32::<LittleEndian>()?;
let last_doc_id = reader.read_u32::<LittleEndian>()?;
let num_docs = reader.read_u16::<LittleEndian>()?;
let doc_bit_width = reader.read_u8()?;
let tf_bit_width = reader.read_u8()?;
let max_tf = reader.read_u32::<LittleEndian>()?;
let max_block_score = reader.read_f32::<LittleEndian>()?;
let doc_deltas_len = reader.read_u16::<LittleEndian>()? as usize;
let mut doc_deltas = vec![0u8; doc_deltas_len];
reader.read_exact(&mut doc_deltas)?;
let term_freqs_len = reader.read_u16::<LittleEndian>()? as usize;
let mut term_freqs = vec![0u8; term_freqs_len];
reader.read_exact(&mut term_freqs)?;
Ok(Self {
doc_deltas,
doc_bit_width,
term_freqs,
tf_bit_width,
first_doc_id,
last_doc_id,
num_docs,
max_tf,
max_block_score,
})
}
pub fn decode_doc_ids(&self) -> Vec<u32> {
let mut output = vec![0u32; self.num_docs as usize];
self.decode_doc_ids_into(&mut output);
output
}
#[inline]
pub fn decode_doc_ids_into(&self, output: &mut [u32]) -> usize {
let count = self.num_docs as usize;
if count == 0 {
return 0;
}
simd::unpack_delta_decode(
&self.doc_deltas,
self.doc_bit_width,
output,
self.first_doc_id,
count,
);
count
}
pub fn decode_term_freqs(&self) -> Vec<u32> {
let mut output = vec![0u32; self.num_docs as usize];
self.decode_term_freqs_into(&mut output);
output
}
#[inline]
pub fn decode_term_freqs_into(&self, output: &mut [u32]) -> usize {
let count = self.num_docs as usize;
if count == 0 {
return 0;
}
unpack_block_n(&self.term_freqs, self.tf_bit_width, output, count);
simd::add_one(output, count);
count
}
}
#[derive(Debug, Clone)]
pub struct HorizontalBP128PostingList {
pub blocks: Vec<HorizontalBP128Block>,
pub doc_count: u32,
pub max_score: f32,
}
impl HorizontalBP128PostingList {
pub fn from_postings(doc_ids: &[u32], term_freqs: &[u32], idf: f32) -> Self {
assert_eq!(doc_ids.len(), term_freqs.len());
if doc_ids.is_empty() {
return Self {
blocks: Vec::new(),
doc_count: 0,
max_score: 0.0,
};
}
let mut blocks = Vec::new();
let mut max_score = 0.0f32;
let mut i = 0;
while i < doc_ids.len() {
let block_end = (i + HORIZONTAL_BP128_BLOCK_SIZE).min(doc_ids.len());
let block_docs = &doc_ids[i..block_end];
let block_tfs = &term_freqs[i..block_end];
let block = Self::create_block(block_docs, block_tfs, idf);
max_score = max_score.max(block.max_block_score);
blocks.push(block);
i = block_end;
}
Self {
blocks,
doc_count: doc_ids.len() as u32,
max_score,
}
}
fn create_block(doc_ids: &[u32], term_freqs: &[u32], idf: f32) -> HorizontalBP128Block {
use crate::query::bm25_upper_bound;
let num_docs = doc_ids.len();
let first_doc_id = doc_ids[0];
let last_doc_id = *doc_ids.last().unwrap();
let mut deltas = [0u32; HORIZONTAL_BP128_BLOCK_SIZE];
let mut max_delta = 0u32;
for j in 1..num_docs {
let delta = doc_ids[j] - doc_ids[j - 1] - 1;
deltas[j - 1] = delta;
max_delta = max_delta.max(delta);
}
let mut tfs = [0u32; HORIZONTAL_BP128_BLOCK_SIZE];
let mut max_tf = 0u32;
for (j, &tf) in term_freqs.iter().enumerate() {
tfs[j] = tf - 1; max_tf = max_tf.max(tf);
}
let max_block_score = bm25_upper_bound(max_tf as f32, idf);
let doc_bit_width = simd::bits_needed(max_delta);
let tf_bit_width = simd::bits_needed(max_tf.saturating_sub(1));
let mut doc_deltas = Vec::new();
pack_block(&deltas, doc_bit_width, &mut doc_deltas);
let mut term_freqs_packed = Vec::new();
pack_block(&tfs, tf_bit_width, &mut term_freqs_packed);
HorizontalBP128Block {
doc_deltas,
doc_bit_width,
term_freqs: term_freqs_packed,
tf_bit_width,
first_doc_id,
last_doc_id,
num_docs: num_docs as u16,
max_tf,
max_block_score,
}
}
pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
writer.write_u32::<LittleEndian>(self.doc_count)?;
writer.write_f32::<LittleEndian>(self.max_score)?;
writer.write_u32::<LittleEndian>(self.blocks.len() as u32)?;
for block in &self.blocks {
block.serialize(writer)?;
}
Ok(())
}
pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
let doc_count = reader.read_u32::<LittleEndian>()?;
let max_score = reader.read_f32::<LittleEndian>()?;
let num_blocks = reader.read_u32::<LittleEndian>()? as usize;
let mut blocks = Vec::with_capacity(num_blocks);
for _ in 0..num_blocks {
blocks.push(HorizontalBP128Block::deserialize(reader)?);
}
Ok(Self {
blocks,
doc_count,
max_score,
})
}
pub fn iterator(&self) -> HorizontalBP128Iterator<'_> {
HorizontalBP128Iterator::new(self)
}
}
pub struct HorizontalBP128Iterator<'a> {
posting_list: &'a HorizontalBP128PostingList,
current_block: usize,
current_block_len: usize,
block_doc_ids: Vec<u32>,
block_term_freqs: Vec<u32>,
pos_in_block: usize,
exhausted: bool,
}
impl<'a> HorizontalBP128Iterator<'a> {
pub fn new(posting_list: &'a HorizontalBP128PostingList) -> Self {
let mut iter = Self {
posting_list,
current_block: 0,
current_block_len: 0,
block_doc_ids: vec![0u32; HORIZONTAL_BP128_BLOCK_SIZE],
block_term_freqs: vec![0u32; HORIZONTAL_BP128_BLOCK_SIZE],
pos_in_block: 0,
exhausted: posting_list.blocks.is_empty(),
};
if !iter.exhausted {
iter.decode_current_block();
}
iter
}
#[inline]
fn decode_current_block(&mut self) {
let block = &self.posting_list.blocks[self.current_block];
self.current_block_len = block.decode_doc_ids_into(&mut self.block_doc_ids);
block.decode_term_freqs_into(&mut self.block_term_freqs);
self.pos_in_block = 0;
}
#[inline]
pub fn doc(&self) -> u32 {
if self.exhausted {
u32::MAX
} else {
self.block_doc_ids[self.pos_in_block]
}
}
#[inline]
pub fn term_freq(&self) -> u32 {
if self.exhausted {
0
} else {
self.block_term_freqs[self.pos_in_block]
}
}
#[inline]
pub fn advance(&mut self) -> u32 {
if self.exhausted {
return u32::MAX;
}
self.pos_in_block += 1;
if self.pos_in_block >= self.current_block_len {
self.current_block += 1;
if self.current_block >= self.posting_list.blocks.len() {
self.exhausted = true;
return u32::MAX;
}
self.decode_current_block();
}
self.doc()
}
pub fn seek(&mut self, target: u32) -> u32 {
if self.exhausted {
return u32::MAX;
}
let block_idx = self.posting_list.blocks[self.current_block..].binary_search_by(|block| {
if block.last_doc_id < target {
std::cmp::Ordering::Less
} else if block.first_doc_id > target {
std::cmp::Ordering::Greater
} else {
std::cmp::Ordering::Equal
}
});
let target_block = match block_idx {
Ok(idx) => self.current_block + idx,
Err(idx) => {
if self.current_block + idx >= self.posting_list.blocks.len() {
self.exhausted = true;
return u32::MAX;
}
self.current_block + idx
}
};
if target_block != self.current_block {
self.current_block = target_block;
self.decode_current_block();
} else if self.current_block_len == 0 {
self.decode_current_block();
}
let pos = binary_search_block(
&self.block_doc_ids[self.pos_in_block..self.current_block_len],
target,
);
self.pos_in_block += pos;
if self.pos_in_block >= self.current_block_len {
self.current_block += 1;
if self.current_block >= self.posting_list.blocks.len() {
self.exhausted = true;
return u32::MAX;
}
self.decode_current_block();
}
self.doc()
}
pub fn max_remaining_score(&self) -> f32 {
if self.exhausted {
return 0.0;
}
self.posting_list.blocks[self.current_block..]
.iter()
.map(|b| b.max_block_score)
.fold(0.0f32, |a, b| a.max(b))
}
pub fn skip_to_block_with_doc(&mut self, target: u32) -> Option<(u32, f32)> {
while self.current_block < self.posting_list.blocks.len() {
let block = &self.posting_list.blocks[self.current_block];
if block.last_doc_id >= target {
return Some((block.first_doc_id, block.max_block_score));
}
self.current_block += 1;
}
self.exhausted = true;
None
}
pub fn current_block_max_score(&self) -> f32 {
if self.exhausted {
0.0
} else {
self.posting_list.blocks[self.current_block].max_block_score
}
}
pub fn current_block_max_tf(&self) -> u32 {
if self.exhausted {
0
} else {
self.posting_list.blocks[self.current_block].max_tf
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bits_needed() {
assert_eq!(simd::bits_needed(0), 0);
assert_eq!(simd::bits_needed(1), 1);
assert_eq!(simd::bits_needed(2), 2);
assert_eq!(simd::bits_needed(3), 2);
assert_eq!(simd::bits_needed(255), 8);
assert_eq!(simd::bits_needed(256), 9);
}
#[test]
fn test_pack_unpack() {
let mut values = [0u32; HORIZONTAL_BP128_BLOCK_SIZE];
for (i, value) in values.iter_mut().enumerate() {
*value = (i * 3) as u32;
}
let max_val = values.iter().max().copied().unwrap();
let bit_width = simd::bits_needed(max_val);
let mut packed = Vec::new();
pack_block(&values, bit_width, &mut packed);
let mut unpacked = [0u32; HORIZONTAL_BP128_BLOCK_SIZE];
unpack_block(&packed, bit_width, &mut unpacked);
assert_eq!(values, unpacked);
}
#[test]
fn test_bitpacked_posting_list() {
let doc_ids: Vec<u32> = (0..200).map(|i| i * 2).collect();
let term_freqs: Vec<u32> = (0..200).map(|i| (i % 10) + 1).collect();
let posting_list = HorizontalBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
assert_eq!(posting_list.doc_count, 200);
assert_eq!(posting_list.blocks.len(), 2);
let mut iter = posting_list.iterator();
for (i, &expected_doc) in doc_ids.iter().enumerate() {
assert_eq!(iter.doc(), expected_doc, "Mismatch at position {}", i);
assert_eq!(iter.term_freq(), term_freqs[i]);
if i < doc_ids.len() - 1 {
iter.advance();
}
}
}
#[test]
fn test_bitpacked_seek() {
let doc_ids: Vec<u32> = vec![10, 20, 30, 100, 200, 300, 1000, 2000];
let term_freqs: Vec<u32> = vec![1, 2, 3, 4, 5, 6, 7, 8];
let posting_list = HorizontalBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
let mut iter = posting_list.iterator();
assert_eq!(iter.seek(25), 30);
assert_eq!(iter.seek(100), 100);
assert_eq!(iter.seek(500), 1000);
assert_eq!(iter.seek(3000), u32::MAX);
}
#[test]
fn test_serialization() {
let doc_ids: Vec<u32> = (0..50).map(|i| i * 3).collect();
let term_freqs: Vec<u32> = (0..50).map(|_| 1).collect();
let posting_list = HorizontalBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.5);
let mut buffer = Vec::new();
posting_list.serialize(&mut buffer).unwrap();
let restored = HorizontalBP128PostingList::deserialize(&mut &buffer[..]).unwrap();
assert_eq!(restored.doc_count, posting_list.doc_count);
assert_eq!(restored.blocks.len(), posting_list.blocks.len());
let mut iter1 = posting_list.iterator();
let mut iter2 = restored.iterator();
while iter1.doc() != u32::MAX {
assert_eq!(iter1.doc(), iter2.doc());
assert_eq!(iter1.term_freq(), iter2.term_freq());
iter1.advance();
iter2.advance();
}
}
#[test]
fn test_hillis_steele_prefix_sum() {
let mut deltas = [1u32, 2, 3, 4, 5, 6, 7, 8];
prefix_sum_8(&mut deltas);
assert_eq!(deltas, [1, 3, 6, 10, 15, 21, 28, 36]);
let deltas2 = [0u32; 16]; let mut output2 = [0u32; 16];
simd::delta_decode(&mut output2, &deltas2, 100, 8);
assert_eq!(&output2[..8], &[100, 101, 102, 103, 104, 105, 106, 107]);
let deltas3 = [1u32, 0, 2, 0, 4, 0, 0, 0];
let mut output3 = [0u32; 8];
simd::delta_decode(&mut output3, &deltas3, 10, 8);
assert_eq!(&output3[..8], &[10, 12, 13, 16, 17, 22, 23, 24]);
}
#[test]
fn test_delta_decode_large_block() {
let doc_ids: Vec<u32> = (0..128).map(|i| i * 5 + 100).collect();
let term_freqs: Vec<u32> = vec![1; 128];
let posting_list = HorizontalBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
let decoded = posting_list.blocks[0].decode_doc_ids();
assert_eq!(decoded.len(), 128);
for (i, (&expected, &actual)) in doc_ids.iter().zip(decoded.iter()).enumerate() {
assert_eq!(expected, actual, "Mismatch at position {}", i);
}
}
}