mod block;
mod config;
mod partitioner;
pub use block::{BlockSparsePostingIterator, BlockSparsePostingList, SparseBlock};
pub use config::{
IndexSize, QueryWeighting, SparseEntry, SparseFormat, SparseQueryConfig, SparseVector,
SparseVectorConfig, WeightQuantization,
};
pub use partitioner::optimal_partition;
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use std::io::{self, Read, Write};
use super::posting_common::{read_vint, write_vint};
use crate::DocId;
#[derive(Debug, Clone, Copy)]
pub struct SparsePosting {
pub doc_id: DocId,
pub weight: f32,
}
pub const SPARSE_BLOCK_SIZE: usize = 128;
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct SparseSkipEntry {
pub first_doc: DocId,
pub last_doc: DocId,
pub offset: u64,
pub length: u32,
pub max_weight: f32,
}
impl SparseSkipEntry {
pub const SIZE: usize = 24;
pub fn new(
first_doc: DocId,
last_doc: DocId,
offset: u64,
length: u32,
max_weight: f32,
) -> Self {
Self {
first_doc,
last_doc,
offset,
length,
max_weight,
}
}
#[inline]
pub fn block_max_contribution(&self, query_weight: f32) -> f32 {
query_weight * self.max_weight
}
#[inline]
pub fn from_bytes(b: &[u8]) -> Self {
Self {
first_doc: u32::from_le_bytes(b[0..4].try_into().unwrap()),
last_doc: u32::from_le_bytes(b[4..8].try_into().unwrap()),
offset: u64::from_le_bytes(b[8..16].try_into().unwrap()),
length: u32::from_le_bytes(b[16..20].try_into().unwrap()),
max_weight: f32::from_le_bytes(b[20..24].try_into().unwrap()),
}
}
#[inline]
pub fn write_to_vec(&self, buf: &mut Vec<u8>) {
buf.extend_from_slice(&self.first_doc.to_le_bytes());
buf.extend_from_slice(&self.last_doc.to_le_bytes());
buf.extend_from_slice(&self.offset.to_le_bytes());
buf.extend_from_slice(&self.length.to_le_bytes());
buf.extend_from_slice(&self.max_weight.to_le_bytes());
}
#[inline]
pub fn read_at(skip_bytes: &[u8], idx: usize) -> Self {
let off = idx * Self::SIZE;
Self::from_bytes(&skip_bytes[off..off + Self::SIZE])
}
pub fn write<W: Write + ?Sized>(&self, writer: &mut W) -> io::Result<()> {
writer.write_u32::<LittleEndian>(self.first_doc)?;
writer.write_u32::<LittleEndian>(self.last_doc)?;
writer.write_u64::<LittleEndian>(self.offset)?;
writer.write_u32::<LittleEndian>(self.length)?;
writer.write_f32::<LittleEndian>(self.max_weight)?;
Ok(())
}
pub fn read<R: Read>(reader: &mut R) -> io::Result<Self> {
let first_doc = reader.read_u32::<LittleEndian>()?;
let last_doc = reader.read_u32::<LittleEndian>()?;
let offset = reader.read_u64::<LittleEndian>()?;
let length = reader.read_u32::<LittleEndian>()?;
let max_weight = reader.read_f32::<LittleEndian>()?;
Ok(Self {
first_doc,
last_doc,
offset,
length,
max_weight,
})
}
}
#[derive(Debug, Clone, Default)]
pub struct SparseSkipList {
entries: Vec<SparseSkipEntry>,
global_max_weight: f32,
}
impl SparseSkipList {
pub fn new() -> Self {
Self::default()
}
pub fn push(
&mut self,
first_doc: DocId,
last_doc: DocId,
offset: u64,
length: u32,
max_weight: f32,
) {
self.global_max_weight = self.global_max_weight.max(max_weight);
self.entries.push(SparseSkipEntry::new(
first_doc, last_doc, offset, length, max_weight,
));
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn get(&self, index: usize) -> Option<&SparseSkipEntry> {
self.entries.get(index)
}
pub fn global_max_weight(&self) -> f32 {
self.global_max_weight
}
pub fn find_block(&self, target: DocId) -> Option<usize> {
if self.entries.is_empty() {
return None;
}
let idx = self.entries.partition_point(|e| e.last_doc < target);
if idx < self.entries.len() {
Some(idx)
} else {
None
}
}
pub fn iter(&self) -> impl Iterator<Item = &SparseSkipEntry> {
self.entries.iter()
}
pub fn write<W: Write>(&self, writer: &mut W) -> io::Result<()> {
writer.write_u32::<LittleEndian>(self.entries.len() as u32)?;
writer.write_f32::<LittleEndian>(self.global_max_weight)?;
for entry in &self.entries {
entry.write(writer)?;
}
Ok(())
}
pub fn read<R: Read>(reader: &mut R) -> io::Result<Self> {
let count = reader.read_u32::<LittleEndian>()? as usize;
let global_max_weight = reader.read_f32::<LittleEndian>()?;
let mut entries = Vec::with_capacity(count);
for _ in 0..count {
entries.push(SparseSkipEntry::read(reader)?);
}
Ok(Self {
entries,
global_max_weight,
})
}
}
#[derive(Debug, Clone)]
pub struct SparsePostingList {
quantization: WeightQuantization,
scale: f32,
min_val: f32,
doc_count: u32,
data: Vec<u8>,
}
impl SparsePostingList {
pub fn from_postings(
postings: &[(DocId, f32)],
quantization: WeightQuantization,
) -> io::Result<Self> {
if postings.is_empty() {
return Ok(Self {
quantization,
scale: 1.0,
min_val: 0.0,
doc_count: 0,
data: Vec::new(),
});
}
let weights: Vec<f32> = postings.iter().map(|(_, w)| *w).collect();
let min_val = weights.iter().cloned().fold(f32::INFINITY, f32::min);
let max_val = weights.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let (scale, adjusted_min) = match quantization {
WeightQuantization::Float32 | WeightQuantization::Float16 => (1.0, 0.0),
WeightQuantization::UInt8 => {
let range = max_val - min_val;
if range < f32::EPSILON {
(1.0, min_val)
} else {
(range / 255.0, min_val)
}
}
WeightQuantization::UInt4 => {
let range = max_val - min_val;
if range < f32::EPSILON {
(1.0, min_val)
} else {
(range / 15.0, min_val)
}
}
};
let mut data = Vec::new();
let mut prev_doc_id = 0u32;
for (doc_id, _) in postings {
let delta = doc_id - prev_doc_id;
write_vint(&mut data, delta as u64)?;
prev_doc_id = *doc_id;
}
match quantization {
WeightQuantization::Float32 => {
for (_, weight) in postings {
data.write_f32::<LittleEndian>(*weight)?;
}
}
WeightQuantization::Float16 => {
use half::slice::HalfFloatSliceExt;
let weights: Vec<f32> = postings.iter().map(|(_, w)| *w).collect();
let mut f16_slice: Vec<half::f16> = vec![half::f16::ZERO; weights.len()];
f16_slice.convert_from_f32_slice(&weights);
for h in f16_slice {
data.write_u16::<LittleEndian>(h.to_bits())?;
}
}
WeightQuantization::UInt8 => {
for (_, weight) in postings {
let quantized = ((*weight - adjusted_min) / scale).round() as u8;
data.write_u8(quantized)?;
}
}
WeightQuantization::UInt4 => {
let mut i = 0;
while i < postings.len() {
let q1 = ((postings[i].1 - adjusted_min) / scale).round() as u8 & 0x0F;
let q2 = if i + 1 < postings.len() {
((postings[i + 1].1 - adjusted_min) / scale).round() as u8 & 0x0F
} else {
0
};
data.write_u8((q2 << 4) | q1)?;
i += 2;
}
}
}
Ok(Self {
quantization,
scale,
min_val: adjusted_min,
doc_count: postings.len() as u32,
data,
})
}
pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
writer.write_u8(self.quantization as u8)?;
writer.write_f32::<LittleEndian>(self.scale)?;
writer.write_f32::<LittleEndian>(self.min_val)?;
writer.write_u32::<LittleEndian>(self.doc_count)?;
writer.write_u32::<LittleEndian>(self.data.len() as u32)?;
writer.write_all(&self.data)?;
Ok(())
}
pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
let quant_byte = reader.read_u8()?;
let quantization = WeightQuantization::from_u8(quant_byte).ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidData, "Invalid quantization type")
})?;
let scale = reader.read_f32::<LittleEndian>()?;
let min_val = reader.read_f32::<LittleEndian>()?;
let doc_count = reader.read_u32::<LittleEndian>()?;
let data_len = reader.read_u32::<LittleEndian>()? as usize;
let mut data = vec![0u8; data_len];
reader.read_exact(&mut data)?;
Ok(Self {
quantization,
scale,
min_val,
doc_count,
data,
})
}
pub fn doc_count(&self) -> u32 {
self.doc_count
}
pub fn quantization(&self) -> WeightQuantization {
self.quantization
}
pub fn iterator(&self) -> SparsePostingIterator<'_> {
SparsePostingIterator::new(self)
}
pub fn decode_all(&self) -> io::Result<Vec<(DocId, f32)>> {
let mut result = Vec::with_capacity(self.doc_count as usize);
let mut iter = self.iterator();
while !iter.exhausted {
result.push((iter.doc_id, iter.weight));
iter.advance();
}
Ok(result)
}
}
pub struct SparsePostingIterator<'a> {
posting_list: &'a SparsePostingList,
doc_id_offset: usize,
weight_offset: usize,
index: usize,
doc_id: DocId,
weight: f32,
exhausted: bool,
}
impl<'a> SparsePostingIterator<'a> {
fn new(posting_list: &'a SparsePostingList) -> Self {
let mut iter = Self {
posting_list,
doc_id_offset: 0,
weight_offset: 0,
index: 0,
doc_id: 0,
weight: 0.0,
exhausted: posting_list.doc_count == 0,
};
if !iter.exhausted {
iter.weight_offset = iter.calculate_weight_offset();
iter.load_current();
}
iter
}
fn calculate_weight_offset(&self) -> usize {
let mut offset = 0;
let mut reader = &self.posting_list.data[..];
for _ in 0..self.posting_list.doc_count {
if read_vint(&mut reader).is_ok() {
offset = self.posting_list.data.len() - reader.len();
}
}
offset
}
fn load_current(&mut self) {
if self.index >= self.posting_list.doc_count as usize {
self.exhausted = true;
return;
}
let mut reader = &self.posting_list.data[self.doc_id_offset..];
if let Ok(delta) = read_vint(&mut reader) {
self.doc_id = self.doc_id.wrapping_add(delta as u32);
self.doc_id_offset = self.posting_list.data.len() - reader.len();
}
let weight_idx = self.index;
let pl = self.posting_list;
self.weight = match pl.quantization {
WeightQuantization::Float32 => {
let offset = self.weight_offset + weight_idx * 4;
if offset + 4 <= pl.data.len() {
let bytes = &pl.data[offset..offset + 4];
f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]])
} else {
0.0
}
}
WeightQuantization::Float16 => {
let offset = self.weight_offset + weight_idx * 2;
if offset + 2 <= pl.data.len() {
let bits = u16::from_le_bytes([pl.data[offset], pl.data[offset + 1]]);
half::f16::from_bits(bits).to_f32()
} else {
0.0
}
}
WeightQuantization::UInt8 => {
let offset = self.weight_offset + weight_idx;
if offset < pl.data.len() {
let quantized = pl.data[offset];
quantized as f32 * pl.scale + pl.min_val
} else {
0.0
}
}
WeightQuantization::UInt4 => {
let byte_offset = self.weight_offset + weight_idx / 2;
if byte_offset < pl.data.len() {
let byte = pl.data[byte_offset];
let quantized = if weight_idx.is_multiple_of(2) {
byte & 0x0F
} else {
(byte >> 4) & 0x0F
};
quantized as f32 * pl.scale + pl.min_val
} else {
0.0
}
}
};
}
pub fn doc(&self) -> DocId {
if self.exhausted {
super::TERMINATED
} else {
self.doc_id
}
}
pub fn weight(&self) -> f32 {
if self.exhausted { 0.0 } else { self.weight }
}
pub fn advance(&mut self) -> DocId {
if self.exhausted {
return super::TERMINATED;
}
self.index += 1;
if self.index >= self.posting_list.doc_count as usize {
self.exhausted = true;
return super::TERMINATED;
}
self.load_current();
self.doc_id
}
pub fn seek(&mut self, target: DocId) -> DocId {
while !self.exhausted && self.doc_id < target {
self.advance();
}
self.doc()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sparse_vector_dot_product() {
let v1 = SparseVector::from_entries(&[0, 2, 5], &[1.0, 2.0, 3.0]);
let v2 = SparseVector::from_entries(&[1, 2, 5], &[1.0, 4.0, 2.0]);
assert!((v1.dot(&v2) - 14.0).abs() < 1e-6);
}
#[test]
fn test_sparse_posting_list_float32() {
let postings = vec![(0, 1.5), (5, 2.3), (10, 0.8), (100, 3.15)];
let pl = SparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
assert_eq!(pl.doc_count(), 4);
let mut iter = pl.iterator();
assert_eq!(iter.doc(), 0);
assert!((iter.weight() - 1.5).abs() < 1e-6);
iter.advance();
assert_eq!(iter.doc(), 5);
assert!((iter.weight() - 2.3).abs() < 1e-6);
iter.advance();
assert_eq!(iter.doc(), 10);
iter.advance();
assert_eq!(iter.doc(), 100);
assert!((iter.weight() - 3.15).abs() < 1e-6);
iter.advance();
assert_eq!(iter.doc(), super::super::TERMINATED);
}
#[test]
fn test_sparse_posting_list_uint8() {
let postings = vec![(0, 0.0), (5, 0.5), (10, 1.0)];
let pl = SparsePostingList::from_postings(&postings, WeightQuantization::UInt8).unwrap();
let decoded = pl.decode_all().unwrap();
assert_eq!(decoded.len(), 3);
assert!(decoded[0].1 < decoded[1].1);
assert!(decoded[1].1 < decoded[2].1);
}
#[test]
fn test_block_sparse_posting_list() {
let postings: Vec<(DocId, u16, f32)> =
(0..300).map(|i| (i * 2, 0, (i as f32) * 0.1)).collect();
let pl =
BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
assert_eq!(pl.doc_count(), 300);
assert!(pl.num_blocks() >= 2);
let mut iter = pl.iterator();
for (expected_doc, _, expected_weight) in &postings {
assert_eq!(iter.doc(), *expected_doc);
assert!((iter.weight() - expected_weight).abs() < 1e-6);
iter.advance();
}
assert_eq!(iter.doc(), super::super::TERMINATED);
}
#[test]
fn test_block_sparse_seek() {
let postings: Vec<(DocId, u16, f32)> = (0..500).map(|i| (i * 3, 0, i as f32)).collect();
let pl =
BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
let mut iter = pl.iterator();
assert_eq!(iter.seek(300), 300);
assert_eq!(iter.seek(301), 303);
assert_eq!(iter.seek(2000), super::super::TERMINATED);
}
#[test]
fn test_serialization_roundtrip() {
let postings: Vec<(DocId, u16, f32)> = vec![(0, 0, 1.0), (10, 0, 2.0), (100, 0, 3.0)];
for quant in [
WeightQuantization::Float32,
WeightQuantization::Float16,
WeightQuantization::UInt8,
] {
let pl = BlockSparsePostingList::from_postings(&postings, quant).unwrap();
let (block_data, skip_entries) = pl.serialize().unwrap();
let pl2 =
BlockSparsePostingList::from_parts(pl.doc_count(), &block_data, &skip_entries)
.unwrap();
assert_eq!(pl.doc_count(), pl2.doc_count());
let mut iter1 = pl.iterator();
let mut iter2 = pl2.iterator();
while iter1.doc() != super::super::TERMINATED {
assert_eq!(iter1.doc(), iter2.doc());
assert!((iter1.weight() - iter2.weight()).abs() < 0.1);
iter1.advance();
iter2.advance();
}
}
}
#[test]
fn test_concatenate() {
let postings1: Vec<(DocId, u16, f32)> = vec![(0, 0, 1.0), (5, 1, 2.0)];
let postings2: Vec<(DocId, u16, f32)> = vec![(0, 0, 3.0), (10, 1, 4.0)];
let pl1 =
BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
let pl2 =
BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
let mut all: Vec<(DocId, u16, f32)> = pl1.decode_all();
for (doc_id, ord, w) in pl2.decode_all() {
all.push((doc_id + 100, ord, w));
}
let merged =
BlockSparsePostingList::from_postings(&all, WeightQuantization::Float32).unwrap();
assert_eq!(merged.doc_count(), 4);
let decoded = merged.decode_all();
assert_eq!(decoded[0], (0, 0, 1.0));
assert_eq!(decoded[1], (5, 1, 2.0));
assert_eq!(decoded[2], (100, 0, 3.0));
assert_eq!(decoded[3], (110, 1, 4.0));
}
#[test]
fn test_sparse_vector_config() {
let default = SparseVectorConfig::default();
assert_eq!(default.index_size, IndexSize::U32);
assert_eq!(default.weight_quantization, WeightQuantization::Float32);
assert_eq!(default.bytes_per_entry(), 8.0);
let splade = SparseVectorConfig::splade();
assert_eq!(splade.index_size, IndexSize::U16);
assert_eq!(splade.weight_quantization, WeightQuantization::UInt8);
assert_eq!(splade.bytes_per_entry(), 3.0); assert_eq!(splade.weight_threshold, 0.01);
assert_eq!(splade.pruning, Some(0.1));
assert!(splade.query_config.is_some());
let query_cfg = splade.query_config.as_ref().unwrap();
assert_eq!(query_cfg.heap_factor, 0.8);
assert_eq!(query_cfg.max_query_dims, Some(20));
let compact = SparseVectorConfig::compact();
assert_eq!(compact.index_size, IndexSize::U16);
assert_eq!(compact.weight_quantization, WeightQuantization::UInt4);
assert_eq!(compact.bytes_per_entry(), 2.5);
let conservative = SparseVectorConfig::conservative();
assert_eq!(conservative.index_size, IndexSize::U32);
assert_eq!(
conservative.weight_quantization,
WeightQuantization::Float16
);
assert_eq!(conservative.weight_threshold, 0.005);
assert_eq!(conservative.pruning, None);
let byte = splade.to_byte();
let restored = SparseVectorConfig::from_byte(byte).unwrap();
assert_eq!(restored.index_size, splade.index_size);
assert_eq!(restored.weight_quantization, splade.weight_quantization);
}
#[test]
fn test_index_size() {
assert_eq!(IndexSize::U16.bytes(), 2);
assert_eq!(IndexSize::U32.bytes(), 4);
assert_eq!(IndexSize::U16.max_value(), 65535);
assert_eq!(IndexSize::U32.max_value(), u32::MAX);
}
#[test]
fn test_block_max_weight() {
let postings: Vec<(DocId, u16, f32)> = (0..300)
.map(|i| (i as DocId, 0, (i as f32) * 0.1))
.collect();
let pl =
BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
assert!((pl.global_max_weight() - 29.9).abs() < 0.01);
assert!(pl.num_blocks() >= 3);
let block0_max = pl.block_max_weight(0).unwrap();
assert!((block0_max - 12.7).abs() < 0.01);
let block1_max = pl.block_max_weight(1).unwrap();
assert!((block1_max - 25.5).abs() < 0.01);
let block2_max = pl.block_max_weight(2).unwrap();
assert!((block2_max - 29.9).abs() < 0.01);
let query_weight = 2.0;
let mut iter = pl.iterator();
assert!((iter.current_block_max_weight() - 12.7).abs() < 0.01);
assert!((iter.current_block_max_contribution(query_weight) - 25.4).abs() < 0.1);
iter.seek(128);
assert!((iter.current_block_max_weight() - 25.5).abs() < 0.01);
}
#[test]
fn test_sparse_skip_list_serialization() {
let mut skip_list = SparseSkipList::new();
skip_list.push(0, 127, 0, 50, 12.7);
skip_list.push(128, 255, 100, 60, 25.5);
skip_list.push(256, 299, 200, 40, 29.9);
assert_eq!(skip_list.len(), 3);
assert!((skip_list.global_max_weight() - 29.9).abs() < 0.01);
let mut buffer = Vec::new();
skip_list.write(&mut buffer).unwrap();
let restored = SparseSkipList::read(&mut buffer.as_slice()).unwrap();
assert_eq!(restored.len(), 3);
assert!((restored.global_max_weight() - 29.9).abs() < 0.01);
let e0 = restored.get(0).unwrap();
assert_eq!(e0.first_doc, 0);
assert_eq!(e0.last_doc, 127);
assert!((e0.max_weight - 12.7).abs() < 0.01);
let e1 = restored.get(1).unwrap();
assert_eq!(e1.first_doc, 128);
assert!((e1.max_weight - 25.5).abs() < 0.01);
}
}