Skip to main content

hermes_core/structures/postings/
opt_p4d.rs

1//! OptP4D (Optimized Patched Frame-of-Reference Delta) posting list compression
2//!
3//! OptP4D is an improvement over PForDelta that finds the optimal bit width for each block
4//! by trying all possible bit widths and selecting the one that minimizes total storage.
5//!
6//! Key features:
7//! - Block-based compression (128 integers per block for SIMD alignment)
8//! - Delta encoding for doc IDs
9//! - Optimal bit-width selection per block
10//! - Patched coding: exceptions (values that don't fit) stored separately
11//! - Fast SIMD-friendly decoding with NEON (ARM) and SSE (x86) support
12//!
13//! Format per block:
14//! - Header: bit_width (5 bits) + num_exceptions (7 bits) + first_doc_id (32 bits)
15//! - Main array: 128 values packed at `bit_width` bits each
16//! - Exceptions: [position (7 bits), high_bits (32 - bit_width bits)] for each exception
17
18use crate::structures::simd;
19use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
20use std::io::{self, Read, Write};
21
22/// Block size for OptP4D (128 integers for SIMD alignment)
23pub const OPT_P4D_BLOCK_SIZE: usize = 128;
24
25/// Maximum number of exceptions before we increase bit width
26/// (keeping exceptions under ~10% of block for good compression)
27const MAX_EXCEPTIONS_RATIO: f32 = 0.10;
28
29/// Find the optimal bit width for a block of values
30/// Returns (bit_width, exception_count, total_bits)
31fn find_optimal_bit_width(values: &[u32]) -> (u8, usize, usize) {
32    if values.is_empty() {
33        return (0, 0, 0);
34    }
35
36    let n = values.len();
37    let max_exceptions = ((n as f32) * MAX_EXCEPTIONS_RATIO).ceil() as usize;
38
39    // Count how many values need each bit width
40    let mut bit_counts = [0usize; 33]; // bit_counts[b] = count of values needing exactly b bits
41    for &v in values {
42        let bits = simd::bits_needed(v) as usize;
43        bit_counts[bits] += 1;
44    }
45
46    // Compute cumulative counts: values that fit in b bits or less
47    let mut cumulative = [0usize; 33];
48    cumulative[0] = bit_counts[0];
49    for b in 1..=32 {
50        cumulative[b] = cumulative[b - 1] + bit_counts[b];
51    }
52
53    let mut best_bits = 32u8;
54    let mut best_total = usize::MAX;
55    let mut best_exceptions = 0usize;
56
57    // Try each bit width and compute total storage
58    for b in 0..=32u8 {
59        let fitting = if b == 0 {
60            bit_counts[0]
61        } else {
62            cumulative[b as usize]
63        };
64        let exceptions = n - fitting;
65
66        // Skip if too many exceptions
67        if exceptions > max_exceptions && b < 32 {
68            continue;
69        }
70
71        // Calculate total bits:
72        // - Main array: n * b bits
73        // - Exceptions: exceptions * (7 bits position + (32 - b) bits high value)
74        let main_bits = n * (b as usize);
75        let exception_bits = if b < 32 {
76            exceptions * (7 + (32 - b as usize))
77        } else {
78            0
79        };
80        let total = main_bits + exception_bits;
81
82        if total < best_total {
83            best_total = total;
84            best_bits = b;
85            best_exceptions = exceptions;
86        }
87    }
88
89    (best_bits, best_exceptions, best_total)
90}
91
92/// Pack values into a bitpacked array with the given bit width (NewPFD/OptPFD style)
93///
94/// Following the paper "Decoding billions of integers per second through vectorization":
95/// - Store the first b bits (low bits) of ALL values in the main array
96/// - For exceptions (values >= 2^b), store only the HIGH (32-b) bits separately with positions
97///
98/// Returns the packed bytes and a list of exceptions (position, high_bits)
99fn pack_with_exceptions(values: &[u32], bit_width: u8) -> (Vec<u8>, Vec<(u8, u32)>) {
100    if bit_width == 0 {
101        // All values must be 0, exceptions store full value
102        let exceptions: Vec<(u8, u32)> = values
103            .iter()
104            .enumerate()
105            .filter(|&(_, &v)| v != 0)
106            .map(|(i, &v)| (i as u8, v)) // For b=0, high bits = full value
107            .collect();
108        return (Vec::new(), exceptions);
109    }
110
111    if bit_width >= 32 {
112        // No exceptions possible, just pack all 32 bits
113        let bytes_needed = values.len() * 4;
114        let mut packed = vec![0u8; bytes_needed];
115        for (i, &value) in values.iter().enumerate() {
116            let bytes = value.to_le_bytes();
117            packed[i * 4..i * 4 + 4].copy_from_slice(&bytes);
118        }
119        return (packed, Vec::new());
120    }
121
122    let mask = (1u64 << bit_width) - 1;
123    let bytes_needed = (values.len() * bit_width as usize).div_ceil(8);
124    let mut packed = vec![0u8; bytes_needed];
125    let mut exceptions = Vec::new();
126
127    let mut bit_pos = 0usize;
128    for (i, &value) in values.iter().enumerate() {
129        // Store lower b bits in main array (for ALL values, including exceptions)
130        let low_bits = (value as u64) & mask;
131
132        // Write low bits to packed array
133        let byte_idx = bit_pos / 8;
134        let bit_offset = bit_pos % 8;
135
136        let mut remaining_bits = bit_width as usize;
137        let mut val = low_bits;
138        let mut current_byte_idx = byte_idx;
139        let mut current_bit_offset = bit_offset;
140
141        while remaining_bits > 0 {
142            let bits_in_byte = (8 - current_bit_offset).min(remaining_bits);
143            let byte_mask = ((1u64 << bits_in_byte) - 1) as u8;
144            packed[current_byte_idx] |= ((val as u8) & byte_mask) << current_bit_offset;
145            val >>= bits_in_byte;
146            remaining_bits -= bits_in_byte;
147            current_byte_idx += 1;
148            current_bit_offset = 0;
149        }
150
151        bit_pos += bit_width as usize;
152
153        // Record exception: store only the HIGH (32-b) bits
154        let fits = value <= mask as u32;
155        if !fits {
156            let high_bits = value >> bit_width;
157            exceptions.push((i as u8, high_bits));
158        }
159    }
160
161    (packed, exceptions)
162}
163
164/// Unpack values from a bitpacked array and apply exceptions (NewPFD/OptPFD style)
165///
166/// Following the paper "Decoding billions of integers per second through vectorization":
167/// - Low b bits are stored in the main array for ALL values
168/// - Exceptions store only the HIGH (32-b) bits
169/// - Reconstruct: value = (high_bits << b) | low_bits
170///
171/// Uses SIMD acceleration for common bit widths (8, 16, 32)
172fn unpack_with_exceptions(
173    packed: &[u8],
174    bit_width: u8,
175    exceptions: &[(u8, u32)],
176    count: usize,
177    output: &mut [u32],
178) {
179    if bit_width == 0 {
180        output[..count].fill(0);
181    } else if bit_width == 8 {
182        // SIMD-accelerated 8-bit unpacking
183        simd::unpack_8bit(packed, output, count);
184    } else if bit_width == 16 {
185        // SIMD-accelerated 16-bit unpacking
186        simd::unpack_16bit(packed, output, count);
187    } else if bit_width >= 32 {
188        // SIMD-accelerated 32-bit unpacking
189        simd::unpack_32bit(packed, output, count);
190        return; // No exceptions for 32-bit
191    } else {
192        // Generic bit unpacking for other bit widths
193        let mask = (1u64 << bit_width) - 1;
194        let mut bit_pos = 0usize;
195        let input_ptr = packed.as_ptr();
196
197        for out in output[..count].iter_mut() {
198            let byte_idx = bit_pos >> 3;
199            let bit_offset = bit_pos & 7;
200
201            // Read 8 bytes at once for efficiency
202            let word = if byte_idx + 8 <= packed.len() {
203                unsafe { (input_ptr.add(byte_idx) as *const u64).read_unaligned() }
204            } else {
205                // Handle edge case near end of buffer
206                let mut word = 0u64;
207                for (i, &b) in packed[byte_idx..].iter().enumerate() {
208                    word |= (b as u64) << (i * 8);
209                }
210                word
211            };
212
213            *out = ((word >> bit_offset) & mask) as u32;
214            bit_pos += bit_width as usize;
215        }
216    }
217
218    // Apply exceptions: combine high bits with low bits already in output
219    // value = (high_bits << bit_width) | low_bits
220    for &(pos, high_bits) in exceptions {
221        if (pos as usize) < count {
222            let low_bits = output[pos as usize];
223            output[pos as usize] = (high_bits << bit_width) | low_bits;
224        }
225    }
226}
227
228/// Fused unpack + exceptions + delta decode for doc_ids
229///
230/// Combines unpacking, exception application, and prefix sum in a single pass.
231/// Avoids intermediate buffer allocation.
232#[inline]
233fn unpack_exceptions_delta_decode(
234    packed: &[u8],
235    bit_width: u8,
236    exceptions: &[(u8, u32)],
237    output: &mut [u32],
238    first_doc_id: u32,
239    count: usize,
240) {
241    if count == 0 {
242        return;
243    }
244
245    output[0] = first_doc_id;
246    if count == 1 {
247        return;
248    }
249
250    // Build exception lookup for O(1) access
251    // Since exceptions are sparse (typically <5%), a simple linear scan is fine
252    // But for very large blocks, we could use a small hashmap
253
254    let mask = if bit_width < 32 {
255        (1u64 << bit_width) - 1
256    } else {
257        u64::MAX
258    };
259
260    let mut carry = first_doc_id;
261
262    // Fast path for SIMD-friendly bit widths
263    match bit_width {
264        0 => {
265            // All zeros = consecutive doc IDs (gap of 1)
266            for item in output.iter_mut().take(count).skip(1) {
267                carry = carry.wrapping_add(1);
268                *item = carry;
269            }
270        }
271        8 => {
272            // Unpack 8-bit, apply exceptions, delta decode in one pass
273            for i in 0..count - 1 {
274                let mut delta = packed[i] as u32;
275                // Check for exception at this position
276                for &(pos, high_bits) in exceptions {
277                    if pos as usize == i {
278                        delta |= high_bits << bit_width;
279                        break;
280                    }
281                }
282                carry = carry.wrapping_add(delta).wrapping_add(1);
283                output[i + 1] = carry;
284            }
285        }
286        16 => {
287            // Unpack 16-bit, apply exceptions, delta decode in one pass
288            for i in 0..count - 1 {
289                let idx = i * 2;
290                let mut delta = u16::from_le_bytes([packed[idx], packed[idx + 1]]) as u32;
291                for &(pos, high_bits) in exceptions {
292                    if pos as usize == i {
293                        delta |= high_bits << bit_width;
294                        break;
295                    }
296                }
297                carry = carry.wrapping_add(delta).wrapping_add(1);
298                output[i + 1] = carry;
299            }
300        }
301        32 => {
302            // 32-bit has no exceptions
303            for i in 0..count - 1 {
304                let idx = i * 4;
305                let delta = u32::from_le_bytes([
306                    packed[idx],
307                    packed[idx + 1],
308                    packed[idx + 2],
309                    packed[idx + 3],
310                ]);
311                carry = carry.wrapping_add(delta).wrapping_add(1);
312                output[i + 1] = carry;
313            }
314        }
315        _ => {
316            // Generic bit width
317            let input_ptr = packed.as_ptr();
318            let mut bit_pos = 0usize;
319
320            for i in 0..count - 1 {
321                let byte_idx = bit_pos >> 3;
322                let bit_offset = bit_pos & 7;
323
324                let word = unsafe { (input_ptr.add(byte_idx) as *const u64).read_unaligned() };
325                let mut delta = ((word >> bit_offset) & mask) as u32;
326
327                // Check for exception
328                for &(pos, high_bits) in exceptions {
329                    if pos as usize == i {
330                        delta |= high_bits << bit_width;
331                        break;
332                    }
333                }
334
335                carry = carry.wrapping_add(delta).wrapping_add(1);
336                output[i + 1] = carry;
337                bit_pos += bit_width as usize;
338            }
339        }
340    }
341}
342
343/// A single OptP4D block
344#[derive(Debug, Clone)]
345pub struct OptP4DBlock {
346    /// First doc_id in this block (absolute)
347    pub first_doc_id: u32,
348    /// Last doc_id in this block (absolute)
349    pub last_doc_id: u32,
350    /// Number of documents in this block
351    pub num_docs: u16,
352    /// Bit width for delta encoding
353    pub doc_bit_width: u8,
354    /// Bit width for term frequencies
355    pub tf_bit_width: u8,
356    /// Maximum term frequency in this block
357    pub max_tf: u32,
358    /// Maximum block score for WAND/MaxScore
359    pub max_block_score: f32,
360    /// Packed doc deltas
361    pub doc_deltas: Vec<u8>,
362    /// Doc delta exceptions: (position, full_delta)
363    pub doc_exceptions: Vec<(u8, u32)>,
364    /// Packed term frequencies
365    pub term_freqs: Vec<u8>,
366    /// TF exceptions: (position, full_tf)
367    pub tf_exceptions: Vec<(u8, u32)>,
368}
369
370impl OptP4DBlock {
371    /// Serialize the block
372    pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
373        writer.write_u32::<LittleEndian>(self.first_doc_id)?;
374        writer.write_u32::<LittleEndian>(self.last_doc_id)?;
375        writer.write_u16::<LittleEndian>(self.num_docs)?;
376        writer.write_u8(self.doc_bit_width)?;
377        writer.write_u8(self.tf_bit_width)?;
378        writer.write_u32::<LittleEndian>(self.max_tf)?;
379        writer.write_f32::<LittleEndian>(self.max_block_score)?;
380
381        // Write doc deltas
382        writer.write_u16::<LittleEndian>(self.doc_deltas.len() as u16)?;
383        writer.write_all(&self.doc_deltas)?;
384
385        // Write doc exceptions
386        writer.write_u8(self.doc_exceptions.len() as u8)?;
387        for &(pos, val) in &self.doc_exceptions {
388            writer.write_u8(pos)?;
389            writer.write_u32::<LittleEndian>(val)?;
390        }
391
392        // Write term freqs
393        writer.write_u16::<LittleEndian>(self.term_freqs.len() as u16)?;
394        writer.write_all(&self.term_freqs)?;
395
396        // Write tf exceptions
397        writer.write_u8(self.tf_exceptions.len() as u8)?;
398        for &(pos, val) in &self.tf_exceptions {
399            writer.write_u8(pos)?;
400            writer.write_u32::<LittleEndian>(val)?;
401        }
402
403        Ok(())
404    }
405
406    /// Deserialize a block
407    pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
408        let first_doc_id = reader.read_u32::<LittleEndian>()?;
409        let last_doc_id = reader.read_u32::<LittleEndian>()?;
410        let num_docs = reader.read_u16::<LittleEndian>()?;
411        let doc_bit_width = reader.read_u8()?;
412        let tf_bit_width = reader.read_u8()?;
413        let max_tf = reader.read_u32::<LittleEndian>()?;
414        let max_block_score = reader.read_f32::<LittleEndian>()?;
415
416        // Read doc deltas
417        let doc_deltas_len = reader.read_u16::<LittleEndian>()? as usize;
418        let mut doc_deltas = vec![0u8; doc_deltas_len];
419        reader.read_exact(&mut doc_deltas)?;
420
421        // Read doc exceptions
422        let num_doc_exceptions = reader.read_u8()? as usize;
423        let mut doc_exceptions = Vec::with_capacity(num_doc_exceptions);
424        for _ in 0..num_doc_exceptions {
425            let pos = reader.read_u8()?;
426            let val = reader.read_u32::<LittleEndian>()?;
427            doc_exceptions.push((pos, val));
428        }
429
430        // Read term freqs
431        let term_freqs_len = reader.read_u16::<LittleEndian>()? as usize;
432        let mut term_freqs = vec![0u8; term_freqs_len];
433        reader.read_exact(&mut term_freqs)?;
434
435        // Read tf exceptions
436        let num_tf_exceptions = reader.read_u8()? as usize;
437        let mut tf_exceptions = Vec::with_capacity(num_tf_exceptions);
438        for _ in 0..num_tf_exceptions {
439            let pos = reader.read_u8()?;
440            let val = reader.read_u32::<LittleEndian>()?;
441            tf_exceptions.push((pos, val));
442        }
443
444        Ok(Self {
445            first_doc_id,
446            last_doc_id,
447            num_docs,
448            doc_bit_width,
449            tf_bit_width,
450            max_tf,
451            max_block_score,
452            doc_deltas,
453            doc_exceptions,
454            term_freqs,
455            tf_exceptions,
456        })
457    }
458
459    /// Decode doc_ids from this block using SIMD-accelerated delta decoding
460    pub fn decode_doc_ids(&self) -> Vec<u32> {
461        let mut output = vec![0u32; self.num_docs as usize];
462        self.decode_doc_ids_into(&mut output);
463        output
464    }
465
466    /// Decode doc_ids into a pre-allocated buffer (avoids allocation)
467    #[inline]
468    pub fn decode_doc_ids_into(&self, output: &mut [u32]) -> usize {
469        let count = self.num_docs as usize;
470        if count == 0 {
471            return 0;
472        }
473
474        // Fused unpack + exceptions + delta decode - no intermediate buffer
475        unpack_exceptions_delta_decode(
476            &self.doc_deltas,
477            self.doc_bit_width,
478            &self.doc_exceptions,
479            output,
480            self.first_doc_id,
481            count,
482        );
483
484        count
485    }
486
487    /// Decode term frequencies from this block using SIMD acceleration
488    pub fn decode_term_freqs(&self) -> Vec<u32> {
489        let mut output = vec![0u32; self.num_docs as usize];
490        self.decode_term_freqs_into(&mut output);
491        output
492    }
493
494    /// Decode term frequencies into a pre-allocated buffer (avoids allocation)
495    #[inline]
496    pub fn decode_term_freqs_into(&self, output: &mut [u32]) -> usize {
497        let count = self.num_docs as usize;
498        if count == 0 {
499            return 0;
500        }
501
502        // Unpack TFs with exceptions (SIMD-accelerated for 8/16/32-bit)
503        unpack_with_exceptions(
504            &self.term_freqs,
505            self.tf_bit_width,
506            &self.tf_exceptions,
507            count,
508            output,
509        );
510
511        // TF is stored as tf-1, so add 1 back using SIMD
512        simd::add_one(output, count);
513
514        count
515    }
516}
517
518/// OptP4D posting list
519#[derive(Debug, Clone)]
520pub struct OptP4DPostingList {
521    /// Blocks of postings
522    pub blocks: Vec<OptP4DBlock>,
523    /// Total document count
524    pub doc_count: u32,
525    /// Maximum score across all blocks
526    pub max_score: f32,
527}
528
529impl OptP4DPostingList {
530    /// Create from raw doc_ids and term frequencies
531    pub fn from_postings(doc_ids: &[u32], term_freqs: &[u32], idf: f32) -> Self {
532        assert_eq!(doc_ids.len(), term_freqs.len());
533
534        if doc_ids.is_empty() {
535            return Self {
536                blocks: Vec::new(),
537                doc_count: 0,
538                max_score: 0.0,
539            };
540        }
541
542        let mut blocks = Vec::new();
543        let mut max_score = 0.0f32;
544        let mut i = 0;
545
546        while i < doc_ids.len() {
547            let block_end = (i + OPT_P4D_BLOCK_SIZE).min(doc_ids.len());
548            let block_docs = &doc_ids[i..block_end];
549            let block_tfs = &term_freqs[i..block_end];
550
551            let block = Self::create_block(block_docs, block_tfs, idf);
552            max_score = max_score.max(block.max_block_score);
553            blocks.push(block);
554
555            i = block_end;
556        }
557
558        Self {
559            blocks,
560            doc_count: doc_ids.len() as u32,
561            max_score,
562        }
563    }
564
565    fn create_block(doc_ids: &[u32], term_freqs: &[u32], idf: f32) -> OptP4DBlock {
566        let num_docs = doc_ids.len();
567        let first_doc_id = doc_ids[0];
568        let last_doc_id = *doc_ids.last().unwrap();
569
570        // Compute deltas using stack array (delta - 1 to save one bit)
571        let mut deltas = [0u32; OPT_P4D_BLOCK_SIZE];
572        for j in 1..num_docs {
573            deltas[j - 1] = doc_ids[j] - doc_ids[j - 1] - 1;
574        }
575
576        // Find optimal bit width for deltas
577        let (doc_bit_width, _, _) = find_optimal_bit_width(&deltas[..num_docs.saturating_sub(1)]);
578        let (doc_deltas, doc_exceptions) =
579            pack_with_exceptions(&deltas[..num_docs.saturating_sub(1)], doc_bit_width);
580
581        // Compute max TF and prepare TF array using stack array (store tf-1)
582        let mut tfs = [0u32; OPT_P4D_BLOCK_SIZE];
583        let mut max_tf = 0u32;
584
585        for (j, &tf) in term_freqs.iter().enumerate() {
586            tfs[j] = tf - 1; // Store tf-1
587            max_tf = max_tf.max(tf);
588        }
589
590        // Find optimal bit width for TFs
591        let (tf_bit_width, _, _) = find_optimal_bit_width(&tfs[..num_docs]);
592        let (term_freqs_packed, tf_exceptions) =
593            pack_with_exceptions(&tfs[..num_docs], tf_bit_width);
594
595        // BM25F upper bound score
596        let max_block_score = crate::query::bm25_upper_bound(max_tf as f32, idf);
597
598        OptP4DBlock {
599            first_doc_id,
600            last_doc_id,
601            num_docs: num_docs as u16,
602            doc_bit_width,
603            tf_bit_width,
604            max_tf,
605            max_block_score,
606            doc_deltas,
607            doc_exceptions,
608            term_freqs: term_freqs_packed,
609            tf_exceptions,
610        }
611    }
612
613    /// Serialize the posting list
614    pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
615        writer.write_u32::<LittleEndian>(self.doc_count)?;
616        writer.write_f32::<LittleEndian>(self.max_score)?;
617        writer.write_u32::<LittleEndian>(self.blocks.len() as u32)?;
618
619        for block in &self.blocks {
620            block.serialize(writer)?;
621        }
622
623        Ok(())
624    }
625
626    /// Deserialize a posting list
627    pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
628        let doc_count = reader.read_u32::<LittleEndian>()?;
629        let max_score = reader.read_f32::<LittleEndian>()?;
630        let num_blocks = reader.read_u32::<LittleEndian>()? as usize;
631
632        let mut blocks = Vec::with_capacity(num_blocks);
633        for _ in 0..num_blocks {
634            blocks.push(OptP4DBlock::deserialize(reader)?);
635        }
636
637        Ok(Self {
638            blocks,
639            doc_count,
640            max_score,
641        })
642    }
643
644    /// Get document count
645    pub fn len(&self) -> u32 {
646        self.doc_count
647    }
648
649    /// Check if empty
650    pub fn is_empty(&self) -> bool {
651        self.doc_count == 0
652    }
653
654    /// Create an iterator
655    pub fn iterator(&self) -> OptP4DIterator<'_> {
656        OptP4DIterator::new(self)
657    }
658}
659
660/// Iterator over OptP4D posting list
661pub struct OptP4DIterator<'a> {
662    posting_list: &'a OptP4DPostingList,
663    current_block: usize,
664    /// Number of valid elements in current block
665    current_block_len: usize,
666    /// Pre-allocated buffer for decoded doc_ids (avoids allocation per block)
667    block_doc_ids: Vec<u32>,
668    /// Pre-allocated buffer for decoded term freqs
669    block_term_freqs: Vec<u32>,
670    pos_in_block: usize,
671    exhausted: bool,
672}
673
674impl<'a> OptP4DIterator<'a> {
675    pub fn new(posting_list: &'a OptP4DPostingList) -> Self {
676        // Pre-allocate buffers to block size to avoid allocations during iteration
677        let mut iter = Self {
678            posting_list,
679            current_block: 0,
680            current_block_len: 0,
681            block_doc_ids: vec![0u32; OPT_P4D_BLOCK_SIZE],
682            block_term_freqs: vec![0u32; OPT_P4D_BLOCK_SIZE],
683            pos_in_block: 0,
684            exhausted: posting_list.blocks.is_empty(),
685        };
686
687        if !iter.exhausted {
688            iter.decode_current_block();
689        }
690
691        iter
692    }
693
694    #[inline]
695    fn decode_current_block(&mut self) {
696        let block = &self.posting_list.blocks[self.current_block];
697        // Decode into pre-allocated buffers (no allocation!)
698        self.current_block_len = block.decode_doc_ids_into(&mut self.block_doc_ids);
699        block.decode_term_freqs_into(&mut self.block_term_freqs);
700        self.pos_in_block = 0;
701    }
702
703    /// Current document ID
704    #[inline]
705    pub fn doc(&self) -> u32 {
706        if self.exhausted {
707            u32::MAX
708        } else {
709            self.block_doc_ids[self.pos_in_block]
710        }
711    }
712
713    /// Current term frequency
714    #[inline]
715    pub fn term_freq(&self) -> u32 {
716        if self.exhausted {
717            0
718        } else {
719            self.block_term_freqs[self.pos_in_block]
720        }
721    }
722
723    /// Advance to next document
724    #[inline]
725    pub fn advance(&mut self) -> u32 {
726        if self.exhausted {
727            return u32::MAX;
728        }
729
730        self.pos_in_block += 1;
731
732        if self.pos_in_block >= self.current_block_len {
733            self.current_block += 1;
734            if self.current_block >= self.posting_list.blocks.len() {
735                self.exhausted = true;
736                return u32::MAX;
737            }
738            self.decode_current_block();
739        }
740
741        self.doc()
742    }
743
744    /// Seek to first doc >= target
745    pub fn seek(&mut self, target: u32) -> u32 {
746        if self.exhausted {
747            return u32::MAX;
748        }
749
750        // Skip blocks where last_doc_id < target
751        while self.current_block < self.posting_list.blocks.len() {
752            let block = &self.posting_list.blocks[self.current_block];
753            if block.last_doc_id >= target {
754                break;
755            }
756            self.current_block += 1;
757        }
758
759        if self.current_block >= self.posting_list.blocks.len() {
760            self.exhausted = true;
761            return u32::MAX;
762        }
763
764        // Decode block if needed
765        if self.current_block_len == 0 || self.current_block != self.posting_list.blocks.len() - 1 {
766            self.decode_current_block();
767        }
768
769        // Binary search within block
770        match self.block_doc_ids[self.pos_in_block..self.current_block_len].binary_search(&target) {
771            Ok(idx) => {
772                self.pos_in_block += idx;
773            }
774            Err(idx) => {
775                self.pos_in_block += idx;
776                if self.pos_in_block >= self.current_block_len {
777                    // Move to next block
778                    self.current_block += 1;
779                    if self.current_block >= self.posting_list.blocks.len() {
780                        self.exhausted = true;
781                        return u32::MAX;
782                    }
783                    self.decode_current_block();
784                }
785            }
786        }
787
788        self.doc()
789    }
790}
791
792#[cfg(test)]
793mod tests {
794    use super::*;
795
796    #[test]
797    fn test_bits_needed() {
798        assert_eq!(simd::bits_needed(0), 0);
799        assert_eq!(simd::bits_needed(1), 1);
800        assert_eq!(simd::bits_needed(2), 2);
801        assert_eq!(simd::bits_needed(3), 2);
802        assert_eq!(simd::bits_needed(4), 3);
803        assert_eq!(simd::bits_needed(255), 8);
804        assert_eq!(simd::bits_needed(256), 9);
805        assert_eq!(simd::bits_needed(u32::MAX), 32);
806    }
807
808    #[test]
809    fn test_find_optimal_bit_width() {
810        // All zeros
811        let values = vec![0u32; 100];
812        let (bits, exceptions, _) = find_optimal_bit_width(&values);
813        assert_eq!(bits, 0);
814        assert_eq!(exceptions, 0);
815
816        // All small values
817        let values: Vec<u32> = (0..100).map(|i| i % 16).collect();
818        let (bits, _, _) = find_optimal_bit_width(&values);
819        assert!(bits <= 4);
820
821        // Mix with outliers
822        let mut values: Vec<u32> = (0..100).map(|i| i % 16).collect();
823        values[50] = 1_000_000; // outlier
824        let (bits, exceptions, _) = find_optimal_bit_width(&values);
825        assert!(bits < 20); // Should use small bit width with exception
826        assert!(exceptions >= 1);
827    }
828
829    #[test]
830    fn test_pack_unpack_with_exceptions() {
831        let values = vec![1, 2, 3, 255, 4, 5, 1000, 6, 7, 8];
832        let (packed, exceptions) = pack_with_exceptions(&values, 4);
833
834        let mut output = vec![0u32; values.len()];
835        unpack_with_exceptions(&packed, 4, &exceptions, values.len(), &mut output);
836
837        assert_eq!(output, values);
838    }
839
840    #[test]
841    fn test_opt_p4d_posting_list_small() {
842        let doc_ids: Vec<u32> = (0..100).map(|i| i * 2).collect();
843        let term_freqs: Vec<u32> = vec![1; 100];
844
845        let list = OptP4DPostingList::from_postings(&doc_ids, &term_freqs, 1.0);
846
847        assert_eq!(list.len(), 100);
848        assert_eq!(list.blocks.len(), 1);
849
850        // Verify iteration
851        let mut iter = list.iterator();
852        for (i, &expected) in doc_ids.iter().enumerate() {
853            assert_eq!(iter.doc(), expected, "Mismatch at {}", i);
854            assert_eq!(iter.term_freq(), 1);
855            iter.advance();
856        }
857        assert_eq!(iter.doc(), u32::MAX);
858    }
859
860    #[test]
861    fn test_opt_p4d_posting_list_large() {
862        let doc_ids: Vec<u32> = (0..500).map(|i| i * 3).collect();
863        let term_freqs: Vec<u32> = (0..500).map(|i| (i % 10) + 1).collect();
864
865        let list = OptP4DPostingList::from_postings(&doc_ids, &term_freqs, 1.0);
866
867        assert_eq!(list.len(), 500);
868        assert_eq!(list.blocks.len(), 4); // 500 / 128 = 3.9 -> 4 blocks
869
870        // Verify iteration
871        let mut iter = list.iterator();
872        for (i, &expected) in doc_ids.iter().enumerate() {
873            assert_eq!(iter.doc(), expected, "Mismatch at {}", i);
874            assert_eq!(iter.term_freq(), term_freqs[i]);
875            iter.advance();
876        }
877    }
878
879    #[test]
880    fn test_opt_p4d_seek() {
881        let doc_ids: Vec<u32> = vec![10, 20, 30, 100, 200, 300, 1000, 2000];
882        let term_freqs: Vec<u32> = vec![1; 8];
883
884        let list = OptP4DPostingList::from_postings(&doc_ids, &term_freqs, 1.0);
885        let mut iter = list.iterator();
886
887        assert_eq!(iter.seek(25), 30);
888        assert_eq!(iter.seek(100), 100);
889        assert_eq!(iter.seek(500), 1000);
890        assert_eq!(iter.seek(3000), u32::MAX);
891    }
892
893    #[test]
894    fn test_opt_p4d_serialization() {
895        let doc_ids: Vec<u32> = (0..200).map(|i| i * 5).collect();
896        let term_freqs: Vec<u32> = (0..200).map(|i| (i % 5) + 1).collect();
897
898        let list = OptP4DPostingList::from_postings(&doc_ids, &term_freqs, 1.0);
899
900        let mut buffer = Vec::new();
901        list.serialize(&mut buffer).unwrap();
902
903        let restored = OptP4DPostingList::deserialize(&mut &buffer[..]).unwrap();
904
905        assert_eq!(restored.len(), list.len());
906        assert_eq!(restored.blocks.len(), list.blocks.len());
907
908        // Verify iteration matches
909        let mut iter1 = list.iterator();
910        let mut iter2 = restored.iterator();
911
912        while iter1.doc() != u32::MAX {
913            assert_eq!(iter1.doc(), iter2.doc());
914            assert_eq!(iter1.term_freq(), iter2.term_freq());
915            iter1.advance();
916            iter2.advance();
917        }
918    }
919
920    #[test]
921    fn test_opt_p4d_with_outliers() {
922        // Create data with some outliers to test exception handling
923        let mut doc_ids: Vec<u32> = (0..128).map(|i| i * 2).collect();
924        doc_ids[64] = 1_000_000; // Large outlier
925
926        // Fix: ensure doc_ids are sorted
927        doc_ids.sort();
928
929        let term_freqs: Vec<u32> = vec![1; 128];
930
931        let list = OptP4DPostingList::from_postings(&doc_ids, &term_freqs, 1.0);
932
933        // Verify the outlier is handled correctly
934        let mut iter = list.iterator();
935        let mut found_outlier = false;
936        while iter.doc() != u32::MAX {
937            if iter.doc() == 1_000_000 {
938                found_outlier = true;
939            }
940            iter.advance();
941        }
942        assert!(found_outlier, "Outlier value should be preserved");
943    }
944
945    #[test]
946    fn test_opt_p4d_simd_full_blocks() {
947        // Test with multiple full 128-integer blocks to exercise SIMD paths
948        let doc_ids: Vec<u32> = (0..1024).map(|i| i * 2).collect();
949        let term_freqs: Vec<u32> = (0..1024).map(|i| (i % 20) + 1).collect();
950
951        let list = OptP4DPostingList::from_postings(&doc_ids, &term_freqs, 1.0);
952
953        assert_eq!(list.len(), 1024);
954        assert_eq!(list.blocks.len(), 8); // 1024 / 128 = 8 full blocks
955
956        // Verify all values are decoded correctly
957        let mut iter = list.iterator();
958        for (i, &expected_doc) in doc_ids.iter().enumerate() {
959            assert_eq!(iter.doc(), expected_doc, "Doc mismatch at {}", i);
960            assert_eq!(iter.term_freq(), term_freqs[i], "TF mismatch at {}", i);
961            iter.advance();
962        }
963        assert_eq!(iter.doc(), u32::MAX);
964    }
965
966    #[test]
967    fn test_opt_p4d_simd_8bit_values() {
968        // Test with values that fit in 8 bits to exercise SIMD 8-bit unpack
969        let doc_ids: Vec<u32> = (0..256).collect();
970        let term_freqs: Vec<u32> = (0..256).map(|i| (i % 100) + 1).collect();
971
972        let list = OptP4DPostingList::from_postings(&doc_ids, &term_freqs, 1.0);
973
974        // Verify all values
975        let mut iter = list.iterator();
976        for (i, &expected_doc) in doc_ids.iter().enumerate() {
977            assert_eq!(iter.doc(), expected_doc, "Doc mismatch at {}", i);
978            assert_eq!(iter.term_freq(), term_freqs[i], "TF mismatch at {}", i);
979            iter.advance();
980        }
981    }
982
983    #[test]
984    fn test_opt_p4d_simd_delta_decode() {
985        // Test SIMD delta decoding with various gap sizes
986        let mut doc_ids = Vec::with_capacity(512);
987        let mut current = 0u32;
988        for i in 0..512 {
989            current += (i % 10) + 1; // Variable gaps
990            doc_ids.push(current);
991        }
992        let term_freqs: Vec<u32> = vec![1; 512];
993
994        let list = OptP4DPostingList::from_postings(&doc_ids, &term_freqs, 1.0);
995
996        // Verify delta decoding is correct
997        let mut iter = list.iterator();
998        for (i, &expected_doc) in doc_ids.iter().enumerate() {
999            assert_eq!(
1000                iter.doc(),
1001                expected_doc,
1002                "Doc mismatch at {} (expected {}, got {})",
1003                i,
1004                expected_doc,
1005                iter.doc()
1006            );
1007            iter.advance();
1008        }
1009    }
1010}