use crate::structures::simd;
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use std::io::{self, Read, Write};
pub const VERTICAL_BP128_BLOCK_SIZE: usize = 128;
#[allow(dead_code)]
const SIMD_LANES: usize = 4;
#[allow(dead_code)]
const GROUPS_PER_BLOCK: usize = VERTICAL_BP128_BLOCK_SIZE / SIMD_LANES;
#[cfg(target_arch = "aarch64")]
#[allow(dead_code)]
mod neon {
use super::*;
use std::arch::aarch64::*;
static BIT_EXPAND_LUT: [[u32; 8]; 256] = {
let mut lut = [[0u32; 8]; 256];
let mut byte = 0usize;
while byte < 256 {
let mut bit = 0;
while bit < 8 {
lut[byte][bit] = ((byte >> bit) & 1) as u32;
bit += 1;
}
byte += 1;
}
lut
};
#[inline]
#[target_feature(enable = "neon")]
pub unsafe fn unpack_4_neon(input: &[u8], bit_width: u8, output: &mut [u32; 4]) {
if bit_width == 0 {
*output = [0; 4];
return;
}
let mask = (1u32 << bit_width) - 1;
let mut packed_bytes = [0u8; 16];
let bytes_needed = ((bit_width as usize) * 4).div_ceil(8);
packed_bytes[..bytes_needed.min(16)].copy_from_slice(&input[..bytes_needed.min(16)]);
let packed = u128::from_le_bytes(packed_bytes);
let v0 = (packed & mask as u128) as u32;
let v1 = ((packed >> bit_width) & mask as u128) as u32;
let v2 = ((packed >> (bit_width * 2)) & mask as u128) as u32;
let v3 = ((packed >> (bit_width * 3)) & mask as u128) as u32;
unsafe {
let result = vld1q_u32([v0, v1, v2, v3].as_ptr());
vst1q_u32(output.as_mut_ptr(), result);
}
}
#[inline]
#[target_feature(enable = "neon")]
pub unsafe fn prefix_sum_4_neon(values: &mut [u32; 4]) {
unsafe {
let mut v = vld1q_u32(values.as_ptr());
let shifted1 = vextq_u32(vdupq_n_u32(0), v, 3); v = vaddq_u32(v, shifted1);
let shifted2 = vextq_u32(vdupq_n_u32(0), v, 2); v = vaddq_u32(v, shifted2);
vst1q_u32(values.as_mut_ptr(), v);
}
}
#[target_feature(enable = "neon")]
pub unsafe fn unpack_block_neon(
input: &[u8],
bit_width: u8,
output: &mut [u32; VERTICAL_BP128_BLOCK_SIZE],
) {
if bit_width == 0 {
output.fill(0);
return;
}
unsafe {
let zero = vdupq_n_u32(0);
for i in (0..VERTICAL_BP128_BLOCK_SIZE).step_by(4) {
vst1q_u32(output[i..].as_mut_ptr(), zero);
}
}
for bit_pos in 0..bit_width as usize {
let byte_offset = bit_pos * 16;
let bit_mask = 1u32 << bit_pos;
if bit_pos + 1 < bit_width as usize {
let next_offset = (bit_pos + 1) * 16;
unsafe {
std::arch::asm!(
"prfm pldl1keep, [{0}]",
in(reg) input.as_ptr().add(next_offset),
options(nostack, preserves_flags)
);
}
}
for chunk in 0..4 {
let chunk_offset = byte_offset + chunk * 4;
let b0 = input[chunk_offset] as usize;
let b1 = input[chunk_offset + 1] as usize;
let b2 = input[chunk_offset + 2] as usize;
let b3 = input[chunk_offset + 3] as usize;
let base_int = chunk * 32;
unsafe {
let mask_vec = vdupq_n_u32(bit_mask);
let lut0 = &BIT_EXPAND_LUT[b0];
let bits_0_3 = vld1q_u32(lut0.as_ptr());
let bits_4_7 = vld1q_u32(lut0[4..].as_ptr());
let shifted_0_3 = vmulq_u32(bits_0_3, mask_vec);
let shifted_4_7 = vmulq_u32(bits_4_7, mask_vec);
let cur_0_3 = vld1q_u32(output[base_int..].as_ptr());
let cur_4_7 = vld1q_u32(output[base_int + 4..].as_ptr());
vst1q_u32(
output[base_int..].as_mut_ptr(),
vorrq_u32(cur_0_3, shifted_0_3),
);
vst1q_u32(
output[base_int + 4..].as_mut_ptr(),
vorrq_u32(cur_4_7, shifted_4_7),
);
let lut1 = &BIT_EXPAND_LUT[b1];
let bits_8_11 = vld1q_u32(lut1.as_ptr());
let bits_12_15 = vld1q_u32(lut1[4..].as_ptr());
let shifted_8_11 = vmulq_u32(bits_8_11, mask_vec);
let shifted_12_15 = vmulq_u32(bits_12_15, mask_vec);
let cur_8_11 = vld1q_u32(output[base_int + 8..].as_ptr());
let cur_12_15 = vld1q_u32(output[base_int + 12..].as_ptr());
vst1q_u32(
output[base_int + 8..].as_mut_ptr(),
vorrq_u32(cur_8_11, shifted_8_11),
);
vst1q_u32(
output[base_int + 12..].as_mut_ptr(),
vorrq_u32(cur_12_15, shifted_12_15),
);
let lut2 = &BIT_EXPAND_LUT[b2];
let bits_16_19 = vld1q_u32(lut2.as_ptr());
let bits_20_23 = vld1q_u32(lut2[4..].as_ptr());
let shifted_16_19 = vmulq_u32(bits_16_19, mask_vec);
let shifted_20_23 = vmulq_u32(bits_20_23, mask_vec);
let cur_16_19 = vld1q_u32(output[base_int + 16..].as_ptr());
let cur_20_23 = vld1q_u32(output[base_int + 20..].as_ptr());
vst1q_u32(
output[base_int + 16..].as_mut_ptr(),
vorrq_u32(cur_16_19, shifted_16_19),
);
vst1q_u32(
output[base_int + 20..].as_mut_ptr(),
vorrq_u32(cur_20_23, shifted_20_23),
);
let lut3 = &BIT_EXPAND_LUT[b3];
let bits_24_27 = vld1q_u32(lut3.as_ptr());
let bits_28_31 = vld1q_u32(lut3[4..].as_ptr());
let shifted_24_27 = vmulq_u32(bits_24_27, mask_vec);
let shifted_28_31 = vmulq_u32(bits_28_31, mask_vec);
let cur_24_27 = vld1q_u32(output[base_int + 24..].as_ptr());
let cur_28_31 = vld1q_u32(output[base_int + 28..].as_ptr());
vst1q_u32(
output[base_int + 24..].as_mut_ptr(),
vorrq_u32(cur_24_27, shifted_24_27),
);
vst1q_u32(
output[base_int + 28..].as_mut_ptr(),
vorrq_u32(cur_28_31, shifted_28_31),
);
}
}
}
}
#[target_feature(enable = "neon")]
pub unsafe fn prefix_sum_block_neon(
deltas: &mut [u32; VERTICAL_BP128_BLOCK_SIZE],
first_val: u32,
) {
let mut carry = first_val;
for group in 0..GROUPS_PER_BLOCK {
let start = group * SIMD_LANES;
let mut group_vals = [
deltas[start],
deltas[start + 1],
deltas[start + 2],
deltas[start + 3],
];
group_vals[0] = group_vals[0].wrapping_add(carry);
unsafe { prefix_sum_4_neon(&mut group_vals) };
deltas[start..start + 4].copy_from_slice(&group_vals);
carry = group_vals[3];
}
}
}
#[allow(dead_code)]
mod scalar {
use super::*;
#[inline]
pub fn pack_4_scalar(values: &[u32; 4], bit_width: u8, output: &mut [u8]) {
if bit_width == 0 {
return;
}
let bytes_needed = ((bit_width as usize) * 4).div_ceil(8);
let mut packed = 0u128;
for (i, &val) in values.iter().enumerate() {
packed |= (val as u128) << (i * bit_width as usize);
}
let packed_bytes = packed.to_le_bytes();
output[..bytes_needed].copy_from_slice(&packed_bytes[..bytes_needed]);
}
#[inline]
pub fn unpack_4_scalar(input: &[u8], bit_width: u8, output: &mut [u32; 4]) {
if bit_width == 0 {
*output = [0; 4];
return;
}
let mask = (1u32 << bit_width) - 1;
let mut packed_bytes = [0u8; 16];
let bytes_needed = ((bit_width as usize) * 4).div_ceil(8);
packed_bytes[..bytes_needed.min(16)].copy_from_slice(&input[..bytes_needed.min(16)]);
let packed = u128::from_le_bytes(packed_bytes);
output[0] = (packed & mask as u128) as u32;
output[1] = ((packed >> bit_width) & mask as u128) as u32;
output[2] = ((packed >> (bit_width * 2)) & mask as u128) as u32;
output[3] = ((packed >> (bit_width * 3)) & mask as u128) as u32;
}
#[inline]
pub fn prefix_sum_4_scalar(vals: &mut [u32; 4]) {
vals[1] = vals[1].wrapping_add(vals[0]);
vals[2] = vals[2].wrapping_add(vals[1]);
vals[3] = vals[3].wrapping_add(vals[2]);
}
pub fn unpack_block_scalar(
input: &[u8],
bit_width: u8,
output: &mut [u32; VERTICAL_BP128_BLOCK_SIZE],
) {
if bit_width == 0 {
output.fill(0);
return;
}
output.fill(0);
for bit_pos in 0..bit_width as usize {
let byte_offset = bit_pos * 16;
for byte_idx in 0..16 {
let byte_val = input[byte_offset + byte_idx];
let base_int = byte_idx * 8;
output[base_int] |= (byte_val & 1) as u32 * (1 << bit_pos);
output[base_int + 1] |= ((byte_val >> 1) & 1) as u32 * (1 << bit_pos);
output[base_int + 2] |= ((byte_val >> 2) & 1) as u32 * (1 << bit_pos);
output[base_int + 3] |= ((byte_val >> 3) & 1) as u32 * (1 << bit_pos);
output[base_int + 4] |= ((byte_val >> 4) & 1) as u32 * (1 << bit_pos);
output[base_int + 5] |= ((byte_val >> 5) & 1) as u32 * (1 << bit_pos);
output[base_int + 6] |= ((byte_val >> 6) & 1) as u32 * (1 << bit_pos);
output[base_int + 7] |= ((byte_val >> 7) & 1) as u32 * (1 << bit_pos);
}
}
}
pub fn prefix_sum_block_scalar(deltas: &mut [u32; VERTICAL_BP128_BLOCK_SIZE], first_val: u32) {
let mut carry = first_val;
for group in 0..GROUPS_PER_BLOCK {
let start = group * SIMD_LANES;
let mut group_vals = [
deltas[start],
deltas[start + 1],
deltas[start + 2],
deltas[start + 3],
];
group_vals[0] = group_vals[0].wrapping_add(carry);
prefix_sum_4_scalar(&mut group_vals);
deltas[start..start + 4].copy_from_slice(&group_vals);
carry = group_vals[3];
}
}
}
pub fn pack_vertical(
values: &[u32; VERTICAL_BP128_BLOCK_SIZE],
bit_width: u8,
output: &mut Vec<u8>,
) {
if bit_width == 0 {
return;
}
let total_bytes = 16 * bit_width as usize;
let start = output.len();
output.resize(start + total_bytes, 0);
for bit_pos in 0..bit_width as usize {
let byte_offset = start + bit_pos * 16;
for byte_idx in 0..16 {
let base_int = byte_idx * 8;
let mut byte_val = 0u8;
byte_val |= ((values[base_int] >> bit_pos) & 1) as u8;
byte_val |= (((values[base_int + 1] >> bit_pos) & 1) as u8) << 1;
byte_val |= (((values[base_int + 2] >> bit_pos) & 1) as u8) << 2;
byte_val |= (((values[base_int + 3] >> bit_pos) & 1) as u8) << 3;
byte_val |= (((values[base_int + 4] >> bit_pos) & 1) as u8) << 4;
byte_val |= (((values[base_int + 5] >> bit_pos) & 1) as u8) << 5;
byte_val |= (((values[base_int + 6] >> bit_pos) & 1) as u8) << 6;
byte_val |= (((values[base_int + 7] >> bit_pos) & 1) as u8) << 7;
output[byte_offset + byte_idx] = byte_val;
}
}
}
pub fn unpack_vertical(input: &[u8], bit_width: u8, output: &mut [u32; VERTICAL_BP128_BLOCK_SIZE]) {
if bit_width == 0 {
output.fill(0);
return;
}
#[cfg(target_arch = "aarch64")]
{
unsafe { unpack_vertical_neon(input, bit_width, output) }
}
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("sse2") {
unsafe { unpack_vertical_sse(input, bit_width, output) }
} else {
unpack_vertical_scalar(input, bit_width, output)
}
}
#[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
{
unpack_vertical_scalar(input, bit_width, output)
}
}
#[inline]
fn unpack_vertical_scalar(
input: &[u8],
bit_width: u8,
output: &mut [u32; VERTICAL_BP128_BLOCK_SIZE],
) {
output.fill(0);
for bit_pos in 0..bit_width as usize {
let byte_offset = bit_pos * 16;
let bit_mask = 1u32 << bit_pos;
for byte_idx in 0..16 {
let byte_val = input[byte_offset + byte_idx];
let base_int = byte_idx * 8;
if byte_val & 0x01 != 0 {
output[base_int] |= bit_mask;
}
if byte_val & 0x02 != 0 {
output[base_int + 1] |= bit_mask;
}
if byte_val & 0x04 != 0 {
output[base_int + 2] |= bit_mask;
}
if byte_val & 0x08 != 0 {
output[base_int + 3] |= bit_mask;
}
if byte_val & 0x10 != 0 {
output[base_int + 4] |= bit_mask;
}
if byte_val & 0x20 != 0 {
output[base_int + 5] |= bit_mask;
}
if byte_val & 0x40 != 0 {
output[base_int + 6] |= bit_mask;
}
if byte_val & 0x80 != 0 {
output[base_int + 7] |= bit_mask;
}
}
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn unpack_vertical_neon(
input: &[u8],
bit_width: u8,
output: &mut [u32; VERTICAL_BP128_BLOCK_SIZE],
) {
use std::arch::aarch64::*;
unsafe {
let zero = vdupq_n_u32(0);
for i in (0..VERTICAL_BP128_BLOCK_SIZE).step_by(4) {
vst1q_u32(output[i..].as_mut_ptr(), zero);
}
for bit_pos in 0..bit_width as usize {
let byte_offset = bit_pos * 16;
let bit_mask = 1u32 << bit_pos;
let bytes = vld1q_u8(input.as_ptr().add(byte_offset));
let mut byte_array = [0u8; 16];
vst1q_u8(byte_array.as_mut_ptr(), bytes);
for (byte_idx, &byte_val) in byte_array.iter().enumerate() {
let base_int = byte_idx * 8;
output[base_int] |= ((byte_val & 0x01) as u32) * bit_mask;
output[base_int + 1] |= (((byte_val >> 1) & 0x01) as u32) * bit_mask;
output[base_int + 2] |= (((byte_val >> 2) & 0x01) as u32) * bit_mask;
output[base_int + 3] |= (((byte_val >> 3) & 0x01) as u32) * bit_mask;
output[base_int + 4] |= (((byte_val >> 4) & 0x01) as u32) * bit_mask;
output[base_int + 5] |= (((byte_val >> 5) & 0x01) as u32) * bit_mask;
output[base_int + 6] |= (((byte_val >> 6) & 0x01) as u32) * bit_mask;
output[base_int + 7] |= (((byte_val >> 7) & 0x01) as u32) * bit_mask;
}
}
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse2")]
unsafe fn unpack_vertical_sse(
input: &[u8],
bit_width: u8,
output: &mut [u32; VERTICAL_BP128_BLOCK_SIZE],
) {
use std::arch::x86_64::*;
unsafe {
let zero = _mm_setzero_si128();
for i in (0..VERTICAL_BP128_BLOCK_SIZE).step_by(4) {
_mm_storeu_si128(output[i..].as_mut_ptr() as *mut __m128i, zero);
}
for bit_pos in 0..bit_width as usize {
let byte_offset = bit_pos * 16;
let bytes = _mm_loadu_si128(input.as_ptr().add(byte_offset) as *const __m128i);
let mut byte_array = [0u8; 16];
_mm_storeu_si128(byte_array.as_mut_ptr() as *mut __m128i, bytes);
for (byte_idx, &byte_val) in byte_array.iter().enumerate() {
let base_int = byte_idx * 8;
if byte_val & 0x01 != 0 {
output[base_int] |= 1u32 << bit_pos;
}
if byte_val & 0x02 != 0 {
output[base_int + 1] |= 1u32 << bit_pos;
}
if byte_val & 0x04 != 0 {
output[base_int + 2] |= 1u32 << bit_pos;
}
if byte_val & 0x08 != 0 {
output[base_int + 3] |= 1u32 << bit_pos;
}
if byte_val & 0x10 != 0 {
output[base_int + 4] |= 1u32 << bit_pos;
}
if byte_val & 0x20 != 0 {
output[base_int + 5] |= 1u32 << bit_pos;
}
if byte_val & 0x40 != 0 {
output[base_int + 6] |= 1u32 << bit_pos;
}
if byte_val & 0x80 != 0 {
output[base_int + 7] |= 1u32 << bit_pos;
}
}
}
}
}
#[allow(dead_code)]
pub fn pack_horizontal(
values: &[u32; VERTICAL_BP128_BLOCK_SIZE],
bit_width: u8,
output: &mut Vec<u8>,
) {
if bit_width == 0 {
return;
}
let bytes_needed = (VERTICAL_BP128_BLOCK_SIZE * bit_width as usize).div_ceil(8);
let start = output.len();
output.resize(start + bytes_needed, 0);
let mut bit_pos = 0usize;
for &value in values {
let byte_idx = start + bit_pos / 8;
let bit_offset = bit_pos % 8;
let mut remaining_bits = bit_width as usize;
let mut val = value;
let mut current_byte_idx = byte_idx;
let mut current_bit_offset = bit_offset;
while remaining_bits > 0 {
let bits_in_byte = (8 - current_bit_offset).min(remaining_bits);
let mask = ((1u32 << bits_in_byte) - 1) as u8;
output[current_byte_idx] |= ((val as u8) & mask) << current_bit_offset;
val >>= bits_in_byte;
remaining_bits -= bits_in_byte;
current_byte_idx += 1;
current_bit_offset = 0;
}
bit_pos += bit_width as usize;
}
}
#[allow(dead_code)]
pub fn unpack_horizontal(
input: &[u8],
bit_width: u8,
output: &mut [u32; VERTICAL_BP128_BLOCK_SIZE],
) {
if bit_width == 0 {
output.fill(0);
return;
}
let mask = (1u64 << bit_width) - 1;
let bit_width_usize = bit_width as usize;
let mut bit_pos = 0usize;
let input_ptr = input.as_ptr();
for out in output.iter_mut() {
let byte_idx = bit_pos >> 3;
let bit_offset = bit_pos & 7;
let word = unsafe { (input_ptr.add(byte_idx) as *const u64).read_unaligned() };
*out = ((word >> bit_offset) & mask) as u32;
bit_pos += bit_width_usize;
}
}
#[allow(dead_code)]
pub fn prefix_sum_128(deltas: &mut [u32; VERTICAL_BP128_BLOCK_SIZE], first_val: u32) {
#[cfg(target_arch = "aarch64")]
{
unsafe { neon::prefix_sum_block_neon(deltas, first_val) }
}
#[cfg(not(target_arch = "aarch64"))]
{
scalar::prefix_sum_block_scalar(deltas, first_val)
}
}
pub fn unpack_vertical_d1(
input: &[u8],
bit_width: u8,
first_doc_id: u32,
output: &mut [u32; VERTICAL_BP128_BLOCK_SIZE],
count: usize,
) {
if count == 0 {
return;
}
if bit_width == 0 {
let mut current = first_doc_id;
output[0] = current;
for out_val in output.iter_mut().take(count).skip(1) {
current = current.wrapping_add(1);
*out_val = current;
}
return;
}
let mut deltas = [0u32; VERTICAL_BP128_BLOCK_SIZE];
unpack_vertical(input, bit_width, &mut deltas);
output[0] = first_doc_id;
let mut current = first_doc_id;
for i in 1..count {
current = current.wrapping_add(deltas[i - 1]).wrapping_add(1);
output[i] = current;
}
}
#[derive(Debug, Clone)]
pub struct VerticalBP128Block {
pub doc_data: Vec<u8>,
pub doc_bit_width: u8,
pub tf_data: Vec<u8>,
pub tf_bit_width: u8,
pub first_doc_id: u32,
pub last_doc_id: u32,
pub num_docs: u16,
pub max_tf: u32,
pub max_block_score: f32,
}
impl VerticalBP128Block {
pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
writer.write_u32::<LittleEndian>(self.first_doc_id)?;
writer.write_u32::<LittleEndian>(self.last_doc_id)?;
writer.write_u16::<LittleEndian>(self.num_docs)?;
writer.write_u8(self.doc_bit_width)?;
writer.write_u8(self.tf_bit_width)?;
writer.write_u32::<LittleEndian>(self.max_tf)?;
writer.write_f32::<LittleEndian>(self.max_block_score)?;
writer.write_u16::<LittleEndian>(self.doc_data.len() as u16)?;
writer.write_all(&self.doc_data)?;
writer.write_u16::<LittleEndian>(self.tf_data.len() as u16)?;
writer.write_all(&self.tf_data)?;
Ok(())
}
pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
let first_doc_id = reader.read_u32::<LittleEndian>()?;
let last_doc_id = reader.read_u32::<LittleEndian>()?;
let num_docs = reader.read_u16::<LittleEndian>()?;
let doc_bit_width = reader.read_u8()?;
let tf_bit_width = reader.read_u8()?;
let max_tf = reader.read_u32::<LittleEndian>()?;
let max_block_score = reader.read_f32::<LittleEndian>()?;
let doc_len = reader.read_u16::<LittleEndian>()? as usize;
let mut doc_data = vec![0u8; doc_len];
reader.read_exact(&mut doc_data)?;
let tf_len = reader.read_u16::<LittleEndian>()? as usize;
let mut tf_data = vec![0u8; tf_len];
reader.read_exact(&mut tf_data)?;
Ok(Self {
doc_data,
doc_bit_width,
tf_data,
tf_bit_width,
first_doc_id,
last_doc_id,
num_docs,
max_tf,
max_block_score,
})
}
pub fn decode_doc_ids(&self) -> Vec<u32> {
let mut output = vec![0u32; self.num_docs as usize];
self.decode_doc_ids_into(&mut output);
output
}
#[inline]
pub fn decode_doc_ids_into(&self, output: &mut [u32]) -> usize {
let count = self.num_docs as usize;
if count == 0 {
return 0;
}
if count == VERTICAL_BP128_BLOCK_SIZE && output.len() >= VERTICAL_BP128_BLOCK_SIZE {
let out_array: &mut [u32; VERTICAL_BP128_BLOCK_SIZE] = (&mut output
[..VERTICAL_BP128_BLOCK_SIZE])
.try_into()
.unwrap();
unpack_vertical_d1(
&self.doc_data,
self.doc_bit_width,
self.first_doc_id,
out_array,
count,
);
} else {
let mut temp = [0u32; VERTICAL_BP128_BLOCK_SIZE];
unpack_vertical_d1(
&self.doc_data,
self.doc_bit_width,
self.first_doc_id,
&mut temp,
count,
);
output[..count].copy_from_slice(&temp[..count]);
}
count
}
pub fn decode_term_freqs(&self) -> Vec<u32> {
let mut output = vec![0u32; self.num_docs as usize];
self.decode_term_freqs_into(&mut output);
output
}
#[inline]
pub fn decode_term_freqs_into(&self, output: &mut [u32]) -> usize {
let count = self.num_docs as usize;
if count == 0 {
return 0;
}
if count == VERTICAL_BP128_BLOCK_SIZE && output.len() >= VERTICAL_BP128_BLOCK_SIZE {
let out_array: &mut [u32; VERTICAL_BP128_BLOCK_SIZE] = (&mut output
[..VERTICAL_BP128_BLOCK_SIZE])
.try_into()
.unwrap();
unpack_vertical(&self.tf_data, self.tf_bit_width, out_array);
} else {
let mut temp = [0u32; VERTICAL_BP128_BLOCK_SIZE];
unpack_vertical(&self.tf_data, self.tf_bit_width, &mut temp);
output[..count].copy_from_slice(&temp[..count]);
}
simd::add_one(output, count);
count
}
}
#[derive(Debug, Clone)]
pub struct VerticalBP128PostingList {
pub blocks: Vec<VerticalBP128Block>,
pub doc_count: u32,
pub max_score: f32,
}
impl VerticalBP128PostingList {
pub fn from_postings(doc_ids: &[u32], term_freqs: &[u32], idf: f32) -> Self {
assert_eq!(doc_ids.len(), term_freqs.len());
if doc_ids.is_empty() {
return Self {
blocks: Vec::new(),
doc_count: 0,
max_score: 0.0,
};
}
let mut blocks = Vec::new();
let mut max_score = 0.0f32;
let mut i = 0;
while i < doc_ids.len() {
let block_end = (i + VERTICAL_BP128_BLOCK_SIZE).min(doc_ids.len());
let block_docs = &doc_ids[i..block_end];
let block_tfs = &term_freqs[i..block_end];
let block = Self::create_block(block_docs, block_tfs, idf);
max_score = max_score.max(block.max_block_score);
blocks.push(block);
i = block_end;
}
Self {
blocks,
doc_count: doc_ids.len() as u32,
max_score,
}
}
fn create_block(doc_ids: &[u32], term_freqs: &[u32], idf: f32) -> VerticalBP128Block {
let num_docs = doc_ids.len();
let first_doc_id = doc_ids[0];
let last_doc_id = *doc_ids.last().unwrap();
let mut deltas = [0u32; VERTICAL_BP128_BLOCK_SIZE];
let mut max_delta = 0u32;
for j in 1..num_docs {
let delta = doc_ids[j] - doc_ids[j - 1] - 1;
deltas[j - 1] = delta;
max_delta = max_delta.max(delta);
}
let mut tfs = [0u32; VERTICAL_BP128_BLOCK_SIZE];
let mut max_tf = 0u32;
for (j, &tf) in term_freqs.iter().enumerate() {
tfs[j] = tf.saturating_sub(1);
max_tf = max_tf.max(tf);
}
let doc_bit_width = simd::bits_needed(max_delta);
let tf_bit_width = simd::bits_needed(max_tf.saturating_sub(1));
let mut doc_data = Vec::new();
pack_vertical(&deltas, doc_bit_width, &mut doc_data);
let mut tf_data = Vec::new();
pack_vertical(&tfs, tf_bit_width, &mut tf_data);
let max_block_score = crate::query::bm25_upper_bound(max_tf as f32, idf);
VerticalBP128Block {
doc_data,
doc_bit_width,
tf_data,
tf_bit_width,
first_doc_id,
last_doc_id,
num_docs: num_docs as u16,
max_tf,
max_block_score,
}
}
pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
writer.write_u32::<LittleEndian>(self.doc_count)?;
writer.write_f32::<LittleEndian>(self.max_score)?;
writer.write_u32::<LittleEndian>(self.blocks.len() as u32)?;
for block in &self.blocks {
block.serialize(writer)?;
}
Ok(())
}
pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
let doc_count = reader.read_u32::<LittleEndian>()?;
let max_score = reader.read_f32::<LittleEndian>()?;
let num_blocks = reader.read_u32::<LittleEndian>()? as usize;
let mut blocks = Vec::with_capacity(num_blocks);
for _ in 0..num_blocks {
blocks.push(VerticalBP128Block::deserialize(reader)?);
}
Ok(Self {
blocks,
doc_count,
max_score,
})
}
pub fn iterator(&self) -> VerticalBP128Iterator<'_> {
VerticalBP128Iterator::new(self)
}
pub fn size_bytes(&self) -> usize {
let mut size = 12; for block in &self.blocks {
size += 22 + block.doc_data.len() + block.tf_data.len();
}
size
}
}
pub struct VerticalBP128Iterator<'a> {
list: &'a VerticalBP128PostingList,
current_block: usize,
current_block_len: usize,
block_doc_ids: Vec<u32>,
block_term_freqs: Vec<u32>,
pos_in_block: usize,
exhausted: bool,
}
impl<'a> VerticalBP128Iterator<'a> {
pub fn new(list: &'a VerticalBP128PostingList) -> Self {
let mut iter = Self {
list,
current_block: 0,
current_block_len: 0,
block_doc_ids: vec![0u32; VERTICAL_BP128_BLOCK_SIZE],
block_term_freqs: vec![0u32; VERTICAL_BP128_BLOCK_SIZE],
pos_in_block: 0,
exhausted: list.blocks.is_empty(),
};
if !iter.exhausted {
iter.decode_current_block();
}
iter
}
#[inline]
fn decode_current_block(&mut self) {
let block = &self.list.blocks[self.current_block];
self.current_block_len = block.decode_doc_ids_into(&mut self.block_doc_ids);
block.decode_term_freqs_into(&mut self.block_term_freqs);
self.pos_in_block = 0;
}
#[inline]
pub fn doc(&self) -> u32 {
if self.exhausted {
u32::MAX
} else {
self.block_doc_ids[self.pos_in_block]
}
}
#[inline]
pub fn term_freq(&self) -> u32 {
if self.exhausted {
0
} else {
self.block_term_freqs[self.pos_in_block]
}
}
#[inline]
pub fn advance(&mut self) -> u32 {
if self.exhausted {
return u32::MAX;
}
self.pos_in_block += 1;
if self.pos_in_block >= self.current_block_len {
self.current_block += 1;
if self.current_block >= self.list.blocks.len() {
self.exhausted = true;
return u32::MAX;
}
self.decode_current_block();
}
self.doc()
}
pub fn seek(&mut self, target: u32) -> u32 {
if self.exhausted {
return u32::MAX;
}
let block_idx = self.list.blocks[self.current_block..].binary_search_by(|block| {
if block.last_doc_id < target {
std::cmp::Ordering::Less
} else if block.first_doc_id > target {
std::cmp::Ordering::Greater
} else {
std::cmp::Ordering::Equal
}
});
let target_block = match block_idx {
Ok(idx) => self.current_block + idx,
Err(idx) => {
if self.current_block + idx >= self.list.blocks.len() {
self.exhausted = true;
return u32::MAX;
}
self.current_block + idx
}
};
if target_block != self.current_block {
self.current_block = target_block;
self.decode_current_block();
}
let pos = self.block_doc_ids[self.pos_in_block..self.current_block_len]
.binary_search(&target)
.unwrap_or_else(|x| x);
self.pos_in_block += pos;
if self.pos_in_block >= self.current_block_len {
self.current_block += 1;
if self.current_block >= self.list.blocks.len() {
self.exhausted = true;
return u32::MAX;
}
self.decode_current_block();
}
self.doc()
}
pub fn max_remaining_score(&self) -> f32 {
if self.exhausted {
return 0.0;
}
self.list.blocks[self.current_block..]
.iter()
.map(|b| b.max_block_score)
.fold(0.0f32, |a, b| a.max(b))
}
pub fn current_block_max_score(&self) -> f32 {
if self.exhausted {
0.0
} else {
self.list.blocks[self.current_block].max_block_score
}
}
pub fn current_block_max_tf(&self) -> u32 {
if self.exhausted {
0
} else {
self.list.blocks[self.current_block].max_tf
}
}
pub fn skip_to_block_with_doc(&mut self, target: u32) -> Option<(u32, f32)> {
while self.current_block < self.list.blocks.len() {
let block = &self.list.blocks[self.current_block];
if block.last_doc_id >= target {
self.decode_current_block();
return Some((block.first_doc_id, block.max_block_score));
}
self.current_block += 1;
}
self.exhausted = true;
None
}
pub fn is_exhausted(&self) -> bool {
self.exhausted
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pack_unpack_vertical() {
let mut values = [0u32; VERTICAL_BP128_BLOCK_SIZE];
for (i, v) in values.iter_mut().enumerate() {
*v = (i * 3) as u32;
}
let max_val = values.iter().max().copied().unwrap();
let bit_width = simd::bits_needed(max_val);
let mut packed = Vec::new();
pack_vertical(&values, bit_width, &mut packed);
let mut unpacked = [0u32; VERTICAL_BP128_BLOCK_SIZE];
unpack_vertical(&packed, bit_width, &mut unpacked);
assert_eq!(values, unpacked);
}
#[test]
fn test_pack_unpack_vertical_various_widths() {
for bit_width in 1..=20 {
let mut values = [0u32; VERTICAL_BP128_BLOCK_SIZE];
let max_val = (1u32 << bit_width) - 1;
for (i, v) in values.iter_mut().enumerate() {
*v = (i as u32) % (max_val + 1);
}
let mut packed = Vec::new();
pack_vertical(&values, bit_width, &mut packed);
let mut unpacked = [0u32; VERTICAL_BP128_BLOCK_SIZE];
unpack_vertical(&packed, bit_width, &mut unpacked);
assert_eq!(values, unpacked, "Failed for bit_width={}", bit_width);
}
}
#[test]
fn test_simd_bp128_posting_list() {
let doc_ids: Vec<u32> = (0..200).map(|i| i * 2).collect();
let term_freqs: Vec<u32> = (0..200).map(|i| (i % 10) + 1).collect();
let list = VerticalBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
assert_eq!(list.doc_count, 200);
assert_eq!(list.blocks.len(), 2);
let mut iter = list.iterator();
for (i, &expected_doc) in doc_ids.iter().enumerate() {
assert_eq!(iter.doc(), expected_doc, "Doc mismatch at {}", i);
assert_eq!(iter.term_freq(), term_freqs[i], "TF mismatch at {}", i);
if i < doc_ids.len() - 1 {
iter.advance();
}
}
}
#[test]
fn test_simd_bp128_seek() {
let doc_ids: Vec<u32> = vec![10, 20, 30, 100, 200, 300, 1000, 2000];
let term_freqs: Vec<u32> = vec![1, 2, 3, 4, 5, 6, 7, 8];
let list = VerticalBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
let mut iter = list.iterator();
assert_eq!(iter.seek(25), 30);
assert_eq!(iter.seek(100), 100);
assert_eq!(iter.seek(500), 1000);
assert_eq!(iter.seek(3000), u32::MAX);
}
#[test]
fn test_simd_bp128_serialization() {
let doc_ids: Vec<u32> = (0..300).map(|i| i * 3).collect();
let term_freqs: Vec<u32> = (0..300).map(|i| (i % 5) + 1).collect();
let list = VerticalBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.5);
let mut buffer = Vec::new();
list.serialize(&mut buffer).unwrap();
let restored = VerticalBP128PostingList::deserialize(&mut &buffer[..]).unwrap();
assert_eq!(restored.doc_count, list.doc_count);
assert_eq!(restored.blocks.len(), list.blocks.len());
let mut iter1 = list.iterator();
let mut iter2 = restored.iterator();
while iter1.doc() != u32::MAX {
assert_eq!(iter1.doc(), iter2.doc());
assert_eq!(iter1.term_freq(), iter2.term_freq());
iter1.advance();
iter2.advance();
}
}
#[test]
fn test_vertical_layout_size() {
let mut values = [0u32; VERTICAL_BP128_BLOCK_SIZE];
for (i, v) in values.iter_mut().enumerate() {
*v = i as u32;
}
let bit_width = simd::bits_needed(127); assert_eq!(bit_width, 7);
let mut packed = Vec::new();
pack_horizontal(&values, bit_width, &mut packed);
let expected_bytes = (VERTICAL_BP128_BLOCK_SIZE * bit_width as usize) / 8;
assert_eq!(expected_bytes, 112);
assert_eq!(packed.len(), expected_bytes);
}
#[test]
fn test_simd_bp128_block_max() {
let doc_ids: Vec<u32> = (0..500).map(|i| i * 2).collect();
let term_freqs: Vec<u32> = (0..500)
.map(|i| {
if i < 128 {
1 } else if i < 256 {
5 } else if i < 384 {
10 } else {
3 }
})
.collect();
let list = VerticalBP128PostingList::from_postings(&doc_ids, &term_freqs, 2.0);
assert_eq!(list.blocks.len(), 4);
assert_eq!(list.blocks[0].max_tf, 1);
assert_eq!(list.blocks[1].max_tf, 5);
assert_eq!(list.blocks[2].max_tf, 10);
assert_eq!(list.blocks[3].max_tf, 3);
assert!(list.blocks[2].max_block_score > list.blocks[0].max_block_score);
assert!(list.blocks[2].max_block_score > list.blocks[1].max_block_score);
assert!(list.blocks[2].max_block_score > list.blocks[3].max_block_score);
assert_eq!(list.max_score, list.blocks[2].max_block_score);
let mut iter = list.iterator();
assert_eq!(iter.current_block_max_tf(), 1);
iter.seek(256); assert_eq!(iter.current_block_max_tf(), 5);
iter.seek(512); assert_eq!(iter.current_block_max_tf(), 10);
let mut iter2 = list.iterator();
let result = iter2.skip_to_block_with_doc(300);
assert!(result.is_some());
let (first_doc, score) = result.unwrap();
assert!(first_doc <= 300);
assert!(score > 0.0);
}
}