Skip to main content

hermes_core/structures/postings/
rounded_bp128.rs

1//! Rounded BP128 posting list format with SIMD-friendly bit widths
2//!
3//! This format rounds bit widths to 8, 16, or 32 bits for faster SIMD decoding
4//! at the cost of ~10-100% more space compared to exact bitpacking.
5//!
6//! Use this format when:
7//! - Query latency is more important than index size
8//! - You have sufficient storage/memory
9//! - Your workload is read-heavy
10//!
11//! The tradeoff: ~2-4x faster decoding for ~20-60% larger posting lists.
12
13use crate::structures::simd::{self, RoundedBitWidth};
14use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
15use std::io::{self, Read, Write};
16
17/// Block size for rounded bitpacking (128 integers per block for SIMD alignment)
18pub const ROUNDED_BP128_BLOCK_SIZE: usize = 128;
19
20/// Rounded bitpacked block with skip info for BlockWAND
21#[derive(Debug, Clone)]
22pub struct RoundedBP128Block {
23    /// Delta-encoded doc_ids (rounded bitpacked: 8/16/32 bits)
24    pub doc_deltas: Vec<u8>,
25    /// Bit width for doc deltas (always 0, 8, 16, or 32)
26    pub doc_bit_width: u8,
27    /// Term frequencies (rounded bitpacked: 8/16/32 bits)
28    pub term_freqs: Vec<u8>,
29    /// Bit width for term frequencies (always 0, 8, 16, or 32)
30    pub tf_bit_width: u8,
31    /// First doc_id in this block (absolute)
32    pub first_doc_id: u32,
33    /// Last doc_id in this block (absolute)
34    pub last_doc_id: u32,
35    /// Number of docs in this block
36    pub num_docs: u16,
37    /// Maximum term frequency in this block
38    pub max_tf: u32,
39    /// Maximum impact score in this block (for MaxScore/WAND)
40    pub max_block_score: f32,
41}
42
43impl RoundedBP128Block {
44    /// Serialize the block
45    pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
46        writer.write_u32::<LittleEndian>(self.first_doc_id)?;
47        writer.write_u32::<LittleEndian>(self.last_doc_id)?;
48        writer.write_u16::<LittleEndian>(self.num_docs)?;
49        writer.write_u8(self.doc_bit_width)?;
50        writer.write_u8(self.tf_bit_width)?;
51        writer.write_u32::<LittleEndian>(self.max_tf)?;
52        writer.write_f32::<LittleEndian>(self.max_block_score)?;
53
54        // Write doc deltas
55        writer.write_u16::<LittleEndian>(self.doc_deltas.len() as u16)?;
56        writer.write_all(&self.doc_deltas)?;
57
58        // Write term freqs
59        writer.write_u16::<LittleEndian>(self.term_freqs.len() as u16)?;
60        writer.write_all(&self.term_freqs)?;
61
62        Ok(())
63    }
64
65    /// Deserialize a block
66    pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
67        let first_doc_id = reader.read_u32::<LittleEndian>()?;
68        let last_doc_id = reader.read_u32::<LittleEndian>()?;
69        let num_docs = reader.read_u16::<LittleEndian>()?;
70        let doc_bit_width = reader.read_u8()?;
71        let tf_bit_width = reader.read_u8()?;
72        let max_tf = reader.read_u32::<LittleEndian>()?;
73        let max_block_score = reader.read_f32::<LittleEndian>()?;
74
75        let doc_deltas_len = reader.read_u16::<LittleEndian>()? as usize;
76        let mut doc_deltas = vec![0u8; doc_deltas_len];
77        reader.read_exact(&mut doc_deltas)?;
78
79        let term_freqs_len = reader.read_u16::<LittleEndian>()? as usize;
80        let mut term_freqs = vec![0u8; term_freqs_len];
81        reader.read_exact(&mut term_freqs)?;
82
83        Ok(Self {
84            doc_deltas,
85            doc_bit_width,
86            term_freqs,
87            tf_bit_width,
88            first_doc_id,
89            last_doc_id,
90            num_docs,
91            max_tf,
92            max_block_score,
93        })
94    }
95
96    /// Decode doc_ids from this block using SIMD-friendly rounded unpacking
97    pub fn decode_doc_ids(&self) -> Vec<u32> {
98        let mut doc_ids = vec![0u32; self.num_docs as usize];
99        self.decode_doc_ids_into(&mut doc_ids);
100        doc_ids
101    }
102
103    /// Decode doc_ids into a pre-allocated buffer (avoids allocation)
104    #[inline]
105    pub fn decode_doc_ids_into(&self, output: &mut [u32]) -> usize {
106        let n = self.num_docs as usize;
107
108        if n == 0 {
109            return 0;
110        }
111
112        output[0] = self.first_doc_id;
113
114        if n == 1 {
115            return 1;
116        }
117
118        // Use fused unpack + delta decode for best performance
119        let rounded_width = RoundedBitWidth::from_u8(self.doc_bit_width);
120        simd::unpack_rounded_delta_decode(
121            &self.doc_deltas,
122            rounded_width,
123            output,
124            self.first_doc_id,
125            n,
126        );
127
128        n
129    }
130
131    /// Decode term frequencies using SIMD-friendly rounded unpacking
132    pub fn decode_term_freqs(&self) -> Vec<u32> {
133        let mut tfs = vec![0u32; self.num_docs as usize];
134        self.decode_term_freqs_into(&mut tfs);
135        tfs
136    }
137
138    /// Decode term frequencies into a pre-allocated buffer (avoids allocation)
139    #[inline]
140    pub fn decode_term_freqs_into(&self, output: &mut [u32]) -> usize {
141        let n = self.num_docs as usize;
142
143        if n == 0 {
144            return 0;
145        }
146
147        // Unpack using rounded bit width (fast SIMD path)
148        let rounded_width = RoundedBitWidth::from_u8(self.tf_bit_width);
149        simd::unpack_rounded(&self.term_freqs, rounded_width, output, n);
150
151        // Add 1 back (we stored tf-1)
152        simd::add_one(output, n);
153
154        n
155    }
156}
157
158/// Rounded BP128 posting list with block-level skip info
159///
160/// Uses rounded bit widths (0, 8, 16, 32) for faster SIMD decoding
161/// at the cost of larger index size.
162#[derive(Debug, Clone)]
163pub struct RoundedBP128PostingList {
164    /// Blocks of postings
165    pub blocks: Vec<RoundedBP128Block>,
166    /// Total document count
167    pub doc_count: u32,
168    /// Maximum score across all blocks (for MaxScore pruning)
169    pub max_score: f32,
170}
171
172impl RoundedBP128PostingList {
173    /// Create from raw doc_ids and term frequencies
174    pub fn from_postings(doc_ids: &[u32], term_freqs: &[u32], idf: f32) -> Self {
175        assert_eq!(doc_ids.len(), term_freqs.len());
176
177        if doc_ids.is_empty() {
178            return Self {
179                blocks: Vec::new(),
180                doc_count: 0,
181                max_score: 0.0,
182            };
183        }
184
185        let mut blocks = Vec::new();
186        let mut max_score = 0.0f32;
187        let mut i = 0;
188
189        while i < doc_ids.len() {
190            let block_end = (i + ROUNDED_BP128_BLOCK_SIZE).min(doc_ids.len());
191            let block_docs = &doc_ids[i..block_end];
192            let block_tfs = &term_freqs[i..block_end];
193
194            let block = Self::create_block(block_docs, block_tfs, idf);
195            max_score = max_score.max(block.max_block_score);
196            blocks.push(block);
197
198            i = block_end;
199        }
200
201        Self {
202            blocks,
203            doc_count: doc_ids.len() as u32,
204            max_score,
205        }
206    }
207
208    fn create_block(doc_ids: &[u32], term_freqs: &[u32], idf: f32) -> RoundedBP128Block {
209        let num_docs = doc_ids.len();
210        let first_doc_id = doc_ids[0];
211        let last_doc_id = *doc_ids.last().unwrap();
212
213        // Compute deltas (delta - 1 to save one bit since deltas are always >= 1)
214        let mut deltas = [0u32; ROUNDED_BP128_BLOCK_SIZE];
215        let mut max_delta = 0u32;
216        for j in 1..num_docs {
217            let delta = doc_ids[j] - doc_ids[j - 1] - 1;
218            deltas[j - 1] = delta;
219            max_delta = max_delta.max(delta);
220        }
221
222        // Compute max TF and prepare TF array (store tf-1)
223        let mut tfs = [0u32; ROUNDED_BP128_BLOCK_SIZE];
224        let mut max_tf = 0u32;
225
226        for (j, &tf) in term_freqs.iter().enumerate() {
227            tfs[j] = tf - 1; // Store tf-1
228            max_tf = max_tf.max(tf);
229        }
230
231        let max_block_score = crate::query::bm25_upper_bound(max_tf as f32, idf);
232
233        // Use rounded bit widths for SIMD-friendly decoding
234        let exact_doc_bits = simd::bits_needed(max_delta);
235        let exact_tf_bits = simd::bits_needed(max_tf.saturating_sub(1));
236
237        let doc_rounded = RoundedBitWidth::from_exact(exact_doc_bits);
238        let tf_rounded = RoundedBitWidth::from_exact(exact_tf_bits);
239
240        // Pack with rounded bit widths
241        let mut doc_deltas = vec![0u8; num_docs.saturating_sub(1) * doc_rounded.bytes_per_value()];
242        if num_docs > 1 {
243            simd::pack_rounded(&deltas[..num_docs - 1], doc_rounded, &mut doc_deltas);
244        }
245
246        let mut term_freqs_packed = vec![0u8; num_docs * tf_rounded.bytes_per_value()];
247        simd::pack_rounded(&tfs[..num_docs], tf_rounded, &mut term_freqs_packed);
248
249        RoundedBP128Block {
250            doc_deltas,
251            doc_bit_width: doc_rounded.as_u8(),
252            term_freqs: term_freqs_packed,
253            tf_bit_width: tf_rounded.as_u8(),
254            first_doc_id,
255            last_doc_id,
256            num_docs: num_docs as u16,
257            max_tf,
258            max_block_score,
259        }
260    }
261
262    /// Serialize the posting list
263    pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
264        writer.write_u32::<LittleEndian>(self.doc_count)?;
265        writer.write_f32::<LittleEndian>(self.max_score)?;
266        writer.write_u32::<LittleEndian>(self.blocks.len() as u32)?;
267
268        for block in &self.blocks {
269            block.serialize(writer)?;
270        }
271
272        Ok(())
273    }
274
275    /// Deserialize a posting list
276    pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
277        let doc_count = reader.read_u32::<LittleEndian>()?;
278        let max_score = reader.read_f32::<LittleEndian>()?;
279        let num_blocks = reader.read_u32::<LittleEndian>()? as usize;
280
281        let mut blocks = Vec::with_capacity(num_blocks);
282        for _ in 0..num_blocks {
283            blocks.push(RoundedBP128Block::deserialize(reader)?);
284        }
285
286        Ok(Self {
287            blocks,
288            doc_count,
289            max_score,
290        })
291    }
292
293    /// Create an iterator
294    pub fn iterator(&self) -> RoundedBP128Iterator<'_> {
295        RoundedBP128Iterator::new(self)
296    }
297
298    /// Get number of documents
299    pub fn len(&self) -> u32 {
300        self.doc_count
301    }
302
303    /// Check if empty
304    pub fn is_empty(&self) -> bool {
305        self.doc_count == 0
306    }
307}
308
309/// Iterator over rounded BP128 posting list with block skipping support
310pub struct RoundedBP128Iterator<'a> {
311    posting_list: &'a RoundedBP128PostingList,
312    current_block: usize,
313    position_in_block: usize,
314    /// Number of valid elements in current block
315    current_block_len: usize,
316    /// Pre-allocated buffer for decoded doc_ids (avoids allocation per block)
317    decoded_doc_ids: Vec<u32>,
318    /// Pre-allocated buffer for decoded term frequencies
319    decoded_tfs: Vec<u32>,
320}
321
322impl<'a> RoundedBP128Iterator<'a> {
323    pub fn new(posting_list: &'a RoundedBP128PostingList) -> Self {
324        // Pre-allocate buffers to block size to avoid allocations during iteration
325        let mut iter = Self {
326            posting_list,
327            current_block: 0,
328            position_in_block: 0,
329            current_block_len: 0,
330            decoded_doc_ids: vec![0u32; ROUNDED_BP128_BLOCK_SIZE],
331            decoded_tfs: vec![0u32; ROUNDED_BP128_BLOCK_SIZE],
332        };
333
334        if !posting_list.blocks.is_empty() {
335            iter.decode_current_block();
336        }
337
338        iter
339    }
340
341    #[inline]
342    fn decode_current_block(&mut self) {
343        if self.current_block < self.posting_list.blocks.len() {
344            let block = &self.posting_list.blocks[self.current_block];
345            // Decode into pre-allocated buffers (no allocation!)
346            self.current_block_len = block.decode_doc_ids_into(&mut self.decoded_doc_ids);
347            block.decode_term_freqs_into(&mut self.decoded_tfs);
348        } else {
349            self.current_block_len = 0;
350        }
351    }
352
353    /// Current document ID
354    #[inline]
355    pub fn doc(&self) -> u32 {
356        if self.current_block >= self.posting_list.blocks.len() {
357            return u32::MAX;
358        }
359        if self.position_in_block >= self.current_block_len {
360            return u32::MAX;
361        }
362        self.decoded_doc_ids[self.position_in_block]
363    }
364
365    /// Current term frequency
366    #[inline]
367    pub fn term_freq(&self) -> u32 {
368        if self.current_block >= self.posting_list.blocks.len() {
369            return 0;
370        }
371        if self.position_in_block >= self.current_block_len {
372            return 0;
373        }
374        self.decoded_tfs[self.position_in_block]
375    }
376
377    /// Advance to next posting
378    #[inline]
379    pub fn advance(&mut self) -> u32 {
380        self.position_in_block += 1;
381
382        if self.position_in_block >= self.current_block_len {
383            self.current_block += 1;
384            self.position_in_block = 0;
385
386            if self.current_block < self.posting_list.blocks.len() {
387                self.decode_current_block();
388            }
389        }
390
391        self.doc()
392    }
393
394    /// Seek to first doc_id >= target
395    pub fn seek(&mut self, target: u32) -> u32 {
396        // Skip blocks where last_doc_id < target
397        while self.current_block < self.posting_list.blocks.len() {
398            let block = &self.posting_list.blocks[self.current_block];
399            if block.last_doc_id >= target {
400                break;
401            }
402            self.current_block += 1;
403            self.position_in_block = 0;
404        }
405
406        if self.current_block >= self.posting_list.blocks.len() {
407            return u32::MAX;
408        }
409
410        // Decode block if needed (check if we're on the right block)
411        let block = &self.posting_list.blocks[self.current_block];
412        if self.current_block_len == 0
413            || self.position_in_block >= self.current_block_len
414            || (self.position_in_block == 0 && self.decoded_doc_ids[0] != block.first_doc_id)
415        {
416            self.decode_current_block();
417            self.position_in_block = 0;
418        }
419
420        // Binary search within block
421        let start = self.position_in_block;
422        let slice = &self.decoded_doc_ids[start..self.current_block_len];
423        match slice.binary_search(&target) {
424            Ok(pos) => {
425                self.position_in_block = start + pos;
426            }
427            Err(pos) => {
428                if pos < slice.len() {
429                    self.position_in_block = start + pos;
430                } else {
431                    // Move to next block
432                    self.current_block += 1;
433                    self.position_in_block = 0;
434                    if self.current_block < self.posting_list.blocks.len() {
435                        self.decode_current_block();
436                        return self.seek(target);
437                    }
438                    return u32::MAX;
439                }
440            }
441        }
442
443        self.doc()
444    }
445
446    /// Get block max score for current block (for WAND/MaxScore)
447    #[inline]
448    pub fn block_max_score(&self) -> f32 {
449        if self.current_block < self.posting_list.blocks.len() {
450            self.posting_list.blocks[self.current_block].max_block_score
451        } else {
452            0.0
453        }
454    }
455}
456
457#[cfg(test)]
458mod tests {
459    use super::*;
460
461    #[test]
462    fn test_rounded_bp128_basic() {
463        let doc_ids: Vec<u32> = vec![1, 5, 10, 15, 20];
464        let term_freqs: Vec<u32> = vec![1, 2, 3, 4, 5];
465
466        let posting_list = RoundedBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
467        assert_eq!(posting_list.doc_count, 5);
468
469        let mut iter = posting_list.iterator();
470        for (i, (&expected_doc, &expected_tf)) in doc_ids.iter().zip(term_freqs.iter()).enumerate()
471        {
472            assert_eq!(iter.doc(), expected_doc, "Doc mismatch at {}", i);
473            assert_eq!(iter.term_freq(), expected_tf, "TF mismatch at {}", i);
474            iter.advance();
475        }
476        assert_eq!(iter.doc(), u32::MAX);
477    }
478
479    #[test]
480    fn test_rounded_bp128_large_block() {
481        // Test with a full 128-element block
482        let doc_ids: Vec<u32> = (0..128).map(|i| i * 5 + 100).collect();
483        let term_freqs: Vec<u32> = vec![1; 128];
484
485        let posting_list = RoundedBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
486        let decoded = posting_list.blocks[0].decode_doc_ids();
487
488        assert_eq!(decoded.len(), 128);
489        for (i, (&expected, &actual)) in doc_ids.iter().zip(decoded.iter()).enumerate() {
490            assert_eq!(expected, actual, "Mismatch at position {}", i);
491        }
492    }
493
494    #[test]
495    fn test_rounded_bp128_serialization() {
496        let doc_ids: Vec<u32> = (0..200).map(|i| i * 7 + 100).collect();
497        let term_freqs: Vec<u32> = (0..200).map(|i| (i % 5) as u32 + 1).collect();
498
499        let posting_list = RoundedBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
500
501        let mut buffer = Vec::new();
502        posting_list.serialize(&mut buffer).unwrap();
503
504        let restored = RoundedBP128PostingList::deserialize(&mut &buffer[..]).unwrap();
505        assert_eq!(restored.doc_count, posting_list.doc_count);
506
507        // Verify iteration produces same results
508        let mut iter1 = posting_list.iterator();
509        let mut iter2 = restored.iterator();
510
511        while iter1.doc() != u32::MAX {
512            assert_eq!(iter1.doc(), iter2.doc());
513            assert_eq!(iter1.term_freq(), iter2.term_freq());
514            iter1.advance();
515            iter2.advance();
516        }
517    }
518
519    #[test]
520    fn test_rounded_bp128_seek() {
521        let doc_ids: Vec<u32> = vec![10, 20, 30, 100, 200, 300, 1000, 2000];
522        let term_freqs: Vec<u32> = vec![1, 2, 3, 4, 5, 6, 7, 8];
523
524        let posting_list = RoundedBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
525        let mut iter = posting_list.iterator();
526
527        assert_eq!(iter.seek(25), 30);
528        assert_eq!(iter.seek(100), 100);
529        assert_eq!(iter.seek(500), 1000);
530        assert_eq!(iter.seek(3000), u32::MAX);
531    }
532
533    #[test]
534    fn test_rounded_bit_widths() {
535        // Test that bit widths are actually rounded
536        let doc_ids: Vec<u32> = (0..128).map(|i| i * 100).collect(); // Large gaps -> needs >8 bits
537        let term_freqs: Vec<u32> = vec![1; 128]; // Small TFs -> 0 bits (all zeros after -1)
538
539        let posting_list = RoundedBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
540        let block = &posting_list.blocks[0];
541
542        // Doc bit width should be rounded to 8, 16, or 32
543        assert!(
544            block.doc_bit_width == 0
545                || block.doc_bit_width == 8
546                || block.doc_bit_width == 16
547                || block.doc_bit_width == 32,
548            "Doc bit width {} is not rounded",
549            block.doc_bit_width
550        );
551
552        // TF bit width should be rounded
553        assert!(
554            block.tf_bit_width == 0
555                || block.tf_bit_width == 8
556                || block.tf_bit_width == 16
557                || block.tf_bit_width == 32,
558            "TF bit width {} is not rounded",
559            block.tf_bit_width
560        );
561    }
562
563    #[test]
564    fn test_rounded_vs_exact_correctness() {
565        // Verify rounded produces same decoded values as exact
566        use super::super::horizontal_bp128::HorizontalBP128PostingList;
567
568        let doc_ids: Vec<u32> = (0..200).map(|i| i * 7 + 100).collect();
569        let term_freqs: Vec<u32> = (0..200).map(|i| (i % 5) as u32 + 1).collect();
570
571        let exact = HorizontalBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
572        let rounded = RoundedBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
573
574        // Rounded should be larger (worse compression)
575        let mut exact_buf = Vec::new();
576        exact.serialize(&mut exact_buf).unwrap();
577        let mut rounded_buf = Vec::new();
578        rounded.serialize(&mut rounded_buf).unwrap();
579
580        assert!(
581            rounded_buf.len() >= exact_buf.len(),
582            "Rounded ({}) should be >= exact ({})",
583            rounded_buf.len(),
584            exact_buf.len()
585        );
586
587        // But both should decode to the same values
588        let mut exact_iter = exact.iterator();
589        let mut rounded_iter = rounded.iterator();
590
591        while exact_iter.doc() != u32::MAX {
592            assert_eq!(exact_iter.doc(), rounded_iter.doc());
593            assert_eq!(exact_iter.term_freq(), rounded_iter.term_freq());
594            exact_iter.advance();
595            rounded_iter.advance();
596        }
597        assert_eq!(rounded_iter.doc(), u32::MAX);
598    }
599}