hermes_core/structures/
opt_p4d.rs

1//! OptP4D (Optimized Patched Frame-of-Reference Delta) posting list compression
2//!
3//! OptP4D is an improvement over PForDelta that finds the optimal bit width for each block
4//! by trying all possible bit widths and selecting the one that minimizes total storage.
5//!
6//! Key features:
7//! - Block-based compression (128 integers per block for SIMD alignment)
8//! - Delta encoding for doc IDs
9//! - Optimal bit-width selection per block
10//! - Patched coding: exceptions (values that don't fit) stored separately
11//! - Fast SIMD-friendly decoding with NEON (ARM) and SSE (x86) support
12//!
13//! Format per block:
14//! - Header: bit_width (5 bits) + num_exceptions (7 bits) + first_doc_id (32 bits)
15//! - Main array: 128 values packed at `bit_width` bits each
16//! - Exceptions: [position (7 bits), high_bits (32 - bit_width bits)] for each exception
17
18use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
19use std::io::{self, Read, Write};
20
21// ============================================================================
22// SIMD optimizations for aarch64 (Apple Silicon, ARM servers)
23// ============================================================================
24
25#[cfg(target_arch = "aarch64")]
26#[allow(unsafe_op_in_unsafe_fn)]
27mod neon {
28    #[allow(unused_imports)]
29    use super::OPT_P4D_BLOCK_SIZE;
30    use std::arch::aarch64::*;
31
32    /// SIMD unpack for 8-bit values using NEON
33    #[target_feature(enable = "neon")]
34    pub unsafe fn unpack_8bit_neon(input: &[u8], output: &mut [u32], count: usize) {
35        let chunks = count / 16;
36        let remainder = count % 16;
37
38        for chunk in 0..chunks {
39            let base = chunk * 16;
40            let in_ptr = input.as_ptr().add(base);
41
42            // Load 16 bytes
43            let bytes = vld1q_u8(in_ptr);
44
45            // Widen u8 -> u16 -> u32
46            let low8 = vget_low_u8(bytes);
47            let high8 = vget_high_u8(bytes);
48
49            let low16 = vmovl_u8(low8);
50            let high16 = vmovl_u8(high8);
51
52            let v0 = vmovl_u16(vget_low_u16(low16));
53            let v1 = vmovl_u16(vget_high_u16(low16));
54            let v2 = vmovl_u16(vget_low_u16(high16));
55            let v3 = vmovl_u16(vget_high_u16(high16));
56
57            let out_ptr = output.as_mut_ptr().add(base);
58            vst1q_u32(out_ptr, v0);
59            vst1q_u32(out_ptr.add(4), v1);
60            vst1q_u32(out_ptr.add(8), v2);
61            vst1q_u32(out_ptr.add(12), v3);
62        }
63
64        // Handle remainder
65        let base = chunks * 16;
66        for i in 0..remainder {
67            output[base + i] = input[base + i] as u32;
68        }
69    }
70
71    /// SIMD unpack for 16-bit values using NEON
72    #[target_feature(enable = "neon")]
73    pub unsafe fn unpack_16bit_neon(input: &[u8], output: &mut [u32], count: usize) {
74        let chunks = count / 8;
75        let remainder = count % 8;
76
77        for chunk in 0..chunks {
78            let base = chunk * 8;
79            let in_ptr = input.as_ptr().add(base * 2) as *const u16;
80
81            let vals = vld1q_u16(in_ptr);
82            let low = vmovl_u16(vget_low_u16(vals));
83            let high = vmovl_u16(vget_high_u16(vals));
84
85            let out_ptr = output.as_mut_ptr().add(base);
86            vst1q_u32(out_ptr, low);
87            vst1q_u32(out_ptr.add(4), high);
88        }
89
90        // Handle remainder
91        let base = chunks * 8;
92        for i in 0..remainder {
93            let idx = (base + i) * 2;
94            output[base + i] = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
95        }
96    }
97
98    /// SIMD delta decode: convert deltas to absolute doc IDs
99    /// deltas[i] stores (gap - 1), output[i] = first + sum(gaps[0..i])
100    #[target_feature(enable = "neon")]
101    pub unsafe fn delta_decode_neon(
102        deltas: &[u32],
103        output: &mut [u32],
104        first_doc_id: u32,
105        count: usize,
106    ) {
107        if count == 0 {
108            return;
109        }
110
111        output[0] = first_doc_id;
112        if count == 1 {
113            return;
114        }
115
116        let mut carry = first_doc_id;
117        let ones = vdupq_n_u32(1);
118
119        let full_groups = (count - 1) / 4;
120        let remainder = (count - 1) % 4;
121
122        for group in 0..full_groups {
123            let base = group * 4;
124
125            // Load 4 deltas
126            let d = vld1q_u32(deltas[base..].as_ptr());
127
128            // Add 1 to each delta (since we store gap-1)
129            let gaps = vaddq_u32(d, ones);
130
131            // Extract and compute prefix sum with carry
132            let g0 = vgetq_lane_u32(gaps, 0);
133            let g1 = vgetq_lane_u32(gaps, 1);
134            let g2 = vgetq_lane_u32(gaps, 2);
135            let g3 = vgetq_lane_u32(gaps, 3);
136
137            let v0 = carry.wrapping_add(g0);
138            let v1 = v0.wrapping_add(g1);
139            let v2 = v1.wrapping_add(g2);
140            let v3 = v2.wrapping_add(g3);
141
142            output[base + 1] = v0;
143            output[base + 2] = v1;
144            output[base + 3] = v2;
145            output[base + 4] = v3;
146
147            carry = v3;
148        }
149
150        // Handle remainder
151        let base = full_groups * 4;
152        for j in 0..remainder {
153            carry = carry.wrapping_add(deltas[base + j]).wrapping_add(1);
154            output[base + j + 1] = carry;
155        }
156    }
157
158    /// SIMD add 1 to all values (for TF decoding: stored as tf-1)
159    #[target_feature(enable = "neon")]
160    pub unsafe fn add_one_neon(values: &mut [u32], count: usize) {
161        let ones = vdupq_n_u32(1);
162        let chunks = count / 4;
163        let remainder = count % 4;
164
165        for chunk in 0..chunks {
166            let base = chunk * 4;
167            let ptr = values.as_mut_ptr().add(base);
168            let v = vld1q_u32(ptr);
169            let result = vaddq_u32(v, ones);
170            vst1q_u32(ptr, result);
171        }
172
173        let base = chunks * 4;
174        for i in 0..remainder {
175            values[base + i] += 1;
176        }
177    }
178
179    /// Check if NEON is available (always true on aarch64)
180    #[inline]
181    pub fn is_available() -> bool {
182        true
183    }
184}
185
186// ============================================================================
187// SIMD optimizations for x86_64 (Intel/AMD)
188// ============================================================================
189
190#[cfg(target_arch = "x86_64")]
191#[allow(unsafe_op_in_unsafe_fn)]
192mod sse {
193    use std::arch::x86_64::*;
194
195    /// SIMD unpack for 8-bit values using SSE
196    #[target_feature(enable = "sse2", enable = "sse4.1")]
197    pub unsafe fn unpack_8bit_sse(input: &[u8], output: &mut [u32], count: usize) {
198        let chunks = count / 16;
199        let remainder = count % 16;
200
201        for chunk in 0..chunks {
202            let base = chunk * 16;
203            let in_ptr = input.as_ptr().add(base);
204
205            let bytes = _mm_loadu_si128(in_ptr as *const __m128i);
206
207            // Zero extend u8 -> u32 using SSE4.1 pmovzx
208            let v0 = _mm_cvtepu8_epi32(bytes);
209            let v1 = _mm_cvtepu8_epi32(_mm_srli_si128(bytes, 4));
210            let v2 = _mm_cvtepu8_epi32(_mm_srli_si128(bytes, 8));
211            let v3 = _mm_cvtepu8_epi32(_mm_srli_si128(bytes, 12));
212
213            let out_ptr = output.as_mut_ptr().add(base);
214            _mm_storeu_si128(out_ptr as *mut __m128i, v0);
215            _mm_storeu_si128(out_ptr.add(4) as *mut __m128i, v1);
216            _mm_storeu_si128(out_ptr.add(8) as *mut __m128i, v2);
217            _mm_storeu_si128(out_ptr.add(12) as *mut __m128i, v3);
218        }
219
220        let base = chunks * 16;
221        for i in 0..remainder {
222            output[base + i] = input[base + i] as u32;
223        }
224    }
225
226    /// SIMD unpack for 16-bit values using SSE
227    #[target_feature(enable = "sse2", enable = "sse4.1")]
228    pub unsafe fn unpack_16bit_sse(input: &[u8], output: &mut [u32], count: usize) {
229        let chunks = count / 8;
230        let remainder = count % 8;
231
232        for chunk in 0..chunks {
233            let base = chunk * 8;
234            let in_ptr = input.as_ptr().add(base * 2);
235
236            let vals = _mm_loadu_si128(in_ptr as *const __m128i);
237            let low = _mm_cvtepu16_epi32(vals);
238            let high = _mm_cvtepu16_epi32(_mm_srli_si128(vals, 8));
239
240            let out_ptr = output.as_mut_ptr().add(base);
241            _mm_storeu_si128(out_ptr as *mut __m128i, low);
242            _mm_storeu_si128(out_ptr.add(4) as *mut __m128i, high);
243        }
244
245        let base = chunks * 8;
246        for i in 0..remainder {
247            let idx = (base + i) * 2;
248            output[base + i] = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
249        }
250    }
251
252    /// SIMD delta decode using SSE
253    #[target_feature(enable = "sse2", enable = "sse4.1")]
254    pub unsafe fn delta_decode_sse(
255        deltas: &[u32],
256        output: &mut [u32],
257        first_doc_id: u32,
258        count: usize,
259    ) {
260        if count == 0 {
261            return;
262        }
263
264        output[0] = first_doc_id;
265        if count == 1 {
266            return;
267        }
268
269        let mut carry = first_doc_id;
270        let ones = _mm_set1_epi32(1);
271
272        let full_groups = (count - 1) / 4;
273        let remainder = (count - 1) % 4;
274
275        for group in 0..full_groups {
276            let base = group * 4;
277
278            let d = _mm_loadu_si128(deltas[base..].as_ptr() as *const __m128i);
279            let gaps = _mm_add_epi32(d, ones);
280
281            let g0 = _mm_extract_epi32(gaps, 0) as u32;
282            let g1 = _mm_extract_epi32(gaps, 1) as u32;
283            let g2 = _mm_extract_epi32(gaps, 2) as u32;
284            let g3 = _mm_extract_epi32(gaps, 3) as u32;
285
286            let v0 = carry.wrapping_add(g0);
287            let v1 = v0.wrapping_add(g1);
288            let v2 = v1.wrapping_add(g2);
289            let v3 = v2.wrapping_add(g3);
290
291            output[base + 1] = v0;
292            output[base + 2] = v1;
293            output[base + 3] = v2;
294            output[base + 4] = v3;
295
296            carry = v3;
297        }
298
299        let base = full_groups * 4;
300        for j in 0..remainder {
301            carry = carry.wrapping_add(deltas[base + j]).wrapping_add(1);
302            output[base + j + 1] = carry;
303        }
304    }
305
306    /// SIMD add 1 to all values using SSE
307    #[target_feature(enable = "sse2")]
308    pub unsafe fn add_one_sse(values: &mut [u32], count: usize) {
309        let ones = _mm_set1_epi32(1);
310        let chunks = count / 4;
311        let remainder = count % 4;
312
313        for chunk in 0..chunks {
314            let base = chunk * 4;
315            let ptr = values.as_mut_ptr().add(base) as *mut __m128i;
316            let v = _mm_loadu_si128(ptr);
317            let result = _mm_add_epi32(v, ones);
318            _mm_storeu_si128(ptr, result);
319        }
320
321        let base = chunks * 4;
322        for i in 0..remainder {
323            values[base + i] += 1;
324        }
325    }
326
327    /// Check if SSE4.1 is available at runtime
328    #[inline]
329    pub fn is_available() -> bool {
330        is_x86_feature_detected!("sse4.1")
331    }
332}
333
334// ============================================================================
335// Scalar fallback implementations
336// ============================================================================
337
338mod scalar {
339    /// Scalar unpack for 8-bit values
340    #[inline]
341    pub fn unpack_8bit_scalar(input: &[u8], output: &mut [u32], count: usize) {
342        for i in 0..count {
343            output[i] = input[i] as u32;
344        }
345    }
346
347    /// Scalar unpack for 16-bit values
348    #[inline]
349    pub fn unpack_16bit_scalar(input: &[u8], output: &mut [u32], count: usize) {
350        for (i, out) in output.iter_mut().enumerate().take(count) {
351            let idx = i * 2;
352            *out = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
353        }
354    }
355
356    /// Scalar delta decode
357    #[inline]
358    pub fn delta_decode_scalar(
359        deltas: &[u32],
360        output: &mut [u32],
361        first_doc_id: u32,
362        count: usize,
363    ) {
364        if count == 0 {
365            return;
366        }
367
368        output[0] = first_doc_id;
369        let mut carry = first_doc_id;
370
371        for i in 0..count - 1 {
372            carry = carry.wrapping_add(deltas[i]).wrapping_add(1);
373            output[i + 1] = carry;
374        }
375    }
376
377    /// Scalar add 1 to all values
378    #[inline]
379    pub fn add_one_scalar(values: &mut [u32], count: usize) {
380        for val in values.iter_mut().take(count) {
381            *val += 1;
382        }
383    }
384}
385
386// ============================================================================
387// Dispatch functions that select SIMD or scalar at runtime
388// ============================================================================
389
390/// Unpack 8-bit packed values to u32 with SIMD acceleration
391#[inline]
392fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
393    #[cfg(target_arch = "aarch64")]
394    {
395        if neon::is_available() {
396            unsafe {
397                neon::unpack_8bit_neon(input, output, count);
398            }
399            return;
400        }
401    }
402
403    #[cfg(target_arch = "x86_64")]
404    {
405        if sse::is_available() {
406            unsafe {
407                sse::unpack_8bit_sse(input, output, count);
408            }
409            return;
410        }
411    }
412
413    scalar::unpack_8bit_scalar(input, output, count);
414}
415
416/// Unpack 16-bit packed values to u32 with SIMD acceleration
417#[inline]
418fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
419    #[cfg(target_arch = "aarch64")]
420    {
421        if neon::is_available() {
422            unsafe {
423                neon::unpack_16bit_neon(input, output, count);
424            }
425            return;
426        }
427    }
428
429    #[cfg(target_arch = "x86_64")]
430    {
431        if sse::is_available() {
432            unsafe {
433                sse::unpack_16bit_sse(input, output, count);
434            }
435            return;
436        }
437    }
438
439    scalar::unpack_16bit_scalar(input, output, count);
440}
441
442/// Delta decode with SIMD acceleration
443#[inline]
444fn delta_decode_simd(deltas: &[u32], output: &mut [u32], first_doc_id: u32, count: usize) {
445    #[cfg(target_arch = "aarch64")]
446    {
447        if neon::is_available() {
448            unsafe {
449                neon::delta_decode_neon(deltas, output, first_doc_id, count);
450            }
451            return;
452        }
453    }
454
455    #[cfg(target_arch = "x86_64")]
456    {
457        if sse::is_available() {
458            unsafe {
459                sse::delta_decode_sse(deltas, output, first_doc_id, count);
460            }
461            return;
462        }
463    }
464
465    scalar::delta_decode_scalar(deltas, output, first_doc_id, count);
466}
467
468/// Add 1 to all values with SIMD acceleration
469#[inline]
470fn add_one_simd(values: &mut [u32], count: usize) {
471    #[cfg(target_arch = "aarch64")]
472    {
473        if neon::is_available() {
474            unsafe {
475                neon::add_one_neon(values, count);
476            }
477            return;
478        }
479    }
480
481    #[cfg(target_arch = "x86_64")]
482    {
483        if sse::is_available() {
484            unsafe {
485                sse::add_one_sse(values, count);
486            }
487            return;
488        }
489    }
490
491    scalar::add_one_scalar(values, count);
492}
493
494/// Block size for OptP4D (128 integers for SIMD alignment)
495pub const OPT_P4D_BLOCK_SIZE: usize = 128;
496
497/// Maximum number of exceptions before we increase bit width
498/// (keeping exceptions under ~10% of block for good compression)
499const MAX_EXCEPTIONS_RATIO: f32 = 0.10;
500
501/// Compute the number of bits needed to represent a value
502#[inline]
503fn bits_needed(val: u32) -> u8 {
504    if val == 0 {
505        0
506    } else {
507        32 - val.leading_zeros() as u8
508    }
509}
510
511/// Find the optimal bit width for a block of values
512/// Returns (bit_width, exception_count, total_bits)
513fn find_optimal_bit_width(values: &[u32]) -> (u8, usize, usize) {
514    if values.is_empty() {
515        return (0, 0, 0);
516    }
517
518    let n = values.len();
519    let max_exceptions = ((n as f32) * MAX_EXCEPTIONS_RATIO).ceil() as usize;
520
521    // Count how many values need each bit width
522    let mut bit_counts = [0usize; 33]; // bit_counts[b] = count of values needing exactly b bits
523    for &v in values {
524        let bits = bits_needed(v) as usize;
525        bit_counts[bits] += 1;
526    }
527
528    // Compute cumulative counts: values that fit in b bits or less
529    let mut cumulative = [0usize; 33];
530    cumulative[0] = bit_counts[0];
531    for b in 1..=32 {
532        cumulative[b] = cumulative[b - 1] + bit_counts[b];
533    }
534
535    let mut best_bits = 32u8;
536    let mut best_total = usize::MAX;
537    let mut best_exceptions = 0usize;
538
539    // Try each bit width and compute total storage
540    for b in 0..=32u8 {
541        let fitting = if b == 0 {
542            bit_counts[0]
543        } else {
544            cumulative[b as usize]
545        };
546        let exceptions = n - fitting;
547
548        // Skip if too many exceptions
549        if exceptions > max_exceptions && b < 32 {
550            continue;
551        }
552
553        // Calculate total bits:
554        // - Main array: n * b bits
555        // - Exceptions: exceptions * (7 bits position + (32 - b) bits high value)
556        let main_bits = n * (b as usize);
557        let exception_bits = if b < 32 {
558            exceptions * (7 + (32 - b as usize))
559        } else {
560            0
561        };
562        let total = main_bits + exception_bits;
563
564        if total < best_total {
565            best_total = total;
566            best_bits = b;
567            best_exceptions = exceptions;
568        }
569    }
570
571    (best_bits, best_exceptions, best_total)
572}
573
574/// Pack values into a bitpacked array with the given bit width (NewPFD/OptPFD style)
575///
576/// Following the paper "Decoding billions of integers per second through vectorization":
577/// - Store the first b bits (low bits) of ALL values in the main array
578/// - For exceptions (values >= 2^b), store only the HIGH (32-b) bits separately with positions
579///
580/// Returns the packed bytes and a list of exceptions (position, high_bits)
581fn pack_with_exceptions(values: &[u32], bit_width: u8) -> (Vec<u8>, Vec<(u8, u32)>) {
582    if bit_width == 0 {
583        // All values must be 0, exceptions store full value
584        let exceptions: Vec<(u8, u32)> = values
585            .iter()
586            .enumerate()
587            .filter(|&(_, &v)| v != 0)
588            .map(|(i, &v)| (i as u8, v)) // For b=0, high bits = full value
589            .collect();
590        return (Vec::new(), exceptions);
591    }
592
593    if bit_width >= 32 {
594        // No exceptions possible, just pack all 32 bits
595        let bytes_needed = values.len() * 4;
596        let mut packed = vec![0u8; bytes_needed];
597        for (i, &value) in values.iter().enumerate() {
598            let bytes = value.to_le_bytes();
599            packed[i * 4..i * 4 + 4].copy_from_slice(&bytes);
600        }
601        return (packed, Vec::new());
602    }
603
604    let mask = (1u64 << bit_width) - 1;
605    let bytes_needed = (values.len() * bit_width as usize).div_ceil(8);
606    let mut packed = vec![0u8; bytes_needed];
607    let mut exceptions = Vec::new();
608
609    let mut bit_pos = 0usize;
610    for (i, &value) in values.iter().enumerate() {
611        // Store lower b bits in main array (for ALL values, including exceptions)
612        let low_bits = (value as u64) & mask;
613
614        // Write low bits to packed array
615        let byte_idx = bit_pos / 8;
616        let bit_offset = bit_pos % 8;
617
618        let mut remaining_bits = bit_width as usize;
619        let mut val = low_bits;
620        let mut current_byte_idx = byte_idx;
621        let mut current_bit_offset = bit_offset;
622
623        while remaining_bits > 0 {
624            let bits_in_byte = (8 - current_bit_offset).min(remaining_bits);
625            let byte_mask = ((1u64 << bits_in_byte) - 1) as u8;
626            packed[current_byte_idx] |= ((val as u8) & byte_mask) << current_bit_offset;
627            val >>= bits_in_byte;
628            remaining_bits -= bits_in_byte;
629            current_byte_idx += 1;
630            current_bit_offset = 0;
631        }
632
633        bit_pos += bit_width as usize;
634
635        // Record exception: store only the HIGH (32-b) bits
636        let fits = value <= mask as u32;
637        if !fits {
638            let high_bits = value >> bit_width;
639            exceptions.push((i as u8, high_bits));
640        }
641    }
642
643    (packed, exceptions)
644}
645
646/// Unpack values from a bitpacked array and apply exceptions (NewPFD/OptPFD style)
647///
648/// Following the paper "Decoding billions of integers per second through vectorization":
649/// - Low b bits are stored in the main array for ALL values
650/// - Exceptions store only the HIGH (32-b) bits
651/// - Reconstruct: value = (high_bits << b) | low_bits
652///
653/// Uses SIMD acceleration for common bit widths (8, 16, 32)
654fn unpack_with_exceptions(
655    packed: &[u8],
656    bit_width: u8,
657    exceptions: &[(u8, u32)],
658    count: usize,
659    output: &mut [u32],
660) {
661    if bit_width == 0 {
662        output[..count].fill(0);
663    } else if bit_width == 8 {
664        // SIMD-accelerated 8-bit unpacking
665        unpack_8bit(packed, output, count);
666    } else if bit_width == 16 {
667        // SIMD-accelerated 16-bit unpacking
668        unpack_16bit(packed, output, count);
669    } else if bit_width >= 32 {
670        // Direct copy for 32-bit values (no exceptions possible)
671        for (i, out) in output.iter_mut().enumerate().take(count) {
672            let idx = i * 4;
673            *out = u32::from_le_bytes([
674                packed[idx],
675                packed[idx + 1],
676                packed[idx + 2],
677                packed[idx + 3],
678            ]);
679        }
680        return; // No exceptions for 32-bit
681    } else {
682        // Generic bit unpacking for other bit widths
683        let mask = (1u64 << bit_width) - 1;
684        let mut bit_pos = 0usize;
685        let input_ptr = packed.as_ptr();
686
687        for out in output[..count].iter_mut() {
688            let byte_idx = bit_pos >> 3;
689            let bit_offset = bit_pos & 7;
690
691            // Read 8 bytes at once for efficiency
692            let word = if byte_idx + 8 <= packed.len() {
693                unsafe { (input_ptr.add(byte_idx) as *const u64).read_unaligned() }
694            } else {
695                // Handle edge case near end of buffer
696                let mut word = 0u64;
697                for (i, &b) in packed[byte_idx..].iter().enumerate() {
698                    word |= (b as u64) << (i * 8);
699                }
700                word
701            };
702
703            *out = ((word >> bit_offset) & mask) as u32;
704            bit_pos += bit_width as usize;
705        }
706    }
707
708    // Apply exceptions: combine high bits with low bits already in output
709    // value = (high_bits << bit_width) | low_bits
710    for &(pos, high_bits) in exceptions {
711        if (pos as usize) < count {
712            let low_bits = output[pos as usize];
713            output[pos as usize] = (high_bits << bit_width) | low_bits;
714        }
715    }
716}
717
718/// A single OptP4D block
719#[derive(Debug, Clone)]
720pub struct OptP4DBlock {
721    /// First doc_id in this block (absolute)
722    pub first_doc_id: u32,
723    /// Last doc_id in this block (absolute)
724    pub last_doc_id: u32,
725    /// Number of documents in this block
726    pub num_docs: u16,
727    /// Bit width for delta encoding
728    pub doc_bit_width: u8,
729    /// Bit width for term frequencies
730    pub tf_bit_width: u8,
731    /// Maximum term frequency in this block
732    pub max_tf: u32,
733    /// Maximum block score for WAND/MaxScore
734    pub max_block_score: f32,
735    /// Packed doc deltas
736    pub doc_deltas: Vec<u8>,
737    /// Doc delta exceptions: (position, full_delta)
738    pub doc_exceptions: Vec<(u8, u32)>,
739    /// Packed term frequencies
740    pub term_freqs: Vec<u8>,
741    /// TF exceptions: (position, full_tf)
742    pub tf_exceptions: Vec<(u8, u32)>,
743}
744
745impl OptP4DBlock {
746    /// Serialize the block
747    pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
748        writer.write_u32::<LittleEndian>(self.first_doc_id)?;
749        writer.write_u32::<LittleEndian>(self.last_doc_id)?;
750        writer.write_u16::<LittleEndian>(self.num_docs)?;
751        writer.write_u8(self.doc_bit_width)?;
752        writer.write_u8(self.tf_bit_width)?;
753        writer.write_u32::<LittleEndian>(self.max_tf)?;
754        writer.write_f32::<LittleEndian>(self.max_block_score)?;
755
756        // Write doc deltas
757        writer.write_u16::<LittleEndian>(self.doc_deltas.len() as u16)?;
758        writer.write_all(&self.doc_deltas)?;
759
760        // Write doc exceptions
761        writer.write_u8(self.doc_exceptions.len() as u8)?;
762        for &(pos, val) in &self.doc_exceptions {
763            writer.write_u8(pos)?;
764            writer.write_u32::<LittleEndian>(val)?;
765        }
766
767        // Write term freqs
768        writer.write_u16::<LittleEndian>(self.term_freqs.len() as u16)?;
769        writer.write_all(&self.term_freqs)?;
770
771        // Write tf exceptions
772        writer.write_u8(self.tf_exceptions.len() as u8)?;
773        for &(pos, val) in &self.tf_exceptions {
774            writer.write_u8(pos)?;
775            writer.write_u32::<LittleEndian>(val)?;
776        }
777
778        Ok(())
779    }
780
781    /// Deserialize a block
782    pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
783        let first_doc_id = reader.read_u32::<LittleEndian>()?;
784        let last_doc_id = reader.read_u32::<LittleEndian>()?;
785        let num_docs = reader.read_u16::<LittleEndian>()?;
786        let doc_bit_width = reader.read_u8()?;
787        let tf_bit_width = reader.read_u8()?;
788        let max_tf = reader.read_u32::<LittleEndian>()?;
789        let max_block_score = reader.read_f32::<LittleEndian>()?;
790
791        // Read doc deltas
792        let doc_deltas_len = reader.read_u16::<LittleEndian>()? as usize;
793        let mut doc_deltas = vec![0u8; doc_deltas_len];
794        reader.read_exact(&mut doc_deltas)?;
795
796        // Read doc exceptions
797        let num_doc_exceptions = reader.read_u8()? as usize;
798        let mut doc_exceptions = Vec::with_capacity(num_doc_exceptions);
799        for _ in 0..num_doc_exceptions {
800            let pos = reader.read_u8()?;
801            let val = reader.read_u32::<LittleEndian>()?;
802            doc_exceptions.push((pos, val));
803        }
804
805        // Read term freqs
806        let term_freqs_len = reader.read_u16::<LittleEndian>()? as usize;
807        let mut term_freqs = vec![0u8; term_freqs_len];
808        reader.read_exact(&mut term_freqs)?;
809
810        // Read tf exceptions
811        let num_tf_exceptions = reader.read_u8()? as usize;
812        let mut tf_exceptions = Vec::with_capacity(num_tf_exceptions);
813        for _ in 0..num_tf_exceptions {
814            let pos = reader.read_u8()?;
815            let val = reader.read_u32::<LittleEndian>()?;
816            tf_exceptions.push((pos, val));
817        }
818
819        Ok(Self {
820            first_doc_id,
821            last_doc_id,
822            num_docs,
823            doc_bit_width,
824            tf_bit_width,
825            max_tf,
826            max_block_score,
827            doc_deltas,
828            doc_exceptions,
829            term_freqs,
830            tf_exceptions,
831        })
832    }
833
834    /// Decode doc_ids from this block using SIMD-accelerated delta decoding
835    pub fn decode_doc_ids(&self) -> Vec<u32> {
836        if self.num_docs == 0 {
837            return Vec::new();
838        }
839
840        let count = self.num_docs as usize;
841        let mut deltas = vec![0u32; count];
842
843        // Unpack deltas with exceptions (SIMD-accelerated for 8/16/32-bit)
844        if count > 1 {
845            unpack_with_exceptions(
846                &self.doc_deltas,
847                self.doc_bit_width,
848                &self.doc_exceptions,
849                count - 1,
850                &mut deltas,
851            );
852        }
853
854        // Convert deltas to absolute doc_ids using SIMD-accelerated prefix sum
855        let mut doc_ids = vec![0u32; count];
856        delta_decode_simd(&deltas, &mut doc_ids, self.first_doc_id, count);
857
858        doc_ids
859    }
860
861    /// Decode term frequencies from this block using SIMD acceleration
862    pub fn decode_term_freqs(&self) -> Vec<u32> {
863        if self.num_docs == 0 {
864            return Vec::new();
865        }
866
867        let count = self.num_docs as usize;
868        let mut tfs = vec![0u32; count];
869
870        // Unpack TFs with exceptions (SIMD-accelerated for 8/16/32-bit)
871        unpack_with_exceptions(
872            &self.term_freqs,
873            self.tf_bit_width,
874            &self.tf_exceptions,
875            count,
876            &mut tfs,
877        );
878
879        // TF is stored as tf-1, so add 1 back using SIMD
880        add_one_simd(&mut tfs, count);
881
882        tfs
883    }
884}
885
886/// OptP4D posting list
887#[derive(Debug, Clone)]
888pub struct OptP4DPostingList {
889    /// Blocks of postings
890    pub blocks: Vec<OptP4DBlock>,
891    /// Total document count
892    pub doc_count: u32,
893    /// Maximum score across all blocks
894    pub max_score: f32,
895}
896
897impl OptP4DPostingList {
898    /// BM25F parameters for block-max score calculation
899    const K1: f32 = 1.2;
900    const B: f32 = 0.75;
901
902    /// Compute BM25F upper bound score for a given max_tf and IDF
903    #[inline]
904    fn compute_bm25f_upper_bound(max_tf: u32, idf: f32) -> f32 {
905        let tf = max_tf as f32;
906        let min_length_norm = 1.0 - Self::B;
907        let tf_norm = (tf * (Self::K1 + 1.0)) / (tf + Self::K1 * min_length_norm);
908        idf * tf_norm
909    }
910
911    /// Create from raw doc_ids and term frequencies
912    pub fn from_postings(doc_ids: &[u32], term_freqs: &[u32], idf: f32) -> Self {
913        assert_eq!(doc_ids.len(), term_freqs.len());
914
915        if doc_ids.is_empty() {
916            return Self {
917                blocks: Vec::new(),
918                doc_count: 0,
919                max_score: 0.0,
920            };
921        }
922
923        let mut blocks = Vec::new();
924        let mut max_score = 0.0f32;
925        let mut i = 0;
926
927        while i < doc_ids.len() {
928            let block_end = (i + OPT_P4D_BLOCK_SIZE).min(doc_ids.len());
929            let block_docs = &doc_ids[i..block_end];
930            let block_tfs = &term_freqs[i..block_end];
931
932            let block = Self::create_block(block_docs, block_tfs, idf);
933            max_score = max_score.max(block.max_block_score);
934            blocks.push(block);
935
936            i = block_end;
937        }
938
939        Self {
940            blocks,
941            doc_count: doc_ids.len() as u32,
942            max_score,
943        }
944    }
945
946    fn create_block(doc_ids: &[u32], term_freqs: &[u32], idf: f32) -> OptP4DBlock {
947        let num_docs = doc_ids.len();
948        let first_doc_id = doc_ids[0];
949        let last_doc_id = *doc_ids.last().unwrap();
950
951        // Compute deltas (delta - 1 to save one bit since deltas are always >= 1)
952        let mut deltas = Vec::with_capacity(num_docs.saturating_sub(1));
953        for j in 1..num_docs {
954            let delta = doc_ids[j] - doc_ids[j - 1] - 1;
955            deltas.push(delta);
956        }
957
958        // Find optimal bit width for deltas
959        let (doc_bit_width, _, _) = find_optimal_bit_width(&deltas);
960        let (doc_deltas, doc_exceptions) = pack_with_exceptions(&deltas, doc_bit_width);
961
962        // Compute max TF and prepare TF array (store tf-1)
963        let mut tfs = Vec::with_capacity(num_docs);
964        let mut max_tf = 0u32;
965
966        for &tf in term_freqs {
967            tfs.push(tf - 1); // Store tf-1
968            max_tf = max_tf.max(tf);
969        }
970
971        // Find optimal bit width for TFs
972        let (tf_bit_width, _, _) = find_optimal_bit_width(&tfs);
973        let (term_freqs_packed, tf_exceptions) = pack_with_exceptions(&tfs, tf_bit_width);
974
975        // BM25F upper bound score
976        let max_block_score = Self::compute_bm25f_upper_bound(max_tf, idf);
977
978        OptP4DBlock {
979            first_doc_id,
980            last_doc_id,
981            num_docs: num_docs as u16,
982            doc_bit_width,
983            tf_bit_width,
984            max_tf,
985            max_block_score,
986            doc_deltas,
987            doc_exceptions,
988            term_freqs: term_freqs_packed,
989            tf_exceptions,
990        }
991    }
992
993    /// Serialize the posting list
994    pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
995        writer.write_u32::<LittleEndian>(self.doc_count)?;
996        writer.write_f32::<LittleEndian>(self.max_score)?;
997        writer.write_u32::<LittleEndian>(self.blocks.len() as u32)?;
998
999        for block in &self.blocks {
1000            block.serialize(writer)?;
1001        }
1002
1003        Ok(())
1004    }
1005
1006    /// Deserialize a posting list
1007    pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
1008        let doc_count = reader.read_u32::<LittleEndian>()?;
1009        let max_score = reader.read_f32::<LittleEndian>()?;
1010        let num_blocks = reader.read_u32::<LittleEndian>()? as usize;
1011
1012        let mut blocks = Vec::with_capacity(num_blocks);
1013        for _ in 0..num_blocks {
1014            blocks.push(OptP4DBlock::deserialize(reader)?);
1015        }
1016
1017        Ok(Self {
1018            blocks,
1019            doc_count,
1020            max_score,
1021        })
1022    }
1023
1024    /// Get document count
1025    pub fn len(&self) -> u32 {
1026        self.doc_count
1027    }
1028
1029    /// Check if empty
1030    pub fn is_empty(&self) -> bool {
1031        self.doc_count == 0
1032    }
1033
1034    /// Create an iterator
1035    pub fn iterator(&self) -> OptP4DIterator<'_> {
1036        OptP4DIterator::new(self)
1037    }
1038}
1039
1040/// Iterator over OptP4D posting list
1041pub struct OptP4DIterator<'a> {
1042    posting_list: &'a OptP4DPostingList,
1043    current_block: usize,
1044    block_doc_ids: Vec<u32>,
1045    block_term_freqs: Vec<u32>,
1046    pos_in_block: usize,
1047    exhausted: bool,
1048}
1049
1050impl<'a> OptP4DIterator<'a> {
1051    pub fn new(posting_list: &'a OptP4DPostingList) -> Self {
1052        let mut iter = Self {
1053            posting_list,
1054            current_block: 0,
1055            block_doc_ids: Vec::new(),
1056            block_term_freqs: Vec::new(),
1057            pos_in_block: 0,
1058            exhausted: posting_list.blocks.is_empty(),
1059        };
1060
1061        if !iter.exhausted {
1062            iter.decode_current_block();
1063        }
1064
1065        iter
1066    }
1067
1068    fn decode_current_block(&mut self) {
1069        let block = &self.posting_list.blocks[self.current_block];
1070        self.block_doc_ids = block.decode_doc_ids();
1071        self.block_term_freqs = block.decode_term_freqs();
1072        self.pos_in_block = 0;
1073    }
1074
1075    /// Current document ID
1076    pub fn doc(&self) -> u32 {
1077        if self.exhausted {
1078            u32::MAX
1079        } else {
1080            self.block_doc_ids[self.pos_in_block]
1081        }
1082    }
1083
1084    /// Current term frequency
1085    pub fn term_freq(&self) -> u32 {
1086        if self.exhausted {
1087            0
1088        } else {
1089            self.block_term_freqs[self.pos_in_block]
1090        }
1091    }
1092
1093    /// Advance to next document
1094    pub fn advance(&mut self) -> u32 {
1095        if self.exhausted {
1096            return u32::MAX;
1097        }
1098
1099        self.pos_in_block += 1;
1100
1101        if self.pos_in_block >= self.block_doc_ids.len() {
1102            self.current_block += 1;
1103            if self.current_block >= self.posting_list.blocks.len() {
1104                self.exhausted = true;
1105                return u32::MAX;
1106            }
1107            self.decode_current_block();
1108        }
1109
1110        self.doc()
1111    }
1112
1113    /// Seek to first doc >= target
1114    pub fn seek(&mut self, target: u32) -> u32 {
1115        if self.exhausted {
1116            return u32::MAX;
1117        }
1118
1119        // Skip blocks where last_doc_id < target
1120        while self.current_block < self.posting_list.blocks.len() {
1121            let block = &self.posting_list.blocks[self.current_block];
1122            if block.last_doc_id >= target {
1123                break;
1124            }
1125            self.current_block += 1;
1126        }
1127
1128        if self.current_block >= self.posting_list.blocks.len() {
1129            self.exhausted = true;
1130            return u32::MAX;
1131        }
1132
1133        // Decode block if needed
1134        if self.block_doc_ids.is_empty() || self.current_block != self.posting_list.blocks.len() - 1
1135        {
1136            self.decode_current_block();
1137        }
1138
1139        // Binary search within block
1140        match self.block_doc_ids[self.pos_in_block..].binary_search(&target) {
1141            Ok(idx) => {
1142                self.pos_in_block += idx;
1143            }
1144            Err(idx) => {
1145                self.pos_in_block += idx;
1146                if self.pos_in_block >= self.block_doc_ids.len() {
1147                    // Move to next block
1148                    self.current_block += 1;
1149                    if self.current_block >= self.posting_list.blocks.len() {
1150                        self.exhausted = true;
1151                        return u32::MAX;
1152                    }
1153                    self.decode_current_block();
1154                }
1155            }
1156        }
1157
1158        self.doc()
1159    }
1160}
1161
1162#[cfg(test)]
1163mod tests {
1164    use super::*;
1165
1166    #[test]
1167    fn test_bits_needed() {
1168        assert_eq!(bits_needed(0), 0);
1169        assert_eq!(bits_needed(1), 1);
1170        assert_eq!(bits_needed(2), 2);
1171        assert_eq!(bits_needed(3), 2);
1172        assert_eq!(bits_needed(4), 3);
1173        assert_eq!(bits_needed(255), 8);
1174        assert_eq!(bits_needed(256), 9);
1175        assert_eq!(bits_needed(u32::MAX), 32);
1176    }
1177
1178    #[test]
1179    fn test_find_optimal_bit_width() {
1180        // All zeros
1181        let values = vec![0u32; 100];
1182        let (bits, exceptions, _) = find_optimal_bit_width(&values);
1183        assert_eq!(bits, 0);
1184        assert_eq!(exceptions, 0);
1185
1186        // All small values
1187        let values: Vec<u32> = (0..100).map(|i| i % 16).collect();
1188        let (bits, _, _) = find_optimal_bit_width(&values);
1189        assert!(bits <= 4);
1190
1191        // Mix with outliers
1192        let mut values: Vec<u32> = (0..100).map(|i| i % 16).collect();
1193        values[50] = 1_000_000; // outlier
1194        let (bits, exceptions, _) = find_optimal_bit_width(&values);
1195        assert!(bits < 20); // Should use small bit width with exception
1196        assert!(exceptions >= 1);
1197    }
1198
1199    #[test]
1200    fn test_pack_unpack_with_exceptions() {
1201        let values = vec![1, 2, 3, 255, 4, 5, 1000, 6, 7, 8];
1202        let (packed, exceptions) = pack_with_exceptions(&values, 4);
1203
1204        let mut output = vec![0u32; values.len()];
1205        unpack_with_exceptions(&packed, 4, &exceptions, values.len(), &mut output);
1206
1207        assert_eq!(output, values);
1208    }
1209
1210    #[test]
1211    fn test_opt_p4d_posting_list_small() {
1212        let doc_ids: Vec<u32> = (0..100).map(|i| i * 2).collect();
1213        let term_freqs: Vec<u32> = vec![1; 100];
1214
1215        let list = OptP4DPostingList::from_postings(&doc_ids, &term_freqs, 1.0);
1216
1217        assert_eq!(list.len(), 100);
1218        assert_eq!(list.blocks.len(), 1);
1219
1220        // Verify iteration
1221        let mut iter = list.iterator();
1222        for (i, &expected) in doc_ids.iter().enumerate() {
1223            assert_eq!(iter.doc(), expected, "Mismatch at {}", i);
1224            assert_eq!(iter.term_freq(), 1);
1225            iter.advance();
1226        }
1227        assert_eq!(iter.doc(), u32::MAX);
1228    }
1229
1230    #[test]
1231    fn test_opt_p4d_posting_list_large() {
1232        let doc_ids: Vec<u32> = (0..500).map(|i| i * 3).collect();
1233        let term_freqs: Vec<u32> = (0..500).map(|i| (i % 10) + 1).collect();
1234
1235        let list = OptP4DPostingList::from_postings(&doc_ids, &term_freqs, 1.0);
1236
1237        assert_eq!(list.len(), 500);
1238        assert_eq!(list.blocks.len(), 4); // 500 / 128 = 3.9 -> 4 blocks
1239
1240        // Verify iteration
1241        let mut iter = list.iterator();
1242        for (i, &expected) in doc_ids.iter().enumerate() {
1243            assert_eq!(iter.doc(), expected, "Mismatch at {}", i);
1244            assert_eq!(iter.term_freq(), term_freqs[i]);
1245            iter.advance();
1246        }
1247    }
1248
1249    #[test]
1250    fn test_opt_p4d_seek() {
1251        let doc_ids: Vec<u32> = vec![10, 20, 30, 100, 200, 300, 1000, 2000];
1252        let term_freqs: Vec<u32> = vec![1; 8];
1253
1254        let list = OptP4DPostingList::from_postings(&doc_ids, &term_freqs, 1.0);
1255        let mut iter = list.iterator();
1256
1257        assert_eq!(iter.seek(25), 30);
1258        assert_eq!(iter.seek(100), 100);
1259        assert_eq!(iter.seek(500), 1000);
1260        assert_eq!(iter.seek(3000), u32::MAX);
1261    }
1262
1263    #[test]
1264    fn test_opt_p4d_serialization() {
1265        let doc_ids: Vec<u32> = (0..200).map(|i| i * 5).collect();
1266        let term_freqs: Vec<u32> = (0..200).map(|i| (i % 5) + 1).collect();
1267
1268        let list = OptP4DPostingList::from_postings(&doc_ids, &term_freqs, 1.0);
1269
1270        let mut buffer = Vec::new();
1271        list.serialize(&mut buffer).unwrap();
1272
1273        let restored = OptP4DPostingList::deserialize(&mut &buffer[..]).unwrap();
1274
1275        assert_eq!(restored.len(), list.len());
1276        assert_eq!(restored.blocks.len(), list.blocks.len());
1277
1278        // Verify iteration matches
1279        let mut iter1 = list.iterator();
1280        let mut iter2 = restored.iterator();
1281
1282        while iter1.doc() != u32::MAX {
1283            assert_eq!(iter1.doc(), iter2.doc());
1284            assert_eq!(iter1.term_freq(), iter2.term_freq());
1285            iter1.advance();
1286            iter2.advance();
1287        }
1288    }
1289
1290    #[test]
1291    fn test_opt_p4d_with_outliers() {
1292        // Create data with some outliers to test exception handling
1293        let mut doc_ids: Vec<u32> = (0..128).map(|i| i * 2).collect();
1294        doc_ids[64] = 1_000_000; // Large outlier
1295
1296        // Fix: ensure doc_ids are sorted
1297        doc_ids.sort();
1298
1299        let term_freqs: Vec<u32> = vec![1; 128];
1300
1301        let list = OptP4DPostingList::from_postings(&doc_ids, &term_freqs, 1.0);
1302
1303        // Verify the outlier is handled correctly
1304        let mut iter = list.iterator();
1305        let mut found_outlier = false;
1306        while iter.doc() != u32::MAX {
1307            if iter.doc() == 1_000_000 {
1308                found_outlier = true;
1309            }
1310            iter.advance();
1311        }
1312        assert!(found_outlier, "Outlier value should be preserved");
1313    }
1314
1315    #[test]
1316    fn test_opt_p4d_simd_full_blocks() {
1317        // Test with multiple full 128-integer blocks to exercise SIMD paths
1318        let doc_ids: Vec<u32> = (0..1024).map(|i| i * 2).collect();
1319        let term_freqs: Vec<u32> = (0..1024).map(|i| (i % 20) + 1).collect();
1320
1321        let list = OptP4DPostingList::from_postings(&doc_ids, &term_freqs, 1.0);
1322
1323        assert_eq!(list.len(), 1024);
1324        assert_eq!(list.blocks.len(), 8); // 1024 / 128 = 8 full blocks
1325
1326        // Verify all values are decoded correctly
1327        let mut iter = list.iterator();
1328        for (i, &expected_doc) in doc_ids.iter().enumerate() {
1329            assert_eq!(iter.doc(), expected_doc, "Doc mismatch at {}", i);
1330            assert_eq!(iter.term_freq(), term_freqs[i], "TF mismatch at {}", i);
1331            iter.advance();
1332        }
1333        assert_eq!(iter.doc(), u32::MAX);
1334    }
1335
1336    #[test]
1337    fn test_opt_p4d_simd_8bit_values() {
1338        // Test with values that fit in 8 bits to exercise SIMD 8-bit unpack
1339        let doc_ids: Vec<u32> = (0..256).collect();
1340        let term_freqs: Vec<u32> = (0..256).map(|i| (i % 100) + 1).collect();
1341
1342        let list = OptP4DPostingList::from_postings(&doc_ids, &term_freqs, 1.0);
1343
1344        // Verify all values
1345        let mut iter = list.iterator();
1346        for (i, &expected_doc) in doc_ids.iter().enumerate() {
1347            assert_eq!(iter.doc(), expected_doc, "Doc mismatch at {}", i);
1348            assert_eq!(iter.term_freq(), term_freqs[i], "TF mismatch at {}", i);
1349            iter.advance();
1350        }
1351    }
1352
1353    #[test]
1354    fn test_opt_p4d_simd_delta_decode() {
1355        // Test SIMD delta decoding with various gap sizes
1356        let mut doc_ids = Vec::with_capacity(512);
1357        let mut current = 0u32;
1358        for i in 0..512 {
1359            current += (i % 10) + 1; // Variable gaps
1360            doc_ids.push(current);
1361        }
1362        let term_freqs: Vec<u32> = vec![1; 512];
1363
1364        let list = OptP4DPostingList::from_postings(&doc_ids, &term_freqs, 1.0);
1365
1366        // Verify delta decoding is correct
1367        let mut iter = list.iterator();
1368        for (i, &expected_doc) in doc_ids.iter().enumerate() {
1369            assert_eq!(
1370                iter.doc(),
1371                expected_doc,
1372                "Doc mismatch at {} (expected {}, got {})",
1373                i,
1374                expected_doc,
1375                iter.doc()
1376            );
1377            iter.advance();
1378        }
1379    }
1380}