Skip to main content

hermes_core/structures/
simd_bp128.rs

1//! SIMD-BP128: Vectorized bitpacking with NEON/SSE intrinsics
2//!
3//! Based on Lemire & Boytsov (2015) "Decoding billions of integers per second through vectorization"
4//! and Quickwit's bitpacking crate architecture.
5//!
6//! Key optimizations:
7//! - **True vertical layout**: Optimal compression (BLOCK_SIZE * bit_width / 8 bytes)
8//! - **Integrated delta decoding**: Fused unpack + prefix sum in single pass
9//! - **128-integer blocks**: 32 groups of 4 integers each
10//! - **NEON intrinsics on ARM**: Uses vld1q_u32, vaddq_u32, etc.
11//! - **Block-level metadata**: Skip info for BlockMax WAND
12
13use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
14use std::io::{self, Read, Write};
15
16/// Block size: 128 integers (32 groups of 4 for SIMD lanes)
17pub const SIMD_BLOCK_SIZE: usize = 128;
18
19/// Number of 32-bit lanes in NEON/SSE (4 x 32-bit = 128-bit)
20#[allow(dead_code)]
21const SIMD_LANES: usize = 4;
22
23/// Number of groups per block (128 / 4 = 32)
24#[allow(dead_code)]
25const GROUPS_PER_BLOCK: usize = SIMD_BLOCK_SIZE / SIMD_LANES;
26
27/// Compute bits needed for max value
28#[inline]
29pub fn bits_needed(max_val: u32) -> u8 {
30    if max_val == 0 {
31        0
32    } else {
33        32 - max_val.leading_zeros() as u8
34    }
35}
36
37// ============================================================================
38// NEON intrinsics for aarch64 (Apple Silicon, ARM servers)
39// ============================================================================
40
41#[cfg(target_arch = "aarch64")]
42#[allow(dead_code)]
43mod neon {
44    use super::*;
45    use std::arch::aarch64::*;
46
47    /// Lookup table for expanding a byte to 8 u32 values (one per bit)
48    /// LUT[byte][bit_position] = (byte >> bit_position) & 1
49    /// We use a different approach: expand byte to 8 separate bit values
50    static BIT_EXPAND_LUT: [[u32; 8]; 256] = {
51        let mut lut = [[0u32; 8]; 256];
52        let mut byte = 0usize;
53        while byte < 256 {
54            let mut bit = 0;
55            while bit < 8 {
56                lut[byte][bit] = ((byte >> bit) & 1) as u32;
57                bit += 1;
58            }
59            byte += 1;
60        }
61        lut
62    };
63
64    /// Unpack 4 u32 values from packed data using NEON
65    #[inline]
66    #[target_feature(enable = "neon")]
67    pub unsafe fn unpack_4_neon(input: &[u8], bit_width: u8, output: &mut [u32; 4]) {
68        if bit_width == 0 {
69            *output = [0; 4];
70            return;
71        }
72
73        let mask = (1u32 << bit_width) - 1;
74
75        // Load packed data
76        let mut packed_bytes = [0u8; 16];
77        let bytes_needed = ((bit_width as usize) * 4).div_ceil(8);
78        packed_bytes[..bytes_needed.min(16)].copy_from_slice(&input[..bytes_needed.min(16)]);
79        let packed = u128::from_le_bytes(packed_bytes);
80
81        // Extract 4 values
82        let v0 = (packed & mask as u128) as u32;
83        let v1 = ((packed >> bit_width) & mask as u128) as u32;
84        let v2 = ((packed >> (bit_width * 2)) & mask as u128) as u32;
85        let v3 = ((packed >> (bit_width * 3)) & mask as u128) as u32;
86
87        // Store using NEON
88        unsafe {
89            let result = vld1q_u32([v0, v1, v2, v3].as_ptr());
90            vst1q_u32(output.as_mut_ptr(), result);
91        }
92    }
93
94    /// SIMD prefix sum for 4 elements using NEON
95    #[inline]
96    #[target_feature(enable = "neon")]
97    pub unsafe fn prefix_sum_4_neon(values: &mut [u32; 4]) {
98        unsafe {
99            // Load values
100            let mut v = vld1q_u32(values.as_ptr());
101
102            // Prefix sum using NEON shuffles and adds
103            // v = [a, b, c, d]
104            // Step 1: v = [a, a+b, c, c+d]
105            let shifted1 = vextq_u32(vdupq_n_u32(0), v, 3); // [0, a, b, c]
106            v = vaddq_u32(v, shifted1);
107            // v = [a, a+b, b+c, c+d]
108
109            // Step 2: v = [a, a+b, a+b+c, a+b+c+d]
110            let shifted2 = vextq_u32(vdupq_n_u32(0), v, 2); // [0, 0, a, a+b]
111            v = vaddq_u32(v, shifted2);
112
113            // Store result
114            vst1q_u32(values.as_mut_ptr(), v);
115        }
116    }
117
118    /// Unpack 128 integers from true vertical layout using NEON (optimized)
119    ///
120    /// Optimizations:
121    /// 1. Lookup table for bit extraction (avoids per-bit shifts)
122    /// 2. Process 4 bytes at once (32 integers per iteration)
123    /// 3. Prefetch next bit position's data
124    #[target_feature(enable = "neon")]
125    pub unsafe fn unpack_block_neon(
126        input: &[u8],
127        bit_width: u8,
128        output: &mut [u32; SIMD_BLOCK_SIZE],
129    ) {
130        if bit_width == 0 {
131            output.fill(0);
132            return;
133        }
134
135        // Clear output using NEON
136        unsafe {
137            let zero = vdupq_n_u32(0);
138            for i in (0..SIMD_BLOCK_SIZE).step_by(4) {
139                vst1q_u32(output[i..].as_mut_ptr(), zero);
140            }
141        }
142
143        // For each bit position, scatter that bit to all 128 integers
144        for bit_pos in 0..bit_width as usize {
145            let byte_offset = bit_pos * 16;
146            let bit_mask = 1u32 << bit_pos;
147
148            // Prefetch next bit position's data (if not last)
149            if bit_pos + 1 < bit_width as usize {
150                let next_offset = (bit_pos + 1) * 16;
151                unsafe {
152                    // Use inline asm for prefetch on aarch64
153                    std::arch::asm!(
154                        "prfm pldl1keep, [{0}]",
155                        in(reg) input.as_ptr().add(next_offset),
156                        options(nostack, preserves_flags)
157                    );
158                }
159            }
160
161            // Process 4 bytes at a time (32 integers)
162            for chunk in 0..4 {
163                let chunk_offset = byte_offset + chunk * 4;
164
165                // Load 4 bytes at once
166                let b0 = input[chunk_offset] as usize;
167                let b1 = input[chunk_offset + 1] as usize;
168                let b2 = input[chunk_offset + 2] as usize;
169                let b3 = input[chunk_offset + 3] as usize;
170
171                let base_int = chunk * 32;
172
173                unsafe {
174                    let mask_vec = vdupq_n_u32(bit_mask);
175
176                    // Process byte 0 (integers 0-7)
177                    let lut0 = &BIT_EXPAND_LUT[b0];
178                    let bits_0_3 = vld1q_u32(lut0.as_ptr());
179                    let bits_4_7 = vld1q_u32(lut0[4..].as_ptr());
180
181                    let shifted_0_3 = vmulq_u32(bits_0_3, mask_vec);
182                    let shifted_4_7 = vmulq_u32(bits_4_7, mask_vec);
183
184                    let cur_0_3 = vld1q_u32(output[base_int..].as_ptr());
185                    let cur_4_7 = vld1q_u32(output[base_int + 4..].as_ptr());
186
187                    vst1q_u32(
188                        output[base_int..].as_mut_ptr(),
189                        vorrq_u32(cur_0_3, shifted_0_3),
190                    );
191                    vst1q_u32(
192                        output[base_int + 4..].as_mut_ptr(),
193                        vorrq_u32(cur_4_7, shifted_4_7),
194                    );
195
196                    // Process byte 1 (integers 8-15)
197                    let lut1 = &BIT_EXPAND_LUT[b1];
198                    let bits_8_11 = vld1q_u32(lut1.as_ptr());
199                    let bits_12_15 = vld1q_u32(lut1[4..].as_ptr());
200
201                    let shifted_8_11 = vmulq_u32(bits_8_11, mask_vec);
202                    let shifted_12_15 = vmulq_u32(bits_12_15, mask_vec);
203
204                    let cur_8_11 = vld1q_u32(output[base_int + 8..].as_ptr());
205                    let cur_12_15 = vld1q_u32(output[base_int + 12..].as_ptr());
206
207                    vst1q_u32(
208                        output[base_int + 8..].as_mut_ptr(),
209                        vorrq_u32(cur_8_11, shifted_8_11),
210                    );
211                    vst1q_u32(
212                        output[base_int + 12..].as_mut_ptr(),
213                        vorrq_u32(cur_12_15, shifted_12_15),
214                    );
215
216                    // Process byte 2 (integers 16-23)
217                    let lut2 = &BIT_EXPAND_LUT[b2];
218                    let bits_16_19 = vld1q_u32(lut2.as_ptr());
219                    let bits_20_23 = vld1q_u32(lut2[4..].as_ptr());
220
221                    let shifted_16_19 = vmulq_u32(bits_16_19, mask_vec);
222                    let shifted_20_23 = vmulq_u32(bits_20_23, mask_vec);
223
224                    let cur_16_19 = vld1q_u32(output[base_int + 16..].as_ptr());
225                    let cur_20_23 = vld1q_u32(output[base_int + 20..].as_ptr());
226
227                    vst1q_u32(
228                        output[base_int + 16..].as_mut_ptr(),
229                        vorrq_u32(cur_16_19, shifted_16_19),
230                    );
231                    vst1q_u32(
232                        output[base_int + 20..].as_mut_ptr(),
233                        vorrq_u32(cur_20_23, shifted_20_23),
234                    );
235
236                    // Process byte 3 (integers 24-31)
237                    let lut3 = &BIT_EXPAND_LUT[b3];
238                    let bits_24_27 = vld1q_u32(lut3.as_ptr());
239                    let bits_28_31 = vld1q_u32(lut3[4..].as_ptr());
240
241                    let shifted_24_27 = vmulq_u32(bits_24_27, mask_vec);
242                    let shifted_28_31 = vmulq_u32(bits_28_31, mask_vec);
243
244                    let cur_24_27 = vld1q_u32(output[base_int + 24..].as_ptr());
245                    let cur_28_31 = vld1q_u32(output[base_int + 28..].as_ptr());
246
247                    vst1q_u32(
248                        output[base_int + 24..].as_mut_ptr(),
249                        vorrq_u32(cur_24_27, shifted_24_27),
250                    );
251                    vst1q_u32(
252                        output[base_int + 28..].as_mut_ptr(),
253                        vorrq_u32(cur_28_31, shifted_28_31),
254                    );
255                }
256            }
257        }
258    }
259
260    /// Prefix sum for 128 elements using NEON
261    #[target_feature(enable = "neon")]
262    pub unsafe fn prefix_sum_block_neon(deltas: &mut [u32; SIMD_BLOCK_SIZE], first_val: u32) {
263        let mut carry = first_val;
264
265        for group in 0..GROUPS_PER_BLOCK {
266            let start = group * SIMD_LANES;
267            let mut group_vals = [
268                deltas[start],
269                deltas[start + 1],
270                deltas[start + 2],
271                deltas[start + 3],
272            ];
273
274            // Add carry to first element
275            group_vals[0] = group_vals[0].wrapping_add(carry);
276
277            // SIMD prefix sum
278            unsafe { prefix_sum_4_neon(&mut group_vals) };
279
280            // Write back
281            deltas[start..start + 4].copy_from_slice(&group_vals);
282
283            // Carry for next group
284            carry = group_vals[3];
285        }
286    }
287}
288
289// ============================================================================
290// Scalar fallback for other architectures
291// ============================================================================
292
293#[allow(dead_code)]
294mod scalar {
295    use super::*;
296
297    /// Pack 4 u32 values into output
298    #[inline]
299    pub fn pack_4_scalar(values: &[u32; 4], bit_width: u8, output: &mut [u8]) {
300        if bit_width == 0 {
301            return;
302        }
303
304        let bytes_needed = ((bit_width as usize) * 4).div_ceil(8);
305        let mut packed = 0u128;
306        for (i, &val) in values.iter().enumerate() {
307            packed |= (val as u128) << (i * bit_width as usize);
308        }
309
310        let packed_bytes = packed.to_le_bytes();
311        output[..bytes_needed].copy_from_slice(&packed_bytes[..bytes_needed]);
312    }
313
314    /// Unpack 4 u32 values from packed data
315    #[inline]
316    pub fn unpack_4_scalar(input: &[u8], bit_width: u8, output: &mut [u32; 4]) {
317        if bit_width == 0 {
318            *output = [0; 4];
319            return;
320        }
321
322        let mask = (1u32 << bit_width) - 1;
323        let mut packed_bytes = [0u8; 16];
324        let bytes_needed = ((bit_width as usize) * 4).div_ceil(8);
325        packed_bytes[..bytes_needed.min(16)].copy_from_slice(&input[..bytes_needed.min(16)]);
326        let packed = u128::from_le_bytes(packed_bytes);
327
328        output[0] = (packed & mask as u128) as u32;
329        output[1] = ((packed >> bit_width) & mask as u128) as u32;
330        output[2] = ((packed >> (bit_width * 2)) & mask as u128) as u32;
331        output[3] = ((packed >> (bit_width * 3)) & mask as u128) as u32;
332    }
333
334    /// Prefix sum for 4 elements
335    #[inline]
336    pub fn prefix_sum_4_scalar(vals: &mut [u32; 4]) {
337        vals[1] = vals[1].wrapping_add(vals[0]);
338        vals[2] = vals[2].wrapping_add(vals[1]);
339        vals[3] = vals[3].wrapping_add(vals[2]);
340    }
341
342    /// Unpack 128 integers from true vertical layout
343    pub fn unpack_block_scalar(input: &[u8], bit_width: u8, output: &mut [u32; SIMD_BLOCK_SIZE]) {
344        if bit_width == 0 {
345            output.fill(0);
346            return;
347        }
348
349        // Clear output first
350        output.fill(0);
351
352        // Unpack from vertical bit-interleaved layout
353        for bit_pos in 0..bit_width as usize {
354            let byte_offset = bit_pos * 16; // 128/8 = 16 bytes per bit position
355
356            for byte_idx in 0..16 {
357                let byte_val = input[byte_offset + byte_idx];
358                let base_int = byte_idx * 8;
359
360                // Extract 8 bits from this byte
361                output[base_int] |= (byte_val & 1) as u32 * (1 << bit_pos);
362                output[base_int + 1] |= ((byte_val >> 1) & 1) as u32 * (1 << bit_pos);
363                output[base_int + 2] |= ((byte_val >> 2) & 1) as u32 * (1 << bit_pos);
364                output[base_int + 3] |= ((byte_val >> 3) & 1) as u32 * (1 << bit_pos);
365                output[base_int + 4] |= ((byte_val >> 4) & 1) as u32 * (1 << bit_pos);
366                output[base_int + 5] |= ((byte_val >> 5) & 1) as u32 * (1 << bit_pos);
367                output[base_int + 6] |= ((byte_val >> 6) & 1) as u32 * (1 << bit_pos);
368                output[base_int + 7] |= ((byte_val >> 7) & 1) as u32 * (1 << bit_pos);
369            }
370        }
371    }
372
373    /// Prefix sum for 128 elements
374    pub fn prefix_sum_block_scalar(deltas: &mut [u32; SIMD_BLOCK_SIZE], first_val: u32) {
375        let mut carry = first_val;
376
377        for group in 0..GROUPS_PER_BLOCK {
378            let start = group * SIMD_LANES;
379            let mut group_vals = [
380                deltas[start],
381                deltas[start + 1],
382                deltas[start + 2],
383                deltas[start + 3],
384            ];
385
386            group_vals[0] = group_vals[0].wrapping_add(carry);
387            prefix_sum_4_scalar(&mut group_vals);
388            deltas[start..start + 4].copy_from_slice(&group_vals);
389            carry = group_vals[3];
390        }
391    }
392}
393
394// ============================================================================
395// Public API - dispatches to NEON or scalar
396// ============================================================================
397
398/// Pack 128 integers using true vertical layout (optimal compression)
399///
400/// Vertical layout stores bit i of all 128 integers together.
401/// Total size: exactly BLOCK_SIZE * bit_width / 8 bytes (no padding waste)
402pub fn pack_horizontal(values: &[u32; SIMD_BLOCK_SIZE], bit_width: u8, output: &mut Vec<u8>) {
403    if bit_width == 0 {
404        return;
405    }
406
407    // True vertical layout: exactly (128 * bit_width) / 8 bytes
408    let total_bytes = (SIMD_BLOCK_SIZE * bit_width as usize) / 8;
409    let start = output.len();
410    output.resize(start + total_bytes, 0);
411
412    // Pack using vertical bit-interleaved layout
413    // For each bit position, pack that bit from all 128 integers
414    for bit_pos in 0..bit_width as usize {
415        let byte_offset = start + bit_pos * (SIMD_BLOCK_SIZE / 8);
416        for (int_idx, &val) in values.iter().enumerate() {
417            let bit = (val >> bit_pos) & 1;
418            let byte_idx = byte_offset + int_idx / 8;
419            let bit_in_byte = int_idx % 8;
420            output[byte_idx] |= (bit as u8) << bit_in_byte;
421        }
422    }
423}
424
425/// Unpack 128 integers from true vertical layout (optimized)
426///
427/// This is the inverse of pack_horizontal - extracts bits from vertical layout.
428/// Optimized to process 8 integers per byte using lookup tables.
429pub fn unpack_horizontal(input: &[u8], bit_width: u8, output: &mut [u32; SIMD_BLOCK_SIZE]) {
430    if bit_width == 0 {
431        output.fill(0);
432        return;
433    }
434
435    // Clear output first
436    output.fill(0);
437
438    // Process 8 integers at a time (one byte contains bit i of 8 consecutive integers)
439    // For each bit position, we have 16 bytes (128 integers / 8 = 16 bytes)
440    for bit_pos in 0..bit_width as usize {
441        let byte_offset = bit_pos * 16; // 128/8 = 16 bytes per bit position
442
443        // Process 16 bytes (128 integers) for this bit position
444        for byte_idx in 0..16 {
445            let byte_val = input[byte_offset + byte_idx];
446            let base_int = byte_idx * 8;
447
448            // Unroll: extract 8 bits from this byte
449            output[base_int] |= (byte_val & 1) as u32 * (1 << bit_pos);
450            output[base_int + 1] |= ((byte_val >> 1) & 1) as u32 * (1 << bit_pos);
451            output[base_int + 2] |= ((byte_val >> 2) & 1) as u32 * (1 << bit_pos);
452            output[base_int + 3] |= ((byte_val >> 3) & 1) as u32 * (1 << bit_pos);
453            output[base_int + 4] |= ((byte_val >> 4) & 1) as u32 * (1 << bit_pos);
454            output[base_int + 5] |= ((byte_val >> 5) & 1) as u32 * (1 << bit_pos);
455            output[base_int + 6] |= ((byte_val >> 6) & 1) as u32 * (1 << bit_pos);
456            output[base_int + 7] |= ((byte_val >> 7) & 1) as u32 * (1 << bit_pos);
457        }
458    }
459}
460
461/// Prefix sum for 128 elements - uses NEON on aarch64
462#[allow(dead_code)]
463pub fn prefix_sum_128(deltas: &mut [u32; SIMD_BLOCK_SIZE], first_val: u32) {
464    #[cfg(target_arch = "aarch64")]
465    {
466        unsafe { neon::prefix_sum_block_neon(deltas, first_val) }
467    }
468
469    #[cfg(not(target_arch = "aarch64"))]
470    {
471        scalar::prefix_sum_block_scalar(deltas, first_val)
472    }
473}
474
475// Keep old names for compatibility
476pub fn pack_vertical(values: &[u32; SIMD_BLOCK_SIZE], bit_width: u8, output: &mut Vec<u8>) {
477    pack_horizontal(values, bit_width, output)
478}
479
480pub fn unpack_vertical(input: &[u8], bit_width: u8, output: &mut [u32; SIMD_BLOCK_SIZE]) {
481    unpack_horizontal(input, bit_width, output)
482}
483
484/// Unpack with integrated delta decoding (fused for better performance)
485///
486/// The encoding stores deltas[i] = doc_ids[i+1] - doc_ids[i] - 1
487/// So we have (count-1) deltas for count doc_ids.
488/// first_doc_id is doc_ids[0], and we compute the rest from deltas.
489///
490/// This fused version avoids a separate prefix sum pass by computing
491/// doc_ids inline during unpacking.
492pub fn unpack_vertical_d1(
493    input: &[u8],
494    bit_width: u8,
495    first_doc_id: u32,
496    output: &mut [u32; SIMD_BLOCK_SIZE],
497    count: usize,
498) {
499    if count == 0 {
500        return;
501    }
502
503    if bit_width == 0 {
504        // All deltas are 0, so gaps are all 1
505        let mut current = first_doc_id;
506        output[0] = current;
507        for out_val in output.iter_mut().take(count).skip(1) {
508            current = current.wrapping_add(1);
509            *out_val = current;
510        }
511        return;
512    }
513
514    // Fused unpack + prefix sum: compute doc_ids inline
515    output[0] = first_doc_id;
516    let mut current = first_doc_id;
517
518    // Process in groups of 4 for better cache locality
519    let full_groups = (count - 1) / 4;
520    let remainder = (count - 1) % 4;
521
522    for group in 0..full_groups {
523        let base_idx = group * 4;
524
525        // Extract 4 deltas from vertical layout
526        let mut deltas = [0u32; 4];
527        for bit_pos in 0..bit_width as usize {
528            let byte_offset = bit_pos * (SIMD_BLOCK_SIZE / 8);
529            for (j, delta) in deltas.iter_mut().enumerate() {
530                let int_idx = base_idx + j;
531                let byte_idx = byte_offset + int_idx / 8;
532                let bit_in_byte = int_idx % 8;
533                let bit = ((input[byte_idx] >> bit_in_byte) & 1) as u32;
534                *delta |= bit << bit_pos;
535            }
536        }
537
538        // Apply prefix sum inline
539        for j in 0..4 {
540            current = current.wrapping_add(deltas[j]).wrapping_add(1);
541            output[base_idx + j + 1] = current;
542        }
543    }
544
545    // Handle remainder
546    let base_idx = full_groups * 4;
547    for j in 0..remainder {
548        let int_idx = base_idx + j;
549        let mut delta = 0u32;
550        for bit_pos in 0..bit_width as usize {
551            let byte_offset = bit_pos * (SIMD_BLOCK_SIZE / 8);
552            let byte_idx = byte_offset + int_idx / 8;
553            let bit_in_byte = int_idx % 8;
554            let bit = ((input[byte_idx] >> bit_in_byte) & 1) as u32;
555            delta |= bit << bit_pos;
556        }
557        current = current.wrapping_add(delta).wrapping_add(1);
558        output[base_idx + j + 1] = current;
559    }
560}
561
562/// A single SIMD-BP128 block with metadata
563#[derive(Debug, Clone)]
564pub struct SimdBp128Block {
565    /// Vertically-packed delta-encoded doc_ids
566    pub doc_data: Vec<u8>,
567    /// Bit width for doc deltas
568    pub doc_bit_width: u8,
569    /// Vertically-packed term frequencies (tf - 1)
570    pub tf_data: Vec<u8>,
571    /// Bit width for term frequencies
572    pub tf_bit_width: u8,
573    /// First doc_id in block (absolute)
574    pub first_doc_id: u32,
575    /// Last doc_id in block (absolute)
576    pub last_doc_id: u32,
577    /// Number of docs in this block
578    pub num_docs: u16,
579    /// Maximum term frequency in block
580    pub max_tf: u32,
581    /// Maximum BM25 score upper bound for BlockMax WAND
582    pub max_block_score: f32,
583}
584
585impl SimdBp128Block {
586    /// Serialize block
587    pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
588        writer.write_u32::<LittleEndian>(self.first_doc_id)?;
589        writer.write_u32::<LittleEndian>(self.last_doc_id)?;
590        writer.write_u16::<LittleEndian>(self.num_docs)?;
591        writer.write_u8(self.doc_bit_width)?;
592        writer.write_u8(self.tf_bit_width)?;
593        writer.write_u32::<LittleEndian>(self.max_tf)?;
594        writer.write_f32::<LittleEndian>(self.max_block_score)?;
595
596        writer.write_u16::<LittleEndian>(self.doc_data.len() as u16)?;
597        writer.write_all(&self.doc_data)?;
598
599        writer.write_u16::<LittleEndian>(self.tf_data.len() as u16)?;
600        writer.write_all(&self.tf_data)?;
601
602        Ok(())
603    }
604
605    /// Deserialize block
606    pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
607        let first_doc_id = reader.read_u32::<LittleEndian>()?;
608        let last_doc_id = reader.read_u32::<LittleEndian>()?;
609        let num_docs = reader.read_u16::<LittleEndian>()?;
610        let doc_bit_width = reader.read_u8()?;
611        let tf_bit_width = reader.read_u8()?;
612        let max_tf = reader.read_u32::<LittleEndian>()?;
613        let max_block_score = reader.read_f32::<LittleEndian>()?;
614
615        let doc_len = reader.read_u16::<LittleEndian>()? as usize;
616        let mut doc_data = vec![0u8; doc_len];
617        reader.read_exact(&mut doc_data)?;
618
619        let tf_len = reader.read_u16::<LittleEndian>()? as usize;
620        let mut tf_data = vec![0u8; tf_len];
621        reader.read_exact(&mut tf_data)?;
622
623        Ok(Self {
624            doc_data,
625            doc_bit_width,
626            tf_data,
627            tf_bit_width,
628            first_doc_id,
629            last_doc_id,
630            num_docs,
631            max_tf,
632            max_block_score,
633        })
634    }
635
636    /// Decode doc_ids from this block
637    pub fn decode_doc_ids(&self) -> Vec<u32> {
638        if self.num_docs == 0 {
639            return Vec::new();
640        }
641
642        let mut output = [0u32; SIMD_BLOCK_SIZE];
643        unpack_vertical_d1(
644            &self.doc_data,
645            self.doc_bit_width,
646            self.first_doc_id,
647            &mut output,
648            self.num_docs as usize,
649        );
650
651        output[..self.num_docs as usize].to_vec()
652    }
653
654    /// Decode term frequencies from this block
655    pub fn decode_term_freqs(&self) -> Vec<u32> {
656        if self.num_docs == 0 {
657            return Vec::new();
658        }
659
660        let mut output = [0u32; SIMD_BLOCK_SIZE];
661        unpack_vertical(&self.tf_data, self.tf_bit_width, &mut output);
662
663        // TF is stored as tf-1, add 1 back
664        output[..self.num_docs as usize]
665            .iter()
666            .map(|&tf| tf + 1)
667            .collect()
668    }
669}
670
671/// SIMD-BP128 posting list with vertical layout and BlockMax support
672#[derive(Debug, Clone)]
673pub struct SimdBp128PostingList {
674    /// Blocks of postings
675    pub blocks: Vec<SimdBp128Block>,
676    /// Total document count
677    pub doc_count: u32,
678    /// Maximum score across all blocks
679    pub max_score: f32,
680}
681
682impl SimdBp128PostingList {
683    /// BM25 parameters
684    const K1: f32 = 1.2;
685    const B: f32 = 0.75;
686
687    /// Compute BM25 upper bound score
688    #[inline]
689    pub fn compute_bm25_upper_bound(max_tf: u32, idf: f32) -> f32 {
690        let tf = max_tf as f32;
691        let min_length_norm = 1.0 - Self::B;
692        let tf_norm = (tf * (Self::K1 + 1.0)) / (tf + Self::K1 * min_length_norm);
693        idf * tf_norm
694    }
695
696    /// Create from raw postings
697    pub fn from_postings(doc_ids: &[u32], term_freqs: &[u32], idf: f32) -> Self {
698        assert_eq!(doc_ids.len(), term_freqs.len());
699
700        if doc_ids.is_empty() {
701            return Self {
702                blocks: Vec::new(),
703                doc_count: 0,
704                max_score: 0.0,
705            };
706        }
707
708        let mut blocks = Vec::new();
709        let mut max_score = 0.0f32;
710        let mut i = 0;
711
712        while i < doc_ids.len() {
713            let block_end = (i + SIMD_BLOCK_SIZE).min(doc_ids.len());
714            let block_docs = &doc_ids[i..block_end];
715            let block_tfs = &term_freqs[i..block_end];
716
717            let block = Self::create_block(block_docs, block_tfs, idf);
718            max_score = max_score.max(block.max_block_score);
719            blocks.push(block);
720
721            i = block_end;
722        }
723
724        Self {
725            blocks,
726            doc_count: doc_ids.len() as u32,
727            max_score,
728        }
729    }
730
731    fn create_block(doc_ids: &[u32], term_freqs: &[u32], idf: f32) -> SimdBp128Block {
732        let num_docs = doc_ids.len();
733        let first_doc_id = doc_ids[0];
734        let last_doc_id = *doc_ids.last().unwrap();
735
736        // Compute deltas (gap - 1)
737        let mut deltas = [0u32; SIMD_BLOCK_SIZE];
738        let mut max_delta = 0u32;
739        for j in 1..num_docs {
740            let delta = doc_ids[j] - doc_ids[j - 1] - 1;
741            deltas[j - 1] = delta;
742            max_delta = max_delta.max(delta);
743        }
744
745        // Compute TFs (tf - 1)
746        let mut tfs = [0u32; SIMD_BLOCK_SIZE];
747        let mut max_tf = 0u32;
748        for (j, &tf) in term_freqs.iter().enumerate() {
749            tfs[j] = tf.saturating_sub(1);
750            max_tf = max_tf.max(tf);
751        }
752
753        let doc_bit_width = bits_needed(max_delta);
754        let tf_bit_width = bits_needed(max_tf.saturating_sub(1));
755
756        let mut doc_data = Vec::new();
757        pack_vertical(&deltas, doc_bit_width, &mut doc_data);
758
759        let mut tf_data = Vec::new();
760        pack_vertical(&tfs, tf_bit_width, &mut tf_data);
761
762        let max_block_score = Self::compute_bm25_upper_bound(max_tf, idf);
763
764        SimdBp128Block {
765            doc_data,
766            doc_bit_width,
767            tf_data,
768            tf_bit_width,
769            first_doc_id,
770            last_doc_id,
771            num_docs: num_docs as u16,
772            max_tf,
773            max_block_score,
774        }
775    }
776
777    /// Serialize
778    pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
779        writer.write_u32::<LittleEndian>(self.doc_count)?;
780        writer.write_f32::<LittleEndian>(self.max_score)?;
781        writer.write_u32::<LittleEndian>(self.blocks.len() as u32)?;
782
783        for block in &self.blocks {
784            block.serialize(writer)?;
785        }
786
787        Ok(())
788    }
789
790    /// Deserialize
791    pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
792        let doc_count = reader.read_u32::<LittleEndian>()?;
793        let max_score = reader.read_f32::<LittleEndian>()?;
794        let num_blocks = reader.read_u32::<LittleEndian>()? as usize;
795
796        let mut blocks = Vec::with_capacity(num_blocks);
797        for _ in 0..num_blocks {
798            blocks.push(SimdBp128Block::deserialize(reader)?);
799        }
800
801        Ok(Self {
802            blocks,
803            doc_count,
804            max_score,
805        })
806    }
807
808    /// Create iterator
809    pub fn iterator(&self) -> SimdBp128Iterator<'_> {
810        SimdBp128Iterator::new(self)
811    }
812
813    /// Get approximate size in bytes
814    pub fn size_bytes(&self) -> usize {
815        let mut size = 12; // header
816        for block in &self.blocks {
817            size += 22 + block.doc_data.len() + block.tf_data.len();
818        }
819        size
820    }
821}
822
823/// Iterator over SIMD-BP128 posting list
824pub struct SimdBp128Iterator<'a> {
825    list: &'a SimdBp128PostingList,
826    current_block: usize,
827    block_doc_ids: Vec<u32>,
828    block_term_freqs: Vec<u32>,
829    pos_in_block: usize,
830    exhausted: bool,
831}
832
833impl<'a> SimdBp128Iterator<'a> {
834    pub fn new(list: &'a SimdBp128PostingList) -> Self {
835        let mut iter = Self {
836            list,
837            current_block: 0,
838            block_doc_ids: Vec::new(),
839            block_term_freqs: Vec::new(),
840            pos_in_block: 0,
841            exhausted: list.blocks.is_empty(),
842        };
843
844        if !iter.exhausted {
845            iter.decode_current_block();
846        }
847
848        iter
849    }
850
851    fn decode_current_block(&mut self) {
852        let block = &self.list.blocks[self.current_block];
853        self.block_doc_ids = block.decode_doc_ids();
854        self.block_term_freqs = block.decode_term_freqs();
855        self.pos_in_block = 0;
856    }
857
858    /// Current document ID
859    pub fn doc(&self) -> u32 {
860        if self.exhausted {
861            u32::MAX
862        } else {
863            self.block_doc_ids[self.pos_in_block]
864        }
865    }
866
867    /// Current term frequency
868    pub fn term_freq(&self) -> u32 {
869        if self.exhausted {
870            0
871        } else {
872            self.block_term_freqs[self.pos_in_block]
873        }
874    }
875
876    /// Advance to next document
877    pub fn advance(&mut self) -> u32 {
878        if self.exhausted {
879            return u32::MAX;
880        }
881
882        self.pos_in_block += 1;
883
884        if self.pos_in_block >= self.block_doc_ids.len() {
885            self.current_block += 1;
886            if self.current_block >= self.list.blocks.len() {
887                self.exhausted = true;
888                return u32::MAX;
889            }
890            self.decode_current_block();
891        }
892
893        self.doc()
894    }
895
896    /// Seek to first doc >= target with block skipping
897    pub fn seek(&mut self, target: u32) -> u32 {
898        if self.exhausted {
899            return u32::MAX;
900        }
901
902        // Binary search for target block
903        let block_idx = self.list.blocks[self.current_block..].binary_search_by(|block| {
904            if block.last_doc_id < target {
905                std::cmp::Ordering::Less
906            } else if block.first_doc_id > target {
907                std::cmp::Ordering::Greater
908            } else {
909                std::cmp::Ordering::Equal
910            }
911        });
912
913        let target_block = match block_idx {
914            Ok(idx) => self.current_block + idx,
915            Err(idx) => {
916                if self.current_block + idx >= self.list.blocks.len() {
917                    self.exhausted = true;
918                    return u32::MAX;
919                }
920                self.current_block + idx
921            }
922        };
923
924        if target_block != self.current_block {
925            self.current_block = target_block;
926            self.decode_current_block();
927        }
928
929        // Binary search within block
930        let pos = self.block_doc_ids[self.pos_in_block..]
931            .binary_search(&target)
932            .unwrap_or_else(|x| x);
933        self.pos_in_block += pos;
934
935        if self.pos_in_block >= self.block_doc_ids.len() {
936            self.current_block += 1;
937            if self.current_block >= self.list.blocks.len() {
938                self.exhausted = true;
939                return u32::MAX;
940            }
941            self.decode_current_block();
942        }
943
944        self.doc()
945    }
946
947    /// Get max score for remaining blocks
948    pub fn max_remaining_score(&self) -> f32 {
949        if self.exhausted {
950            return 0.0;
951        }
952        self.list.blocks[self.current_block..]
953            .iter()
954            .map(|b| b.max_block_score)
955            .fold(0.0f32, |a, b| a.max(b))
956    }
957
958    /// Get current block's max score
959    pub fn current_block_max_score(&self) -> f32 {
960        if self.exhausted {
961            0.0
962        } else {
963            self.list.blocks[self.current_block].max_block_score
964        }
965    }
966
967    /// Get current block's max TF
968    pub fn current_block_max_tf(&self) -> u32 {
969        if self.exhausted {
970            0
971        } else {
972            self.list.blocks[self.current_block].max_tf
973        }
974    }
975
976    /// Skip to next block containing doc >= target (for BlockWAND)
977    /// Returns (first_doc_in_block, block_max_score) or None if exhausted
978    pub fn skip_to_block_with_doc(&mut self, target: u32) -> Option<(u32, f32)> {
979        while self.current_block < self.list.blocks.len() {
980            let block = &self.list.blocks[self.current_block];
981            if block.last_doc_id >= target {
982                // Decode this block and position at start
983                self.decode_current_block();
984                return Some((block.first_doc_id, block.max_block_score));
985            }
986            self.current_block += 1;
987        }
988        self.exhausted = true;
989        None
990    }
991
992    /// Check if iterator is exhausted
993    pub fn is_exhausted(&self) -> bool {
994        self.exhausted
995    }
996}
997
998#[cfg(test)]
999mod tests {
1000    use super::*;
1001
1002    #[test]
1003    fn test_pack_unpack_vertical() {
1004        let mut values = [0u32; SIMD_BLOCK_SIZE];
1005        for (i, v) in values.iter_mut().enumerate() {
1006            *v = (i * 3) as u32;
1007        }
1008
1009        let max_val = values.iter().max().copied().unwrap();
1010        let bit_width = bits_needed(max_val);
1011
1012        let mut packed = Vec::new();
1013        pack_vertical(&values, bit_width, &mut packed);
1014
1015        let mut unpacked = [0u32; SIMD_BLOCK_SIZE];
1016        unpack_vertical(&packed, bit_width, &mut unpacked);
1017
1018        assert_eq!(values, unpacked);
1019    }
1020
1021    #[test]
1022    fn test_pack_unpack_vertical_various_widths() {
1023        for bit_width in 1..=20 {
1024            let mut values = [0u32; SIMD_BLOCK_SIZE];
1025            let max_val = (1u32 << bit_width) - 1;
1026            for (i, v) in values.iter_mut().enumerate() {
1027                *v = (i as u32) % (max_val + 1);
1028            }
1029
1030            let mut packed = Vec::new();
1031            pack_vertical(&values, bit_width, &mut packed);
1032
1033            let mut unpacked = [0u32; SIMD_BLOCK_SIZE];
1034            unpack_vertical(&packed, bit_width, &mut unpacked);
1035
1036            assert_eq!(values, unpacked, "Failed for bit_width={}", bit_width);
1037        }
1038    }
1039
1040    #[test]
1041    fn test_simd_bp128_posting_list() {
1042        let doc_ids: Vec<u32> = (0..200).map(|i| i * 2).collect();
1043        let term_freqs: Vec<u32> = (0..200).map(|i| (i % 10) + 1).collect();
1044
1045        let list = SimdBp128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
1046
1047        assert_eq!(list.doc_count, 200);
1048        assert_eq!(list.blocks.len(), 2); // 128 + 72
1049
1050        let mut iter = list.iterator();
1051        for (i, &expected_doc) in doc_ids.iter().enumerate() {
1052            assert_eq!(iter.doc(), expected_doc, "Doc mismatch at {}", i);
1053            assert_eq!(iter.term_freq(), term_freqs[i], "TF mismatch at {}", i);
1054            if i < doc_ids.len() - 1 {
1055                iter.advance();
1056            }
1057        }
1058    }
1059
1060    #[test]
1061    fn test_simd_bp128_seek() {
1062        let doc_ids: Vec<u32> = vec![10, 20, 30, 100, 200, 300, 1000, 2000];
1063        let term_freqs: Vec<u32> = vec![1, 2, 3, 4, 5, 6, 7, 8];
1064
1065        let list = SimdBp128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
1066        let mut iter = list.iterator();
1067
1068        assert_eq!(iter.seek(25), 30);
1069        assert_eq!(iter.seek(100), 100);
1070        assert_eq!(iter.seek(500), 1000);
1071        assert_eq!(iter.seek(3000), u32::MAX);
1072    }
1073
1074    #[test]
1075    fn test_simd_bp128_serialization() {
1076        let doc_ids: Vec<u32> = (0..300).map(|i| i * 3).collect();
1077        let term_freqs: Vec<u32> = (0..300).map(|i| (i % 5) + 1).collect();
1078
1079        let list = SimdBp128PostingList::from_postings(&doc_ids, &term_freqs, 1.5);
1080
1081        let mut buffer = Vec::new();
1082        list.serialize(&mut buffer).unwrap();
1083
1084        let restored = SimdBp128PostingList::deserialize(&mut &buffer[..]).unwrap();
1085
1086        assert_eq!(restored.doc_count, list.doc_count);
1087        assert_eq!(restored.blocks.len(), list.blocks.len());
1088
1089        let mut iter1 = list.iterator();
1090        let mut iter2 = restored.iterator();
1091
1092        while iter1.doc() != u32::MAX {
1093            assert_eq!(iter1.doc(), iter2.doc());
1094            assert_eq!(iter1.term_freq(), iter2.term_freq());
1095            iter1.advance();
1096            iter2.advance();
1097        }
1098    }
1099
1100    #[test]
1101    fn test_vertical_layout_size() {
1102        // True vertical layout: BLOCK_SIZE * bit_width / 8 bytes (optimal)
1103        let mut values = [0u32; SIMD_BLOCK_SIZE];
1104        for (i, v) in values.iter_mut().enumerate() {
1105            *v = i as u32;
1106        }
1107
1108        let bit_width = bits_needed(127); // 7 bits
1109        assert_eq!(bit_width, 7);
1110
1111        let mut packed = Vec::new();
1112        pack_horizontal(&values, bit_width, &mut packed);
1113
1114        // True vertical layout: 128 * 7 / 8 = 112 bytes (optimal, no padding)
1115        let expected_bytes = (SIMD_BLOCK_SIZE * bit_width as usize) / 8;
1116        assert_eq!(expected_bytes, 112);
1117        assert_eq!(packed.len(), expected_bytes);
1118    }
1119
1120    #[test]
1121    fn test_simd_bp128_block_max() {
1122        // Create a large posting list that spans multiple blocks
1123        let doc_ids: Vec<u32> = (0..500).map(|i| i * 2).collect();
1124        // Vary term frequencies so different blocks have different max_tf
1125        let term_freqs: Vec<u32> = (0..500)
1126            .map(|i| {
1127                if i < 128 {
1128                    1 // Block 0: max_tf = 1
1129                } else if i < 256 {
1130                    5 // Block 1: max_tf = 5
1131                } else if i < 384 {
1132                    10 // Block 2: max_tf = 10
1133                } else {
1134                    3 // Block 3: max_tf = 3
1135                }
1136            })
1137            .collect();
1138
1139        let list = SimdBp128PostingList::from_postings(&doc_ids, &term_freqs, 2.0);
1140
1141        // Should have 4 blocks (500 docs / 128 per block)
1142        assert_eq!(list.blocks.len(), 4);
1143        assert_eq!(list.blocks[0].max_tf, 1);
1144        assert_eq!(list.blocks[1].max_tf, 5);
1145        assert_eq!(list.blocks[2].max_tf, 10);
1146        assert_eq!(list.blocks[3].max_tf, 3);
1147
1148        // Block 2 should have highest score (max_tf = 10)
1149        assert!(list.blocks[2].max_block_score > list.blocks[0].max_block_score);
1150        assert!(list.blocks[2].max_block_score > list.blocks[1].max_block_score);
1151        assert!(list.blocks[2].max_block_score > list.blocks[3].max_block_score);
1152
1153        // Global max_score should equal block 2's score
1154        assert_eq!(list.max_score, list.blocks[2].max_block_score);
1155
1156        // Test iterator block-max methods
1157        let mut iter = list.iterator();
1158        assert_eq!(iter.current_block_max_tf(), 1); // Block 0
1159
1160        // Seek to block 1
1161        iter.seek(256); // first doc in block 1
1162        assert_eq!(iter.current_block_max_tf(), 5);
1163
1164        // Seek to block 2
1165        iter.seek(512); // first doc in block 2
1166        assert_eq!(iter.current_block_max_tf(), 10);
1167
1168        // Test skip_to_block_with_doc
1169        let mut iter2 = list.iterator();
1170        let result = iter2.skip_to_block_with_doc(300);
1171        assert!(result.is_some());
1172        let (first_doc, score) = result.unwrap();
1173        assert!(first_doc <= 300);
1174        assert!(score > 0.0);
1175    }
1176}