use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use std::io::{self, Cursor, Read, Write};
use super::config::WeightQuantization;
use crate::DocId;
use crate::directories::OwnedBytes;
use crate::structures::postings::TERMINATED;
use crate::structures::simd;
pub const BLOCK_SIZE: usize = 128;
pub const MAX_BLOCK_SIZE: usize = 256;
#[derive(Debug, Clone, Copy)]
pub struct BlockHeader {
pub count: u16,
pub doc_id_bits: u8,
pub ordinal_bits: u8,
pub weight_quant: WeightQuantization,
pub first_doc_id: DocId,
pub max_weight: f32,
}
impl BlockHeader {
pub const SIZE: usize = 16;
pub fn write<W: Write>(&self, w: &mut W) -> io::Result<()> {
w.write_u16::<LittleEndian>(self.count)?;
w.write_u8(self.doc_id_bits)?;
w.write_u8(self.ordinal_bits)?;
w.write_u8(self.weight_quant as u8)?;
w.write_u8(0)?;
w.write_u16::<LittleEndian>(0)?;
w.write_u32::<LittleEndian>(self.first_doc_id)?;
w.write_f32::<LittleEndian>(self.max_weight)?;
Ok(())
}
pub fn read<R: Read>(r: &mut R) -> io::Result<Self> {
let count = r.read_u16::<LittleEndian>()?;
let doc_id_bits = r.read_u8()?;
let ordinal_bits = r.read_u8()?;
let weight_quant_byte = r.read_u8()?;
let _ = r.read_u8()?;
let _ = r.read_u16::<LittleEndian>()?;
let first_doc_id = r.read_u32::<LittleEndian>()?;
let max_weight = r.read_f32::<LittleEndian>()?;
let weight_quant = WeightQuantization::from_u8(weight_quant_byte)
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Invalid weight quant"))?;
Ok(Self {
count,
doc_id_bits,
ordinal_bits,
weight_quant,
first_doc_id,
max_weight,
})
}
}
#[derive(Debug, Clone)]
pub struct SparseBlock {
pub header: BlockHeader,
pub doc_ids_data: OwnedBytes,
pub ordinals_data: OwnedBytes,
pub weights_data: OwnedBytes,
last_doc_id: DocId,
}
impl SparseBlock {
pub fn from_postings(
postings: &[(DocId, u16, f32)],
weight_quant: WeightQuantization,
) -> io::Result<Self> {
assert!(!postings.is_empty() && postings.len() <= MAX_BLOCK_SIZE);
let count = postings.len();
let first_doc_id = postings[0].0;
let mut deltas = Vec::with_capacity(count);
let mut prev = first_doc_id;
for &(doc_id, _, _) in postings {
deltas.push(doc_id.saturating_sub(prev));
prev = doc_id;
}
deltas[0] = 0;
let doc_id_bits = simd::round_bit_width(find_optimal_bit_width(&deltas[1..]));
let ordinals: Vec<u16> = postings.iter().map(|(_, o, _)| *o).collect();
let max_ordinal = ordinals.iter().copied().max().unwrap_or(0);
let ordinal_bits = if max_ordinal == 0 {
0
} else {
simd::round_bit_width(bits_needed_u16(max_ordinal))
};
let weights: Vec<f32> = postings.iter().map(|(_, _, w)| *w).collect();
let max_weight = weights
.iter()
.copied()
.fold(0.0f32, |acc, w| acc.max(w.abs()));
let doc_ids_data = OwnedBytes::new({
let rounded = simd::RoundedBitWidth::from_u8(doc_id_bits);
let num_deltas = count - 1;
let byte_count = num_deltas * rounded.bytes_per_value();
let mut data = vec![0u8; byte_count];
simd::pack_rounded(&deltas[1..], rounded, &mut data);
data
});
let ordinals_data = OwnedBytes::new(if ordinal_bits > 0 {
let rounded = simd::RoundedBitWidth::from_u8(ordinal_bits);
let byte_count = count * rounded.bytes_per_value();
let mut data = vec![0u8; byte_count];
let ord_u32: Vec<u32> = ordinals.iter().map(|&o| o as u32).collect();
simd::pack_rounded(&ord_u32, rounded, &mut data);
data
} else {
Vec::new()
});
let weights_data = OwnedBytes::new(encode_weights(&weights, weight_quant)?);
let last_doc_id = postings.last().unwrap().0;
Ok(Self {
header: BlockHeader {
count: count as u16,
doc_id_bits,
ordinal_bits,
weight_quant,
first_doc_id,
max_weight,
},
doc_ids_data,
ordinals_data,
weights_data,
last_doc_id,
})
}
#[inline]
pub fn last_doc_id(&self) -> DocId {
self.last_doc_id
}
pub fn decode_doc_ids(&self) -> Vec<DocId> {
let mut out = Vec::with_capacity(self.header.count as usize);
self.decode_doc_ids_into(&mut out);
out
}
pub fn decode_doc_ids_into(&self, out: &mut Vec<DocId>) {
let count = self.header.count as usize;
out.clear();
out.resize(count, 0);
out[0] = self.header.first_doc_id;
if count > 1 {
let bits = self.header.doc_id_bits;
if bits == 0 {
out[1..].fill(self.header.first_doc_id);
} else {
simd::unpack_rounded(
&self.doc_ids_data,
simd::RoundedBitWidth::from_u8(bits),
&mut out[1..],
count - 1,
);
for i in 1..count {
out[i] += out[i - 1];
}
}
}
}
pub fn decode_ordinals(&self) -> Vec<u16> {
let mut out = Vec::with_capacity(self.header.count as usize);
self.decode_ordinals_into(&mut out);
out
}
pub fn decode_ordinals_into(&self, out: &mut Vec<u16>) {
let count = self.header.count as usize;
out.clear();
if self.header.ordinal_bits == 0 {
out.resize(count, 0u16);
} else {
let mut temp = [0u32; BLOCK_SIZE];
simd::unpack_rounded(
&self.ordinals_data,
simd::RoundedBitWidth::from_u8(self.header.ordinal_bits),
&mut temp[..count],
count,
);
out.reserve(count);
for &v in &temp[..count] {
out.push(v as u16);
}
}
}
pub fn decode_weights(&self) -> Vec<f32> {
let mut out = Vec::with_capacity(self.header.count as usize);
self.decode_weights_into(&mut out);
out
}
pub fn decode_weights_into(&self, out: &mut Vec<f32>) {
out.clear();
decode_weights_into(
&self.weights_data,
self.header.weight_quant,
self.header.count as usize,
out,
);
}
pub fn decode_scored_weights_into(&self, query_weight: f32, out: &mut Vec<f32>) {
out.clear();
let count = self.header.count as usize;
match self.header.weight_quant {
WeightQuantization::UInt8 if self.weights_data.len() >= 8 => {
let scale = f32::from_le_bytes([
self.weights_data[0],
self.weights_data[1],
self.weights_data[2],
self.weights_data[3],
]);
let min_val = f32::from_le_bytes([
self.weights_data[4],
self.weights_data[5],
self.weights_data[6],
self.weights_data[7],
]);
let eff_scale = query_weight * scale;
let eff_bias = query_weight * min_val;
out.resize(count, 0.0);
simd::dequantize_uint8(&self.weights_data[8..], out, eff_scale, eff_bias, count);
}
_ => {
decode_weights_into(&self.weights_data, self.header.weight_quant, count, out);
for w in out.iter_mut() {
*w *= query_weight;
}
}
}
}
#[inline]
pub fn accumulate_scored_weights(
&self,
query_weight: f32,
doc_ids: &[u32],
flat_scores: &mut [f32],
base_doc: u32,
dirty: &mut Vec<u32>,
) -> usize {
let count = self.header.count as usize;
match self.header.weight_quant {
WeightQuantization::UInt8 if self.weights_data.len() >= 8 => {
let scale = f32::from_le_bytes([
self.weights_data[0],
self.weights_data[1],
self.weights_data[2],
self.weights_data[3],
]);
let min_val = f32::from_le_bytes([
self.weights_data[4],
self.weights_data[5],
self.weights_data[6],
self.weights_data[7],
]);
let eff_scale = query_weight * scale;
let eff_bias = query_weight * min_val;
let quant_data = &self.weights_data[8..];
for i in 0..count.min(quant_data.len()).min(doc_ids.len()) {
let w = quant_data[i] as f32 * eff_scale + eff_bias;
let off = (doc_ids[i] - base_doc) as usize;
if off >= flat_scores.len() {
continue;
}
if flat_scores[off] == 0.0 {
dirty.push(doc_ids[i]);
}
flat_scores[off] += w;
}
count
}
_ => {
let mut weights_buf = Vec::with_capacity(count);
decode_weights_into(
&self.weights_data,
self.header.weight_quant,
count,
&mut weights_buf,
);
for i in 0..count.min(weights_buf.len()).min(doc_ids.len()) {
let w = weights_buf[i] * query_weight;
let off = (doc_ids[i] - base_doc) as usize;
if off >= flat_scores.len() {
continue;
}
if flat_scores[off] == 0.0 {
dirty.push(doc_ids[i]);
}
flat_scores[off] += w;
}
count
}
}
}
pub fn write<W: Write>(&self, w: &mut W) -> io::Result<()> {
self.header.write(w)?;
if self.doc_ids_data.len() > u16::MAX as usize
|| self.ordinals_data.len() > u16::MAX as usize
|| self.weights_data.len() > u16::MAX as usize
{
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"sparse sub-block too large for u16 length: doc_ids={}B ords={}B wts={}B",
self.doc_ids_data.len(),
self.ordinals_data.len(),
self.weights_data.len()
),
));
}
w.write_u16::<LittleEndian>(self.doc_ids_data.len() as u16)?;
w.write_u16::<LittleEndian>(self.ordinals_data.len() as u16)?;
w.write_u16::<LittleEndian>(self.weights_data.len() as u16)?;
w.write_u16::<LittleEndian>(0)?;
w.write_all(&self.doc_ids_data)?;
w.write_all(&self.ordinals_data)?;
w.write_all(&self.weights_data)?;
Ok(())
}
pub fn read<R: Read>(r: &mut R) -> io::Result<Self> {
let header = BlockHeader::read(r)?;
let doc_ids_len = r.read_u16::<LittleEndian>()? as usize;
let ordinals_len = r.read_u16::<LittleEndian>()? as usize;
let weights_len = r.read_u16::<LittleEndian>()? as usize;
let _ = r.read_u16::<LittleEndian>()?;
let mut doc_ids_vec = vec![0u8; doc_ids_len];
r.read_exact(&mut doc_ids_vec)?;
let mut ordinals_vec = vec![0u8; ordinals_len];
r.read_exact(&mut ordinals_vec)?;
let mut weights_vec = vec![0u8; weights_len];
r.read_exact(&mut weights_vec)?;
let last_doc_id = compute_last_doc(&header, &doc_ids_vec);
Ok(Self {
header,
doc_ids_data: OwnedBytes::new(doc_ids_vec),
ordinals_data: OwnedBytes::new(ordinals_vec),
weights_data: OwnedBytes::new(weights_vec),
last_doc_id,
})
}
pub fn from_owned_bytes(data: crate::directories::OwnedBytes) -> crate::Result<Self> {
let b = data.as_slice();
if b.len() < BlockHeader::SIZE + 8 {
return Err(crate::Error::Corruption(
"sparse block too small".to_string(),
));
}
let mut cursor = Cursor::new(&b[..BlockHeader::SIZE]);
let header =
BlockHeader::read(&mut cursor).map_err(|e| crate::Error::Corruption(e.to_string()))?;
if header.count == 0 {
let hex: String = b
.iter()
.take(32)
.map(|x| format!("{x:02x}"))
.collect::<Vec<_>>()
.join(" ");
return Err(crate::Error::Corruption(format!(
"sparse block has count=0 (data_len={}, first_32_bytes=[{}])",
b.len(),
hex
)));
}
let p = BlockHeader::SIZE;
let doc_ids_len = u16::from_le_bytes([b[p], b[p + 1]]) as usize;
let ordinals_len = u16::from_le_bytes([b[p + 2], b[p + 3]]) as usize;
let weights_len = u16::from_le_bytes([b[p + 4], b[p + 5]]) as usize;
let data_start = p + 8;
let ord_start = data_start + doc_ids_len;
let wt_start = ord_start + ordinals_len;
let expected_end = wt_start + weights_len;
if expected_end > b.len() {
let hex: String = b
.iter()
.take(32)
.map(|x| format!("{x:02x}"))
.collect::<Vec<_>>()
.join(" ");
return Err(crate::Error::Corruption(format!(
"sparse block sub-block overflow: count={} doc_ids={}B ords={}B wts={}B need={}B have={}B (first_32=[{}])",
header.count,
doc_ids_len,
ordinals_len,
weights_len,
expected_end,
b.len(),
hex
)));
}
let doc_ids_slice = data.slice(data_start..ord_start);
let last_doc_id = compute_last_doc(&header, &doc_ids_slice);
Ok(Self {
header,
doc_ids_data: doc_ids_slice,
ordinals_data: data.slice(ord_start..wt_start),
weights_data: data.slice(wt_start..wt_start + weights_len),
last_doc_id,
})
}
pub fn with_doc_offset(&self, doc_offset: u32) -> Self {
Self {
header: BlockHeader {
first_doc_id: self.header.first_doc_id + doc_offset,
..self.header
},
doc_ids_data: self.doc_ids_data.clone(),
ordinals_data: self.ordinals_data.clone(),
weights_data: self.weights_data.clone(),
last_doc_id: self.last_doc_id + doc_offset,
}
}
}
#[derive(Debug, Clone)]
pub struct BlockSparsePostingList {
pub doc_count: u32,
pub blocks: Vec<SparseBlock>,
}
impl BlockSparsePostingList {
pub fn from_postings_with_block_size(
postings: &[(DocId, u16, f32)],
weight_quant: WeightQuantization,
block_size: usize,
) -> io::Result<Self> {
if postings.is_empty() {
return Ok(Self {
doc_count: 0,
blocks: Vec::new(),
});
}
let block_size = block_size.max(16); let mut blocks = Vec::new();
for chunk in postings.chunks(block_size) {
blocks.push(SparseBlock::from_postings(chunk, weight_quant)?);
}
let mut unique_docs = 1u32;
for i in 1..postings.len() {
if postings[i].0 != postings[i - 1].0 {
unique_docs += 1;
}
}
Ok(Self {
doc_count: unique_docs,
blocks,
})
}
pub fn from_postings(
postings: &[(DocId, u16, f32)],
weight_quant: WeightQuantization,
) -> io::Result<Self> {
Self::from_postings_with_block_size(postings, weight_quant, BLOCK_SIZE)
}
pub fn from_postings_with_partition(
postings: &[(DocId, u16, f32)],
weight_quant: WeightQuantization,
partition: &[usize],
) -> io::Result<Self> {
if postings.is_empty() {
return Ok(Self {
doc_count: 0,
blocks: Vec::new(),
});
}
let mut blocks = Vec::with_capacity(partition.len());
let mut offset = 0;
for &block_size in partition {
let end = (offset + block_size).min(postings.len());
blocks.push(SparseBlock::from_postings(
&postings[offset..end],
weight_quant,
)?);
offset = end;
}
let mut unique_docs = 1u32;
for i in 1..postings.len() {
if postings[i].0 != postings[i - 1].0 {
unique_docs += 1;
}
}
Ok(Self {
doc_count: unique_docs,
blocks,
})
}
pub fn doc_count(&self) -> u32 {
self.doc_count
}
pub fn num_blocks(&self) -> usize {
self.blocks.len()
}
pub fn global_max_weight(&self) -> f32 {
self.blocks
.iter()
.map(|b| b.header.max_weight)
.fold(0.0f32, f32::max)
}
pub fn block_max_weight(&self, block_idx: usize) -> Option<f32> {
self.blocks.get(block_idx).map(|b| b.header.max_weight)
}
pub fn size_bytes(&self) -> usize {
use std::mem::size_of;
let header_size = size_of::<u32>() * 2; let blocks_size: usize = self
.blocks
.iter()
.map(|b| {
size_of::<BlockHeader>()
+ b.doc_ids_data.len()
+ b.ordinals_data.len()
+ b.weights_data.len()
})
.sum();
header_size + blocks_size
}
pub fn iterator(&self) -> BlockSparsePostingIterator<'_> {
BlockSparsePostingIterator::new(self)
}
pub fn serialize(&self) -> io::Result<(Vec<u8>, Vec<super::SparseSkipEntry>)> {
let mut block_data = Vec::new();
let mut skip_entries = Vec::with_capacity(self.blocks.len());
let mut offset = 0u64;
for block in &self.blocks {
let mut buf = Vec::new();
block.write(&mut buf)?;
let length = buf.len() as u32;
let first_doc = block.header.first_doc_id;
let last_doc = block.last_doc_id;
skip_entries.push(super::SparseSkipEntry::new(
first_doc,
last_doc,
offset,
length,
block.header.max_weight,
));
block_data.extend_from_slice(&buf);
offset += length as u64;
}
Ok((block_data, skip_entries))
}
#[cfg(test)]
pub fn from_parts(
doc_count: u32,
block_data: &[u8],
skip_entries: &[super::SparseSkipEntry],
) -> io::Result<Self> {
let mut blocks = Vec::with_capacity(skip_entries.len());
for entry in skip_entries {
let start = entry.offset as usize;
let end = start + entry.length as usize;
blocks.push(SparseBlock::read(&mut std::io::Cursor::new(
&block_data[start..end],
))?);
}
Ok(Self { doc_count, blocks })
}
pub fn decode_all(&self) -> Vec<(DocId, u16, f32)> {
let total_postings: usize = self.blocks.iter().map(|b| b.header.count as usize).sum();
let mut result = Vec::with_capacity(total_postings);
for block in &self.blocks {
let doc_ids = block.decode_doc_ids();
let ordinals = block.decode_ordinals();
let weights = block.decode_weights();
for i in 0..block.header.count as usize {
result.push((doc_ids[i], ordinals[i], weights[i]));
}
}
result
}
pub fn merge_with_offsets(lists: &[(&BlockSparsePostingList, u32)]) -> Self {
if lists.is_empty() {
return Self {
doc_count: 0,
blocks: Vec::new(),
};
}
let total_blocks: usize = lists.iter().map(|(pl, _)| pl.blocks.len()).sum();
let total_docs: u32 = lists.iter().map(|(pl, _)| pl.doc_count).sum();
let mut merged_blocks = Vec::with_capacity(total_blocks);
for (posting_list, doc_offset) in lists {
for block in &posting_list.blocks {
merged_blocks.push(block.with_doc_offset(*doc_offset));
}
}
Self {
doc_count: total_docs,
blocks: merged_blocks,
}
}
fn find_block(&self, target: DocId) -> Option<usize> {
if self.blocks.is_empty() {
return None;
}
let idx = self
.blocks
.partition_point(|b| b.header.first_doc_id <= target);
if idx == 0 {
Some(0)
} else {
Some(idx - 1)
}
}
}
pub struct BlockSparsePostingIterator<'a> {
posting_list: &'a BlockSparsePostingList,
block_idx: usize,
in_block_idx: usize,
current_doc_ids: Vec<DocId>,
current_ordinals: Vec<u16>,
current_weights: Vec<f32>,
ordinals_decoded: bool,
exhausted: bool,
}
impl<'a> BlockSparsePostingIterator<'a> {
fn new(posting_list: &'a BlockSparsePostingList) -> Self {
let mut iter = Self {
posting_list,
block_idx: 0,
in_block_idx: 0,
current_doc_ids: Vec::with_capacity(128),
current_ordinals: Vec::with_capacity(128),
current_weights: Vec::with_capacity(128),
ordinals_decoded: false,
exhausted: posting_list.blocks.is_empty(),
};
if !iter.exhausted {
iter.load_block(0);
}
iter
}
fn load_block(&mut self, block_idx: usize) {
if let Some(block) = self.posting_list.blocks.get(block_idx) {
block.decode_doc_ids_into(&mut self.current_doc_ids);
block.decode_weights_into(&mut self.current_weights);
self.ordinals_decoded = false;
self.block_idx = block_idx;
self.in_block_idx = 0;
}
}
#[inline]
fn ensure_ordinals_decoded(&mut self) {
if !self.ordinals_decoded {
if let Some(block) = self.posting_list.blocks.get(self.block_idx) {
block.decode_ordinals_into(&mut self.current_ordinals);
}
self.ordinals_decoded = true;
}
}
#[inline]
pub fn doc(&self) -> DocId {
if self.exhausted {
TERMINATED
} else {
self.current_doc_ids[self.in_block_idx]
}
}
#[inline]
pub fn weight(&self) -> f32 {
if self.exhausted {
return 0.0;
}
self.current_weights[self.in_block_idx]
}
#[inline]
pub fn ordinal(&mut self) -> u16 {
if self.exhausted {
return 0;
}
self.ensure_ordinals_decoded();
self.current_ordinals[self.in_block_idx]
}
pub fn advance(&mut self) -> DocId {
if self.exhausted {
return TERMINATED;
}
self.in_block_idx += 1;
if self.in_block_idx >= self.current_doc_ids.len() {
self.block_idx += 1;
if self.block_idx >= self.posting_list.blocks.len() {
self.exhausted = true;
} else {
self.load_block(self.block_idx);
}
}
self.doc()
}
pub fn seek(&mut self, target: DocId) -> DocId {
if self.exhausted {
return TERMINATED;
}
if self.doc() >= target {
return self.doc();
}
if let Some(&last_doc) = self.current_doc_ids.last()
&& last_doc >= target
{
let remaining = &self.current_doc_ids[self.in_block_idx..];
let pos = crate::structures::simd::find_first_ge_u32(remaining, target);
self.in_block_idx += pos;
if self.in_block_idx >= self.current_doc_ids.len() {
self.block_idx += 1;
if self.block_idx >= self.posting_list.blocks.len() {
self.exhausted = true;
} else {
self.load_block(self.block_idx);
}
}
return self.doc();
}
if let Some(block_idx) = self.posting_list.find_block(target) {
self.load_block(block_idx);
let pos = crate::structures::simd::find_first_ge_u32(&self.current_doc_ids, target);
self.in_block_idx = pos;
if self.in_block_idx >= self.current_doc_ids.len() {
self.block_idx += 1;
if self.block_idx >= self.posting_list.blocks.len() {
self.exhausted = true;
} else {
self.load_block(self.block_idx);
}
}
} else {
self.exhausted = true;
}
self.doc()
}
pub fn skip_to_next_block(&mut self) -> DocId {
if self.exhausted {
return TERMINATED;
}
let next = self.block_idx + 1;
if next >= self.posting_list.blocks.len() {
self.exhausted = true;
return TERMINATED;
}
self.load_block(next);
self.doc()
}
pub fn is_exhausted(&self) -> bool {
self.exhausted
}
pub fn current_block_max_weight(&self) -> f32 {
self.posting_list
.blocks
.get(self.block_idx)
.map(|b| b.header.max_weight)
.unwrap_or(0.0)
}
pub fn current_block_max_contribution(&self, query_weight: f32) -> f32 {
query_weight * self.current_block_max_weight()
}
}
fn compute_last_doc(header: &BlockHeader, doc_ids_data: &[u8]) -> DocId {
let count = header.count as usize;
if count <= 1 {
return header.first_doc_id;
}
let bits = header.doc_id_bits;
if bits == 0 {
return header.first_doc_id; }
let rounded = simd::RoundedBitWidth::from_u8(bits);
let num_deltas = count - 1;
let mut deltas = [0u32; MAX_BLOCK_SIZE];
simd::unpack_rounded(doc_ids_data, rounded, &mut deltas[..num_deltas], num_deltas);
let sum: u32 = deltas[..num_deltas].iter().sum();
header.first_doc_id + sum
}
fn find_optimal_bit_width(values: &[u32]) -> u8 {
if values.is_empty() {
return 0;
}
let max_val = values.iter().copied().max().unwrap_or(0);
simd::bits_needed(max_val)
}
fn bits_needed_u16(val: u16) -> u8 {
if val == 0 {
0
} else {
16 - val.leading_zeros() as u8
}
}
fn encode_weights(weights: &[f32], quant: WeightQuantization) -> io::Result<Vec<u8>> {
let mut data = Vec::new();
match quant {
WeightQuantization::Float32 => {
for &w in weights {
data.write_f32::<LittleEndian>(w)?;
}
}
WeightQuantization::Float16 => {
use half::f16;
for &w in weights {
data.write_u16::<LittleEndian>(f16::from_f32(w).to_bits())?;
}
}
WeightQuantization::UInt8 => {
let min = weights.iter().copied().fold(f32::INFINITY, f32::min);
let max = weights.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let range = max - min;
let scale = if range < f32::EPSILON {
1.0
} else {
range / 255.0
};
data.write_f32::<LittleEndian>(scale)?;
data.write_f32::<LittleEndian>(min)?;
for &w in weights {
data.write_u8(((w - min) / scale).round() as u8)?;
}
}
WeightQuantization::UInt4 => {
let min = weights.iter().copied().fold(f32::INFINITY, f32::min);
let max = weights.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let range = max - min;
let scale = if range < f32::EPSILON {
1.0
} else {
range / 15.0
};
data.write_f32::<LittleEndian>(scale)?;
data.write_f32::<LittleEndian>(min)?;
let mut i = 0;
while i < weights.len() {
let q1 = ((weights[i] - min) / scale).round() as u8 & 0x0F;
let q2 = if i + 1 < weights.len() {
((weights[i + 1] - min) / scale).round() as u8 & 0x0F
} else {
0
};
data.write_u8((q2 << 4) | q1)?;
i += 2;
}
}
}
Ok(data)
}
fn decode_weights_into(data: &[u8], quant: WeightQuantization, count: usize, out: &mut Vec<f32>) {
match quant {
WeightQuantization::Float32 => {
out.reserve(count);
for chunk in data[..count * 4].chunks_exact(4) {
out.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
}
}
WeightQuantization::Float16 => {
use half::f16;
use half::slice::HalfFloatSliceExt;
let byte_count = count * 2;
let src = &data[..byte_count];
let mut f16_buf: Vec<f16> = Vec::with_capacity(count);
for chunk in src.chunks_exact(2) {
f16_buf.push(f16::from_bits(u16::from_le_bytes([chunk[0], chunk[1]])));
}
let start = out.len();
out.resize(start + count, 0.0);
f16_buf.convert_to_f32_slice(&mut out[start..start + count]);
}
WeightQuantization::UInt8 => {
let mut cursor = Cursor::new(data);
let scale = cursor.read_f32::<LittleEndian>().unwrap_or(1.0);
let min_val = cursor.read_f32::<LittleEndian>().unwrap_or(0.0);
let offset = cursor.position() as usize;
out.resize(count, 0.0);
simd::dequantize_uint8(&data[offset..], out, scale, min_val, count);
}
WeightQuantization::UInt4 => {
let mut cursor = Cursor::new(data);
let scale = cursor.read_f32::<LittleEndian>().unwrap_or(1.0);
let min = cursor.read_f32::<LittleEndian>().unwrap_or(0.0);
let mut i = 0;
while i < count {
let byte = cursor.read_u8().unwrap_or(0);
out.push((byte & 0x0F) as f32 * scale + min);
i += 1;
if i < count {
out.push((byte >> 4) as f32 * scale + min);
i += 1;
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_block_roundtrip() {
let postings = vec![
(10u32, 0u16, 1.5f32),
(15, 0, 2.0),
(20, 1, 0.5),
(100, 0, 3.0),
];
let block = SparseBlock::from_postings(&postings, WeightQuantization::Float32).unwrap();
assert_eq!(block.decode_doc_ids(), vec![10, 15, 20, 100]);
assert_eq!(block.decode_ordinals(), vec![0, 0, 1, 0]);
let weights = block.decode_weights();
assert!((weights[0] - 1.5).abs() < 0.01);
}
#[test]
fn test_posting_list() {
let postings: Vec<(DocId, u16, f32)> =
(0..300).map(|i| (i * 2, 0, i as f32 * 0.1)).collect();
let list =
BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
assert_eq!(list.doc_count(), 300);
assert_eq!(list.num_blocks(), 3);
let mut iter = list.iterator();
assert_eq!(iter.doc(), 0);
iter.advance();
assert_eq!(iter.doc(), 2);
}
#[test]
fn test_serialization() {
let postings = vec![(1u32, 0u16, 0.5f32), (10, 1, 1.5), (100, 0, 2.5)];
let list =
BlockSparsePostingList::from_postings(&postings, WeightQuantization::UInt8).unwrap();
let (block_data, skip_entries) = list.serialize().unwrap();
let list2 =
BlockSparsePostingList::from_parts(list.doc_count(), &block_data, &skip_entries)
.unwrap();
assert_eq!(list.doc_count(), list2.doc_count());
}
#[test]
fn test_seek() {
let postings: Vec<(DocId, u16, f32)> = (0..500).map(|i| (i * 3, 0, i as f32)).collect();
let list =
BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
let mut iter = list.iterator();
assert_eq!(iter.seek(300), 300);
assert_eq!(iter.seek(301), 303);
assert_eq!(iter.seek(2000), TERMINATED);
}
#[test]
fn test_merge_with_offsets() {
let postings1: Vec<(DocId, u16, f32)> = vec![(0, 0, 1.0), (5, 0, 2.0), (10, 1, 3.0)];
let list1 =
BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
let postings2: Vec<(DocId, u16, f32)> = vec![(0, 0, 4.0), (3, 1, 5.0), (7, 0, 6.0)];
let list2 =
BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
assert_eq!(merged.doc_count(), 6);
let decoded = merged.decode_all();
assert_eq!(decoded.len(), 6);
assert_eq!(decoded[0].0, 0);
assert_eq!(decoded[1].0, 5);
assert_eq!(decoded[2].0, 10);
assert_eq!(decoded[3].0, 100); assert_eq!(decoded[4].0, 103); assert_eq!(decoded[5].0, 107);
assert!((decoded[0].2 - 1.0).abs() < 0.01);
assert!((decoded[3].2 - 4.0).abs() < 0.01);
assert_eq!(decoded[2].1, 1); assert_eq!(decoded[4].1, 1); }
#[test]
fn test_merge_with_offsets_multi_block() {
let postings1: Vec<(DocId, u16, f32)> = (0..200).map(|i| (i * 2, 0, i as f32)).collect();
let list1 =
BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
assert!(list1.num_blocks() > 1, "Should have multiple blocks");
let postings2: Vec<(DocId, u16, f32)> = (0..150).map(|i| (i * 3, 1, i as f32)).collect();
let list2 =
BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 1000)]);
assert_eq!(merged.doc_count(), 350);
assert_eq!(merged.num_blocks(), list1.num_blocks() + list2.num_blocks());
let mut iter = merged.iterator();
assert_eq!(iter.doc(), 0);
let doc = iter.seek(1000);
assert_eq!(doc, 1000);
iter.advance();
assert_eq!(iter.doc(), 1003); }
#[test]
fn test_merge_with_offsets_serialize_roundtrip() {
let postings1: Vec<(DocId, u16, f32)> = vec![(0, 0, 1.0), (5, 0, 2.0), (10, 1, 3.0)];
let list1 =
BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
let postings2: Vec<(DocId, u16, f32)> = vec![(0, 0, 4.0), (3, 1, 5.0), (7, 0, 6.0)];
let list2 =
BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
let (block_data, skip_entries) = merged.serialize().unwrap();
let loaded =
BlockSparsePostingList::from_parts(merged.doc_count(), &block_data, &skip_entries)
.unwrap();
let decoded = loaded.decode_all();
assert_eq!(decoded.len(), 6);
assert_eq!(decoded[0].0, 0);
assert_eq!(decoded[1].0, 5);
assert_eq!(decoded[2].0, 10);
assert_eq!(decoded[3].0, 100, "First doc of seg2 should be 0+100=100");
assert_eq!(decoded[4].0, 103, "Second doc of seg2 should be 3+100=103");
assert_eq!(decoded[5].0, 107, "Third doc of seg2 should be 7+100=107");
let mut iter = loaded.iterator();
assert_eq!(iter.doc(), 0);
iter.advance();
assert_eq!(iter.doc(), 5);
iter.advance();
assert_eq!(iter.doc(), 10);
iter.advance();
assert_eq!(iter.doc(), 100);
iter.advance();
assert_eq!(iter.doc(), 103);
iter.advance();
assert_eq!(iter.doc(), 107);
}
#[test]
fn test_merge_seek_after_roundtrip() {
let postings1: Vec<(DocId, u16, f32)> = (0..200).map(|i| (i * 2, 0, 1.0)).collect();
let list1 =
BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
let postings2: Vec<(DocId, u16, f32)> = (0..150).map(|i| (i * 3, 0, 2.0)).collect();
let list2 =
BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 1000)]);
let (block_data, skip_entries) = merged.serialize().unwrap();
let loaded =
BlockSparsePostingList::from_parts(merged.doc_count(), &block_data, &skip_entries)
.unwrap();
let mut iter = loaded.iterator();
let doc = iter.seek(100);
assert_eq!(doc, 100, "Seek to 100 in segment 1");
let doc = iter.seek(1000);
assert_eq!(doc, 1000, "Seek to 1000 (first doc of segment 2)");
let doc = iter.seek(1050);
assert!(
doc >= 1050,
"Seek to 1050 should find doc >= 1050, got {}",
doc
);
let doc = iter.seek(500);
assert!(
doc >= 1050,
"Seek backwards should not go back, got {}",
doc
);
let mut iter2 = loaded.iterator();
let mut count = 0;
let mut prev_doc = 0;
while iter2.doc() != super::TERMINATED {
let current = iter2.doc();
if count > 0 {
assert!(
current > prev_doc,
"Docs should be monotonically increasing: {} vs {}",
prev_doc,
current
);
}
prev_doc = current;
iter2.advance();
count += 1;
}
assert_eq!(count, 350, "Should have 350 total docs");
}
#[test]
fn test_doc_count_multi_value() {
let postings: Vec<(DocId, u16, f32)> = vec![
(0, 0, 1.0),
(0, 1, 1.5),
(0, 2, 2.0),
(5, 0, 3.0),
(5, 1, 3.5),
(10, 0, 4.0),
];
let list =
BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
assert_eq!(list.doc_count(), 3);
let decoded = list.decode_all();
assert_eq!(decoded.len(), 6);
}
#[test]
fn test_zero_copy_merge_patches_first_doc_id() {
use crate::structures::SparseSkipEntry;
let postings1: Vec<(DocId, u16, f32)> = (0..200).map(|i| (i * 2, 0, i as f32)).collect();
let list1 =
BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
assert!(list1.num_blocks() > 1);
let postings2: Vec<(DocId, u16, f32)> = (0..150).map(|i| (i * 3, 1, i as f32)).collect();
let list2 =
BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
let (raw1, skip1) = list1.serialize().unwrap();
let (raw2, skip2) = list2.serialize().unwrap();
let doc_offset: u32 = 1000; let total_docs = list1.doc_count() + list2.doc_count();
let mut merged_skip = Vec::new();
let mut cumulative_offset = 0u64;
for entry in &skip1 {
merged_skip.push(SparseSkipEntry::new(
entry.first_doc,
entry.last_doc,
cumulative_offset + entry.offset,
entry.length,
entry.max_weight,
));
}
if let Some(last) = skip1.last() {
cumulative_offset += last.offset + last.length as u64;
}
for entry in &skip2 {
merged_skip.push(SparseSkipEntry::new(
entry.first_doc + doc_offset,
entry.last_doc + doc_offset,
cumulative_offset + entry.offset,
entry.length,
entry.max_weight,
));
}
let mut merged_block_data = Vec::new();
merged_block_data.extend_from_slice(&raw1);
const FIRST_DOC_ID_OFFSET: usize = 8;
let mut buf2 = raw2.to_vec();
for entry in &skip2 {
let off = entry.offset as usize + FIRST_DOC_ID_OFFSET;
if off + 4 <= buf2.len() {
let old = u32::from_le_bytes(buf2[off..off + 4].try_into().unwrap());
let patched = (old + doc_offset).to_le_bytes();
buf2[off..off + 4].copy_from_slice(&patched);
}
}
merged_block_data.extend_from_slice(&buf2);
let loaded =
BlockSparsePostingList::from_parts(total_docs, &merged_block_data, &merged_skip)
.unwrap();
assert_eq!(loaded.doc_count(), 350);
let mut iter = loaded.iterator();
assert_eq!(iter.doc(), 0);
let doc = iter.seek(100);
assert_eq!(doc, 100);
let doc = iter.seek(398);
assert_eq!(doc, 398);
let doc = iter.seek(1000);
assert_eq!(doc, 1000, "First doc of segment 2 should be 1000");
iter.advance();
assert_eq!(iter.doc(), 1003, "Second doc of segment 2 should be 1003");
let doc = iter.seek(1447);
assert_eq!(doc, 1447, "Last doc of segment 2 should be 1447");
iter.advance();
assert_eq!(iter.doc(), super::TERMINATED);
let reference =
BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, doc_offset)]);
let mut ref_iter = reference.iterator();
let mut zc_iter = loaded.iterator();
while ref_iter.doc() != super::TERMINATED {
assert_eq!(
ref_iter.doc(),
zc_iter.doc(),
"Zero-copy and reference merge should produce identical doc_ids"
);
assert!(
(ref_iter.weight() - zc_iter.weight()).abs() < 0.01,
"Weights should match: {} vs {}",
ref_iter.weight(),
zc_iter.weight()
);
ref_iter.advance();
zc_iter.advance();
}
assert_eq!(zc_iter.doc(), super::TERMINATED);
}
#[test]
fn test_doc_count_single_value() {
let postings: Vec<(DocId, u16, f32)> =
vec![(0, 0, 1.0), (5, 0, 2.0), (10, 0, 3.0), (15, 0, 4.0)];
let list =
BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
assert_eq!(list.doc_count(), 4);
}
#[test]
fn test_doc_count_multi_value_serialization_roundtrip() {
let postings: Vec<(DocId, u16, f32)> =
vec![(0, 0, 1.0), (0, 1, 1.5), (5, 0, 2.0), (5, 1, 2.5)];
let list =
BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
assert_eq!(list.doc_count(), 2);
let (block_data, skip_entries) = list.serialize().unwrap();
let loaded =
BlockSparsePostingList::from_parts(list.doc_count(), &block_data, &skip_entries)
.unwrap();
assert_eq!(loaded.doc_count(), 2);
}
#[test]
fn test_merge_preserves_weights_and_ordinals() {
let postings1: Vec<(DocId, u16, f32)> = vec![(0, 0, 1.5), (5, 1, 2.5), (10, 2, 3.5)];
let list1 =
BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
let postings2: Vec<(DocId, u16, f32)> = vec![(0, 0, 4.5), (3, 1, 5.5), (7, 3, 6.5)];
let list2 =
BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
let (block_data, skip_entries) = merged.serialize().unwrap();
let loaded =
BlockSparsePostingList::from_parts(merged.doc_count(), &block_data, &skip_entries)
.unwrap();
let mut iter = loaded.iterator();
assert_eq!(iter.doc(), 0);
assert!(
(iter.weight() - 1.5).abs() < 0.01,
"Weight should be 1.5, got {}",
iter.weight()
);
assert_eq!(iter.ordinal(), 0);
iter.advance();
assert_eq!(iter.doc(), 5);
assert!(
(iter.weight() - 2.5).abs() < 0.01,
"Weight should be 2.5, got {}",
iter.weight()
);
assert_eq!(iter.ordinal(), 1);
iter.advance();
assert_eq!(iter.doc(), 10);
assert!(
(iter.weight() - 3.5).abs() < 0.01,
"Weight should be 3.5, got {}",
iter.weight()
);
assert_eq!(iter.ordinal(), 2);
iter.advance();
assert_eq!(iter.doc(), 100);
assert!(
(iter.weight() - 4.5).abs() < 0.01,
"Weight should be 4.5, got {}",
iter.weight()
);
assert_eq!(iter.ordinal(), 0);
iter.advance();
assert_eq!(iter.doc(), 103);
assert!(
(iter.weight() - 5.5).abs() < 0.01,
"Weight should be 5.5, got {}",
iter.weight()
);
assert_eq!(iter.ordinal(), 1);
iter.advance();
assert_eq!(iter.doc(), 107);
assert!(
(iter.weight() - 6.5).abs() < 0.01,
"Weight should be 6.5, got {}",
iter.weight()
);
assert_eq!(iter.ordinal(), 3);
iter.advance();
assert_eq!(iter.doc(), super::TERMINATED);
}
#[test]
fn test_merge_global_max_weight() {
let postings1: Vec<(DocId, u16, f32)> = vec![
(0, 0, 3.0),
(1, 0, 7.0), (2, 0, 2.0),
];
let list1 =
BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
let postings2: Vec<(DocId, u16, f32)> = vec![
(0, 0, 5.0),
(1, 0, 4.0),
(2, 0, 6.0), ];
let list2 =
BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
assert!((list1.global_max_weight() - 7.0).abs() < 0.01);
assert!((list2.global_max_weight() - 6.0).abs() < 0.01);
let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
assert!(
(merged.global_max_weight() - 7.0).abs() < 0.01,
"Global max should be 7.0, got {}",
merged.global_max_weight()
);
let (block_data, skip_entries) = merged.serialize().unwrap();
let loaded =
BlockSparsePostingList::from_parts(merged.doc_count(), &block_data, &skip_entries)
.unwrap();
assert!(
(loaded.global_max_weight() - 7.0).abs() < 0.01,
"After roundtrip, global max should still be 7.0, got {}",
loaded.global_max_weight()
);
}
#[test]
fn test_scoring_simulation_after_merge() {
let postings1: Vec<(DocId, u16, f32)> = vec![
(0, 0, 0.5), (5, 0, 0.8), ];
let list1 =
BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
let postings2: Vec<(DocId, u16, f32)> = vec![
(0, 0, 0.6), (3, 0, 0.9), ];
let list2 =
BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
let (block_data, skip_entries) = merged.serialize().unwrap();
let loaded =
BlockSparsePostingList::from_parts(merged.doc_count(), &block_data, &skip_entries)
.unwrap();
let query_weight = 2.0f32;
let mut iter = loaded.iterator();
assert_eq!(iter.doc(), 0);
let score = query_weight * iter.weight();
assert!(
(score - 1.0).abs() < 0.01,
"Doc 0 score should be 1.0, got {}",
score
);
iter.advance();
assert_eq!(iter.doc(), 5);
let score = query_weight * iter.weight();
assert!(
(score - 1.6).abs() < 0.01,
"Doc 5 score should be 1.6, got {}",
score
);
iter.advance();
assert_eq!(iter.doc(), 100);
let score = query_weight * iter.weight();
assert!(
(score - 1.2).abs() < 0.01,
"Doc 100 score should be 1.2, got {}",
score
);
iter.advance();
assert_eq!(iter.doc(), 103);
let score = query_weight * iter.weight();
assert!(
(score - 1.8).abs() < 0.01,
"Doc 103 score should be 1.8, got {}",
score
);
}
}