hermes_core/structures/
horizontal_bp128.rs

1//! Bitpacking utilities for compact integer encoding
2//!
3//! Implements SIMD-friendly bitpacking for posting list compression.
4//! Uses PForDelta-style encoding with exceptions for outliers.
5//!
6//! Optimizations:
7//! - SIMD-accelerated unpacking (when available)
8//! - Hillis-Steele parallel prefix sum for delta decoding
9//! - Binary search within decoded blocks
10//! - Variable block sizes based on posting list length
11
12use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
13use std::io::{self, Read, Write};
14
15// ============================================================================
16// SIMD optimizations for aarch64 (Apple Silicon, ARM servers)
17// ============================================================================
18
19#[cfg(target_arch = "aarch64")]
20mod neon {
21    use super::HORIZONTAL_BP128_BLOCK_SIZE;
22    use std::arch::aarch64::*;
23
24    /// Vectorized unpack for 8-bit values using NEON
25    /// Processes 16 bytes at a time (4x u32 per iteration)
26    #[target_feature(enable = "neon")]
27    pub unsafe fn unpack_block_8_neon(
28        input: &[u8],
29        output: &mut [u32; HORIZONTAL_BP128_BLOCK_SIZE],
30    ) {
31        unsafe {
32            // Process 16 u8 -> 16 u32 at a time (4 NEON registers)
33            for chunk in 0..8 {
34                let base = chunk * 16;
35                let in_ptr = input.as_ptr().add(base);
36
37                // Load 16 bytes
38                let bytes = vld1q_u8(in_ptr);
39
40                // Widen u8 -> u16 -> u32
41                // Low 8 bytes
42                let low8 = vget_low_u8(bytes);
43                let high8 = vget_high_u8(bytes);
44
45                // u8 -> u16
46                let low16 = vmovl_u8(low8);
47                let high16 = vmovl_u8(high8);
48
49                // u16 -> u32 (4 vectors of 4 u32 each)
50                let v0 = vmovl_u16(vget_low_u16(low16));
51                let v1 = vmovl_u16(vget_high_u16(low16));
52                let v2 = vmovl_u16(vget_low_u16(high16));
53                let v3 = vmovl_u16(vget_high_u16(high16));
54
55                // Store 16 u32 values
56                let out_ptr = output.as_mut_ptr().add(base);
57                vst1q_u32(out_ptr, v0);
58                vst1q_u32(out_ptr.add(4), v1);
59                vst1q_u32(out_ptr.add(8), v2);
60                vst1q_u32(out_ptr.add(12), v3);
61            }
62        }
63    }
64
65    /// Vectorized unpack for 16-bit values using NEON
66    #[target_feature(enable = "neon")]
67    pub unsafe fn unpack_block_16_neon(
68        input: &[u8],
69        output: &mut [u32; HORIZONTAL_BP128_BLOCK_SIZE],
70    ) {
71        unsafe {
72            // Process 8 u16 -> 8 u32 at a time (2 NEON registers)
73            for chunk in 0..16 {
74                let base = chunk * 8;
75                let in_ptr = input.as_ptr().add(base * 2) as *const u16;
76
77                // Load 8 u16 values
78                let vals = vld1q_u16(in_ptr);
79
80                // Widen u16 -> u32
81                let low = vmovl_u16(vget_low_u16(vals));
82                let high = vmovl_u16(vget_high_u16(vals));
83
84                // Store 8 u32 values
85                let out_ptr = output.as_mut_ptr().add(base);
86                vst1q_u32(out_ptr, low);
87                vst1q_u32(out_ptr.add(4), high);
88            }
89        }
90    }
91
92    /// Vectorized unpack for 32-bit values using NEON (just a fast copy)
93    #[target_feature(enable = "neon")]
94    pub unsafe fn unpack_block_32_neon(
95        input: &[u8],
96        output: &mut [u32; HORIZONTAL_BP128_BLOCK_SIZE],
97    ) {
98        unsafe {
99            let in_ptr = input.as_ptr() as *const u32;
100            let out_ptr = output.as_mut_ptr();
101
102            // Copy 128 u32 values (4 at a time)
103            for i in 0..32 {
104                let vals = vld1q_u32(in_ptr.add(i * 4));
105                vst1q_u32(out_ptr.add(i * 4), vals);
106            }
107        }
108    }
109
110    /// SIMD prefix sum for delta decoding
111    /// Converts deltas to absolute values: output[i] = first + sum(deltas[0..i]) + i
112    #[target_feature(enable = "neon")]
113    #[allow(dead_code)]
114    pub unsafe fn delta_decode_block_neon(
115        output: &mut [u32],
116        deltas: &[u32],
117        first_doc_id: u32,
118        count: usize,
119    ) {
120        if count == 0 {
121            return;
122        }
123
124        // Process in groups of 4 for SIMD prefix sum
125        let mut carry = first_doc_id;
126        output[0] = carry;
127
128        let full_groups = (count - 1) / 4;
129        let remainder = (count - 1) % 4;
130
131        for group in 0..full_groups {
132            let base = group * 4;
133
134            unsafe {
135                // Load 4 deltas
136                let d = vld1q_u32(deltas[base..].as_ptr());
137
138                // Add 1 to each delta (since we store gap-1)
139                let ones = vdupq_n_u32(1);
140                let gaps = vaddq_u32(d, ones);
141
142                // Extract lanes and compute prefix sum with carry
143                let g0 = vgetq_lane_u32(gaps, 0);
144                let g1 = vgetq_lane_u32(gaps, 1);
145                let g2 = vgetq_lane_u32(gaps, 2);
146                let g3 = vgetq_lane_u32(gaps, 3);
147
148                let v0 = carry.wrapping_add(g0);
149                let v1 = v0.wrapping_add(g1);
150                let v2 = v1.wrapping_add(g2);
151                let v3 = v2.wrapping_add(g3);
152
153                // Store results
154                output[base + 1] = v0;
155                output[base + 2] = v1;
156                output[base + 3] = v2;
157                output[base + 4] = v3;
158
159                carry = v3;
160            }
161        }
162
163        // Handle remainder
164        let base = full_groups * 4;
165        for j in 0..remainder {
166            carry = carry.wrapping_add(deltas[base + j]).wrapping_add(1);
167            output[base + j + 1] = carry;
168        }
169    }
170}
171
172// ============================================================================
173// SIMD optimizations for x86_64 (Intel/AMD)
174// ============================================================================
175
176#[cfg(target_arch = "x86_64")]
177#[allow(dead_code)]
178mod sse {
179    use super::HORIZONTAL_BP128_BLOCK_SIZE;
180    use std::arch::x86_64::*;
181
182    /// Vectorized unpack for 8-bit values using SSE
183    /// Processes 16 bytes at a time
184    #[target_feature(enable = "sse2", enable = "sse4.1")]
185    pub unsafe fn unpack_block_8_sse(
186        input: &[u8],
187        output: &mut [u32; HORIZONTAL_BP128_BLOCK_SIZE],
188    ) {
189        // Process 16 u8 -> 16 u32 at a time
190        for chunk in 0..8 {
191            let base = chunk * 16;
192            let in_ptr = input.as_ptr().add(base);
193
194            // Load 16 bytes
195            let bytes = _mm_loadu_si128(in_ptr as *const __m128i);
196
197            // Zero extend u8 -> u32 using SSE4.1 pmovzx
198            // We need to do this in 4 steps (4 bytes at a time)
199            let v0 = _mm_cvtepu8_epi32(bytes);
200            let v1 = _mm_cvtepu8_epi32(_mm_srli_si128(bytes, 4));
201            let v2 = _mm_cvtepu8_epi32(_mm_srli_si128(bytes, 8));
202            let v3 = _mm_cvtepu8_epi32(_mm_srli_si128(bytes, 12));
203
204            // Store 16 u32 values
205            let out_ptr = output.as_mut_ptr().add(base);
206            _mm_storeu_si128(out_ptr as *mut __m128i, v0);
207            _mm_storeu_si128(out_ptr.add(4) as *mut __m128i, v1);
208            _mm_storeu_si128(out_ptr.add(8) as *mut __m128i, v2);
209            _mm_storeu_si128(out_ptr.add(12) as *mut __m128i, v3);
210        }
211    }
212
213    /// Vectorized unpack for 16-bit values using SSE
214    #[target_feature(enable = "sse2", enable = "sse4.1")]
215    pub unsafe fn unpack_block_16_sse(
216        input: &[u8],
217        output: &mut [u32; HORIZONTAL_BP128_BLOCK_SIZE],
218    ) {
219        // Process 8 u16 -> 8 u32 at a time
220        for chunk in 0..16 {
221            let base = chunk * 8;
222            let in_ptr = input.as_ptr().add(base * 2);
223
224            // Load 16 bytes (8 u16 values)
225            let vals = _mm_loadu_si128(in_ptr as *const __m128i);
226
227            // Zero extend u16 -> u32
228            let low = _mm_cvtepu16_epi32(vals);
229            let high = _mm_cvtepu16_epi32(_mm_srli_si128(vals, 8));
230
231            // Store 8 u32 values
232            let out_ptr = output.as_mut_ptr().add(base);
233            _mm_storeu_si128(out_ptr as *mut __m128i, low);
234            _mm_storeu_si128(out_ptr.add(4) as *mut __m128i, high);
235        }
236    }
237
238    /// Vectorized unpack for 32-bit values using SSE (fast copy)
239    #[target_feature(enable = "sse2")]
240    pub unsafe fn unpack_block_32_sse(
241        input: &[u8],
242        output: &mut [u32; HORIZONTAL_BP128_BLOCK_SIZE],
243    ) {
244        let in_ptr = input.as_ptr() as *const __m128i;
245        let out_ptr = output.as_mut_ptr() as *mut __m128i;
246
247        // Copy 128 u32 values (4 at a time = 32 iterations)
248        for i in 0..32 {
249            let vals = _mm_loadu_si128(in_ptr.add(i));
250            _mm_storeu_si128(out_ptr.add(i), vals);
251        }
252    }
253
254    /// SIMD prefix sum for delta decoding using SSE
255    #[target_feature(enable = "sse2")]
256    pub unsafe fn delta_decode_block_sse(
257        output: &mut [u32],
258        deltas: &[u32],
259        first_doc_id: u32,
260        count: usize,
261    ) {
262        if count == 0 {
263            return;
264        }
265
266        // Process in groups of 4 for SIMD prefix sum
267        let mut carry = first_doc_id;
268        output[0] = carry;
269
270        let full_groups = (count - 1) / 4;
271        let remainder = (count - 1) % 4;
272
273        let ones = _mm_set1_epi32(1);
274
275        for group in 0..full_groups {
276            let base = group * 4;
277
278            // Load 4 deltas
279            let d = _mm_loadu_si128(deltas[base..].as_ptr() as *const __m128i);
280
281            // Add 1 to each delta (since we store gap-1)
282            let gaps = _mm_add_epi32(d, ones);
283
284            // Extract lanes and compute prefix sum with carry
285            let g0 = _mm_extract_epi32(gaps, 0) as u32;
286            let g1 = _mm_extract_epi32(gaps, 1) as u32;
287            let g2 = _mm_extract_epi32(gaps, 2) as u32;
288            let g3 = _mm_extract_epi32(gaps, 3) as u32;
289
290            let v0 = carry.wrapping_add(g0);
291            let v1 = v0.wrapping_add(g1);
292            let v2 = v1.wrapping_add(g2);
293            let v3 = v2.wrapping_add(g3);
294
295            // Store results
296            output[base + 1] = v0;
297            output[base + 2] = v1;
298            output[base + 3] = v2;
299            output[base + 4] = v3;
300
301            carry = v3;
302        }
303
304        // Handle remainder
305        let base = full_groups * 4;
306        for j in 0..remainder {
307            carry = carry.wrapping_add(deltas[base + j]).wrapping_add(1);
308            output[base + j + 1] = carry;
309        }
310    }
311}
312
313// Scalar fallback implementations (used on non-aarch64 platforms)
314#[allow(dead_code)]
315mod scalar {
316    use super::HORIZONTAL_BP128_BLOCK_SIZE;
317
318    #[inline]
319    pub fn unpack_block_8_scalar(input: &[u8], output: &mut [u32; HORIZONTAL_BP128_BLOCK_SIZE]) {
320        for (i, out) in output.iter_mut().enumerate() {
321            *out = input[i] as u32;
322        }
323    }
324
325    #[inline]
326    pub fn unpack_block_16_scalar(input: &[u8], output: &mut [u32; HORIZONTAL_BP128_BLOCK_SIZE]) {
327        for (i, out) in output.iter_mut().enumerate() {
328            let idx = i * 2;
329            *out = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
330        }
331    }
332
333    #[inline]
334    pub fn unpack_block_32_scalar(input: &[u8], output: &mut [u32; HORIZONTAL_BP128_BLOCK_SIZE]) {
335        for (i, out) in output.iter_mut().enumerate() {
336            let idx = i * 4;
337            *out = u32::from_le_bytes([input[idx], input[idx + 1], input[idx + 2], input[idx + 3]]);
338        }
339    }
340}
341
342/// Block size for bitpacking (128 integers per block for SIMD alignment)
343pub const HORIZONTAL_BP128_BLOCK_SIZE: usize = 128;
344
345/// Small block size for short posting lists (better cache locality)
346pub const SMALL_BLOCK_SIZE: usize = 32;
347
348/// Threshold for using small blocks (posting lists shorter than this use small blocks)
349pub const SMALL_BLOCK_THRESHOLD: usize = 256;
350
351/// Compute the number of bits needed to represent the maximum value
352#[inline]
353pub fn bits_needed(max_val: u32) -> u8 {
354    if max_val == 0 {
355        0
356    } else {
357        32 - max_val.leading_zeros() as u8
358    }
359}
360
361/// Pack a block of 128 u32 values using the specified bit width
362pub fn pack_block(
363    values: &[u32; HORIZONTAL_BP128_BLOCK_SIZE],
364    bit_width: u8,
365    output: &mut Vec<u8>,
366) {
367    if bit_width == 0 {
368        return;
369    }
370
371    let bytes_needed = (HORIZONTAL_BP128_BLOCK_SIZE * bit_width as usize).div_ceil(8);
372    let start = output.len();
373    output.resize(start + bytes_needed, 0);
374
375    let mut bit_pos = 0usize;
376    for &value in values {
377        let byte_idx = start + bit_pos / 8;
378        let bit_offset = bit_pos % 8;
379
380        // Write value across potentially multiple bytes
381        let mut remaining_bits = bit_width as usize;
382        let mut val = value;
383        let mut current_byte_idx = byte_idx;
384        let mut current_bit_offset = bit_offset;
385
386        while remaining_bits > 0 {
387            let bits_in_byte = (8 - current_bit_offset).min(remaining_bits);
388            let mask = ((1u32 << bits_in_byte) - 1) as u8;
389            output[current_byte_idx] |= ((val as u8) & mask) << current_bit_offset;
390            val >>= bits_in_byte;
391            remaining_bits -= bits_in_byte;
392            current_byte_idx += 1;
393            current_bit_offset = 0;
394        }
395
396        bit_pos += bit_width as usize;
397    }
398}
399
400/// Unpack a block of 128 u32 values
401/// Uses SIMD-optimized unpacking for common bit widths on supported architectures
402pub fn unpack_block(input: &[u8], bit_width: u8, output: &mut [u32; HORIZONTAL_BP128_BLOCK_SIZE]) {
403    if bit_width == 0 {
404        output.fill(0);
405        return;
406    }
407
408    // Fast path for byte-aligned bit widths with SIMD
409    match bit_width {
410        8 => unpack_block_8(input, output),
411        16 => unpack_block_16(input, output),
412        32 => unpack_block_32(input, output),
413        _ => unpack_block_generic(input, bit_width, output),
414    }
415}
416
417/// Optimized unpacking for 8-bit values - uses NEON on aarch64, SSE on x86_64
418#[inline]
419fn unpack_block_8(input: &[u8], output: &mut [u32; HORIZONTAL_BP128_BLOCK_SIZE]) {
420    #[cfg(target_arch = "aarch64")]
421    {
422        // SAFETY: NEON is always available on aarch64
423        unsafe { neon::unpack_block_8_neon(input, output) }
424    }
425
426    #[cfg(target_arch = "x86_64")]
427    {
428        // SAFETY: SSE4.1 is available on virtually all x86_64 CPUs (2006+)
429        // Runtime check for older CPUs that lack SSE4.1
430        if is_x86_feature_detected!("sse4.1") {
431            unsafe { sse::unpack_block_8_sse(input, output) }
432        } else {
433            scalar::unpack_block_8_scalar(input, output)
434        }
435    }
436
437    #[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
438    {
439        scalar::unpack_block_8_scalar(input, output)
440    }
441}
442
443/// Optimized unpacking for 16-bit values - uses NEON on aarch64, SSE on x86_64
444#[inline]
445fn unpack_block_16(input: &[u8], output: &mut [u32; HORIZONTAL_BP128_BLOCK_SIZE]) {
446    #[cfg(target_arch = "aarch64")]
447    {
448        // SAFETY: NEON is always available on aarch64
449        unsafe { neon::unpack_block_16_neon(input, output) }
450    }
451
452    #[cfg(target_arch = "x86_64")]
453    {
454        // SAFETY: SSE4.1 is available on virtually all x86_64 CPUs (2006+)
455        if is_x86_feature_detected!("sse4.1") {
456            unsafe { sse::unpack_block_16_sse(input, output) }
457        } else {
458            scalar::unpack_block_16_scalar(input, output)
459        }
460    }
461
462    #[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
463    {
464        scalar::unpack_block_16_scalar(input, output)
465    }
466}
467
468/// Optimized unpacking for 32-bit values - uses NEON on aarch64, SSE on x86_64
469#[inline]
470fn unpack_block_32(input: &[u8], output: &mut [u32; HORIZONTAL_BP128_BLOCK_SIZE]) {
471    #[cfg(target_arch = "aarch64")]
472    {
473        // SAFETY: NEON is always available on aarch64
474        unsafe { neon::unpack_block_32_neon(input, output) }
475    }
476
477    #[cfg(target_arch = "x86_64")]
478    {
479        // SAFETY: SSE2 is always available on x86_64
480        unsafe { sse::unpack_block_32_sse(input, output) }
481    }
482
483    #[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
484    {
485        scalar::unpack_block_32_scalar(input, output)
486    }
487}
488
489/// Generic unpacking for arbitrary bit widths
490/// Optimized: reads 64 bits at a time using unaligned pointer read
491#[inline]
492fn unpack_block_generic(
493    input: &[u8],
494    bit_width: u8,
495    output: &mut [u32; HORIZONTAL_BP128_BLOCK_SIZE],
496) {
497    let mask = (1u64 << bit_width) - 1;
498    let bit_width_usize = bit_width as usize;
499    let mut bit_pos = 0usize;
500
501    // Ensure we have enough padding for the last read
502    // Max bytes needed: (127 * 32 + 32 + 7) / 8 = 516 bytes for 32-bit width
503    // For typical widths (1-20 bits), we need much less
504    let input_ptr = input.as_ptr();
505
506    for out in output.iter_mut() {
507        let byte_idx = bit_pos >> 3; // bit_pos / 8
508        let bit_offset = bit_pos & 7; // bit_pos % 8
509
510        // SAFETY: We read up to 8 bytes. The caller guarantees input has enough data.
511        // For 128 values at max 32 bits = 512 bytes, plus up to 7 bits offset = 513 bytes max.
512        let word = unsafe { (input_ptr.add(byte_idx) as *const u64).read_unaligned() };
513
514        *out = ((word >> bit_offset) & mask) as u32;
515        bit_pos += bit_width_usize;
516    }
517}
518
519/// Unpack a smaller block (for variable block sizes)
520/// Optimized: reads 64 bits at a time using unaligned pointer read
521#[inline]
522pub fn unpack_block_n(input: &[u8], bit_width: u8, output: &mut [u32], n: usize) {
523    if bit_width == 0 {
524        output[..n].fill(0);
525        return;
526    }
527
528    let mask = (1u64 << bit_width) - 1;
529    let bit_width_usize = bit_width as usize;
530    let mut bit_pos = 0usize;
531    let input_ptr = input.as_ptr();
532
533    for out in output[..n].iter_mut() {
534        let byte_idx = bit_pos >> 3;
535        let bit_offset = bit_pos & 7;
536
537        // SAFETY: Caller guarantees input has enough data for n values at bit_width bits each
538        let word = unsafe { (input_ptr.add(byte_idx) as *const u64).read_unaligned() };
539
540        *out = ((word >> bit_offset) & mask) as u32;
541        bit_pos += bit_width_usize;
542    }
543}
544
545/// Binary search within a decoded block to find first element >= target
546/// Returns the index within the block, or block.len() if not found
547#[inline]
548pub fn binary_search_block(block: &[u32], target: u32) -> usize {
549    match block.binary_search(&target) {
550        Ok(idx) => idx,
551        Err(idx) => idx,
552    }
553}
554
555/// Hillis-Steele inclusive prefix sum for 8 elements
556/// Computes: out[i] = sum(input[0..=i])
557/// This is the scalar fallback; SIMD version uses AVX2 intrinsics
558#[allow(dead_code)]
559#[inline]
560fn prefix_sum_8(deltas: &mut [u32; 8]) {
561    // Step 1: shift by 1
562    for i in (1..8).rev() {
563        deltas[i] = deltas[i].wrapping_add(deltas[i - 1]);
564    }
565    // Step 2: shift by 2
566    for i in (2..8).rev() {
567        deltas[i] = deltas[i].wrapping_add(deltas[i - 2]);
568    }
569    // Step 4: shift by 4
570    for i in (4..8).rev() {
571        deltas[i] = deltas[i].wrapping_add(deltas[i - 4]);
572    }
573}
574
575/// Apply prefix sum to convert deltas to absolute doc_ids
576///
577/// Input: deltas array where deltas[i] = doc_id[i+1] - doc_id[i] - 1
578/// Output: absolute doc_ids starting from first_doc_id
579///
580/// Note: This uses a simple sequential algorithm. The Hillis-Steele parallel
581/// prefix sum could be used for SIMD optimization but requires careful handling
582/// of the delta-1 encoding and carry propagation across chunks.
583#[inline]
584pub fn delta_decode_block(output: &mut [u32], deltas: &[u32], first_doc_id: u32, count: usize) {
585    if count == 0 {
586        return;
587    }
588
589    let mut doc_id = first_doc_id;
590    output[0] = doc_id;
591
592    for i in 1..count {
593        // deltas[i-1] stores (gap - 1), so actual gap = deltas[i-1] + 1
594        doc_id = doc_id.wrapping_add(deltas[i - 1]).wrapping_add(1);
595        output[i] = doc_id;
596    }
597}
598
599/// Bitpacked block with skip info for BlockWAND
600#[derive(Debug, Clone)]
601pub struct HorizontalBP128Block {
602    /// Delta-encoded doc_ids (bitpacked)
603    pub doc_deltas: Vec<u8>,
604    /// Bit width for doc deltas
605    pub doc_bit_width: u8,
606    /// Term frequencies (bitpacked)
607    pub term_freqs: Vec<u8>,
608    /// Bit width for term frequencies
609    pub tf_bit_width: u8,
610    /// First doc_id in this block (absolute)
611    pub first_doc_id: u32,
612    /// Last doc_id in this block (absolute)
613    pub last_doc_id: u32,
614    /// Number of docs in this block
615    pub num_docs: u16,
616    /// Maximum term frequency in this block (for BM25F upper bound calculation)
617    pub max_tf: u32,
618    /// Maximum impact score in this block (for MaxScore/WAND)
619    /// This is computed using BM25F with conservative length normalization
620    pub max_block_score: f32,
621}
622
623impl HorizontalBP128Block {
624    /// Serialize the block
625    pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
626        writer.write_u32::<LittleEndian>(self.first_doc_id)?;
627        writer.write_u32::<LittleEndian>(self.last_doc_id)?;
628        writer.write_u16::<LittleEndian>(self.num_docs)?;
629        writer.write_u8(self.doc_bit_width)?;
630        writer.write_u8(self.tf_bit_width)?;
631        writer.write_u32::<LittleEndian>(self.max_tf)?;
632        writer.write_f32::<LittleEndian>(self.max_block_score)?;
633
634        // Write doc deltas
635        writer.write_u16::<LittleEndian>(self.doc_deltas.len() as u16)?;
636        writer.write_all(&self.doc_deltas)?;
637
638        // Write term freqs
639        writer.write_u16::<LittleEndian>(self.term_freqs.len() as u16)?;
640        writer.write_all(&self.term_freqs)?;
641
642        Ok(())
643    }
644
645    /// Deserialize a block
646    pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
647        let first_doc_id = reader.read_u32::<LittleEndian>()?;
648        let last_doc_id = reader.read_u32::<LittleEndian>()?;
649        let num_docs = reader.read_u16::<LittleEndian>()?;
650        let doc_bit_width = reader.read_u8()?;
651        let tf_bit_width = reader.read_u8()?;
652        let max_tf = reader.read_u32::<LittleEndian>()?;
653        let max_block_score = reader.read_f32::<LittleEndian>()?;
654
655        let doc_deltas_len = reader.read_u16::<LittleEndian>()? as usize;
656        let mut doc_deltas = vec![0u8; doc_deltas_len];
657        reader.read_exact(&mut doc_deltas)?;
658
659        let term_freqs_len = reader.read_u16::<LittleEndian>()? as usize;
660        let mut term_freqs = vec![0u8; term_freqs_len];
661        reader.read_exact(&mut term_freqs)?;
662
663        Ok(Self {
664            doc_deltas,
665            doc_bit_width,
666            term_freqs,
667            tf_bit_width,
668            first_doc_id,
669            last_doc_id,
670            num_docs,
671            max_tf,
672            max_block_score,
673        })
674    }
675
676    /// Decode doc_ids from this block
677    pub fn decode_doc_ids(&self) -> Vec<u32> {
678        if self.num_docs == 0 {
679            return Vec::new();
680        }
681
682        let count = self.num_docs as usize;
683        let mut deltas = [0u32; HORIZONTAL_BP128_BLOCK_SIZE];
684        unpack_block(&self.doc_deltas, self.doc_bit_width, &mut deltas);
685
686        let mut output = [0u32; HORIZONTAL_BP128_BLOCK_SIZE];
687        delta_decode_block(&mut output, &deltas, self.first_doc_id, count);
688
689        output[..count].to_vec()
690    }
691
692    /// Decode term frequencies from this block
693    pub fn decode_term_freqs(&self) -> Vec<u32> {
694        if self.num_docs == 0 {
695            return Vec::new();
696        }
697
698        let mut tfs = [0u32; HORIZONTAL_BP128_BLOCK_SIZE];
699        unpack_block(&self.term_freqs, self.tf_bit_width, &mut tfs);
700
701        // TF is stored as tf-1, so add 1 back
702        tfs[..self.num_docs as usize]
703            .iter()
704            .map(|&tf| tf + 1)
705            .collect()
706    }
707}
708
709/// Bitpacked posting list with block-level skip info
710#[derive(Debug, Clone)]
711pub struct HorizontalBP128PostingList {
712    /// Blocks of postings
713    pub blocks: Vec<HorizontalBP128Block>,
714    /// Total document count
715    pub doc_count: u32,
716    /// Maximum score across all blocks (for MaxScore pruning)
717    pub max_score: f32,
718}
719
720impl HorizontalBP128PostingList {
721    /// Create from raw doc_ids and term frequencies
722    pub fn from_postings(doc_ids: &[u32], term_freqs: &[u32], idf: f32) -> Self {
723        assert_eq!(doc_ids.len(), term_freqs.len());
724
725        if doc_ids.is_empty() {
726            return Self {
727                blocks: Vec::new(),
728                doc_count: 0,
729                max_score: 0.0,
730            };
731        }
732
733        let mut blocks = Vec::new();
734        let mut max_score = 0.0f32;
735        let mut i = 0;
736
737        while i < doc_ids.len() {
738            let block_end = (i + HORIZONTAL_BP128_BLOCK_SIZE).min(doc_ids.len());
739            let block_docs = &doc_ids[i..block_end];
740            let block_tfs = &term_freqs[i..block_end];
741
742            let block = Self::create_block(block_docs, block_tfs, idf);
743            max_score = max_score.max(block.max_block_score);
744            blocks.push(block);
745
746            i = block_end;
747        }
748
749        Self {
750            blocks,
751            doc_count: doc_ids.len() as u32,
752            max_score,
753        }
754    }
755
756    /// BM25F parameters for block-max score calculation
757    const K1: f32 = 1.2;
758    const B: f32 = 0.75;
759
760    /// Compute BM25F upper bound score for a given max_tf and IDF
761    /// Uses conservative length normalization (assumes shortest possible document)
762    #[inline]
763    pub fn compute_bm25f_upper_bound(max_tf: u32, idf: f32, field_boost: f32) -> f32 {
764        let tf = max_tf as f32;
765        // Conservative upper bound: assume dl=0, so length_norm = 1 - b = 0.25
766        // This gives the maximum possible score for this tf
767        let min_length_norm = 1.0 - Self::B;
768        let tf_norm =
769            (tf * field_boost * (Self::K1 + 1.0)) / (tf * field_boost + Self::K1 * min_length_norm);
770        idf * tf_norm
771    }
772
773    fn create_block(doc_ids: &[u32], term_freqs: &[u32], idf: f32) -> HorizontalBP128Block {
774        let num_docs = doc_ids.len();
775        let first_doc_id = doc_ids[0];
776        let last_doc_id = *doc_ids.last().unwrap();
777
778        // Compute deltas (delta - 1 to save one bit since deltas are always >= 1)
779        let mut deltas = [0u32; HORIZONTAL_BP128_BLOCK_SIZE];
780        let mut max_delta = 0u32;
781        for j in 1..num_docs {
782            let delta = doc_ids[j] - doc_ids[j - 1] - 1;
783            deltas[j - 1] = delta;
784            max_delta = max_delta.max(delta);
785        }
786
787        // Compute max TF and prepare TF array (store tf-1)
788        let mut tfs = [0u32; HORIZONTAL_BP128_BLOCK_SIZE];
789        let mut max_tf = 0u32;
790
791        for (j, &tf) in term_freqs.iter().enumerate() {
792            tfs[j] = tf - 1; // Store tf-1
793            max_tf = max_tf.max(tf);
794        }
795
796        // BM25F upper bound score using conservative length normalization
797        // field_boost defaults to 1.0 at index time; can be adjusted at query time
798        let max_block_score = Self::compute_bm25f_upper_bound(max_tf, idf, 1.0);
799
800        let doc_bit_width = bits_needed(max_delta);
801        let tf_bit_width = bits_needed(max_tf.saturating_sub(1)); // Store tf-1
802
803        let mut doc_deltas = Vec::new();
804        pack_block(&deltas, doc_bit_width, &mut doc_deltas);
805
806        let mut term_freqs_packed = Vec::new();
807        pack_block(&tfs, tf_bit_width, &mut term_freqs_packed);
808
809        HorizontalBP128Block {
810            doc_deltas,
811            doc_bit_width,
812            term_freqs: term_freqs_packed,
813            tf_bit_width,
814            first_doc_id,
815            last_doc_id,
816            num_docs: num_docs as u16,
817            max_tf,
818            max_block_score,
819        }
820    }
821
822    /// Serialize the posting list
823    pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
824        writer.write_u32::<LittleEndian>(self.doc_count)?;
825        writer.write_f32::<LittleEndian>(self.max_score)?;
826        writer.write_u32::<LittleEndian>(self.blocks.len() as u32)?;
827
828        for block in &self.blocks {
829            block.serialize(writer)?;
830        }
831
832        Ok(())
833    }
834
835    /// Deserialize a posting list
836    pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
837        let doc_count = reader.read_u32::<LittleEndian>()?;
838        let max_score = reader.read_f32::<LittleEndian>()?;
839        let num_blocks = reader.read_u32::<LittleEndian>()? as usize;
840
841        let mut blocks = Vec::with_capacity(num_blocks);
842        for _ in 0..num_blocks {
843            blocks.push(HorizontalBP128Block::deserialize(reader)?);
844        }
845
846        Ok(Self {
847            blocks,
848            doc_count,
849            max_score,
850        })
851    }
852
853    /// Create an iterator
854    pub fn iterator(&self) -> HorizontalBP128Iterator<'_> {
855        HorizontalBP128Iterator::new(self)
856    }
857}
858
859/// Iterator over bitpacked posting list with block skipping support
860pub struct HorizontalBP128Iterator<'a> {
861    posting_list: &'a HorizontalBP128PostingList,
862    /// Current block index
863    current_block: usize,
864    /// Decoded doc_ids for current block
865    block_doc_ids: Vec<u32>,
866    /// Decoded term freqs for current block
867    block_term_freqs: Vec<u32>,
868    /// Position within current block
869    pos_in_block: usize,
870    /// Whether we've exhausted all postings
871    exhausted: bool,
872}
873
874impl<'a> HorizontalBP128Iterator<'a> {
875    pub fn new(posting_list: &'a HorizontalBP128PostingList) -> Self {
876        let mut iter = Self {
877            posting_list,
878            current_block: 0,
879            block_doc_ids: Vec::new(),
880            block_term_freqs: Vec::new(),
881            pos_in_block: 0,
882            exhausted: posting_list.blocks.is_empty(),
883        };
884
885        if !iter.exhausted {
886            iter.decode_current_block();
887        }
888
889        iter
890    }
891
892    fn decode_current_block(&mut self) {
893        let block = &self.posting_list.blocks[self.current_block];
894        self.block_doc_ids = block.decode_doc_ids();
895        self.block_term_freqs = block.decode_term_freqs();
896        self.pos_in_block = 0;
897    }
898
899    /// Current document ID
900    pub fn doc(&self) -> u32 {
901        if self.exhausted {
902            u32::MAX
903        } else {
904            self.block_doc_ids[self.pos_in_block]
905        }
906    }
907
908    /// Current term frequency
909    pub fn term_freq(&self) -> u32 {
910        if self.exhausted {
911            0
912        } else {
913            self.block_term_freqs[self.pos_in_block]
914        }
915    }
916
917    /// Advance to next document
918    pub fn advance(&mut self) -> u32 {
919        if self.exhausted {
920            return u32::MAX;
921        }
922
923        self.pos_in_block += 1;
924
925        if self.pos_in_block >= self.block_doc_ids.len() {
926            self.current_block += 1;
927            if self.current_block >= self.posting_list.blocks.len() {
928                self.exhausted = true;
929                return u32::MAX;
930            }
931            self.decode_current_block();
932        }
933
934        self.doc()
935    }
936
937    /// Seek to first doc >= target (with block skipping and binary search)
938    pub fn seek(&mut self, target: u32) -> u32 {
939        if self.exhausted {
940            return u32::MAX;
941        }
942
943        // Binary search to find the right block
944        let block_idx = self.posting_list.blocks[self.current_block..].binary_search_by(|block| {
945            if block.last_doc_id < target {
946                std::cmp::Ordering::Less
947            } else if block.first_doc_id > target {
948                std::cmp::Ordering::Greater
949            } else {
950                std::cmp::Ordering::Equal
951            }
952        });
953
954        let target_block = match block_idx {
955            Ok(idx) => self.current_block + idx,
956            Err(idx) => {
957                if self.current_block + idx >= self.posting_list.blocks.len() {
958                    self.exhausted = true;
959                    return u32::MAX;
960                }
961                self.current_block + idx
962            }
963        };
964
965        // Move to target block if different
966        if target_block != self.current_block {
967            self.current_block = target_block;
968            self.decode_current_block();
969        } else if self.block_doc_ids.is_empty() {
970            self.decode_current_block();
971        }
972
973        // Binary search within the block
974        let pos = binary_search_block(&self.block_doc_ids[self.pos_in_block..], target);
975        self.pos_in_block += pos;
976
977        if self.pos_in_block >= self.block_doc_ids.len() {
978            // Target not in this block, move to next
979            self.current_block += 1;
980            if self.current_block >= self.posting_list.blocks.len() {
981                self.exhausted = true;
982                return u32::MAX;
983            }
984            self.decode_current_block();
985        }
986
987        self.doc()
988    }
989
990    /// Get max score for remaining blocks (for MaxScore optimization)
991    pub fn max_remaining_score(&self) -> f32 {
992        if self.exhausted {
993            return 0.0;
994        }
995
996        self.posting_list.blocks[self.current_block..]
997            .iter()
998            .map(|b| b.max_block_score)
999            .fold(0.0f32, |a, b| a.max(b))
1000    }
1001
1002    /// Skip to next block (for BlockWAND)
1003    pub fn skip_to_block_with_doc(&mut self, target: u32) -> Option<(u32, f32)> {
1004        while self.current_block < self.posting_list.blocks.len() {
1005            let block = &self.posting_list.blocks[self.current_block];
1006            if block.last_doc_id >= target {
1007                return Some((block.first_doc_id, block.max_block_score));
1008            }
1009            self.current_block += 1;
1010        }
1011        self.exhausted = true;
1012        None
1013    }
1014
1015    /// Get current block's max score
1016    pub fn current_block_max_score(&self) -> f32 {
1017        if self.exhausted {
1018            0.0
1019        } else {
1020            self.posting_list.blocks[self.current_block].max_block_score
1021        }
1022    }
1023
1024    /// Get current block's max term frequency (for BM25F upper bound recalculation)
1025    pub fn current_block_max_tf(&self) -> u32 {
1026        if self.exhausted {
1027            0
1028        } else {
1029            self.posting_list.blocks[self.current_block].max_tf
1030        }
1031    }
1032}
1033
1034#[cfg(test)]
1035mod tests {
1036    use super::*;
1037
1038    #[test]
1039    fn test_bits_needed() {
1040        assert_eq!(bits_needed(0), 0);
1041        assert_eq!(bits_needed(1), 1);
1042        assert_eq!(bits_needed(2), 2);
1043        assert_eq!(bits_needed(3), 2);
1044        assert_eq!(bits_needed(255), 8);
1045        assert_eq!(bits_needed(256), 9);
1046    }
1047
1048    #[test]
1049    fn test_pack_unpack() {
1050        let mut values = [0u32; HORIZONTAL_BP128_BLOCK_SIZE];
1051        for (i, value) in values.iter_mut().enumerate() {
1052            *value = (i * 3) as u32;
1053        }
1054
1055        let max_val = values.iter().max().copied().unwrap();
1056        let bit_width = bits_needed(max_val);
1057
1058        let mut packed = Vec::new();
1059        pack_block(&values, bit_width, &mut packed);
1060
1061        let mut unpacked = [0u32; HORIZONTAL_BP128_BLOCK_SIZE];
1062        unpack_block(&packed, bit_width, &mut unpacked);
1063
1064        assert_eq!(values, unpacked);
1065    }
1066
1067    #[test]
1068    fn test_bitpacked_posting_list() {
1069        let doc_ids: Vec<u32> = (0..200).map(|i| i * 2).collect();
1070        let term_freqs: Vec<u32> = (0..200).map(|i| (i % 10) + 1).collect();
1071
1072        let posting_list = HorizontalBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
1073
1074        assert_eq!(posting_list.doc_count, 200);
1075        assert_eq!(posting_list.blocks.len(), 2); // 128 + 72
1076
1077        // Test iteration
1078        let mut iter = posting_list.iterator();
1079        for (i, &expected_doc) in doc_ids.iter().enumerate() {
1080            assert_eq!(iter.doc(), expected_doc, "Mismatch at position {}", i);
1081            assert_eq!(iter.term_freq(), term_freqs[i]);
1082            if i < doc_ids.len() - 1 {
1083                iter.advance();
1084            }
1085        }
1086    }
1087
1088    #[test]
1089    fn test_bitpacked_seek() {
1090        let doc_ids: Vec<u32> = vec![10, 20, 30, 100, 200, 300, 1000, 2000];
1091        let term_freqs: Vec<u32> = vec![1, 2, 3, 4, 5, 6, 7, 8];
1092
1093        let posting_list = HorizontalBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
1094        let mut iter = posting_list.iterator();
1095
1096        assert_eq!(iter.seek(25), 30);
1097        assert_eq!(iter.seek(100), 100);
1098        assert_eq!(iter.seek(500), 1000);
1099        assert_eq!(iter.seek(3000), u32::MAX);
1100    }
1101
1102    #[test]
1103    fn test_serialization() {
1104        let doc_ids: Vec<u32> = (0..50).map(|i| i * 3).collect();
1105        let term_freqs: Vec<u32> = (0..50).map(|_| 1).collect();
1106
1107        let posting_list = HorizontalBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.5);
1108
1109        let mut buffer = Vec::new();
1110        posting_list.serialize(&mut buffer).unwrap();
1111
1112        let restored = HorizontalBP128PostingList::deserialize(&mut &buffer[..]).unwrap();
1113
1114        assert_eq!(restored.doc_count, posting_list.doc_count);
1115        assert_eq!(restored.blocks.len(), posting_list.blocks.len());
1116
1117        // Verify iteration produces same results
1118        let mut iter1 = posting_list.iterator();
1119        let mut iter2 = restored.iterator();
1120
1121        while iter1.doc() != u32::MAX {
1122            assert_eq!(iter1.doc(), iter2.doc());
1123            assert_eq!(iter1.term_freq(), iter2.term_freq());
1124            iter1.advance();
1125            iter2.advance();
1126        }
1127    }
1128
1129    #[test]
1130    fn test_hillis_steele_prefix_sum() {
1131        // Test the prefix_sum_8 function directly
1132        let mut deltas = [1u32, 2, 3, 4, 5, 6, 7, 8];
1133        prefix_sum_8(&mut deltas);
1134        // Expected: [1, 1+2, 1+2+3, 1+2+3+4, ...]
1135        assert_eq!(deltas, [1, 3, 6, 10, 15, 21, 28, 36]);
1136
1137        // Test delta_decode_block
1138        let deltas2 = [0u32; 16]; // gaps of 1 (stored as 0)
1139        let mut output2 = [0u32; 16];
1140        delta_decode_block(&mut output2, &deltas2, 100, 8);
1141        // first_doc_id=100, then +1 each
1142        assert_eq!(&output2[..8], &[100, 101, 102, 103, 104, 105, 106, 107]);
1143
1144        // Test with varying deltas (stored as gap-1)
1145        // gaps: 2, 1, 3, 1, 5, 1, 1 → stored as: 1, 0, 2, 0, 4, 0, 0
1146        let deltas3 = [1u32, 0, 2, 0, 4, 0, 0, 0];
1147        let mut output3 = [0u32; 8];
1148        delta_decode_block(&mut output3, &deltas3, 10, 8);
1149        // 10, 10+2=12, 12+1=13, 13+3=16, 16+1=17, 17+5=22, 22+1=23, 23+1=24
1150        assert_eq!(&output3[..8], &[10, 12, 13, 16, 17, 22, 23, 24]);
1151    }
1152
1153    #[test]
1154    fn test_delta_decode_large_block() {
1155        // Test with a full 128-element block
1156        let doc_ids: Vec<u32> = (0..128).map(|i| i * 5 + 100).collect();
1157        let term_freqs: Vec<u32> = vec![1; 128];
1158
1159        let posting_list = HorizontalBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
1160        let decoded = posting_list.blocks[0].decode_doc_ids();
1161
1162        assert_eq!(decoded.len(), 128);
1163        for (i, (&expected, &actual)) in doc_ids.iter().zip(decoded.iter()).enumerate() {
1164            assert_eq!(expected, actual, "Mismatch at position {}", i);
1165        }
1166    }
1167}