hermes_core/structures/
opt_p4d.rs

1//! OptP4D (Optimized Patched Frame-of-Reference Delta) posting list compression
2//!
3//! OptP4D is an improvement over PForDelta that finds the optimal bit width for each block
4//! by trying all possible bit widths and selecting the one that minimizes total storage.
5//!
6//! Key features:
7//! - Block-based compression (128 integers per block for SIMD alignment)
8//! - Delta encoding for doc IDs
9//! - Optimal bit-width selection per block
10//! - Patched coding: exceptions (values that don't fit) stored separately
11//! - Fast SIMD-friendly decoding with NEON (ARM) and SSE (x86) support
12//!
13//! Format per block:
14//! - Header: bit_width (5 bits) + num_exceptions (7 bits) + first_doc_id (32 bits)
15//! - Main array: 128 values packed at `bit_width` bits each
16//! - Exceptions: [position (7 bits), high_bits (32 - bit_width bits)] for each exception
17
18use super::simd;
19use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
20use std::io::{self, Read, Write};
21
22/// Block size for OptP4D (128 integers for SIMD alignment)
23pub const OPT_P4D_BLOCK_SIZE: usize = 128;
24
25/// Maximum number of exceptions before we increase bit width
26/// (keeping exceptions under ~10% of block for good compression)
27const MAX_EXCEPTIONS_RATIO: f32 = 0.10;
28
29/// Find the optimal bit width for a block of values
30/// Returns (bit_width, exception_count, total_bits)
31fn find_optimal_bit_width(values: &[u32]) -> (u8, usize, usize) {
32    if values.is_empty() {
33        return (0, 0, 0);
34    }
35
36    let n = values.len();
37    let max_exceptions = ((n as f32) * MAX_EXCEPTIONS_RATIO).ceil() as usize;
38
39    // Count how many values need each bit width
40    let mut bit_counts = [0usize; 33]; // bit_counts[b] = count of values needing exactly b bits
41    for &v in values {
42        let bits = simd::bits_needed(v) as usize;
43        bit_counts[bits] += 1;
44    }
45
46    // Compute cumulative counts: values that fit in b bits or less
47    let mut cumulative = [0usize; 33];
48    cumulative[0] = bit_counts[0];
49    for b in 1..=32 {
50        cumulative[b] = cumulative[b - 1] + bit_counts[b];
51    }
52
53    let mut best_bits = 32u8;
54    let mut best_total = usize::MAX;
55    let mut best_exceptions = 0usize;
56
57    // Try each bit width and compute total storage
58    for b in 0..=32u8 {
59        let fitting = if b == 0 {
60            bit_counts[0]
61        } else {
62            cumulative[b as usize]
63        };
64        let exceptions = n - fitting;
65
66        // Skip if too many exceptions
67        if exceptions > max_exceptions && b < 32 {
68            continue;
69        }
70
71        // Calculate total bits:
72        // - Main array: n * b bits
73        // - Exceptions: exceptions * (7 bits position + (32 - b) bits high value)
74        let main_bits = n * (b as usize);
75        let exception_bits = if b < 32 {
76            exceptions * (7 + (32 - b as usize))
77        } else {
78            0
79        };
80        let total = main_bits + exception_bits;
81
82        if total < best_total {
83            best_total = total;
84            best_bits = b;
85            best_exceptions = exceptions;
86        }
87    }
88
89    (best_bits, best_exceptions, best_total)
90}
91
92/// Pack values into a bitpacked array with the given bit width (NewPFD/OptPFD style)
93///
94/// Following the paper "Decoding billions of integers per second through vectorization":
95/// - Store the first b bits (low bits) of ALL values in the main array
96/// - For exceptions (values >= 2^b), store only the HIGH (32-b) bits separately with positions
97///
98/// Returns the packed bytes and a list of exceptions (position, high_bits)
99fn pack_with_exceptions(values: &[u32], bit_width: u8) -> (Vec<u8>, Vec<(u8, u32)>) {
100    if bit_width == 0 {
101        // All values must be 0, exceptions store full value
102        let exceptions: Vec<(u8, u32)> = values
103            .iter()
104            .enumerate()
105            .filter(|&(_, &v)| v != 0)
106            .map(|(i, &v)| (i as u8, v)) // For b=0, high bits = full value
107            .collect();
108        return (Vec::new(), exceptions);
109    }
110
111    if bit_width >= 32 {
112        // No exceptions possible, just pack all 32 bits
113        let bytes_needed = values.len() * 4;
114        let mut packed = vec![0u8; bytes_needed];
115        for (i, &value) in values.iter().enumerate() {
116            let bytes = value.to_le_bytes();
117            packed[i * 4..i * 4 + 4].copy_from_slice(&bytes);
118        }
119        return (packed, Vec::new());
120    }
121
122    let mask = (1u64 << bit_width) - 1;
123    let bytes_needed = (values.len() * bit_width as usize).div_ceil(8);
124    let mut packed = vec![0u8; bytes_needed];
125    let mut exceptions = Vec::new();
126
127    let mut bit_pos = 0usize;
128    for (i, &value) in values.iter().enumerate() {
129        // Store lower b bits in main array (for ALL values, including exceptions)
130        let low_bits = (value as u64) & mask;
131
132        // Write low bits to packed array
133        let byte_idx = bit_pos / 8;
134        let bit_offset = bit_pos % 8;
135
136        let mut remaining_bits = bit_width as usize;
137        let mut val = low_bits;
138        let mut current_byte_idx = byte_idx;
139        let mut current_bit_offset = bit_offset;
140
141        while remaining_bits > 0 {
142            let bits_in_byte = (8 - current_bit_offset).min(remaining_bits);
143            let byte_mask = ((1u64 << bits_in_byte) - 1) as u8;
144            packed[current_byte_idx] |= ((val as u8) & byte_mask) << current_bit_offset;
145            val >>= bits_in_byte;
146            remaining_bits -= bits_in_byte;
147            current_byte_idx += 1;
148            current_bit_offset = 0;
149        }
150
151        bit_pos += bit_width as usize;
152
153        // Record exception: store only the HIGH (32-b) bits
154        let fits = value <= mask as u32;
155        if !fits {
156            let high_bits = value >> bit_width;
157            exceptions.push((i as u8, high_bits));
158        }
159    }
160
161    (packed, exceptions)
162}
163
164/// Unpack values from a bitpacked array and apply exceptions (NewPFD/OptPFD style)
165///
166/// Following the paper "Decoding billions of integers per second through vectorization":
167/// - Low b bits are stored in the main array for ALL values
168/// - Exceptions store only the HIGH (32-b) bits
169/// - Reconstruct: value = (high_bits << b) | low_bits
170///
171/// Uses SIMD acceleration for common bit widths (8, 16, 32)
172fn unpack_with_exceptions(
173    packed: &[u8],
174    bit_width: u8,
175    exceptions: &[(u8, u32)],
176    count: usize,
177    output: &mut [u32],
178) {
179    if bit_width == 0 {
180        output[..count].fill(0);
181    } else if bit_width == 8 {
182        // SIMD-accelerated 8-bit unpacking
183        simd::unpack_8bit(packed, output, count);
184    } else if bit_width == 16 {
185        // SIMD-accelerated 16-bit unpacking
186        simd::unpack_16bit(packed, output, count);
187    } else if bit_width >= 32 {
188        // SIMD-accelerated 32-bit unpacking
189        simd::unpack_32bit(packed, output, count);
190        return; // No exceptions for 32-bit
191    } else {
192        // Generic bit unpacking for other bit widths
193        let mask = (1u64 << bit_width) - 1;
194        let mut bit_pos = 0usize;
195        let input_ptr = packed.as_ptr();
196
197        for out in output[..count].iter_mut() {
198            let byte_idx = bit_pos >> 3;
199            let bit_offset = bit_pos & 7;
200
201            // Read 8 bytes at once for efficiency
202            let word = if byte_idx + 8 <= packed.len() {
203                unsafe { (input_ptr.add(byte_idx) as *const u64).read_unaligned() }
204            } else {
205                // Handle edge case near end of buffer
206                let mut word = 0u64;
207                for (i, &b) in packed[byte_idx..].iter().enumerate() {
208                    word |= (b as u64) << (i * 8);
209                }
210                word
211            };
212
213            *out = ((word >> bit_offset) & mask) as u32;
214            bit_pos += bit_width as usize;
215        }
216    }
217
218    // Apply exceptions: combine high bits with low bits already in output
219    // value = (high_bits << bit_width) | low_bits
220    for &(pos, high_bits) in exceptions {
221        if (pos as usize) < count {
222            let low_bits = output[pos as usize];
223            output[pos as usize] = (high_bits << bit_width) | low_bits;
224        }
225    }
226}
227
228/// Fused unpack + exceptions + delta decode for doc_ids
229///
230/// Combines unpacking, exception application, and prefix sum in a single pass.
231/// Avoids intermediate buffer allocation.
232#[inline]
233fn unpack_exceptions_delta_decode(
234    packed: &[u8],
235    bit_width: u8,
236    exceptions: &[(u8, u32)],
237    output: &mut [u32],
238    first_doc_id: u32,
239    count: usize,
240) {
241    if count == 0 {
242        return;
243    }
244
245    output[0] = first_doc_id;
246    if count == 1 {
247        return;
248    }
249
250    // Build exception lookup for O(1) access
251    // Since exceptions are sparse (typically <5%), a simple linear scan is fine
252    // But for very large blocks, we could use a small hashmap
253
254    let mask = if bit_width < 32 {
255        (1u64 << bit_width) - 1
256    } else {
257        u64::MAX
258    };
259
260    let mut carry = first_doc_id;
261
262    // Fast path for SIMD-friendly bit widths
263    match bit_width {
264        0 => {
265            // All zeros = consecutive doc IDs (gap of 1)
266            for item in output.iter_mut().take(count).skip(1) {
267                carry = carry.wrapping_add(1);
268                *item = carry;
269            }
270        }
271        8 => {
272            // Unpack 8-bit, apply exceptions, delta decode in one pass
273            for i in 0..count - 1 {
274                let mut delta = packed[i] as u32;
275                // Check for exception at this position
276                for &(pos, high_bits) in exceptions {
277                    if pos as usize == i {
278                        delta |= high_bits << bit_width;
279                        break;
280                    }
281                }
282                carry = carry.wrapping_add(delta).wrapping_add(1);
283                output[i + 1] = carry;
284            }
285        }
286        16 => {
287            // Unpack 16-bit, apply exceptions, delta decode in one pass
288            for i in 0..count - 1 {
289                let idx = i * 2;
290                let mut delta = u16::from_le_bytes([packed[idx], packed[idx + 1]]) as u32;
291                for &(pos, high_bits) in exceptions {
292                    if pos as usize == i {
293                        delta |= high_bits << bit_width;
294                        break;
295                    }
296                }
297                carry = carry.wrapping_add(delta).wrapping_add(1);
298                output[i + 1] = carry;
299            }
300        }
301        32 => {
302            // 32-bit has no exceptions
303            for i in 0..count - 1 {
304                let idx = i * 4;
305                let delta = u32::from_le_bytes([
306                    packed[idx],
307                    packed[idx + 1],
308                    packed[idx + 2],
309                    packed[idx + 3],
310                ]);
311                carry = carry.wrapping_add(delta).wrapping_add(1);
312                output[i + 1] = carry;
313            }
314        }
315        _ => {
316            // Generic bit width
317            let input_ptr = packed.as_ptr();
318            let mut bit_pos = 0usize;
319
320            for i in 0..count - 1 {
321                let byte_idx = bit_pos >> 3;
322                let bit_offset = bit_pos & 7;
323
324                let word = unsafe { (input_ptr.add(byte_idx) as *const u64).read_unaligned() };
325                let mut delta = ((word >> bit_offset) & mask) as u32;
326
327                // Check for exception
328                for &(pos, high_bits) in exceptions {
329                    if pos as usize == i {
330                        delta |= high_bits << bit_width;
331                        break;
332                    }
333                }
334
335                carry = carry.wrapping_add(delta).wrapping_add(1);
336                output[i + 1] = carry;
337                bit_pos += bit_width as usize;
338            }
339        }
340    }
341}
342
343/// A single OptP4D block
344#[derive(Debug, Clone)]
345pub struct OptP4DBlock {
346    /// First doc_id in this block (absolute)
347    pub first_doc_id: u32,
348    /// Last doc_id in this block (absolute)
349    pub last_doc_id: u32,
350    /// Number of documents in this block
351    pub num_docs: u16,
352    /// Bit width for delta encoding
353    pub doc_bit_width: u8,
354    /// Bit width for term frequencies
355    pub tf_bit_width: u8,
356    /// Maximum term frequency in this block
357    pub max_tf: u32,
358    /// Maximum block score for WAND/MaxScore
359    pub max_block_score: f32,
360    /// Packed doc deltas
361    pub doc_deltas: Vec<u8>,
362    /// Doc delta exceptions: (position, full_delta)
363    pub doc_exceptions: Vec<(u8, u32)>,
364    /// Packed term frequencies
365    pub term_freqs: Vec<u8>,
366    /// TF exceptions: (position, full_tf)
367    pub tf_exceptions: Vec<(u8, u32)>,
368}
369
370impl OptP4DBlock {
371    /// Serialize the block
372    pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
373        writer.write_u32::<LittleEndian>(self.first_doc_id)?;
374        writer.write_u32::<LittleEndian>(self.last_doc_id)?;
375        writer.write_u16::<LittleEndian>(self.num_docs)?;
376        writer.write_u8(self.doc_bit_width)?;
377        writer.write_u8(self.tf_bit_width)?;
378        writer.write_u32::<LittleEndian>(self.max_tf)?;
379        writer.write_f32::<LittleEndian>(self.max_block_score)?;
380
381        // Write doc deltas
382        writer.write_u16::<LittleEndian>(self.doc_deltas.len() as u16)?;
383        writer.write_all(&self.doc_deltas)?;
384
385        // Write doc exceptions
386        writer.write_u8(self.doc_exceptions.len() as u8)?;
387        for &(pos, val) in &self.doc_exceptions {
388            writer.write_u8(pos)?;
389            writer.write_u32::<LittleEndian>(val)?;
390        }
391
392        // Write term freqs
393        writer.write_u16::<LittleEndian>(self.term_freqs.len() as u16)?;
394        writer.write_all(&self.term_freqs)?;
395
396        // Write tf exceptions
397        writer.write_u8(self.tf_exceptions.len() as u8)?;
398        for &(pos, val) in &self.tf_exceptions {
399            writer.write_u8(pos)?;
400            writer.write_u32::<LittleEndian>(val)?;
401        }
402
403        Ok(())
404    }
405
406    /// Deserialize a block
407    pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
408        let first_doc_id = reader.read_u32::<LittleEndian>()?;
409        let last_doc_id = reader.read_u32::<LittleEndian>()?;
410        let num_docs = reader.read_u16::<LittleEndian>()?;
411        let doc_bit_width = reader.read_u8()?;
412        let tf_bit_width = reader.read_u8()?;
413        let max_tf = reader.read_u32::<LittleEndian>()?;
414        let max_block_score = reader.read_f32::<LittleEndian>()?;
415
416        // Read doc deltas
417        let doc_deltas_len = reader.read_u16::<LittleEndian>()? as usize;
418        let mut doc_deltas = vec![0u8; doc_deltas_len];
419        reader.read_exact(&mut doc_deltas)?;
420
421        // Read doc exceptions
422        let num_doc_exceptions = reader.read_u8()? as usize;
423        let mut doc_exceptions = Vec::with_capacity(num_doc_exceptions);
424        for _ in 0..num_doc_exceptions {
425            let pos = reader.read_u8()?;
426            let val = reader.read_u32::<LittleEndian>()?;
427            doc_exceptions.push((pos, val));
428        }
429
430        // Read term freqs
431        let term_freqs_len = reader.read_u16::<LittleEndian>()? as usize;
432        let mut term_freqs = vec![0u8; term_freqs_len];
433        reader.read_exact(&mut term_freqs)?;
434
435        // Read tf exceptions
436        let num_tf_exceptions = reader.read_u8()? as usize;
437        let mut tf_exceptions = Vec::with_capacity(num_tf_exceptions);
438        for _ in 0..num_tf_exceptions {
439            let pos = reader.read_u8()?;
440            let val = reader.read_u32::<LittleEndian>()?;
441            tf_exceptions.push((pos, val));
442        }
443
444        Ok(Self {
445            first_doc_id,
446            last_doc_id,
447            num_docs,
448            doc_bit_width,
449            tf_bit_width,
450            max_tf,
451            max_block_score,
452            doc_deltas,
453            doc_exceptions,
454            term_freqs,
455            tf_exceptions,
456        })
457    }
458
459    /// Decode doc_ids from this block using SIMD-accelerated delta decoding
460    pub fn decode_doc_ids(&self) -> Vec<u32> {
461        let mut output = vec![0u32; self.num_docs as usize];
462        self.decode_doc_ids_into(&mut output);
463        output
464    }
465
466    /// Decode doc_ids into a pre-allocated buffer (avoids allocation)
467    #[inline]
468    pub fn decode_doc_ids_into(&self, output: &mut [u32]) -> usize {
469        let count = self.num_docs as usize;
470        if count == 0 {
471            return 0;
472        }
473
474        // Fused unpack + exceptions + delta decode - no intermediate buffer
475        unpack_exceptions_delta_decode(
476            &self.doc_deltas,
477            self.doc_bit_width,
478            &self.doc_exceptions,
479            output,
480            self.first_doc_id,
481            count,
482        );
483
484        count
485    }
486
487    /// Decode term frequencies from this block using SIMD acceleration
488    pub fn decode_term_freqs(&self) -> Vec<u32> {
489        let mut output = vec![0u32; self.num_docs as usize];
490        self.decode_term_freqs_into(&mut output);
491        output
492    }
493
494    /// Decode term frequencies into a pre-allocated buffer (avoids allocation)
495    #[inline]
496    pub fn decode_term_freqs_into(&self, output: &mut [u32]) -> usize {
497        let count = self.num_docs as usize;
498        if count == 0 {
499            return 0;
500        }
501
502        // Unpack TFs with exceptions (SIMD-accelerated for 8/16/32-bit)
503        unpack_with_exceptions(
504            &self.term_freqs,
505            self.tf_bit_width,
506            &self.tf_exceptions,
507            count,
508            output,
509        );
510
511        // TF is stored as tf-1, so add 1 back using SIMD
512        simd::add_one(output, count);
513
514        count
515    }
516}
517
518/// OptP4D posting list
519#[derive(Debug, Clone)]
520pub struct OptP4DPostingList {
521    /// Blocks of postings
522    pub blocks: Vec<OptP4DBlock>,
523    /// Total document count
524    pub doc_count: u32,
525    /// Maximum score across all blocks
526    pub max_score: f32,
527}
528
529impl OptP4DPostingList {
530    /// BM25F parameters for block-max score calculation
531    const K1: f32 = 1.2;
532    const B: f32 = 0.75;
533
534    /// Compute BM25F upper bound score for a given max_tf and IDF
535    #[inline]
536    fn compute_bm25f_upper_bound(max_tf: u32, idf: f32) -> f32 {
537        let tf = max_tf as f32;
538        let min_length_norm = 1.0 - Self::B;
539        let tf_norm = (tf * (Self::K1 + 1.0)) / (tf + Self::K1 * min_length_norm);
540        idf * tf_norm
541    }
542
543    /// Create from raw doc_ids and term frequencies
544    pub fn from_postings(doc_ids: &[u32], term_freqs: &[u32], idf: f32) -> Self {
545        assert_eq!(doc_ids.len(), term_freqs.len());
546
547        if doc_ids.is_empty() {
548            return Self {
549                blocks: Vec::new(),
550                doc_count: 0,
551                max_score: 0.0,
552            };
553        }
554
555        let mut blocks = Vec::new();
556        let mut max_score = 0.0f32;
557        let mut i = 0;
558
559        while i < doc_ids.len() {
560            let block_end = (i + OPT_P4D_BLOCK_SIZE).min(doc_ids.len());
561            let block_docs = &doc_ids[i..block_end];
562            let block_tfs = &term_freqs[i..block_end];
563
564            let block = Self::create_block(block_docs, block_tfs, idf);
565            max_score = max_score.max(block.max_block_score);
566            blocks.push(block);
567
568            i = block_end;
569        }
570
571        Self {
572            blocks,
573            doc_count: doc_ids.len() as u32,
574            max_score,
575        }
576    }
577
578    fn create_block(doc_ids: &[u32], term_freqs: &[u32], idf: f32) -> OptP4DBlock {
579        let num_docs = doc_ids.len();
580        let first_doc_id = doc_ids[0];
581        let last_doc_id = *doc_ids.last().unwrap();
582
583        // Compute deltas using stack array (delta - 1 to save one bit)
584        let mut deltas = [0u32; OPT_P4D_BLOCK_SIZE];
585        for j in 1..num_docs {
586            deltas[j - 1] = doc_ids[j] - doc_ids[j - 1] - 1;
587        }
588
589        // Find optimal bit width for deltas
590        let (doc_bit_width, _, _) = find_optimal_bit_width(&deltas[..num_docs.saturating_sub(1)]);
591        let (doc_deltas, doc_exceptions) =
592            pack_with_exceptions(&deltas[..num_docs.saturating_sub(1)], doc_bit_width);
593
594        // Compute max TF and prepare TF array using stack array (store tf-1)
595        let mut tfs = [0u32; OPT_P4D_BLOCK_SIZE];
596        let mut max_tf = 0u32;
597
598        for (j, &tf) in term_freqs.iter().enumerate() {
599            tfs[j] = tf - 1; // Store tf-1
600            max_tf = max_tf.max(tf);
601        }
602
603        // Find optimal bit width for TFs
604        let (tf_bit_width, _, _) = find_optimal_bit_width(&tfs[..num_docs]);
605        let (term_freqs_packed, tf_exceptions) =
606            pack_with_exceptions(&tfs[..num_docs], tf_bit_width);
607
608        // BM25F upper bound score
609        let max_block_score = Self::compute_bm25f_upper_bound(max_tf, idf);
610
611        OptP4DBlock {
612            first_doc_id,
613            last_doc_id,
614            num_docs: num_docs as u16,
615            doc_bit_width,
616            tf_bit_width,
617            max_tf,
618            max_block_score,
619            doc_deltas,
620            doc_exceptions,
621            term_freqs: term_freqs_packed,
622            tf_exceptions,
623        }
624    }
625
626    /// Serialize the posting list
627    pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
628        writer.write_u32::<LittleEndian>(self.doc_count)?;
629        writer.write_f32::<LittleEndian>(self.max_score)?;
630        writer.write_u32::<LittleEndian>(self.blocks.len() as u32)?;
631
632        for block in &self.blocks {
633            block.serialize(writer)?;
634        }
635
636        Ok(())
637    }
638
639    /// Deserialize a posting list
640    pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
641        let doc_count = reader.read_u32::<LittleEndian>()?;
642        let max_score = reader.read_f32::<LittleEndian>()?;
643        let num_blocks = reader.read_u32::<LittleEndian>()? as usize;
644
645        let mut blocks = Vec::with_capacity(num_blocks);
646        for _ in 0..num_blocks {
647            blocks.push(OptP4DBlock::deserialize(reader)?);
648        }
649
650        Ok(Self {
651            blocks,
652            doc_count,
653            max_score,
654        })
655    }
656
657    /// Get document count
658    pub fn len(&self) -> u32 {
659        self.doc_count
660    }
661
662    /// Check if empty
663    pub fn is_empty(&self) -> bool {
664        self.doc_count == 0
665    }
666
667    /// Create an iterator
668    pub fn iterator(&self) -> OptP4DIterator<'_> {
669        OptP4DIterator::new(self)
670    }
671}
672
673/// Iterator over OptP4D posting list
674pub struct OptP4DIterator<'a> {
675    posting_list: &'a OptP4DPostingList,
676    current_block: usize,
677    /// Number of valid elements in current block
678    current_block_len: usize,
679    /// Pre-allocated buffer for decoded doc_ids (avoids allocation per block)
680    block_doc_ids: Vec<u32>,
681    /// Pre-allocated buffer for decoded term freqs
682    block_term_freqs: Vec<u32>,
683    pos_in_block: usize,
684    exhausted: bool,
685}
686
687impl<'a> OptP4DIterator<'a> {
688    pub fn new(posting_list: &'a OptP4DPostingList) -> Self {
689        // Pre-allocate buffers to block size to avoid allocations during iteration
690        let mut iter = Self {
691            posting_list,
692            current_block: 0,
693            current_block_len: 0,
694            block_doc_ids: vec![0u32; OPT_P4D_BLOCK_SIZE],
695            block_term_freqs: vec![0u32; OPT_P4D_BLOCK_SIZE],
696            pos_in_block: 0,
697            exhausted: posting_list.blocks.is_empty(),
698        };
699
700        if !iter.exhausted {
701            iter.decode_current_block();
702        }
703
704        iter
705    }
706
707    #[inline]
708    fn decode_current_block(&mut self) {
709        let block = &self.posting_list.blocks[self.current_block];
710        // Decode into pre-allocated buffers (no allocation!)
711        self.current_block_len = block.decode_doc_ids_into(&mut self.block_doc_ids);
712        block.decode_term_freqs_into(&mut self.block_term_freqs);
713        self.pos_in_block = 0;
714    }
715
716    /// Current document ID
717    #[inline]
718    pub fn doc(&self) -> u32 {
719        if self.exhausted {
720            u32::MAX
721        } else {
722            self.block_doc_ids[self.pos_in_block]
723        }
724    }
725
726    /// Current term frequency
727    #[inline]
728    pub fn term_freq(&self) -> u32 {
729        if self.exhausted {
730            0
731        } else {
732            self.block_term_freqs[self.pos_in_block]
733        }
734    }
735
736    /// Advance to next document
737    #[inline]
738    pub fn advance(&mut self) -> u32 {
739        if self.exhausted {
740            return u32::MAX;
741        }
742
743        self.pos_in_block += 1;
744
745        if self.pos_in_block >= self.current_block_len {
746            self.current_block += 1;
747            if self.current_block >= self.posting_list.blocks.len() {
748                self.exhausted = true;
749                return u32::MAX;
750            }
751            self.decode_current_block();
752        }
753
754        self.doc()
755    }
756
757    /// Seek to first doc >= target
758    pub fn seek(&mut self, target: u32) -> u32 {
759        if self.exhausted {
760            return u32::MAX;
761        }
762
763        // Skip blocks where last_doc_id < target
764        while self.current_block < self.posting_list.blocks.len() {
765            let block = &self.posting_list.blocks[self.current_block];
766            if block.last_doc_id >= target {
767                break;
768            }
769            self.current_block += 1;
770        }
771
772        if self.current_block >= self.posting_list.blocks.len() {
773            self.exhausted = true;
774            return u32::MAX;
775        }
776
777        // Decode block if needed
778        if self.current_block_len == 0 || self.current_block != self.posting_list.blocks.len() - 1 {
779            self.decode_current_block();
780        }
781
782        // Binary search within block
783        match self.block_doc_ids[self.pos_in_block..self.current_block_len].binary_search(&target) {
784            Ok(idx) => {
785                self.pos_in_block += idx;
786            }
787            Err(idx) => {
788                self.pos_in_block += idx;
789                if self.pos_in_block >= self.current_block_len {
790                    // Move to next block
791                    self.current_block += 1;
792                    if self.current_block >= self.posting_list.blocks.len() {
793                        self.exhausted = true;
794                        return u32::MAX;
795                    }
796                    self.decode_current_block();
797                }
798            }
799        }
800
801        self.doc()
802    }
803}
804
805#[cfg(test)]
806mod tests {
807    use super::*;
808
809    #[test]
810    fn test_bits_needed() {
811        assert_eq!(simd::bits_needed(0), 0);
812        assert_eq!(simd::bits_needed(1), 1);
813        assert_eq!(simd::bits_needed(2), 2);
814        assert_eq!(simd::bits_needed(3), 2);
815        assert_eq!(simd::bits_needed(4), 3);
816        assert_eq!(simd::bits_needed(255), 8);
817        assert_eq!(simd::bits_needed(256), 9);
818        assert_eq!(simd::bits_needed(u32::MAX), 32);
819    }
820
821    #[test]
822    fn test_find_optimal_bit_width() {
823        // All zeros
824        let values = vec![0u32; 100];
825        let (bits, exceptions, _) = find_optimal_bit_width(&values);
826        assert_eq!(bits, 0);
827        assert_eq!(exceptions, 0);
828
829        // All small values
830        let values: Vec<u32> = (0..100).map(|i| i % 16).collect();
831        let (bits, _, _) = find_optimal_bit_width(&values);
832        assert!(bits <= 4);
833
834        // Mix with outliers
835        let mut values: Vec<u32> = (0..100).map(|i| i % 16).collect();
836        values[50] = 1_000_000; // outlier
837        let (bits, exceptions, _) = find_optimal_bit_width(&values);
838        assert!(bits < 20); // Should use small bit width with exception
839        assert!(exceptions >= 1);
840    }
841
842    #[test]
843    fn test_pack_unpack_with_exceptions() {
844        let values = vec![1, 2, 3, 255, 4, 5, 1000, 6, 7, 8];
845        let (packed, exceptions) = pack_with_exceptions(&values, 4);
846
847        let mut output = vec![0u32; values.len()];
848        unpack_with_exceptions(&packed, 4, &exceptions, values.len(), &mut output);
849
850        assert_eq!(output, values);
851    }
852
853    #[test]
854    fn test_opt_p4d_posting_list_small() {
855        let doc_ids: Vec<u32> = (0..100).map(|i| i * 2).collect();
856        let term_freqs: Vec<u32> = vec![1; 100];
857
858        let list = OptP4DPostingList::from_postings(&doc_ids, &term_freqs, 1.0);
859
860        assert_eq!(list.len(), 100);
861        assert_eq!(list.blocks.len(), 1);
862
863        // Verify iteration
864        let mut iter = list.iterator();
865        for (i, &expected) in doc_ids.iter().enumerate() {
866            assert_eq!(iter.doc(), expected, "Mismatch at {}", i);
867            assert_eq!(iter.term_freq(), 1);
868            iter.advance();
869        }
870        assert_eq!(iter.doc(), u32::MAX);
871    }
872
873    #[test]
874    fn test_opt_p4d_posting_list_large() {
875        let doc_ids: Vec<u32> = (0..500).map(|i| i * 3).collect();
876        let term_freqs: Vec<u32> = (0..500).map(|i| (i % 10) + 1).collect();
877
878        let list = OptP4DPostingList::from_postings(&doc_ids, &term_freqs, 1.0);
879
880        assert_eq!(list.len(), 500);
881        assert_eq!(list.blocks.len(), 4); // 500 / 128 = 3.9 -> 4 blocks
882
883        // Verify iteration
884        let mut iter = list.iterator();
885        for (i, &expected) in doc_ids.iter().enumerate() {
886            assert_eq!(iter.doc(), expected, "Mismatch at {}", i);
887            assert_eq!(iter.term_freq(), term_freqs[i]);
888            iter.advance();
889        }
890    }
891
892    #[test]
893    fn test_opt_p4d_seek() {
894        let doc_ids: Vec<u32> = vec![10, 20, 30, 100, 200, 300, 1000, 2000];
895        let term_freqs: Vec<u32> = vec![1; 8];
896
897        let list = OptP4DPostingList::from_postings(&doc_ids, &term_freqs, 1.0);
898        let mut iter = list.iterator();
899
900        assert_eq!(iter.seek(25), 30);
901        assert_eq!(iter.seek(100), 100);
902        assert_eq!(iter.seek(500), 1000);
903        assert_eq!(iter.seek(3000), u32::MAX);
904    }
905
906    #[test]
907    fn test_opt_p4d_serialization() {
908        let doc_ids: Vec<u32> = (0..200).map(|i| i * 5).collect();
909        let term_freqs: Vec<u32> = (0..200).map(|i| (i % 5) + 1).collect();
910
911        let list = OptP4DPostingList::from_postings(&doc_ids, &term_freqs, 1.0);
912
913        let mut buffer = Vec::new();
914        list.serialize(&mut buffer).unwrap();
915
916        let restored = OptP4DPostingList::deserialize(&mut &buffer[..]).unwrap();
917
918        assert_eq!(restored.len(), list.len());
919        assert_eq!(restored.blocks.len(), list.blocks.len());
920
921        // Verify iteration matches
922        let mut iter1 = list.iterator();
923        let mut iter2 = restored.iterator();
924
925        while iter1.doc() != u32::MAX {
926            assert_eq!(iter1.doc(), iter2.doc());
927            assert_eq!(iter1.term_freq(), iter2.term_freq());
928            iter1.advance();
929            iter2.advance();
930        }
931    }
932
933    #[test]
934    fn test_opt_p4d_with_outliers() {
935        // Create data with some outliers to test exception handling
936        let mut doc_ids: Vec<u32> = (0..128).map(|i| i * 2).collect();
937        doc_ids[64] = 1_000_000; // Large outlier
938
939        // Fix: ensure doc_ids are sorted
940        doc_ids.sort();
941
942        let term_freqs: Vec<u32> = vec![1; 128];
943
944        let list = OptP4DPostingList::from_postings(&doc_ids, &term_freqs, 1.0);
945
946        // Verify the outlier is handled correctly
947        let mut iter = list.iterator();
948        let mut found_outlier = false;
949        while iter.doc() != u32::MAX {
950            if iter.doc() == 1_000_000 {
951                found_outlier = true;
952            }
953            iter.advance();
954        }
955        assert!(found_outlier, "Outlier value should be preserved");
956    }
957
958    #[test]
959    fn test_opt_p4d_simd_full_blocks() {
960        // Test with multiple full 128-integer blocks to exercise SIMD paths
961        let doc_ids: Vec<u32> = (0..1024).map(|i| i * 2).collect();
962        let term_freqs: Vec<u32> = (0..1024).map(|i| (i % 20) + 1).collect();
963
964        let list = OptP4DPostingList::from_postings(&doc_ids, &term_freqs, 1.0);
965
966        assert_eq!(list.len(), 1024);
967        assert_eq!(list.blocks.len(), 8); // 1024 / 128 = 8 full blocks
968
969        // Verify all values are decoded correctly
970        let mut iter = list.iterator();
971        for (i, &expected_doc) in doc_ids.iter().enumerate() {
972            assert_eq!(iter.doc(), expected_doc, "Doc mismatch at {}", i);
973            assert_eq!(iter.term_freq(), term_freqs[i], "TF mismatch at {}", i);
974            iter.advance();
975        }
976        assert_eq!(iter.doc(), u32::MAX);
977    }
978
979    #[test]
980    fn test_opt_p4d_simd_8bit_values() {
981        // Test with values that fit in 8 bits to exercise SIMD 8-bit unpack
982        let doc_ids: Vec<u32> = (0..256).collect();
983        let term_freqs: Vec<u32> = (0..256).map(|i| (i % 100) + 1).collect();
984
985        let list = OptP4DPostingList::from_postings(&doc_ids, &term_freqs, 1.0);
986
987        // Verify all values
988        let mut iter = list.iterator();
989        for (i, &expected_doc) in doc_ids.iter().enumerate() {
990            assert_eq!(iter.doc(), expected_doc, "Doc mismatch at {}", i);
991            assert_eq!(iter.term_freq(), term_freqs[i], "TF mismatch at {}", i);
992            iter.advance();
993        }
994    }
995
996    #[test]
997    fn test_opt_p4d_simd_delta_decode() {
998        // Test SIMD delta decoding with various gap sizes
999        let mut doc_ids = Vec::with_capacity(512);
1000        let mut current = 0u32;
1001        for i in 0..512 {
1002            current += (i % 10) + 1; // Variable gaps
1003            doc_ids.push(current);
1004        }
1005        let term_freqs: Vec<u32> = vec![1; 512];
1006
1007        let list = OptP4DPostingList::from_postings(&doc_ids, &term_freqs, 1.0);
1008
1009        // Verify delta decoding is correct
1010        let mut iter = list.iterator();
1011        for (i, &expected_doc) in doc_ids.iter().enumerate() {
1012            assert_eq!(
1013                iter.doc(),
1014                expected_doc,
1015                "Doc mismatch at {} (expected {}, got {})",
1016                i,
1017                expected_doc,
1018                iter.doc()
1019            );
1020            iter.advance();
1021        }
1022    }
1023}