1use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
14use std::io::{self, Read, Write};
15
16pub const VERTICAL_BP128_BLOCK_SIZE: usize = 128;
18
19#[allow(dead_code)]
21const SIMD_LANES: usize = 4;
22
23#[allow(dead_code)]
25const GROUPS_PER_BLOCK: usize = VERTICAL_BP128_BLOCK_SIZE / SIMD_LANES;
26
27#[inline]
29pub fn bits_needed(max_val: u32) -> u8 {
30 if max_val == 0 {
31 0
32 } else {
33 32 - max_val.leading_zeros() as u8
34 }
35}
36
37#[cfg(target_arch = "aarch64")]
42#[allow(dead_code)]
43mod neon {
44 use super::*;
45 use std::arch::aarch64::*;
46
47 static BIT_EXPAND_LUT: [[u32; 8]; 256] = {
51 let mut lut = [[0u32; 8]; 256];
52 let mut byte = 0usize;
53 while byte < 256 {
54 let mut bit = 0;
55 while bit < 8 {
56 lut[byte][bit] = ((byte >> bit) & 1) as u32;
57 bit += 1;
58 }
59 byte += 1;
60 }
61 lut
62 };
63
64 #[inline]
66 #[target_feature(enable = "neon")]
67 pub unsafe fn unpack_4_neon(input: &[u8], bit_width: u8, output: &mut [u32; 4]) {
68 if bit_width == 0 {
69 *output = [0; 4];
70 return;
71 }
72
73 let mask = (1u32 << bit_width) - 1;
74
75 let mut packed_bytes = [0u8; 16];
77 let bytes_needed = ((bit_width as usize) * 4).div_ceil(8);
78 packed_bytes[..bytes_needed.min(16)].copy_from_slice(&input[..bytes_needed.min(16)]);
79 let packed = u128::from_le_bytes(packed_bytes);
80
81 let v0 = (packed & mask as u128) as u32;
83 let v1 = ((packed >> bit_width) & mask as u128) as u32;
84 let v2 = ((packed >> (bit_width * 2)) & mask as u128) as u32;
85 let v3 = ((packed >> (bit_width * 3)) & mask as u128) as u32;
86
87 unsafe {
89 let result = vld1q_u32([v0, v1, v2, v3].as_ptr());
90 vst1q_u32(output.as_mut_ptr(), result);
91 }
92 }
93
94 #[inline]
96 #[target_feature(enable = "neon")]
97 pub unsafe fn prefix_sum_4_neon(values: &mut [u32; 4]) {
98 unsafe {
99 let mut v = vld1q_u32(values.as_ptr());
101
102 let shifted1 = vextq_u32(vdupq_n_u32(0), v, 3); v = vaddq_u32(v, shifted1);
107 let shifted2 = vextq_u32(vdupq_n_u32(0), v, 2); v = vaddq_u32(v, shifted2);
112
113 vst1q_u32(values.as_mut_ptr(), v);
115 }
116 }
117
118 #[target_feature(enable = "neon")]
125 pub unsafe fn unpack_block_neon(
126 input: &[u8],
127 bit_width: u8,
128 output: &mut [u32; VERTICAL_BP128_BLOCK_SIZE],
129 ) {
130 if bit_width == 0 {
131 output.fill(0);
132 return;
133 }
134
135 unsafe {
137 let zero = vdupq_n_u32(0);
138 for i in (0..VERTICAL_BP128_BLOCK_SIZE).step_by(4) {
139 vst1q_u32(output[i..].as_mut_ptr(), zero);
140 }
141 }
142
143 for bit_pos in 0..bit_width as usize {
145 let byte_offset = bit_pos * 16;
146 let bit_mask = 1u32 << bit_pos;
147
148 if bit_pos + 1 < bit_width as usize {
150 let next_offset = (bit_pos + 1) * 16;
151 unsafe {
152 std::arch::asm!(
154 "prfm pldl1keep, [{0}]",
155 in(reg) input.as_ptr().add(next_offset),
156 options(nostack, preserves_flags)
157 );
158 }
159 }
160
161 for chunk in 0..4 {
163 let chunk_offset = byte_offset + chunk * 4;
164
165 let b0 = input[chunk_offset] as usize;
167 let b1 = input[chunk_offset + 1] as usize;
168 let b2 = input[chunk_offset + 2] as usize;
169 let b3 = input[chunk_offset + 3] as usize;
170
171 let base_int = chunk * 32;
172
173 unsafe {
174 let mask_vec = vdupq_n_u32(bit_mask);
175
176 let lut0 = &BIT_EXPAND_LUT[b0];
178 let bits_0_3 = vld1q_u32(lut0.as_ptr());
179 let bits_4_7 = vld1q_u32(lut0[4..].as_ptr());
180
181 let shifted_0_3 = vmulq_u32(bits_0_3, mask_vec);
182 let shifted_4_7 = vmulq_u32(bits_4_7, mask_vec);
183
184 let cur_0_3 = vld1q_u32(output[base_int..].as_ptr());
185 let cur_4_7 = vld1q_u32(output[base_int + 4..].as_ptr());
186
187 vst1q_u32(
188 output[base_int..].as_mut_ptr(),
189 vorrq_u32(cur_0_3, shifted_0_3),
190 );
191 vst1q_u32(
192 output[base_int + 4..].as_mut_ptr(),
193 vorrq_u32(cur_4_7, shifted_4_7),
194 );
195
196 let lut1 = &BIT_EXPAND_LUT[b1];
198 let bits_8_11 = vld1q_u32(lut1.as_ptr());
199 let bits_12_15 = vld1q_u32(lut1[4..].as_ptr());
200
201 let shifted_8_11 = vmulq_u32(bits_8_11, mask_vec);
202 let shifted_12_15 = vmulq_u32(bits_12_15, mask_vec);
203
204 let cur_8_11 = vld1q_u32(output[base_int + 8..].as_ptr());
205 let cur_12_15 = vld1q_u32(output[base_int + 12..].as_ptr());
206
207 vst1q_u32(
208 output[base_int + 8..].as_mut_ptr(),
209 vorrq_u32(cur_8_11, shifted_8_11),
210 );
211 vst1q_u32(
212 output[base_int + 12..].as_mut_ptr(),
213 vorrq_u32(cur_12_15, shifted_12_15),
214 );
215
216 let lut2 = &BIT_EXPAND_LUT[b2];
218 let bits_16_19 = vld1q_u32(lut2.as_ptr());
219 let bits_20_23 = vld1q_u32(lut2[4..].as_ptr());
220
221 let shifted_16_19 = vmulq_u32(bits_16_19, mask_vec);
222 let shifted_20_23 = vmulq_u32(bits_20_23, mask_vec);
223
224 let cur_16_19 = vld1q_u32(output[base_int + 16..].as_ptr());
225 let cur_20_23 = vld1q_u32(output[base_int + 20..].as_ptr());
226
227 vst1q_u32(
228 output[base_int + 16..].as_mut_ptr(),
229 vorrq_u32(cur_16_19, shifted_16_19),
230 );
231 vst1q_u32(
232 output[base_int + 20..].as_mut_ptr(),
233 vorrq_u32(cur_20_23, shifted_20_23),
234 );
235
236 let lut3 = &BIT_EXPAND_LUT[b3];
238 let bits_24_27 = vld1q_u32(lut3.as_ptr());
239 let bits_28_31 = vld1q_u32(lut3[4..].as_ptr());
240
241 let shifted_24_27 = vmulq_u32(bits_24_27, mask_vec);
242 let shifted_28_31 = vmulq_u32(bits_28_31, mask_vec);
243
244 let cur_24_27 = vld1q_u32(output[base_int + 24..].as_ptr());
245 let cur_28_31 = vld1q_u32(output[base_int + 28..].as_ptr());
246
247 vst1q_u32(
248 output[base_int + 24..].as_mut_ptr(),
249 vorrq_u32(cur_24_27, shifted_24_27),
250 );
251 vst1q_u32(
252 output[base_int + 28..].as_mut_ptr(),
253 vorrq_u32(cur_28_31, shifted_28_31),
254 );
255 }
256 }
257 }
258 }
259
260 #[target_feature(enable = "neon")]
262 pub unsafe fn prefix_sum_block_neon(
263 deltas: &mut [u32; VERTICAL_BP128_BLOCK_SIZE],
264 first_val: u32,
265 ) {
266 let mut carry = first_val;
267
268 for group in 0..GROUPS_PER_BLOCK {
269 let start = group * SIMD_LANES;
270 let mut group_vals = [
271 deltas[start],
272 deltas[start + 1],
273 deltas[start + 2],
274 deltas[start + 3],
275 ];
276
277 group_vals[0] = group_vals[0].wrapping_add(carry);
279
280 unsafe { prefix_sum_4_neon(&mut group_vals) };
282
283 deltas[start..start + 4].copy_from_slice(&group_vals);
285
286 carry = group_vals[3];
288 }
289 }
290}
291
292#[allow(dead_code)]
297mod scalar {
298 use super::*;
299
300 #[inline]
302 pub fn pack_4_scalar(values: &[u32; 4], bit_width: u8, output: &mut [u8]) {
303 if bit_width == 0 {
304 return;
305 }
306
307 let bytes_needed = ((bit_width as usize) * 4).div_ceil(8);
308 let mut packed = 0u128;
309 for (i, &val) in values.iter().enumerate() {
310 packed |= (val as u128) << (i * bit_width as usize);
311 }
312
313 let packed_bytes = packed.to_le_bytes();
314 output[..bytes_needed].copy_from_slice(&packed_bytes[..bytes_needed]);
315 }
316
317 #[inline]
319 pub fn unpack_4_scalar(input: &[u8], bit_width: u8, output: &mut [u32; 4]) {
320 if bit_width == 0 {
321 *output = [0; 4];
322 return;
323 }
324
325 let mask = (1u32 << bit_width) - 1;
326 let mut packed_bytes = [0u8; 16];
327 let bytes_needed = ((bit_width as usize) * 4).div_ceil(8);
328 packed_bytes[..bytes_needed.min(16)].copy_from_slice(&input[..bytes_needed.min(16)]);
329 let packed = u128::from_le_bytes(packed_bytes);
330
331 output[0] = (packed & mask as u128) as u32;
332 output[1] = ((packed >> bit_width) & mask as u128) as u32;
333 output[2] = ((packed >> (bit_width * 2)) & mask as u128) as u32;
334 output[3] = ((packed >> (bit_width * 3)) & mask as u128) as u32;
335 }
336
337 #[inline]
339 pub fn prefix_sum_4_scalar(vals: &mut [u32; 4]) {
340 vals[1] = vals[1].wrapping_add(vals[0]);
341 vals[2] = vals[2].wrapping_add(vals[1]);
342 vals[3] = vals[3].wrapping_add(vals[2]);
343 }
344
345 pub fn unpack_block_scalar(
347 input: &[u8],
348 bit_width: u8,
349 output: &mut [u32; VERTICAL_BP128_BLOCK_SIZE],
350 ) {
351 if bit_width == 0 {
352 output.fill(0);
353 return;
354 }
355
356 output.fill(0);
358
359 for bit_pos in 0..bit_width as usize {
361 let byte_offset = bit_pos * 16; for byte_idx in 0..16 {
364 let byte_val = input[byte_offset + byte_idx];
365 let base_int = byte_idx * 8;
366
367 output[base_int] |= (byte_val & 1) as u32 * (1 << bit_pos);
369 output[base_int + 1] |= ((byte_val >> 1) & 1) as u32 * (1 << bit_pos);
370 output[base_int + 2] |= ((byte_val >> 2) & 1) as u32 * (1 << bit_pos);
371 output[base_int + 3] |= ((byte_val >> 3) & 1) as u32 * (1 << bit_pos);
372 output[base_int + 4] |= ((byte_val >> 4) & 1) as u32 * (1 << bit_pos);
373 output[base_int + 5] |= ((byte_val >> 5) & 1) as u32 * (1 << bit_pos);
374 output[base_int + 6] |= ((byte_val >> 6) & 1) as u32 * (1 << bit_pos);
375 output[base_int + 7] |= ((byte_val >> 7) & 1) as u32 * (1 << bit_pos);
376 }
377 }
378 }
379
380 pub fn prefix_sum_block_scalar(deltas: &mut [u32; VERTICAL_BP128_BLOCK_SIZE], first_val: u32) {
382 let mut carry = first_val;
383
384 for group in 0..GROUPS_PER_BLOCK {
385 let start = group * SIMD_LANES;
386 let mut group_vals = [
387 deltas[start],
388 deltas[start + 1],
389 deltas[start + 2],
390 deltas[start + 3],
391 ];
392
393 group_vals[0] = group_vals[0].wrapping_add(carry);
394 prefix_sum_4_scalar(&mut group_vals);
395 deltas[start..start + 4].copy_from_slice(&group_vals);
396 carry = group_vals[3];
397 }
398 }
399}
400
401pub fn pack_vertical(
430 values: &[u32; VERTICAL_BP128_BLOCK_SIZE],
431 bit_width: u8,
432 output: &mut Vec<u8>,
433) {
434 if bit_width == 0 {
435 return;
436 }
437
438 let total_bytes = 16 * bit_width as usize;
440 let start = output.len();
441 output.resize(start + total_bytes, 0);
442
443 for bit_pos in 0..bit_width as usize {
445 let byte_offset = start + bit_pos * 16;
446
447 for byte_idx in 0..16 {
449 let base_int = byte_idx * 8;
450 let mut byte_val = 0u8;
451
452 byte_val |= ((values[base_int] >> bit_pos) & 1) as u8;
454 byte_val |= (((values[base_int + 1] >> bit_pos) & 1) as u8) << 1;
455 byte_val |= (((values[base_int + 2] >> bit_pos) & 1) as u8) << 2;
456 byte_val |= (((values[base_int + 3] >> bit_pos) & 1) as u8) << 3;
457 byte_val |= (((values[base_int + 4] >> bit_pos) & 1) as u8) << 4;
458 byte_val |= (((values[base_int + 5] >> bit_pos) & 1) as u8) << 5;
459 byte_val |= (((values[base_int + 6] >> bit_pos) & 1) as u8) << 6;
460 byte_val |= (((values[base_int + 7] >> bit_pos) & 1) as u8) << 7;
461
462 output[byte_offset + byte_idx] = byte_val;
463 }
464 }
465}
466
467pub fn unpack_vertical(input: &[u8], bit_width: u8, output: &mut [u32; VERTICAL_BP128_BLOCK_SIZE]) {
472 if bit_width == 0 {
473 output.fill(0);
474 return;
475 }
476
477 #[cfg(target_arch = "aarch64")]
478 {
479 unsafe { unpack_vertical_neon(input, bit_width, output) }
480 }
481
482 #[cfg(target_arch = "x86_64")]
483 {
484 if is_x86_feature_detected!("sse2") {
485 unsafe { unpack_vertical_sse(input, bit_width, output) }
486 } else {
487 unpack_vertical_scalar(input, bit_width, output)
488 }
489 }
490
491 #[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
492 {
493 unpack_vertical_scalar(input, bit_width, output)
494 }
495}
496
497#[inline]
499fn unpack_vertical_scalar(
500 input: &[u8],
501 bit_width: u8,
502 output: &mut [u32; VERTICAL_BP128_BLOCK_SIZE],
503) {
504 output.fill(0);
505
506 for bit_pos in 0..bit_width as usize {
508 let byte_offset = bit_pos * 16;
509 let bit_mask = 1u32 << bit_pos;
510
511 for byte_idx in 0..16 {
513 let byte_val = input[byte_offset + byte_idx];
514 let base_int = byte_idx * 8;
515
516 if byte_val & 0x01 != 0 {
518 output[base_int] |= bit_mask;
519 }
520 if byte_val & 0x02 != 0 {
521 output[base_int + 1] |= bit_mask;
522 }
523 if byte_val & 0x04 != 0 {
524 output[base_int + 2] |= bit_mask;
525 }
526 if byte_val & 0x08 != 0 {
527 output[base_int + 3] |= bit_mask;
528 }
529 if byte_val & 0x10 != 0 {
530 output[base_int + 4] |= bit_mask;
531 }
532 if byte_val & 0x20 != 0 {
533 output[base_int + 5] |= bit_mask;
534 }
535 if byte_val & 0x40 != 0 {
536 output[base_int + 6] |= bit_mask;
537 }
538 if byte_val & 0x80 != 0 {
539 output[base_int + 7] |= bit_mask;
540 }
541 }
542 }
543}
544
545#[cfg(target_arch = "aarch64")]
547#[target_feature(enable = "neon")]
548unsafe fn unpack_vertical_neon(
549 input: &[u8],
550 bit_width: u8,
551 output: &mut [u32; VERTICAL_BP128_BLOCK_SIZE],
552) {
553 use std::arch::aarch64::*;
554
555 unsafe {
556 let zero = vdupq_n_u32(0);
558 for i in (0..VERTICAL_BP128_BLOCK_SIZE).step_by(4) {
559 vst1q_u32(output[i..].as_mut_ptr(), zero);
560 }
561
562 for bit_pos in 0..bit_width as usize {
564 let byte_offset = bit_pos * 16;
565 let bit_mask = 1u32 << bit_pos;
566
567 let bytes = vld1q_u8(input.as_ptr().add(byte_offset));
569
570 let mut byte_array = [0u8; 16];
572 vst1q_u8(byte_array.as_mut_ptr(), bytes);
573
574 for (byte_idx, &byte_val) in byte_array.iter().enumerate() {
576 let base_int = byte_idx * 8;
577
578 output[base_int] |= ((byte_val & 0x01) as u32) * bit_mask;
580 output[base_int + 1] |= (((byte_val >> 1) & 0x01) as u32) * bit_mask;
581 output[base_int + 2] |= (((byte_val >> 2) & 0x01) as u32) * bit_mask;
582 output[base_int + 3] |= (((byte_val >> 3) & 0x01) as u32) * bit_mask;
583 output[base_int + 4] |= (((byte_val >> 4) & 0x01) as u32) * bit_mask;
584 output[base_int + 5] |= (((byte_val >> 5) & 0x01) as u32) * bit_mask;
585 output[base_int + 6] |= (((byte_val >> 6) & 0x01) as u32) * bit_mask;
586 output[base_int + 7] |= (((byte_val >> 7) & 0x01) as u32) * bit_mask;
587 }
588 }
589 }
590}
591
592#[cfg(target_arch = "x86_64")]
594#[target_feature(enable = "sse2")]
595unsafe fn unpack_vertical_sse(
596 input: &[u8],
597 bit_width: u8,
598 output: &mut [u32; VERTICAL_BP128_BLOCK_SIZE],
599) {
600 use std::arch::x86_64::*;
601
602 unsafe {
603 let zero = _mm_setzero_si128();
605 for i in (0..VERTICAL_BP128_BLOCK_SIZE).step_by(4) {
606 _mm_storeu_si128(output[i..].as_mut_ptr() as *mut __m128i, zero);
607 }
608
609 for bit_pos in 0..bit_width as usize {
611 let byte_offset = bit_pos * 16;
612
613 let bytes = _mm_loadu_si128(input.as_ptr().add(byte_offset) as *const __m128i);
615
616 let mut byte_array = [0u8; 16];
618 _mm_storeu_si128(byte_array.as_mut_ptr() as *mut __m128i, bytes);
619
620 for byte_idx in 0..16 {
621 let byte_val = byte_array[byte_idx];
622 let base_int = byte_idx * 8;
623
624 if byte_val & 0x01 != 0 {
626 output[base_int] |= 1u32 << bit_pos;
627 }
628 if byte_val & 0x02 != 0 {
629 output[base_int + 1] |= 1u32 << bit_pos;
630 }
631 if byte_val & 0x04 != 0 {
632 output[base_int + 2] |= 1u32 << bit_pos;
633 }
634 if byte_val & 0x08 != 0 {
635 output[base_int + 3] |= 1u32 << bit_pos;
636 }
637 if byte_val & 0x10 != 0 {
638 output[base_int + 4] |= 1u32 << bit_pos;
639 }
640 if byte_val & 0x20 != 0 {
641 output[base_int + 5] |= 1u32 << bit_pos;
642 }
643 if byte_val & 0x40 != 0 {
644 output[base_int + 6] |= 1u32 << bit_pos;
645 }
646 if byte_val & 0x80 != 0 {
647 output[base_int + 7] |= 1u32 << bit_pos;
648 }
649 }
650 }
651 }
652}
653
654#[allow(dead_code)]
656pub fn pack_horizontal(
657 values: &[u32; VERTICAL_BP128_BLOCK_SIZE],
658 bit_width: u8,
659 output: &mut Vec<u8>,
660) {
661 if bit_width == 0 {
662 return;
663 }
664
665 let bytes_needed = (VERTICAL_BP128_BLOCK_SIZE * bit_width as usize).div_ceil(8);
666 let start = output.len();
667 output.resize(start + bytes_needed, 0);
668
669 let mut bit_pos = 0usize;
670 for &value in values {
671 let byte_idx = start + bit_pos / 8;
672 let bit_offset = bit_pos % 8;
673
674 let mut remaining_bits = bit_width as usize;
675 let mut val = value;
676 let mut current_byte_idx = byte_idx;
677 let mut current_bit_offset = bit_offset;
678
679 while remaining_bits > 0 {
680 let bits_in_byte = (8 - current_bit_offset).min(remaining_bits);
681 let mask = ((1u32 << bits_in_byte) - 1) as u8;
682 output[current_byte_idx] |= ((val as u8) & mask) << current_bit_offset;
683 val >>= bits_in_byte;
684 remaining_bits -= bits_in_byte;
685 current_byte_idx += 1;
686 current_bit_offset = 0;
687 }
688
689 bit_pos += bit_width as usize;
690 }
691}
692
693#[allow(dead_code)]
694pub fn unpack_horizontal(
695 input: &[u8],
696 bit_width: u8,
697 output: &mut [u32; VERTICAL_BP128_BLOCK_SIZE],
698) {
699 if bit_width == 0 {
700 output.fill(0);
701 return;
702 }
703
704 let mask = (1u64 << bit_width) - 1;
705 let bit_width_usize = bit_width as usize;
706 let mut bit_pos = 0usize;
707 let input_ptr = input.as_ptr();
708
709 for out in output.iter_mut() {
710 let byte_idx = bit_pos >> 3;
711 let bit_offset = bit_pos & 7;
712
713 let word = unsafe { (input_ptr.add(byte_idx) as *const u64).read_unaligned() };
714
715 *out = ((word >> bit_offset) & mask) as u32;
716 bit_pos += bit_width_usize;
717 }
718}
719
720#[allow(dead_code)]
722pub fn prefix_sum_128(deltas: &mut [u32; VERTICAL_BP128_BLOCK_SIZE], first_val: u32) {
723 #[cfg(target_arch = "aarch64")]
724 {
725 unsafe { neon::prefix_sum_block_neon(deltas, first_val) }
726 }
727
728 #[cfg(not(target_arch = "aarch64"))]
729 {
730 scalar::prefix_sum_block_scalar(deltas, first_val)
731 }
732}
733
734pub fn unpack_vertical_d1(
743 input: &[u8],
744 bit_width: u8,
745 first_doc_id: u32,
746 output: &mut [u32; VERTICAL_BP128_BLOCK_SIZE],
747 count: usize,
748) {
749 if count == 0 {
750 return;
751 }
752
753 if bit_width == 0 {
754 let mut current = first_doc_id;
756 output[0] = current;
757 for out_val in output.iter_mut().take(count).skip(1) {
758 current = current.wrapping_add(1);
759 *out_val = current;
760 }
761 return;
762 }
763
764 let mut deltas = [0u32; VERTICAL_BP128_BLOCK_SIZE];
766 unpack_vertical(input, bit_width, &mut deltas);
767
768 output[0] = first_doc_id;
770 let mut current = first_doc_id;
771
772 for i in 1..count {
773 current = current.wrapping_add(deltas[i - 1]).wrapping_add(1);
775 output[i] = current;
776 }
777}
778
779#[derive(Debug, Clone)]
781pub struct VerticalBP128Block {
782 pub doc_data: Vec<u8>,
784 pub doc_bit_width: u8,
786 pub tf_data: Vec<u8>,
788 pub tf_bit_width: u8,
790 pub first_doc_id: u32,
792 pub last_doc_id: u32,
794 pub num_docs: u16,
796 pub max_tf: u32,
798 pub max_block_score: f32,
800}
801
802impl VerticalBP128Block {
803 pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
805 writer.write_u32::<LittleEndian>(self.first_doc_id)?;
806 writer.write_u32::<LittleEndian>(self.last_doc_id)?;
807 writer.write_u16::<LittleEndian>(self.num_docs)?;
808 writer.write_u8(self.doc_bit_width)?;
809 writer.write_u8(self.tf_bit_width)?;
810 writer.write_u32::<LittleEndian>(self.max_tf)?;
811 writer.write_f32::<LittleEndian>(self.max_block_score)?;
812
813 writer.write_u16::<LittleEndian>(self.doc_data.len() as u16)?;
814 writer.write_all(&self.doc_data)?;
815
816 writer.write_u16::<LittleEndian>(self.tf_data.len() as u16)?;
817 writer.write_all(&self.tf_data)?;
818
819 Ok(())
820 }
821
822 pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
824 let first_doc_id = reader.read_u32::<LittleEndian>()?;
825 let last_doc_id = reader.read_u32::<LittleEndian>()?;
826 let num_docs = reader.read_u16::<LittleEndian>()?;
827 let doc_bit_width = reader.read_u8()?;
828 let tf_bit_width = reader.read_u8()?;
829 let max_tf = reader.read_u32::<LittleEndian>()?;
830 let max_block_score = reader.read_f32::<LittleEndian>()?;
831
832 let doc_len = reader.read_u16::<LittleEndian>()? as usize;
833 let mut doc_data = vec![0u8; doc_len];
834 reader.read_exact(&mut doc_data)?;
835
836 let tf_len = reader.read_u16::<LittleEndian>()? as usize;
837 let mut tf_data = vec![0u8; tf_len];
838 reader.read_exact(&mut tf_data)?;
839
840 Ok(Self {
841 doc_data,
842 doc_bit_width,
843 tf_data,
844 tf_bit_width,
845 first_doc_id,
846 last_doc_id,
847 num_docs,
848 max_tf,
849 max_block_score,
850 })
851 }
852
853 pub fn decode_doc_ids(&self) -> Vec<u32> {
855 if self.num_docs == 0 {
856 return Vec::new();
857 }
858
859 let mut output = [0u32; VERTICAL_BP128_BLOCK_SIZE];
860 unpack_vertical_d1(
861 &self.doc_data,
862 self.doc_bit_width,
863 self.first_doc_id,
864 &mut output,
865 self.num_docs as usize,
866 );
867
868 output[..self.num_docs as usize].to_vec()
869 }
870
871 pub fn decode_term_freqs(&self) -> Vec<u32> {
873 if self.num_docs == 0 {
874 return Vec::new();
875 }
876
877 let mut output = [0u32; VERTICAL_BP128_BLOCK_SIZE];
878 unpack_vertical(&self.tf_data, self.tf_bit_width, &mut output);
879
880 output[..self.num_docs as usize]
882 .iter()
883 .map(|&tf| tf + 1)
884 .collect()
885 }
886}
887
888#[derive(Debug, Clone)]
890pub struct VerticalBP128PostingList {
891 pub blocks: Vec<VerticalBP128Block>,
893 pub doc_count: u32,
895 pub max_score: f32,
897}
898
899impl VerticalBP128PostingList {
900 const K1: f32 = 1.2;
902 const B: f32 = 0.75;
903
904 #[inline]
906 pub fn compute_bm25_upper_bound(max_tf: u32, idf: f32) -> f32 {
907 let tf = max_tf as f32;
908 let min_length_norm = 1.0 - Self::B;
909 let tf_norm = (tf * (Self::K1 + 1.0)) / (tf + Self::K1 * min_length_norm);
910 idf * tf_norm
911 }
912
913 pub fn from_postings(doc_ids: &[u32], term_freqs: &[u32], idf: f32) -> Self {
915 assert_eq!(doc_ids.len(), term_freqs.len());
916
917 if doc_ids.is_empty() {
918 return Self {
919 blocks: Vec::new(),
920 doc_count: 0,
921 max_score: 0.0,
922 };
923 }
924
925 let mut blocks = Vec::new();
926 let mut max_score = 0.0f32;
927 let mut i = 0;
928
929 while i < doc_ids.len() {
930 let block_end = (i + VERTICAL_BP128_BLOCK_SIZE).min(doc_ids.len());
931 let block_docs = &doc_ids[i..block_end];
932 let block_tfs = &term_freqs[i..block_end];
933
934 let block = Self::create_block(block_docs, block_tfs, idf);
935 max_score = max_score.max(block.max_block_score);
936 blocks.push(block);
937
938 i = block_end;
939 }
940
941 Self {
942 blocks,
943 doc_count: doc_ids.len() as u32,
944 max_score,
945 }
946 }
947
948 fn create_block(doc_ids: &[u32], term_freqs: &[u32], idf: f32) -> VerticalBP128Block {
949 let num_docs = doc_ids.len();
950 let first_doc_id = doc_ids[0];
951 let last_doc_id = *doc_ids.last().unwrap();
952
953 let mut deltas = [0u32; VERTICAL_BP128_BLOCK_SIZE];
955 let mut max_delta = 0u32;
956 for j in 1..num_docs {
957 let delta = doc_ids[j] - doc_ids[j - 1] - 1;
958 deltas[j - 1] = delta;
959 max_delta = max_delta.max(delta);
960 }
961
962 let mut tfs = [0u32; VERTICAL_BP128_BLOCK_SIZE];
964 let mut max_tf = 0u32;
965 for (j, &tf) in term_freqs.iter().enumerate() {
966 tfs[j] = tf.saturating_sub(1);
967 max_tf = max_tf.max(tf);
968 }
969
970 let doc_bit_width = bits_needed(max_delta);
971 let tf_bit_width = bits_needed(max_tf.saturating_sub(1));
972
973 let mut doc_data = Vec::new();
974 pack_vertical(&deltas, doc_bit_width, &mut doc_data);
975
976 let mut tf_data = Vec::new();
977 pack_vertical(&tfs, tf_bit_width, &mut tf_data);
978
979 let max_block_score = Self::compute_bm25_upper_bound(max_tf, idf);
980
981 VerticalBP128Block {
982 doc_data,
983 doc_bit_width,
984 tf_data,
985 tf_bit_width,
986 first_doc_id,
987 last_doc_id,
988 num_docs: num_docs as u16,
989 max_tf,
990 max_block_score,
991 }
992 }
993
994 pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
996 writer.write_u32::<LittleEndian>(self.doc_count)?;
997 writer.write_f32::<LittleEndian>(self.max_score)?;
998 writer.write_u32::<LittleEndian>(self.blocks.len() as u32)?;
999
1000 for block in &self.blocks {
1001 block.serialize(writer)?;
1002 }
1003
1004 Ok(())
1005 }
1006
1007 pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
1009 let doc_count = reader.read_u32::<LittleEndian>()?;
1010 let max_score = reader.read_f32::<LittleEndian>()?;
1011 let num_blocks = reader.read_u32::<LittleEndian>()? as usize;
1012
1013 let mut blocks = Vec::with_capacity(num_blocks);
1014 for _ in 0..num_blocks {
1015 blocks.push(VerticalBP128Block::deserialize(reader)?);
1016 }
1017
1018 Ok(Self {
1019 blocks,
1020 doc_count,
1021 max_score,
1022 })
1023 }
1024
1025 pub fn iterator(&self) -> VerticalBP128Iterator<'_> {
1027 VerticalBP128Iterator::new(self)
1028 }
1029
1030 pub fn size_bytes(&self) -> usize {
1032 let mut size = 12; for block in &self.blocks {
1034 size += 22 + block.doc_data.len() + block.tf_data.len();
1035 }
1036 size
1037 }
1038}
1039
1040pub struct VerticalBP128Iterator<'a> {
1042 list: &'a VerticalBP128PostingList,
1043 current_block: usize,
1044 block_doc_ids: Vec<u32>,
1045 block_term_freqs: Vec<u32>,
1046 pos_in_block: usize,
1047 exhausted: bool,
1048}
1049
1050impl<'a> VerticalBP128Iterator<'a> {
1051 pub fn new(list: &'a VerticalBP128PostingList) -> Self {
1052 let mut iter = Self {
1053 list,
1054 current_block: 0,
1055 block_doc_ids: Vec::new(),
1056 block_term_freqs: Vec::new(),
1057 pos_in_block: 0,
1058 exhausted: list.blocks.is_empty(),
1059 };
1060
1061 if !iter.exhausted {
1062 iter.decode_current_block();
1063 }
1064
1065 iter
1066 }
1067
1068 fn decode_current_block(&mut self) {
1069 let block = &self.list.blocks[self.current_block];
1070 self.block_doc_ids = block.decode_doc_ids();
1071 self.block_term_freqs = block.decode_term_freqs();
1072 self.pos_in_block = 0;
1073 }
1074
1075 pub fn doc(&self) -> u32 {
1077 if self.exhausted {
1078 u32::MAX
1079 } else {
1080 self.block_doc_ids[self.pos_in_block]
1081 }
1082 }
1083
1084 pub fn term_freq(&self) -> u32 {
1086 if self.exhausted {
1087 0
1088 } else {
1089 self.block_term_freqs[self.pos_in_block]
1090 }
1091 }
1092
1093 pub fn advance(&mut self) -> u32 {
1095 if self.exhausted {
1096 return u32::MAX;
1097 }
1098
1099 self.pos_in_block += 1;
1100
1101 if self.pos_in_block >= self.block_doc_ids.len() {
1102 self.current_block += 1;
1103 if self.current_block >= self.list.blocks.len() {
1104 self.exhausted = true;
1105 return u32::MAX;
1106 }
1107 self.decode_current_block();
1108 }
1109
1110 self.doc()
1111 }
1112
1113 pub fn seek(&mut self, target: u32) -> u32 {
1115 if self.exhausted {
1116 return u32::MAX;
1117 }
1118
1119 let block_idx = self.list.blocks[self.current_block..].binary_search_by(|block| {
1121 if block.last_doc_id < target {
1122 std::cmp::Ordering::Less
1123 } else if block.first_doc_id > target {
1124 std::cmp::Ordering::Greater
1125 } else {
1126 std::cmp::Ordering::Equal
1127 }
1128 });
1129
1130 let target_block = match block_idx {
1131 Ok(idx) => self.current_block + idx,
1132 Err(idx) => {
1133 if self.current_block + idx >= self.list.blocks.len() {
1134 self.exhausted = true;
1135 return u32::MAX;
1136 }
1137 self.current_block + idx
1138 }
1139 };
1140
1141 if target_block != self.current_block {
1142 self.current_block = target_block;
1143 self.decode_current_block();
1144 }
1145
1146 let pos = self.block_doc_ids[self.pos_in_block..]
1148 .binary_search(&target)
1149 .unwrap_or_else(|x| x);
1150 self.pos_in_block += pos;
1151
1152 if self.pos_in_block >= self.block_doc_ids.len() {
1153 self.current_block += 1;
1154 if self.current_block >= self.list.blocks.len() {
1155 self.exhausted = true;
1156 return u32::MAX;
1157 }
1158 self.decode_current_block();
1159 }
1160
1161 self.doc()
1162 }
1163
1164 pub fn max_remaining_score(&self) -> f32 {
1166 if self.exhausted {
1167 return 0.0;
1168 }
1169 self.list.blocks[self.current_block..]
1170 .iter()
1171 .map(|b| b.max_block_score)
1172 .fold(0.0f32, |a, b| a.max(b))
1173 }
1174
1175 pub fn current_block_max_score(&self) -> f32 {
1177 if self.exhausted {
1178 0.0
1179 } else {
1180 self.list.blocks[self.current_block].max_block_score
1181 }
1182 }
1183
1184 pub fn current_block_max_tf(&self) -> u32 {
1186 if self.exhausted {
1187 0
1188 } else {
1189 self.list.blocks[self.current_block].max_tf
1190 }
1191 }
1192
1193 pub fn skip_to_block_with_doc(&mut self, target: u32) -> Option<(u32, f32)> {
1196 while self.current_block < self.list.blocks.len() {
1197 let block = &self.list.blocks[self.current_block];
1198 if block.last_doc_id >= target {
1199 self.decode_current_block();
1201 return Some((block.first_doc_id, block.max_block_score));
1202 }
1203 self.current_block += 1;
1204 }
1205 self.exhausted = true;
1206 None
1207 }
1208
1209 pub fn is_exhausted(&self) -> bool {
1211 self.exhausted
1212 }
1213}
1214
1215#[cfg(test)]
1216mod tests {
1217 use super::*;
1218
1219 #[test]
1220 fn test_pack_unpack_vertical() {
1221 let mut values = [0u32; VERTICAL_BP128_BLOCK_SIZE];
1222 for (i, v) in values.iter_mut().enumerate() {
1223 *v = (i * 3) as u32;
1224 }
1225
1226 let max_val = values.iter().max().copied().unwrap();
1227 let bit_width = bits_needed(max_val);
1228
1229 let mut packed = Vec::new();
1230 pack_vertical(&values, bit_width, &mut packed);
1231
1232 let mut unpacked = [0u32; VERTICAL_BP128_BLOCK_SIZE];
1233 unpack_vertical(&packed, bit_width, &mut unpacked);
1234
1235 assert_eq!(values, unpacked);
1236 }
1237
1238 #[test]
1239 fn test_pack_unpack_vertical_various_widths() {
1240 for bit_width in 1..=20 {
1241 let mut values = [0u32; VERTICAL_BP128_BLOCK_SIZE];
1242 let max_val = (1u32 << bit_width) - 1;
1243 for (i, v) in values.iter_mut().enumerate() {
1244 *v = (i as u32) % (max_val + 1);
1245 }
1246
1247 let mut packed = Vec::new();
1248 pack_vertical(&values, bit_width, &mut packed);
1249
1250 let mut unpacked = [0u32; VERTICAL_BP128_BLOCK_SIZE];
1251 unpack_vertical(&packed, bit_width, &mut unpacked);
1252
1253 assert_eq!(values, unpacked, "Failed for bit_width={}", bit_width);
1254 }
1255 }
1256
1257 #[test]
1258 fn test_simd_bp128_posting_list() {
1259 let doc_ids: Vec<u32> = (0..200).map(|i| i * 2).collect();
1260 let term_freqs: Vec<u32> = (0..200).map(|i| (i % 10) + 1).collect();
1261
1262 let list = VerticalBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
1263
1264 assert_eq!(list.doc_count, 200);
1265 assert_eq!(list.blocks.len(), 2); let mut iter = list.iterator();
1268 for (i, &expected_doc) in doc_ids.iter().enumerate() {
1269 assert_eq!(iter.doc(), expected_doc, "Doc mismatch at {}", i);
1270 assert_eq!(iter.term_freq(), term_freqs[i], "TF mismatch at {}", i);
1271 if i < doc_ids.len() - 1 {
1272 iter.advance();
1273 }
1274 }
1275 }
1276
1277 #[test]
1278 fn test_simd_bp128_seek() {
1279 let doc_ids: Vec<u32> = vec![10, 20, 30, 100, 200, 300, 1000, 2000];
1280 let term_freqs: Vec<u32> = vec![1, 2, 3, 4, 5, 6, 7, 8];
1281
1282 let list = VerticalBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
1283 let mut iter = list.iterator();
1284
1285 assert_eq!(iter.seek(25), 30);
1286 assert_eq!(iter.seek(100), 100);
1287 assert_eq!(iter.seek(500), 1000);
1288 assert_eq!(iter.seek(3000), u32::MAX);
1289 }
1290
1291 #[test]
1292 fn test_simd_bp128_serialization() {
1293 let doc_ids: Vec<u32> = (0..300).map(|i| i * 3).collect();
1294 let term_freqs: Vec<u32> = (0..300).map(|i| (i % 5) + 1).collect();
1295
1296 let list = VerticalBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.5);
1297
1298 let mut buffer = Vec::new();
1299 list.serialize(&mut buffer).unwrap();
1300
1301 let restored = VerticalBP128PostingList::deserialize(&mut &buffer[..]).unwrap();
1302
1303 assert_eq!(restored.doc_count, list.doc_count);
1304 assert_eq!(restored.blocks.len(), list.blocks.len());
1305
1306 let mut iter1 = list.iterator();
1307 let mut iter2 = restored.iterator();
1308
1309 while iter1.doc() != u32::MAX {
1310 assert_eq!(iter1.doc(), iter2.doc());
1311 assert_eq!(iter1.term_freq(), iter2.term_freq());
1312 iter1.advance();
1313 iter2.advance();
1314 }
1315 }
1316
1317 #[test]
1318 fn test_vertical_layout_size() {
1319 let mut values = [0u32; VERTICAL_BP128_BLOCK_SIZE];
1321 for (i, v) in values.iter_mut().enumerate() {
1322 *v = i as u32;
1323 }
1324
1325 let bit_width = bits_needed(127); assert_eq!(bit_width, 7);
1327
1328 let mut packed = Vec::new();
1329 pack_horizontal(&values, bit_width, &mut packed);
1330
1331 let expected_bytes = (VERTICAL_BP128_BLOCK_SIZE * bit_width as usize) / 8;
1333 assert_eq!(expected_bytes, 112);
1334 assert_eq!(packed.len(), expected_bytes);
1335 }
1336
1337 #[test]
1338 fn test_simd_bp128_block_max() {
1339 let doc_ids: Vec<u32> = (0..500).map(|i| i * 2).collect();
1341 let term_freqs: Vec<u32> = (0..500)
1343 .map(|i| {
1344 if i < 128 {
1345 1 } else if i < 256 {
1347 5 } else if i < 384 {
1349 10 } else {
1351 3 }
1353 })
1354 .collect();
1355
1356 let list = VerticalBP128PostingList::from_postings(&doc_ids, &term_freqs, 2.0);
1357
1358 assert_eq!(list.blocks.len(), 4);
1360 assert_eq!(list.blocks[0].max_tf, 1);
1361 assert_eq!(list.blocks[1].max_tf, 5);
1362 assert_eq!(list.blocks[2].max_tf, 10);
1363 assert_eq!(list.blocks[3].max_tf, 3);
1364
1365 assert!(list.blocks[2].max_block_score > list.blocks[0].max_block_score);
1367 assert!(list.blocks[2].max_block_score > list.blocks[1].max_block_score);
1368 assert!(list.blocks[2].max_block_score > list.blocks[3].max_block_score);
1369
1370 assert_eq!(list.max_score, list.blocks[2].max_block_score);
1372
1373 let mut iter = list.iterator();
1375 assert_eq!(iter.current_block_max_tf(), 1); iter.seek(256); assert_eq!(iter.current_block_max_tf(), 5);
1380
1381 iter.seek(512); assert_eq!(iter.current_block_max_tf(), 10);
1384
1385 let mut iter2 = list.iterator();
1387 let result = iter2.skip_to_block_with_doc(300);
1388 assert!(result.is_some());
1389 let (first_doc, score) = result.unwrap();
1390 assert!(first_doc <= 300);
1391 assert!(score > 0.0);
1392 }
1393}