1use super::simd;
14use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
15use std::io::{self, Read, Write};
16
17pub const VERTICAL_BP128_BLOCK_SIZE: usize = 128;
19
20#[allow(dead_code)]
22const SIMD_LANES: usize = 4;
23
24#[allow(dead_code)]
26const GROUPS_PER_BLOCK: usize = VERTICAL_BP128_BLOCK_SIZE / SIMD_LANES;
27
28#[cfg(target_arch = "aarch64")]
33#[allow(dead_code)]
34mod neon {
35 use super::*;
36 use std::arch::aarch64::*;
37
38 static BIT_EXPAND_LUT: [[u32; 8]; 256] = {
42 let mut lut = [[0u32; 8]; 256];
43 let mut byte = 0usize;
44 while byte < 256 {
45 let mut bit = 0;
46 while bit < 8 {
47 lut[byte][bit] = ((byte >> bit) & 1) as u32;
48 bit += 1;
49 }
50 byte += 1;
51 }
52 lut
53 };
54
55 #[inline]
57 #[target_feature(enable = "neon")]
58 pub unsafe fn unpack_4_neon(input: &[u8], bit_width: u8, output: &mut [u32; 4]) {
59 if bit_width == 0 {
60 *output = [0; 4];
61 return;
62 }
63
64 let mask = (1u32 << bit_width) - 1;
65
66 let mut packed_bytes = [0u8; 16];
68 let bytes_needed = ((bit_width as usize) * 4).div_ceil(8);
69 packed_bytes[..bytes_needed.min(16)].copy_from_slice(&input[..bytes_needed.min(16)]);
70 let packed = u128::from_le_bytes(packed_bytes);
71
72 let v0 = (packed & mask as u128) as u32;
74 let v1 = ((packed >> bit_width) & mask as u128) as u32;
75 let v2 = ((packed >> (bit_width * 2)) & mask as u128) as u32;
76 let v3 = ((packed >> (bit_width * 3)) & mask as u128) as u32;
77
78 unsafe {
80 let result = vld1q_u32([v0, v1, v2, v3].as_ptr());
81 vst1q_u32(output.as_mut_ptr(), result);
82 }
83 }
84
85 #[inline]
87 #[target_feature(enable = "neon")]
88 pub unsafe fn prefix_sum_4_neon(values: &mut [u32; 4]) {
89 unsafe {
90 let mut v = vld1q_u32(values.as_ptr());
92
93 let shifted1 = vextq_u32(vdupq_n_u32(0), v, 3); v = vaddq_u32(v, shifted1);
98 let shifted2 = vextq_u32(vdupq_n_u32(0), v, 2); v = vaddq_u32(v, shifted2);
103
104 vst1q_u32(values.as_mut_ptr(), v);
106 }
107 }
108
109 #[target_feature(enable = "neon")]
116 pub unsafe fn unpack_block_neon(
117 input: &[u8],
118 bit_width: u8,
119 output: &mut [u32; VERTICAL_BP128_BLOCK_SIZE],
120 ) {
121 if bit_width == 0 {
122 output.fill(0);
123 return;
124 }
125
126 unsafe {
128 let zero = vdupq_n_u32(0);
129 for i in (0..VERTICAL_BP128_BLOCK_SIZE).step_by(4) {
130 vst1q_u32(output[i..].as_mut_ptr(), zero);
131 }
132 }
133
134 for bit_pos in 0..bit_width as usize {
136 let byte_offset = bit_pos * 16;
137 let bit_mask = 1u32 << bit_pos;
138
139 if bit_pos + 1 < bit_width as usize {
141 let next_offset = (bit_pos + 1) * 16;
142 unsafe {
143 std::arch::asm!(
145 "prfm pldl1keep, [{0}]",
146 in(reg) input.as_ptr().add(next_offset),
147 options(nostack, preserves_flags)
148 );
149 }
150 }
151
152 for chunk in 0..4 {
154 let chunk_offset = byte_offset + chunk * 4;
155
156 let b0 = input[chunk_offset] as usize;
158 let b1 = input[chunk_offset + 1] as usize;
159 let b2 = input[chunk_offset + 2] as usize;
160 let b3 = input[chunk_offset + 3] as usize;
161
162 let base_int = chunk * 32;
163
164 unsafe {
165 let mask_vec = vdupq_n_u32(bit_mask);
166
167 let lut0 = &BIT_EXPAND_LUT[b0];
169 let bits_0_3 = vld1q_u32(lut0.as_ptr());
170 let bits_4_7 = vld1q_u32(lut0[4..].as_ptr());
171
172 let shifted_0_3 = vmulq_u32(bits_0_3, mask_vec);
173 let shifted_4_7 = vmulq_u32(bits_4_7, mask_vec);
174
175 let cur_0_3 = vld1q_u32(output[base_int..].as_ptr());
176 let cur_4_7 = vld1q_u32(output[base_int + 4..].as_ptr());
177
178 vst1q_u32(
179 output[base_int..].as_mut_ptr(),
180 vorrq_u32(cur_0_3, shifted_0_3),
181 );
182 vst1q_u32(
183 output[base_int + 4..].as_mut_ptr(),
184 vorrq_u32(cur_4_7, shifted_4_7),
185 );
186
187 let lut1 = &BIT_EXPAND_LUT[b1];
189 let bits_8_11 = vld1q_u32(lut1.as_ptr());
190 let bits_12_15 = vld1q_u32(lut1[4..].as_ptr());
191
192 let shifted_8_11 = vmulq_u32(bits_8_11, mask_vec);
193 let shifted_12_15 = vmulq_u32(bits_12_15, mask_vec);
194
195 let cur_8_11 = vld1q_u32(output[base_int + 8..].as_ptr());
196 let cur_12_15 = vld1q_u32(output[base_int + 12..].as_ptr());
197
198 vst1q_u32(
199 output[base_int + 8..].as_mut_ptr(),
200 vorrq_u32(cur_8_11, shifted_8_11),
201 );
202 vst1q_u32(
203 output[base_int + 12..].as_mut_ptr(),
204 vorrq_u32(cur_12_15, shifted_12_15),
205 );
206
207 let lut2 = &BIT_EXPAND_LUT[b2];
209 let bits_16_19 = vld1q_u32(lut2.as_ptr());
210 let bits_20_23 = vld1q_u32(lut2[4..].as_ptr());
211
212 let shifted_16_19 = vmulq_u32(bits_16_19, mask_vec);
213 let shifted_20_23 = vmulq_u32(bits_20_23, mask_vec);
214
215 let cur_16_19 = vld1q_u32(output[base_int + 16..].as_ptr());
216 let cur_20_23 = vld1q_u32(output[base_int + 20..].as_ptr());
217
218 vst1q_u32(
219 output[base_int + 16..].as_mut_ptr(),
220 vorrq_u32(cur_16_19, shifted_16_19),
221 );
222 vst1q_u32(
223 output[base_int + 20..].as_mut_ptr(),
224 vorrq_u32(cur_20_23, shifted_20_23),
225 );
226
227 let lut3 = &BIT_EXPAND_LUT[b3];
229 let bits_24_27 = vld1q_u32(lut3.as_ptr());
230 let bits_28_31 = vld1q_u32(lut3[4..].as_ptr());
231
232 let shifted_24_27 = vmulq_u32(bits_24_27, mask_vec);
233 let shifted_28_31 = vmulq_u32(bits_28_31, mask_vec);
234
235 let cur_24_27 = vld1q_u32(output[base_int + 24..].as_ptr());
236 let cur_28_31 = vld1q_u32(output[base_int + 28..].as_ptr());
237
238 vst1q_u32(
239 output[base_int + 24..].as_mut_ptr(),
240 vorrq_u32(cur_24_27, shifted_24_27),
241 );
242 vst1q_u32(
243 output[base_int + 28..].as_mut_ptr(),
244 vorrq_u32(cur_28_31, shifted_28_31),
245 );
246 }
247 }
248 }
249 }
250
251 #[target_feature(enable = "neon")]
253 pub unsafe fn prefix_sum_block_neon(
254 deltas: &mut [u32; VERTICAL_BP128_BLOCK_SIZE],
255 first_val: u32,
256 ) {
257 let mut carry = first_val;
258
259 for group in 0..GROUPS_PER_BLOCK {
260 let start = group * SIMD_LANES;
261 let mut group_vals = [
262 deltas[start],
263 deltas[start + 1],
264 deltas[start + 2],
265 deltas[start + 3],
266 ];
267
268 group_vals[0] = group_vals[0].wrapping_add(carry);
270
271 unsafe { prefix_sum_4_neon(&mut group_vals) };
273
274 deltas[start..start + 4].copy_from_slice(&group_vals);
276
277 carry = group_vals[3];
279 }
280 }
281}
282
283#[allow(dead_code)]
288mod scalar {
289 use super::*;
290
291 #[inline]
293 pub fn pack_4_scalar(values: &[u32; 4], bit_width: u8, output: &mut [u8]) {
294 if bit_width == 0 {
295 return;
296 }
297
298 let bytes_needed = ((bit_width as usize) * 4).div_ceil(8);
299 let mut packed = 0u128;
300 for (i, &val) in values.iter().enumerate() {
301 packed |= (val as u128) << (i * bit_width as usize);
302 }
303
304 let packed_bytes = packed.to_le_bytes();
305 output[..bytes_needed].copy_from_slice(&packed_bytes[..bytes_needed]);
306 }
307
308 #[inline]
310 pub fn unpack_4_scalar(input: &[u8], bit_width: u8, output: &mut [u32; 4]) {
311 if bit_width == 0 {
312 *output = [0; 4];
313 return;
314 }
315
316 let mask = (1u32 << bit_width) - 1;
317 let mut packed_bytes = [0u8; 16];
318 let bytes_needed = ((bit_width as usize) * 4).div_ceil(8);
319 packed_bytes[..bytes_needed.min(16)].copy_from_slice(&input[..bytes_needed.min(16)]);
320 let packed = u128::from_le_bytes(packed_bytes);
321
322 output[0] = (packed & mask as u128) as u32;
323 output[1] = ((packed >> bit_width) & mask as u128) as u32;
324 output[2] = ((packed >> (bit_width * 2)) & mask as u128) as u32;
325 output[3] = ((packed >> (bit_width * 3)) & mask as u128) as u32;
326 }
327
328 #[inline]
330 pub fn prefix_sum_4_scalar(vals: &mut [u32; 4]) {
331 vals[1] = vals[1].wrapping_add(vals[0]);
332 vals[2] = vals[2].wrapping_add(vals[1]);
333 vals[3] = vals[3].wrapping_add(vals[2]);
334 }
335
336 pub fn unpack_block_scalar(
338 input: &[u8],
339 bit_width: u8,
340 output: &mut [u32; VERTICAL_BP128_BLOCK_SIZE],
341 ) {
342 if bit_width == 0 {
343 output.fill(0);
344 return;
345 }
346
347 output.fill(0);
349
350 for bit_pos in 0..bit_width as usize {
352 let byte_offset = bit_pos * 16; for byte_idx in 0..16 {
355 let byte_val = input[byte_offset + byte_idx];
356 let base_int = byte_idx * 8;
357
358 output[base_int] |= (byte_val & 1) as u32 * (1 << bit_pos);
360 output[base_int + 1] |= ((byte_val >> 1) & 1) as u32 * (1 << bit_pos);
361 output[base_int + 2] |= ((byte_val >> 2) & 1) as u32 * (1 << bit_pos);
362 output[base_int + 3] |= ((byte_val >> 3) & 1) as u32 * (1 << bit_pos);
363 output[base_int + 4] |= ((byte_val >> 4) & 1) as u32 * (1 << bit_pos);
364 output[base_int + 5] |= ((byte_val >> 5) & 1) as u32 * (1 << bit_pos);
365 output[base_int + 6] |= ((byte_val >> 6) & 1) as u32 * (1 << bit_pos);
366 output[base_int + 7] |= ((byte_val >> 7) & 1) as u32 * (1 << bit_pos);
367 }
368 }
369 }
370
371 pub fn prefix_sum_block_scalar(deltas: &mut [u32; VERTICAL_BP128_BLOCK_SIZE], first_val: u32) {
373 let mut carry = first_val;
374
375 for group in 0..GROUPS_PER_BLOCK {
376 let start = group * SIMD_LANES;
377 let mut group_vals = [
378 deltas[start],
379 deltas[start + 1],
380 deltas[start + 2],
381 deltas[start + 3],
382 ];
383
384 group_vals[0] = group_vals[0].wrapping_add(carry);
385 prefix_sum_4_scalar(&mut group_vals);
386 deltas[start..start + 4].copy_from_slice(&group_vals);
387 carry = group_vals[3];
388 }
389 }
390}
391
392pub fn pack_vertical(
421 values: &[u32; VERTICAL_BP128_BLOCK_SIZE],
422 bit_width: u8,
423 output: &mut Vec<u8>,
424) {
425 if bit_width == 0 {
426 return;
427 }
428
429 let total_bytes = 16 * bit_width as usize;
431 let start = output.len();
432 output.resize(start + total_bytes, 0);
433
434 for bit_pos in 0..bit_width as usize {
436 let byte_offset = start + bit_pos * 16;
437
438 for byte_idx in 0..16 {
440 let base_int = byte_idx * 8;
441 let mut byte_val = 0u8;
442
443 byte_val |= ((values[base_int] >> bit_pos) & 1) as u8;
445 byte_val |= (((values[base_int + 1] >> bit_pos) & 1) as u8) << 1;
446 byte_val |= (((values[base_int + 2] >> bit_pos) & 1) as u8) << 2;
447 byte_val |= (((values[base_int + 3] >> bit_pos) & 1) as u8) << 3;
448 byte_val |= (((values[base_int + 4] >> bit_pos) & 1) as u8) << 4;
449 byte_val |= (((values[base_int + 5] >> bit_pos) & 1) as u8) << 5;
450 byte_val |= (((values[base_int + 6] >> bit_pos) & 1) as u8) << 6;
451 byte_val |= (((values[base_int + 7] >> bit_pos) & 1) as u8) << 7;
452
453 output[byte_offset + byte_idx] = byte_val;
454 }
455 }
456}
457
458pub fn unpack_vertical(input: &[u8], bit_width: u8, output: &mut [u32; VERTICAL_BP128_BLOCK_SIZE]) {
463 if bit_width == 0 {
464 output.fill(0);
465 return;
466 }
467
468 #[cfg(target_arch = "aarch64")]
469 {
470 unsafe { unpack_vertical_neon(input, bit_width, output) }
471 }
472
473 #[cfg(target_arch = "x86_64")]
474 {
475 if is_x86_feature_detected!("sse2") {
476 unsafe { unpack_vertical_sse(input, bit_width, output) }
477 } else {
478 unpack_vertical_scalar(input, bit_width, output)
479 }
480 }
481
482 #[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
483 {
484 unpack_vertical_scalar(input, bit_width, output)
485 }
486}
487
488#[inline]
490fn unpack_vertical_scalar(
491 input: &[u8],
492 bit_width: u8,
493 output: &mut [u32; VERTICAL_BP128_BLOCK_SIZE],
494) {
495 output.fill(0);
496
497 for bit_pos in 0..bit_width as usize {
499 let byte_offset = bit_pos * 16;
500 let bit_mask = 1u32 << bit_pos;
501
502 for byte_idx in 0..16 {
504 let byte_val = input[byte_offset + byte_idx];
505 let base_int = byte_idx * 8;
506
507 if byte_val & 0x01 != 0 {
509 output[base_int] |= bit_mask;
510 }
511 if byte_val & 0x02 != 0 {
512 output[base_int + 1] |= bit_mask;
513 }
514 if byte_val & 0x04 != 0 {
515 output[base_int + 2] |= bit_mask;
516 }
517 if byte_val & 0x08 != 0 {
518 output[base_int + 3] |= bit_mask;
519 }
520 if byte_val & 0x10 != 0 {
521 output[base_int + 4] |= bit_mask;
522 }
523 if byte_val & 0x20 != 0 {
524 output[base_int + 5] |= bit_mask;
525 }
526 if byte_val & 0x40 != 0 {
527 output[base_int + 6] |= bit_mask;
528 }
529 if byte_val & 0x80 != 0 {
530 output[base_int + 7] |= bit_mask;
531 }
532 }
533 }
534}
535
536#[cfg(target_arch = "aarch64")]
538#[target_feature(enable = "neon")]
539unsafe fn unpack_vertical_neon(
540 input: &[u8],
541 bit_width: u8,
542 output: &mut [u32; VERTICAL_BP128_BLOCK_SIZE],
543) {
544 use std::arch::aarch64::*;
545
546 unsafe {
547 let zero = vdupq_n_u32(0);
549 for i in (0..VERTICAL_BP128_BLOCK_SIZE).step_by(4) {
550 vst1q_u32(output[i..].as_mut_ptr(), zero);
551 }
552
553 for bit_pos in 0..bit_width as usize {
555 let byte_offset = bit_pos * 16;
556 let bit_mask = 1u32 << bit_pos;
557
558 let bytes = vld1q_u8(input.as_ptr().add(byte_offset));
560
561 let mut byte_array = [0u8; 16];
563 vst1q_u8(byte_array.as_mut_ptr(), bytes);
564
565 for (byte_idx, &byte_val) in byte_array.iter().enumerate() {
567 let base_int = byte_idx * 8;
568
569 output[base_int] |= ((byte_val & 0x01) as u32) * bit_mask;
571 output[base_int + 1] |= (((byte_val >> 1) & 0x01) as u32) * bit_mask;
572 output[base_int + 2] |= (((byte_val >> 2) & 0x01) as u32) * bit_mask;
573 output[base_int + 3] |= (((byte_val >> 3) & 0x01) as u32) * bit_mask;
574 output[base_int + 4] |= (((byte_val >> 4) & 0x01) as u32) * bit_mask;
575 output[base_int + 5] |= (((byte_val >> 5) & 0x01) as u32) * bit_mask;
576 output[base_int + 6] |= (((byte_val >> 6) & 0x01) as u32) * bit_mask;
577 output[base_int + 7] |= (((byte_val >> 7) & 0x01) as u32) * bit_mask;
578 }
579 }
580 }
581}
582
583#[cfg(target_arch = "x86_64")]
585#[target_feature(enable = "sse2")]
586unsafe fn unpack_vertical_sse(
587 input: &[u8],
588 bit_width: u8,
589 output: &mut [u32; VERTICAL_BP128_BLOCK_SIZE],
590) {
591 use std::arch::x86_64::*;
592
593 unsafe {
594 let zero = _mm_setzero_si128();
596 for i in (0..VERTICAL_BP128_BLOCK_SIZE).step_by(4) {
597 _mm_storeu_si128(output[i..].as_mut_ptr() as *mut __m128i, zero);
598 }
599
600 for bit_pos in 0..bit_width as usize {
602 let byte_offset = bit_pos * 16;
603
604 let bytes = _mm_loadu_si128(input.as_ptr().add(byte_offset) as *const __m128i);
606
607 let mut byte_array = [0u8; 16];
609 _mm_storeu_si128(byte_array.as_mut_ptr() as *mut __m128i, bytes);
610
611 for (byte_idx, &byte_val) in byte_array.iter().enumerate() {
612 let base_int = byte_idx * 8;
613
614 if byte_val & 0x01 != 0 {
616 output[base_int] |= 1u32 << bit_pos;
617 }
618 if byte_val & 0x02 != 0 {
619 output[base_int + 1] |= 1u32 << bit_pos;
620 }
621 if byte_val & 0x04 != 0 {
622 output[base_int + 2] |= 1u32 << bit_pos;
623 }
624 if byte_val & 0x08 != 0 {
625 output[base_int + 3] |= 1u32 << bit_pos;
626 }
627 if byte_val & 0x10 != 0 {
628 output[base_int + 4] |= 1u32 << bit_pos;
629 }
630 if byte_val & 0x20 != 0 {
631 output[base_int + 5] |= 1u32 << bit_pos;
632 }
633 if byte_val & 0x40 != 0 {
634 output[base_int + 6] |= 1u32 << bit_pos;
635 }
636 if byte_val & 0x80 != 0 {
637 output[base_int + 7] |= 1u32 << bit_pos;
638 }
639 }
640 }
641 }
642}
643
644#[allow(dead_code)]
646pub fn pack_horizontal(
647 values: &[u32; VERTICAL_BP128_BLOCK_SIZE],
648 bit_width: u8,
649 output: &mut Vec<u8>,
650) {
651 if bit_width == 0 {
652 return;
653 }
654
655 let bytes_needed = (VERTICAL_BP128_BLOCK_SIZE * bit_width as usize).div_ceil(8);
656 let start = output.len();
657 output.resize(start + bytes_needed, 0);
658
659 let mut bit_pos = 0usize;
660 for &value in values {
661 let byte_idx = start + bit_pos / 8;
662 let bit_offset = bit_pos % 8;
663
664 let mut remaining_bits = bit_width as usize;
665 let mut val = value;
666 let mut current_byte_idx = byte_idx;
667 let mut current_bit_offset = bit_offset;
668
669 while remaining_bits > 0 {
670 let bits_in_byte = (8 - current_bit_offset).min(remaining_bits);
671 let mask = ((1u32 << bits_in_byte) - 1) as u8;
672 output[current_byte_idx] |= ((val as u8) & mask) << current_bit_offset;
673 val >>= bits_in_byte;
674 remaining_bits -= bits_in_byte;
675 current_byte_idx += 1;
676 current_bit_offset = 0;
677 }
678
679 bit_pos += bit_width as usize;
680 }
681}
682
683#[allow(dead_code)]
684pub fn unpack_horizontal(
685 input: &[u8],
686 bit_width: u8,
687 output: &mut [u32; VERTICAL_BP128_BLOCK_SIZE],
688) {
689 if bit_width == 0 {
690 output.fill(0);
691 return;
692 }
693
694 let mask = (1u64 << bit_width) - 1;
695 let bit_width_usize = bit_width as usize;
696 let mut bit_pos = 0usize;
697 let input_ptr = input.as_ptr();
698
699 for out in output.iter_mut() {
700 let byte_idx = bit_pos >> 3;
701 let bit_offset = bit_pos & 7;
702
703 let word = unsafe { (input_ptr.add(byte_idx) as *const u64).read_unaligned() };
704
705 *out = ((word >> bit_offset) & mask) as u32;
706 bit_pos += bit_width_usize;
707 }
708}
709
710#[allow(dead_code)]
712pub fn prefix_sum_128(deltas: &mut [u32; VERTICAL_BP128_BLOCK_SIZE], first_val: u32) {
713 #[cfg(target_arch = "aarch64")]
714 {
715 unsafe { neon::prefix_sum_block_neon(deltas, first_val) }
716 }
717
718 #[cfg(not(target_arch = "aarch64"))]
719 {
720 scalar::prefix_sum_block_scalar(deltas, first_val)
721 }
722}
723
724pub fn unpack_vertical_d1(
733 input: &[u8],
734 bit_width: u8,
735 first_doc_id: u32,
736 output: &mut [u32; VERTICAL_BP128_BLOCK_SIZE],
737 count: usize,
738) {
739 if count == 0 {
740 return;
741 }
742
743 if bit_width == 0 {
744 let mut current = first_doc_id;
746 output[0] = current;
747 for out_val in output.iter_mut().take(count).skip(1) {
748 current = current.wrapping_add(1);
749 *out_val = current;
750 }
751 return;
752 }
753
754 let mut deltas = [0u32; VERTICAL_BP128_BLOCK_SIZE];
756 unpack_vertical(input, bit_width, &mut deltas);
757
758 output[0] = first_doc_id;
760 let mut current = first_doc_id;
761
762 for i in 1..count {
763 current = current.wrapping_add(deltas[i - 1]).wrapping_add(1);
765 output[i] = current;
766 }
767}
768
769#[derive(Debug, Clone)]
771pub struct VerticalBP128Block {
772 pub doc_data: Vec<u8>,
774 pub doc_bit_width: u8,
776 pub tf_data: Vec<u8>,
778 pub tf_bit_width: u8,
780 pub first_doc_id: u32,
782 pub last_doc_id: u32,
784 pub num_docs: u16,
786 pub max_tf: u32,
788 pub max_block_score: f32,
790}
791
792impl VerticalBP128Block {
793 pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
795 writer.write_u32::<LittleEndian>(self.first_doc_id)?;
796 writer.write_u32::<LittleEndian>(self.last_doc_id)?;
797 writer.write_u16::<LittleEndian>(self.num_docs)?;
798 writer.write_u8(self.doc_bit_width)?;
799 writer.write_u8(self.tf_bit_width)?;
800 writer.write_u32::<LittleEndian>(self.max_tf)?;
801 writer.write_f32::<LittleEndian>(self.max_block_score)?;
802
803 writer.write_u16::<LittleEndian>(self.doc_data.len() as u16)?;
804 writer.write_all(&self.doc_data)?;
805
806 writer.write_u16::<LittleEndian>(self.tf_data.len() as u16)?;
807 writer.write_all(&self.tf_data)?;
808
809 Ok(())
810 }
811
812 pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
814 let first_doc_id = reader.read_u32::<LittleEndian>()?;
815 let last_doc_id = reader.read_u32::<LittleEndian>()?;
816 let num_docs = reader.read_u16::<LittleEndian>()?;
817 let doc_bit_width = reader.read_u8()?;
818 let tf_bit_width = reader.read_u8()?;
819 let max_tf = reader.read_u32::<LittleEndian>()?;
820 let max_block_score = reader.read_f32::<LittleEndian>()?;
821
822 let doc_len = reader.read_u16::<LittleEndian>()? as usize;
823 let mut doc_data = vec![0u8; doc_len];
824 reader.read_exact(&mut doc_data)?;
825
826 let tf_len = reader.read_u16::<LittleEndian>()? as usize;
827 let mut tf_data = vec![0u8; tf_len];
828 reader.read_exact(&mut tf_data)?;
829
830 Ok(Self {
831 doc_data,
832 doc_bit_width,
833 tf_data,
834 tf_bit_width,
835 first_doc_id,
836 last_doc_id,
837 num_docs,
838 max_tf,
839 max_block_score,
840 })
841 }
842
843 pub fn decode_doc_ids(&self) -> Vec<u32> {
845 let mut output = vec![0u32; self.num_docs as usize];
846 self.decode_doc_ids_into(&mut output);
847 output
848 }
849
850 #[inline]
852 pub fn decode_doc_ids_into(&self, output: &mut [u32]) -> usize {
853 let count = self.num_docs as usize;
854 if count == 0 {
855 return 0;
856 }
857
858 if count == VERTICAL_BP128_BLOCK_SIZE && output.len() >= VERTICAL_BP128_BLOCK_SIZE {
861 let out_array: &mut [u32; VERTICAL_BP128_BLOCK_SIZE] = (&mut output
863 [..VERTICAL_BP128_BLOCK_SIZE])
864 .try_into()
865 .unwrap();
866 unpack_vertical_d1(
867 &self.doc_data,
868 self.doc_bit_width,
869 self.first_doc_id,
870 out_array,
871 count,
872 );
873 } else {
874 let mut temp = [0u32; VERTICAL_BP128_BLOCK_SIZE];
876 unpack_vertical_d1(
877 &self.doc_data,
878 self.doc_bit_width,
879 self.first_doc_id,
880 &mut temp,
881 count,
882 );
883 output[..count].copy_from_slice(&temp[..count]);
884 }
885
886 count
887 }
888
889 pub fn decode_term_freqs(&self) -> Vec<u32> {
891 let mut output = vec![0u32; self.num_docs as usize];
892 self.decode_term_freqs_into(&mut output);
893 output
894 }
895
896 #[inline]
898 pub fn decode_term_freqs_into(&self, output: &mut [u32]) -> usize {
899 let count = self.num_docs as usize;
900 if count == 0 {
901 return 0;
902 }
903
904 if count == VERTICAL_BP128_BLOCK_SIZE && output.len() >= VERTICAL_BP128_BLOCK_SIZE {
906 let out_array: &mut [u32; VERTICAL_BP128_BLOCK_SIZE] = (&mut output
907 [..VERTICAL_BP128_BLOCK_SIZE])
908 .try_into()
909 .unwrap();
910 unpack_vertical(&self.tf_data, self.tf_bit_width, out_array);
911 } else {
912 let mut temp = [0u32; VERTICAL_BP128_BLOCK_SIZE];
914 unpack_vertical(&self.tf_data, self.tf_bit_width, &mut temp);
915 output[..count].copy_from_slice(&temp[..count]);
916 }
917
918 simd::add_one(output, count);
920
921 count
922 }
923}
924
925#[derive(Debug, Clone)]
927pub struct VerticalBP128PostingList {
928 pub blocks: Vec<VerticalBP128Block>,
930 pub doc_count: u32,
932 pub max_score: f32,
934}
935
936impl VerticalBP128PostingList {
937 const K1: f32 = 1.2;
939 const B: f32 = 0.75;
940
941 #[inline]
943 pub fn compute_bm25_upper_bound(max_tf: u32, idf: f32) -> f32 {
944 let tf = max_tf as f32;
945 let min_length_norm = 1.0 - Self::B;
946 let tf_norm = (tf * (Self::K1 + 1.0)) / (tf + Self::K1 * min_length_norm);
947 idf * tf_norm
948 }
949
950 pub fn from_postings(doc_ids: &[u32], term_freqs: &[u32], idf: f32) -> Self {
952 assert_eq!(doc_ids.len(), term_freqs.len());
953
954 if doc_ids.is_empty() {
955 return Self {
956 blocks: Vec::new(),
957 doc_count: 0,
958 max_score: 0.0,
959 };
960 }
961
962 let mut blocks = Vec::new();
963 let mut max_score = 0.0f32;
964 let mut i = 0;
965
966 while i < doc_ids.len() {
967 let block_end = (i + VERTICAL_BP128_BLOCK_SIZE).min(doc_ids.len());
968 let block_docs = &doc_ids[i..block_end];
969 let block_tfs = &term_freqs[i..block_end];
970
971 let block = Self::create_block(block_docs, block_tfs, idf);
972 max_score = max_score.max(block.max_block_score);
973 blocks.push(block);
974
975 i = block_end;
976 }
977
978 Self {
979 blocks,
980 doc_count: doc_ids.len() as u32,
981 max_score,
982 }
983 }
984
985 fn create_block(doc_ids: &[u32], term_freqs: &[u32], idf: f32) -> VerticalBP128Block {
986 let num_docs = doc_ids.len();
987 let first_doc_id = doc_ids[0];
988 let last_doc_id = *doc_ids.last().unwrap();
989
990 let mut deltas = [0u32; VERTICAL_BP128_BLOCK_SIZE];
992 let mut max_delta = 0u32;
993 for j in 1..num_docs {
994 let delta = doc_ids[j] - doc_ids[j - 1] - 1;
995 deltas[j - 1] = delta;
996 max_delta = max_delta.max(delta);
997 }
998
999 let mut tfs = [0u32; VERTICAL_BP128_BLOCK_SIZE];
1001 let mut max_tf = 0u32;
1002 for (j, &tf) in term_freqs.iter().enumerate() {
1003 tfs[j] = tf.saturating_sub(1);
1004 max_tf = max_tf.max(tf);
1005 }
1006
1007 let doc_bit_width = simd::bits_needed(max_delta);
1008 let tf_bit_width = simd::bits_needed(max_tf.saturating_sub(1));
1009
1010 let mut doc_data = Vec::new();
1011 pack_vertical(&deltas, doc_bit_width, &mut doc_data);
1012
1013 let mut tf_data = Vec::new();
1014 pack_vertical(&tfs, tf_bit_width, &mut tf_data);
1015
1016 let max_block_score = Self::compute_bm25_upper_bound(max_tf, idf);
1017
1018 VerticalBP128Block {
1019 doc_data,
1020 doc_bit_width,
1021 tf_data,
1022 tf_bit_width,
1023 first_doc_id,
1024 last_doc_id,
1025 num_docs: num_docs as u16,
1026 max_tf,
1027 max_block_score,
1028 }
1029 }
1030
1031 pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
1033 writer.write_u32::<LittleEndian>(self.doc_count)?;
1034 writer.write_f32::<LittleEndian>(self.max_score)?;
1035 writer.write_u32::<LittleEndian>(self.blocks.len() as u32)?;
1036
1037 for block in &self.blocks {
1038 block.serialize(writer)?;
1039 }
1040
1041 Ok(())
1042 }
1043
1044 pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
1046 let doc_count = reader.read_u32::<LittleEndian>()?;
1047 let max_score = reader.read_f32::<LittleEndian>()?;
1048 let num_blocks = reader.read_u32::<LittleEndian>()? as usize;
1049
1050 let mut blocks = Vec::with_capacity(num_blocks);
1051 for _ in 0..num_blocks {
1052 blocks.push(VerticalBP128Block::deserialize(reader)?);
1053 }
1054
1055 Ok(Self {
1056 blocks,
1057 doc_count,
1058 max_score,
1059 })
1060 }
1061
1062 pub fn iterator(&self) -> VerticalBP128Iterator<'_> {
1064 VerticalBP128Iterator::new(self)
1065 }
1066
1067 pub fn size_bytes(&self) -> usize {
1069 let mut size = 12; for block in &self.blocks {
1071 size += 22 + block.doc_data.len() + block.tf_data.len();
1072 }
1073 size
1074 }
1075}
1076
1077pub struct VerticalBP128Iterator<'a> {
1079 list: &'a VerticalBP128PostingList,
1080 current_block: usize,
1081 current_block_len: usize,
1083 block_doc_ids: Vec<u32>,
1085 block_term_freqs: Vec<u32>,
1087 pos_in_block: usize,
1088 exhausted: bool,
1089}
1090
1091impl<'a> VerticalBP128Iterator<'a> {
1092 pub fn new(list: &'a VerticalBP128PostingList) -> Self {
1093 let mut iter = Self {
1095 list,
1096 current_block: 0,
1097 current_block_len: 0,
1098 block_doc_ids: vec![0u32; VERTICAL_BP128_BLOCK_SIZE],
1099 block_term_freqs: vec![0u32; VERTICAL_BP128_BLOCK_SIZE],
1100 pos_in_block: 0,
1101 exhausted: list.blocks.is_empty(),
1102 };
1103
1104 if !iter.exhausted {
1105 iter.decode_current_block();
1106 }
1107
1108 iter
1109 }
1110
1111 #[inline]
1112 fn decode_current_block(&mut self) {
1113 let block = &self.list.blocks[self.current_block];
1114 self.current_block_len = block.decode_doc_ids_into(&mut self.block_doc_ids);
1116 block.decode_term_freqs_into(&mut self.block_term_freqs);
1117 self.pos_in_block = 0;
1118 }
1119
1120 #[inline]
1122 pub fn doc(&self) -> u32 {
1123 if self.exhausted {
1124 u32::MAX
1125 } else {
1126 self.block_doc_ids[self.pos_in_block]
1127 }
1128 }
1129
1130 #[inline]
1132 pub fn term_freq(&self) -> u32 {
1133 if self.exhausted {
1134 0
1135 } else {
1136 self.block_term_freqs[self.pos_in_block]
1137 }
1138 }
1139
1140 #[inline]
1142 pub fn advance(&mut self) -> u32 {
1143 if self.exhausted {
1144 return u32::MAX;
1145 }
1146
1147 self.pos_in_block += 1;
1148
1149 if self.pos_in_block >= self.current_block_len {
1150 self.current_block += 1;
1151 if self.current_block >= self.list.blocks.len() {
1152 self.exhausted = true;
1153 return u32::MAX;
1154 }
1155 self.decode_current_block();
1156 }
1157
1158 self.doc()
1159 }
1160
1161 pub fn seek(&mut self, target: u32) -> u32 {
1163 if self.exhausted {
1164 return u32::MAX;
1165 }
1166
1167 let block_idx = self.list.blocks[self.current_block..].binary_search_by(|block| {
1169 if block.last_doc_id < target {
1170 std::cmp::Ordering::Less
1171 } else if block.first_doc_id > target {
1172 std::cmp::Ordering::Greater
1173 } else {
1174 std::cmp::Ordering::Equal
1175 }
1176 });
1177
1178 let target_block = match block_idx {
1179 Ok(idx) => self.current_block + idx,
1180 Err(idx) => {
1181 if self.current_block + idx >= self.list.blocks.len() {
1182 self.exhausted = true;
1183 return u32::MAX;
1184 }
1185 self.current_block + idx
1186 }
1187 };
1188
1189 if target_block != self.current_block {
1190 self.current_block = target_block;
1191 self.decode_current_block();
1192 }
1193
1194 let pos = self.block_doc_ids[self.pos_in_block..self.current_block_len]
1196 .binary_search(&target)
1197 .unwrap_or_else(|x| x);
1198 self.pos_in_block += pos;
1199
1200 if self.pos_in_block >= self.current_block_len {
1201 self.current_block += 1;
1202 if self.current_block >= self.list.blocks.len() {
1203 self.exhausted = true;
1204 return u32::MAX;
1205 }
1206 self.decode_current_block();
1207 }
1208
1209 self.doc()
1210 }
1211
1212 pub fn max_remaining_score(&self) -> f32 {
1214 if self.exhausted {
1215 return 0.0;
1216 }
1217 self.list.blocks[self.current_block..]
1218 .iter()
1219 .map(|b| b.max_block_score)
1220 .fold(0.0f32, |a, b| a.max(b))
1221 }
1222
1223 pub fn current_block_max_score(&self) -> f32 {
1225 if self.exhausted {
1226 0.0
1227 } else {
1228 self.list.blocks[self.current_block].max_block_score
1229 }
1230 }
1231
1232 pub fn current_block_max_tf(&self) -> u32 {
1234 if self.exhausted {
1235 0
1236 } else {
1237 self.list.blocks[self.current_block].max_tf
1238 }
1239 }
1240
1241 pub fn skip_to_block_with_doc(&mut self, target: u32) -> Option<(u32, f32)> {
1244 while self.current_block < self.list.blocks.len() {
1245 let block = &self.list.blocks[self.current_block];
1246 if block.last_doc_id >= target {
1247 self.decode_current_block();
1249 return Some((block.first_doc_id, block.max_block_score));
1250 }
1251 self.current_block += 1;
1252 }
1253 self.exhausted = true;
1254 None
1255 }
1256
1257 pub fn is_exhausted(&self) -> bool {
1259 self.exhausted
1260 }
1261}
1262
1263#[cfg(test)]
1264mod tests {
1265 use super::*;
1266
1267 #[test]
1268 fn test_pack_unpack_vertical() {
1269 let mut values = [0u32; VERTICAL_BP128_BLOCK_SIZE];
1270 for (i, v) in values.iter_mut().enumerate() {
1271 *v = (i * 3) as u32;
1272 }
1273
1274 let max_val = values.iter().max().copied().unwrap();
1275 let bit_width = simd::bits_needed(max_val);
1276
1277 let mut packed = Vec::new();
1278 pack_vertical(&values, bit_width, &mut packed);
1279
1280 let mut unpacked = [0u32; VERTICAL_BP128_BLOCK_SIZE];
1281 unpack_vertical(&packed, bit_width, &mut unpacked);
1282
1283 assert_eq!(values, unpacked);
1284 }
1285
1286 #[test]
1287 fn test_pack_unpack_vertical_various_widths() {
1288 for bit_width in 1..=20 {
1289 let mut values = [0u32; VERTICAL_BP128_BLOCK_SIZE];
1290 let max_val = (1u32 << bit_width) - 1;
1291 for (i, v) in values.iter_mut().enumerate() {
1292 *v = (i as u32) % (max_val + 1);
1293 }
1294
1295 let mut packed = Vec::new();
1296 pack_vertical(&values, bit_width, &mut packed);
1297
1298 let mut unpacked = [0u32; VERTICAL_BP128_BLOCK_SIZE];
1299 unpack_vertical(&packed, bit_width, &mut unpacked);
1300
1301 assert_eq!(values, unpacked, "Failed for bit_width={}", bit_width);
1302 }
1303 }
1304
1305 #[test]
1306 fn test_simd_bp128_posting_list() {
1307 let doc_ids: Vec<u32> = (0..200).map(|i| i * 2).collect();
1308 let term_freqs: Vec<u32> = (0..200).map(|i| (i % 10) + 1).collect();
1309
1310 let list = VerticalBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
1311
1312 assert_eq!(list.doc_count, 200);
1313 assert_eq!(list.blocks.len(), 2); let mut iter = list.iterator();
1316 for (i, &expected_doc) in doc_ids.iter().enumerate() {
1317 assert_eq!(iter.doc(), expected_doc, "Doc mismatch at {}", i);
1318 assert_eq!(iter.term_freq(), term_freqs[i], "TF mismatch at {}", i);
1319 if i < doc_ids.len() - 1 {
1320 iter.advance();
1321 }
1322 }
1323 }
1324
1325 #[test]
1326 fn test_simd_bp128_seek() {
1327 let doc_ids: Vec<u32> = vec![10, 20, 30, 100, 200, 300, 1000, 2000];
1328 let term_freqs: Vec<u32> = vec![1, 2, 3, 4, 5, 6, 7, 8];
1329
1330 let list = VerticalBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
1331 let mut iter = list.iterator();
1332
1333 assert_eq!(iter.seek(25), 30);
1334 assert_eq!(iter.seek(100), 100);
1335 assert_eq!(iter.seek(500), 1000);
1336 assert_eq!(iter.seek(3000), u32::MAX);
1337 }
1338
1339 #[test]
1340 fn test_simd_bp128_serialization() {
1341 let doc_ids: Vec<u32> = (0..300).map(|i| i * 3).collect();
1342 let term_freqs: Vec<u32> = (0..300).map(|i| (i % 5) + 1).collect();
1343
1344 let list = VerticalBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.5);
1345
1346 let mut buffer = Vec::new();
1347 list.serialize(&mut buffer).unwrap();
1348
1349 let restored = VerticalBP128PostingList::deserialize(&mut &buffer[..]).unwrap();
1350
1351 assert_eq!(restored.doc_count, list.doc_count);
1352 assert_eq!(restored.blocks.len(), list.blocks.len());
1353
1354 let mut iter1 = list.iterator();
1355 let mut iter2 = restored.iterator();
1356
1357 while iter1.doc() != u32::MAX {
1358 assert_eq!(iter1.doc(), iter2.doc());
1359 assert_eq!(iter1.term_freq(), iter2.term_freq());
1360 iter1.advance();
1361 iter2.advance();
1362 }
1363 }
1364
1365 #[test]
1366 fn test_vertical_layout_size() {
1367 let mut values = [0u32; VERTICAL_BP128_BLOCK_SIZE];
1369 for (i, v) in values.iter_mut().enumerate() {
1370 *v = i as u32;
1371 }
1372
1373 let bit_width = simd::bits_needed(127); assert_eq!(bit_width, 7);
1375
1376 let mut packed = Vec::new();
1377 pack_horizontal(&values, bit_width, &mut packed);
1378
1379 let expected_bytes = (VERTICAL_BP128_BLOCK_SIZE * bit_width as usize) / 8;
1381 assert_eq!(expected_bytes, 112);
1382 assert_eq!(packed.len(), expected_bytes);
1383 }
1384
1385 #[test]
1386 fn test_simd_bp128_block_max() {
1387 let doc_ids: Vec<u32> = (0..500).map(|i| i * 2).collect();
1389 let term_freqs: Vec<u32> = (0..500)
1391 .map(|i| {
1392 if i < 128 {
1393 1 } else if i < 256 {
1395 5 } else if i < 384 {
1397 10 } else {
1399 3 }
1401 })
1402 .collect();
1403
1404 let list = VerticalBP128PostingList::from_postings(&doc_ids, &term_freqs, 2.0);
1405
1406 assert_eq!(list.blocks.len(), 4);
1408 assert_eq!(list.blocks[0].max_tf, 1);
1409 assert_eq!(list.blocks[1].max_tf, 5);
1410 assert_eq!(list.blocks[2].max_tf, 10);
1411 assert_eq!(list.blocks[3].max_tf, 3);
1412
1413 assert!(list.blocks[2].max_block_score > list.blocks[0].max_block_score);
1415 assert!(list.blocks[2].max_block_score > list.blocks[1].max_block_score);
1416 assert!(list.blocks[2].max_block_score > list.blocks[3].max_block_score);
1417
1418 assert_eq!(list.max_score, list.blocks[2].max_block_score);
1420
1421 let mut iter = list.iterator();
1423 assert_eq!(iter.current_block_max_tf(), 1); iter.seek(256); assert_eq!(iter.current_block_max_tf(), 5);
1428
1429 iter.seek(512); assert_eq!(iter.current_block_max_tf(), 10);
1432
1433 let mut iter2 = list.iterator();
1435 let result = iter2.skip_to_block_with_doc(300);
1436 assert!(result.is_some());
1437 let (first_doc, score) = result.unwrap();
1438 assert!(first_doc <= 300);
1439 assert!(score > 0.0);
1440 }
1441}