hermes_core/structures/postings/
rounded_bp128.rs

1//! Rounded BP128 posting list format with SIMD-friendly bit widths
2//!
3//! This format rounds bit widths to 8, 16, or 32 bits for faster SIMD decoding
4//! at the cost of ~10-100% more space compared to exact bitpacking.
5//!
6//! Use this format when:
7//! - Query latency is more important than index size
8//! - You have sufficient storage/memory
9//! - Your workload is read-heavy
10//!
11//! The tradeoff: ~2-4x faster decoding for ~20-60% larger posting lists.
12
13use crate::structures::simd::{self, RoundedBitWidth};
14use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
15use std::io::{self, Read, Write};
16
17/// Block size for rounded bitpacking (128 integers per block for SIMD alignment)
18pub const ROUNDED_BP128_BLOCK_SIZE: usize = 128;
19
20/// Rounded bitpacked block with skip info for BlockWAND
21#[derive(Debug, Clone)]
22pub struct RoundedBP128Block {
23    /// Delta-encoded doc_ids (rounded bitpacked: 8/16/32 bits)
24    pub doc_deltas: Vec<u8>,
25    /// Bit width for doc deltas (always 0, 8, 16, or 32)
26    pub doc_bit_width: u8,
27    /// Term frequencies (rounded bitpacked: 8/16/32 bits)
28    pub term_freqs: Vec<u8>,
29    /// Bit width for term frequencies (always 0, 8, 16, or 32)
30    pub tf_bit_width: u8,
31    /// First doc_id in this block (absolute)
32    pub first_doc_id: u32,
33    /// Last doc_id in this block (absolute)
34    pub last_doc_id: u32,
35    /// Number of docs in this block
36    pub num_docs: u16,
37    /// Maximum term frequency in this block
38    pub max_tf: u32,
39    /// Maximum impact score in this block (for MaxScore/WAND)
40    pub max_block_score: f32,
41}
42
43impl RoundedBP128Block {
44    /// Serialize the block
45    pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
46        writer.write_u32::<LittleEndian>(self.first_doc_id)?;
47        writer.write_u32::<LittleEndian>(self.last_doc_id)?;
48        writer.write_u16::<LittleEndian>(self.num_docs)?;
49        writer.write_u8(self.doc_bit_width)?;
50        writer.write_u8(self.tf_bit_width)?;
51        writer.write_u32::<LittleEndian>(self.max_tf)?;
52        writer.write_f32::<LittleEndian>(self.max_block_score)?;
53
54        // Write doc deltas
55        writer.write_u16::<LittleEndian>(self.doc_deltas.len() as u16)?;
56        writer.write_all(&self.doc_deltas)?;
57
58        // Write term freqs
59        writer.write_u16::<LittleEndian>(self.term_freqs.len() as u16)?;
60        writer.write_all(&self.term_freqs)?;
61
62        Ok(())
63    }
64
65    /// Deserialize a block
66    pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
67        let first_doc_id = reader.read_u32::<LittleEndian>()?;
68        let last_doc_id = reader.read_u32::<LittleEndian>()?;
69        let num_docs = reader.read_u16::<LittleEndian>()?;
70        let doc_bit_width = reader.read_u8()?;
71        let tf_bit_width = reader.read_u8()?;
72        let max_tf = reader.read_u32::<LittleEndian>()?;
73        let max_block_score = reader.read_f32::<LittleEndian>()?;
74
75        let doc_deltas_len = reader.read_u16::<LittleEndian>()? as usize;
76        let mut doc_deltas = vec![0u8; doc_deltas_len];
77        reader.read_exact(&mut doc_deltas)?;
78
79        let term_freqs_len = reader.read_u16::<LittleEndian>()? as usize;
80        let mut term_freqs = vec![0u8; term_freqs_len];
81        reader.read_exact(&mut term_freqs)?;
82
83        Ok(Self {
84            doc_deltas,
85            doc_bit_width,
86            term_freqs,
87            tf_bit_width,
88            first_doc_id,
89            last_doc_id,
90            num_docs,
91            max_tf,
92            max_block_score,
93        })
94    }
95
96    /// Decode doc_ids from this block using SIMD-friendly rounded unpacking
97    pub fn decode_doc_ids(&self) -> Vec<u32> {
98        let mut doc_ids = vec![0u32; self.num_docs as usize];
99        self.decode_doc_ids_into(&mut doc_ids);
100        doc_ids
101    }
102
103    /// Decode doc_ids into a pre-allocated buffer (avoids allocation)
104    #[inline]
105    pub fn decode_doc_ids_into(&self, output: &mut [u32]) -> usize {
106        let n = self.num_docs as usize;
107
108        if n == 0 {
109            return 0;
110        }
111
112        output[0] = self.first_doc_id;
113
114        if n == 1 {
115            return 1;
116        }
117
118        // Use fused unpack + delta decode for best performance
119        let rounded_width = RoundedBitWidth::from_u8(self.doc_bit_width);
120        simd::unpack_rounded_delta_decode(
121            &self.doc_deltas,
122            rounded_width,
123            output,
124            self.first_doc_id,
125            n,
126        );
127
128        n
129    }
130
131    /// Decode term frequencies using SIMD-friendly rounded unpacking
132    pub fn decode_term_freqs(&self) -> Vec<u32> {
133        let mut tfs = vec![0u32; self.num_docs as usize];
134        self.decode_term_freqs_into(&mut tfs);
135        tfs
136    }
137
138    /// Decode term frequencies into a pre-allocated buffer (avoids allocation)
139    #[inline]
140    pub fn decode_term_freqs_into(&self, output: &mut [u32]) -> usize {
141        let n = self.num_docs as usize;
142
143        if n == 0 {
144            return 0;
145        }
146
147        // Unpack using rounded bit width (fast SIMD path)
148        let rounded_width = RoundedBitWidth::from_u8(self.tf_bit_width);
149        simd::unpack_rounded(&self.term_freqs, rounded_width, output, n);
150
151        // Add 1 back (we stored tf-1)
152        simd::add_one(output, n);
153
154        n
155    }
156}
157
158/// Rounded BP128 posting list with block-level skip info
159///
160/// Uses rounded bit widths (0, 8, 16, 32) for faster SIMD decoding
161/// at the cost of larger index size.
162#[derive(Debug, Clone)]
163pub struct RoundedBP128PostingList {
164    /// Blocks of postings
165    pub blocks: Vec<RoundedBP128Block>,
166    /// Total document count
167    pub doc_count: u32,
168    /// Maximum score across all blocks (for MaxScore pruning)
169    pub max_score: f32,
170}
171
172impl RoundedBP128PostingList {
173    /// Create from raw doc_ids and term frequencies
174    pub fn from_postings(doc_ids: &[u32], term_freqs: &[u32], idf: f32) -> Self {
175        assert_eq!(doc_ids.len(), term_freqs.len());
176
177        if doc_ids.is_empty() {
178            return Self {
179                blocks: Vec::new(),
180                doc_count: 0,
181                max_score: 0.0,
182            };
183        }
184
185        let mut blocks = Vec::new();
186        let mut max_score = 0.0f32;
187        let mut i = 0;
188
189        while i < doc_ids.len() {
190            let block_end = (i + ROUNDED_BP128_BLOCK_SIZE).min(doc_ids.len());
191            let block_docs = &doc_ids[i..block_end];
192            let block_tfs = &term_freqs[i..block_end];
193
194            let block = Self::create_block(block_docs, block_tfs, idf);
195            max_score = max_score.max(block.max_block_score);
196            blocks.push(block);
197
198            i = block_end;
199        }
200
201        Self {
202            blocks,
203            doc_count: doc_ids.len() as u32,
204            max_score,
205        }
206    }
207
208    /// BM25F parameters for block-max score calculation
209    const K1: f32 = 1.2;
210    const B: f32 = 0.75;
211
212    /// Compute BM25F upper bound score for a given max_tf and IDF
213    #[inline]
214    pub fn compute_bm25f_upper_bound(max_tf: u32, idf: f32, field_boost: f32) -> f32 {
215        let tf = max_tf as f32;
216        let min_length_norm = 1.0 - Self::B;
217        let tf_norm =
218            (tf * field_boost * (Self::K1 + 1.0)) / (tf * field_boost + Self::K1 * min_length_norm);
219        idf * tf_norm
220    }
221
222    fn create_block(doc_ids: &[u32], term_freqs: &[u32], idf: f32) -> RoundedBP128Block {
223        let num_docs = doc_ids.len();
224        let first_doc_id = doc_ids[0];
225        let last_doc_id = *doc_ids.last().unwrap();
226
227        // Compute deltas (delta - 1 to save one bit since deltas are always >= 1)
228        let mut deltas = [0u32; ROUNDED_BP128_BLOCK_SIZE];
229        let mut max_delta = 0u32;
230        for j in 1..num_docs {
231            let delta = doc_ids[j] - doc_ids[j - 1] - 1;
232            deltas[j - 1] = delta;
233            max_delta = max_delta.max(delta);
234        }
235
236        // Compute max TF and prepare TF array (store tf-1)
237        let mut tfs = [0u32; ROUNDED_BP128_BLOCK_SIZE];
238        let mut max_tf = 0u32;
239
240        for (j, &tf) in term_freqs.iter().enumerate() {
241            tfs[j] = tf - 1; // Store tf-1
242            max_tf = max_tf.max(tf);
243        }
244
245        let max_block_score = Self::compute_bm25f_upper_bound(max_tf, idf, 1.0);
246
247        // Use rounded bit widths for SIMD-friendly decoding
248        let exact_doc_bits = simd::bits_needed(max_delta);
249        let exact_tf_bits = simd::bits_needed(max_tf.saturating_sub(1));
250
251        let doc_rounded = RoundedBitWidth::from_exact(exact_doc_bits);
252        let tf_rounded = RoundedBitWidth::from_exact(exact_tf_bits);
253
254        // Pack with rounded bit widths
255        let mut doc_deltas = vec![0u8; num_docs.saturating_sub(1) * doc_rounded.bytes_per_value()];
256        if num_docs > 1 {
257            simd::pack_rounded(&deltas[..num_docs - 1], doc_rounded, &mut doc_deltas);
258        }
259
260        let mut term_freqs_packed = vec![0u8; num_docs * tf_rounded.bytes_per_value()];
261        simd::pack_rounded(&tfs[..num_docs], tf_rounded, &mut term_freqs_packed);
262
263        RoundedBP128Block {
264            doc_deltas,
265            doc_bit_width: doc_rounded.as_u8(),
266            term_freqs: term_freqs_packed,
267            tf_bit_width: tf_rounded.as_u8(),
268            first_doc_id,
269            last_doc_id,
270            num_docs: num_docs as u16,
271            max_tf,
272            max_block_score,
273        }
274    }
275
276    /// Serialize the posting list
277    pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
278        writer.write_u32::<LittleEndian>(self.doc_count)?;
279        writer.write_f32::<LittleEndian>(self.max_score)?;
280        writer.write_u32::<LittleEndian>(self.blocks.len() as u32)?;
281
282        for block in &self.blocks {
283            block.serialize(writer)?;
284        }
285
286        Ok(())
287    }
288
289    /// Deserialize a posting list
290    pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
291        let doc_count = reader.read_u32::<LittleEndian>()?;
292        let max_score = reader.read_f32::<LittleEndian>()?;
293        let num_blocks = reader.read_u32::<LittleEndian>()? as usize;
294
295        let mut blocks = Vec::with_capacity(num_blocks);
296        for _ in 0..num_blocks {
297            blocks.push(RoundedBP128Block::deserialize(reader)?);
298        }
299
300        Ok(Self {
301            blocks,
302            doc_count,
303            max_score,
304        })
305    }
306
307    /// Create an iterator
308    pub fn iterator(&self) -> RoundedBP128Iterator<'_> {
309        RoundedBP128Iterator::new(self)
310    }
311
312    /// Get number of documents
313    pub fn len(&self) -> u32 {
314        self.doc_count
315    }
316
317    /// Check if empty
318    pub fn is_empty(&self) -> bool {
319        self.doc_count == 0
320    }
321}
322
323/// Iterator over rounded BP128 posting list with block skipping support
324pub struct RoundedBP128Iterator<'a> {
325    posting_list: &'a RoundedBP128PostingList,
326    current_block: usize,
327    position_in_block: usize,
328    /// Number of valid elements in current block
329    current_block_len: usize,
330    /// Pre-allocated buffer for decoded doc_ids (avoids allocation per block)
331    decoded_doc_ids: Vec<u32>,
332    /// Pre-allocated buffer for decoded term frequencies
333    decoded_tfs: Vec<u32>,
334}
335
336impl<'a> RoundedBP128Iterator<'a> {
337    pub fn new(posting_list: &'a RoundedBP128PostingList) -> Self {
338        // Pre-allocate buffers to block size to avoid allocations during iteration
339        let mut iter = Self {
340            posting_list,
341            current_block: 0,
342            position_in_block: 0,
343            current_block_len: 0,
344            decoded_doc_ids: vec![0u32; ROUNDED_BP128_BLOCK_SIZE],
345            decoded_tfs: vec![0u32; ROUNDED_BP128_BLOCK_SIZE],
346        };
347
348        if !posting_list.blocks.is_empty() {
349            iter.decode_current_block();
350        }
351
352        iter
353    }
354
355    #[inline]
356    fn decode_current_block(&mut self) {
357        if self.current_block < self.posting_list.blocks.len() {
358            let block = &self.posting_list.blocks[self.current_block];
359            // Decode into pre-allocated buffers (no allocation!)
360            self.current_block_len = block.decode_doc_ids_into(&mut self.decoded_doc_ids);
361            block.decode_term_freqs_into(&mut self.decoded_tfs);
362        } else {
363            self.current_block_len = 0;
364        }
365    }
366
367    /// Current document ID
368    #[inline]
369    pub fn doc(&self) -> u32 {
370        if self.current_block >= self.posting_list.blocks.len() {
371            return u32::MAX;
372        }
373        if self.position_in_block >= self.current_block_len {
374            return u32::MAX;
375        }
376        self.decoded_doc_ids[self.position_in_block]
377    }
378
379    /// Current term frequency
380    #[inline]
381    pub fn term_freq(&self) -> u32 {
382        if self.current_block >= self.posting_list.blocks.len() {
383            return 0;
384        }
385        if self.position_in_block >= self.current_block_len {
386            return 0;
387        }
388        self.decoded_tfs[self.position_in_block]
389    }
390
391    /// Advance to next posting
392    #[inline]
393    pub fn advance(&mut self) -> u32 {
394        self.position_in_block += 1;
395
396        if self.position_in_block >= self.current_block_len {
397            self.current_block += 1;
398            self.position_in_block = 0;
399
400            if self.current_block < self.posting_list.blocks.len() {
401                self.decode_current_block();
402            }
403        }
404
405        self.doc()
406    }
407
408    /// Seek to first doc_id >= target
409    pub fn seek(&mut self, target: u32) -> u32 {
410        // Skip blocks where last_doc_id < target
411        while self.current_block < self.posting_list.blocks.len() {
412            let block = &self.posting_list.blocks[self.current_block];
413            if block.last_doc_id >= target {
414                break;
415            }
416            self.current_block += 1;
417            self.position_in_block = 0;
418        }
419
420        if self.current_block >= self.posting_list.blocks.len() {
421            return u32::MAX;
422        }
423
424        // Decode block if needed (check if we're on the right block)
425        let block = &self.posting_list.blocks[self.current_block];
426        if self.current_block_len == 0
427            || self.position_in_block >= self.current_block_len
428            || (self.position_in_block == 0 && self.decoded_doc_ids[0] != block.first_doc_id)
429        {
430            self.decode_current_block();
431            self.position_in_block = 0;
432        }
433
434        // Binary search within block
435        let start = self.position_in_block;
436        let slice = &self.decoded_doc_ids[start..self.current_block_len];
437        match slice.binary_search(&target) {
438            Ok(pos) => {
439                self.position_in_block = start + pos;
440            }
441            Err(pos) => {
442                if pos < slice.len() {
443                    self.position_in_block = start + pos;
444                } else {
445                    // Move to next block
446                    self.current_block += 1;
447                    self.position_in_block = 0;
448                    if self.current_block < self.posting_list.blocks.len() {
449                        self.decode_current_block();
450                        return self.seek(target);
451                    }
452                    return u32::MAX;
453                }
454            }
455        }
456
457        self.doc()
458    }
459
460    /// Get block max score for current block (for WAND/MaxScore)
461    #[inline]
462    pub fn block_max_score(&self) -> f32 {
463        if self.current_block < self.posting_list.blocks.len() {
464            self.posting_list.blocks[self.current_block].max_block_score
465        } else {
466            0.0
467        }
468    }
469}
470
471#[cfg(test)]
472mod tests {
473    use super::*;
474
475    #[test]
476    fn test_rounded_bp128_basic() {
477        let doc_ids: Vec<u32> = vec![1, 5, 10, 15, 20];
478        let term_freqs: Vec<u32> = vec![1, 2, 3, 4, 5];
479
480        let posting_list = RoundedBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
481        assert_eq!(posting_list.doc_count, 5);
482
483        let mut iter = posting_list.iterator();
484        for (i, (&expected_doc, &expected_tf)) in doc_ids.iter().zip(term_freqs.iter()).enumerate()
485        {
486            assert_eq!(iter.doc(), expected_doc, "Doc mismatch at {}", i);
487            assert_eq!(iter.term_freq(), expected_tf, "TF mismatch at {}", i);
488            iter.advance();
489        }
490        assert_eq!(iter.doc(), u32::MAX);
491    }
492
493    #[test]
494    fn test_rounded_bp128_large_block() {
495        // Test with a full 128-element block
496        let doc_ids: Vec<u32> = (0..128).map(|i| i * 5 + 100).collect();
497        let term_freqs: Vec<u32> = vec![1; 128];
498
499        let posting_list = RoundedBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
500        let decoded = posting_list.blocks[0].decode_doc_ids();
501
502        assert_eq!(decoded.len(), 128);
503        for (i, (&expected, &actual)) in doc_ids.iter().zip(decoded.iter()).enumerate() {
504            assert_eq!(expected, actual, "Mismatch at position {}", i);
505        }
506    }
507
508    #[test]
509    fn test_rounded_bp128_serialization() {
510        let doc_ids: Vec<u32> = (0..200).map(|i| i * 7 + 100).collect();
511        let term_freqs: Vec<u32> = (0..200).map(|i| (i % 5) as u32 + 1).collect();
512
513        let posting_list = RoundedBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
514
515        let mut buffer = Vec::new();
516        posting_list.serialize(&mut buffer).unwrap();
517
518        let restored = RoundedBP128PostingList::deserialize(&mut &buffer[..]).unwrap();
519        assert_eq!(restored.doc_count, posting_list.doc_count);
520
521        // Verify iteration produces same results
522        let mut iter1 = posting_list.iterator();
523        let mut iter2 = restored.iterator();
524
525        while iter1.doc() != u32::MAX {
526            assert_eq!(iter1.doc(), iter2.doc());
527            assert_eq!(iter1.term_freq(), iter2.term_freq());
528            iter1.advance();
529            iter2.advance();
530        }
531    }
532
533    #[test]
534    fn test_rounded_bp128_seek() {
535        let doc_ids: Vec<u32> = vec![10, 20, 30, 100, 200, 300, 1000, 2000];
536        let term_freqs: Vec<u32> = vec![1, 2, 3, 4, 5, 6, 7, 8];
537
538        let posting_list = RoundedBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
539        let mut iter = posting_list.iterator();
540
541        assert_eq!(iter.seek(25), 30);
542        assert_eq!(iter.seek(100), 100);
543        assert_eq!(iter.seek(500), 1000);
544        assert_eq!(iter.seek(3000), u32::MAX);
545    }
546
547    #[test]
548    fn test_rounded_bit_widths() {
549        // Test that bit widths are actually rounded
550        let doc_ids: Vec<u32> = (0..128).map(|i| i * 100).collect(); // Large gaps -> needs >8 bits
551        let term_freqs: Vec<u32> = vec![1; 128]; // Small TFs -> 0 bits (all zeros after -1)
552
553        let posting_list = RoundedBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
554        let block = &posting_list.blocks[0];
555
556        // Doc bit width should be rounded to 8, 16, or 32
557        assert!(
558            block.doc_bit_width == 0
559                || block.doc_bit_width == 8
560                || block.doc_bit_width == 16
561                || block.doc_bit_width == 32,
562            "Doc bit width {} is not rounded",
563            block.doc_bit_width
564        );
565
566        // TF bit width should be rounded
567        assert!(
568            block.tf_bit_width == 0
569                || block.tf_bit_width == 8
570                || block.tf_bit_width == 16
571                || block.tf_bit_width == 32,
572            "TF bit width {} is not rounded",
573            block.tf_bit_width
574        );
575    }
576
577    #[test]
578    fn test_rounded_vs_exact_correctness() {
579        // Verify rounded produces same decoded values as exact
580        use super::super::horizontal_bp128::HorizontalBP128PostingList;
581
582        let doc_ids: Vec<u32> = (0..200).map(|i| i * 7 + 100).collect();
583        let term_freqs: Vec<u32> = (0..200).map(|i| (i % 5) as u32 + 1).collect();
584
585        let exact = HorizontalBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
586        let rounded = RoundedBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
587
588        // Rounded should be larger (worse compression)
589        let mut exact_buf = Vec::new();
590        exact.serialize(&mut exact_buf).unwrap();
591        let mut rounded_buf = Vec::new();
592        rounded.serialize(&mut rounded_buf).unwrap();
593
594        assert!(
595            rounded_buf.len() >= exact_buf.len(),
596            "Rounded ({}) should be >= exact ({})",
597            rounded_buf.len(),
598            exact_buf.len()
599        );
600
601        // But both should decode to the same values
602        let mut exact_iter = exact.iterator();
603        let mut rounded_iter = rounded.iterator();
604
605        while exact_iter.doc() != u32::MAX {
606            assert_eq!(exact_iter.doc(), rounded_iter.doc());
607            assert_eq!(exact_iter.term_freq(), rounded_iter.term_freq());
608            exact_iter.advance();
609            rounded_iter.advance();
610        }
611        assert_eq!(rounded_iter.doc(), u32::MAX);
612    }
613}