1use super::simd;
14use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
15use std::io::{self, Read, Write};
16
17pub const VERTICAL_BP128_BLOCK_SIZE: usize = 128;
19
20#[allow(dead_code)]
22const SIMD_LANES: usize = 4;
23
24#[allow(dead_code)]
26const GROUPS_PER_BLOCK: usize = VERTICAL_BP128_BLOCK_SIZE / SIMD_LANES;
27
28#[cfg(target_arch = "aarch64")]
33#[allow(dead_code)]
34mod neon {
35 use super::*;
36 use std::arch::aarch64::*;
37
38 static BIT_EXPAND_LUT: [[u32; 8]; 256] = {
42 let mut lut = [[0u32; 8]; 256];
43 let mut byte = 0usize;
44 while byte < 256 {
45 let mut bit = 0;
46 while bit < 8 {
47 lut[byte][bit] = ((byte >> bit) & 1) as u32;
48 bit += 1;
49 }
50 byte += 1;
51 }
52 lut
53 };
54
55 #[inline]
57 #[target_feature(enable = "neon")]
58 pub unsafe fn unpack_4_neon(input: &[u8], bit_width: u8, output: &mut [u32; 4]) {
59 if bit_width == 0 {
60 *output = [0; 4];
61 return;
62 }
63
64 let mask = (1u32 << bit_width) - 1;
65
66 let mut packed_bytes = [0u8; 16];
68 let bytes_needed = ((bit_width as usize) * 4).div_ceil(8);
69 packed_bytes[..bytes_needed.min(16)].copy_from_slice(&input[..bytes_needed.min(16)]);
70 let packed = u128::from_le_bytes(packed_bytes);
71
72 let v0 = (packed & mask as u128) as u32;
74 let v1 = ((packed >> bit_width) & mask as u128) as u32;
75 let v2 = ((packed >> (bit_width * 2)) & mask as u128) as u32;
76 let v3 = ((packed >> (bit_width * 3)) & mask as u128) as u32;
77
78 unsafe {
80 let result = vld1q_u32([v0, v1, v2, v3].as_ptr());
81 vst1q_u32(output.as_mut_ptr(), result);
82 }
83 }
84
85 #[inline]
87 #[target_feature(enable = "neon")]
88 pub unsafe fn prefix_sum_4_neon(values: &mut [u32; 4]) {
89 unsafe {
90 let mut v = vld1q_u32(values.as_ptr());
92
93 let shifted1 = vextq_u32(vdupq_n_u32(0), v, 3); v = vaddq_u32(v, shifted1);
98 let shifted2 = vextq_u32(vdupq_n_u32(0), v, 2); v = vaddq_u32(v, shifted2);
103
104 vst1q_u32(values.as_mut_ptr(), v);
106 }
107 }
108
109 #[target_feature(enable = "neon")]
116 pub unsafe fn unpack_block_neon(
117 input: &[u8],
118 bit_width: u8,
119 output: &mut [u32; VERTICAL_BP128_BLOCK_SIZE],
120 ) {
121 if bit_width == 0 {
122 output.fill(0);
123 return;
124 }
125
126 unsafe {
128 let zero = vdupq_n_u32(0);
129 for i in (0..VERTICAL_BP128_BLOCK_SIZE).step_by(4) {
130 vst1q_u32(output[i..].as_mut_ptr(), zero);
131 }
132 }
133
134 for bit_pos in 0..bit_width as usize {
136 let byte_offset = bit_pos * 16;
137 let bit_mask = 1u32 << bit_pos;
138
139 if bit_pos + 1 < bit_width as usize {
141 let next_offset = (bit_pos + 1) * 16;
142 unsafe {
143 std::arch::asm!(
145 "prfm pldl1keep, [{0}]",
146 in(reg) input.as_ptr().add(next_offset),
147 options(nostack, preserves_flags)
148 );
149 }
150 }
151
152 for chunk in 0..4 {
154 let chunk_offset = byte_offset + chunk * 4;
155
156 let b0 = input[chunk_offset] as usize;
158 let b1 = input[chunk_offset + 1] as usize;
159 let b2 = input[chunk_offset + 2] as usize;
160 let b3 = input[chunk_offset + 3] as usize;
161
162 let base_int = chunk * 32;
163
164 unsafe {
165 let mask_vec = vdupq_n_u32(bit_mask);
166
167 let lut0 = &BIT_EXPAND_LUT[b0];
169 let bits_0_3 = vld1q_u32(lut0.as_ptr());
170 let bits_4_7 = vld1q_u32(lut0[4..].as_ptr());
171
172 let shifted_0_3 = vmulq_u32(bits_0_3, mask_vec);
173 let shifted_4_7 = vmulq_u32(bits_4_7, mask_vec);
174
175 let cur_0_3 = vld1q_u32(output[base_int..].as_ptr());
176 let cur_4_7 = vld1q_u32(output[base_int + 4..].as_ptr());
177
178 vst1q_u32(
179 output[base_int..].as_mut_ptr(),
180 vorrq_u32(cur_0_3, shifted_0_3),
181 );
182 vst1q_u32(
183 output[base_int + 4..].as_mut_ptr(),
184 vorrq_u32(cur_4_7, shifted_4_7),
185 );
186
187 let lut1 = &BIT_EXPAND_LUT[b1];
189 let bits_8_11 = vld1q_u32(lut1.as_ptr());
190 let bits_12_15 = vld1q_u32(lut1[4..].as_ptr());
191
192 let shifted_8_11 = vmulq_u32(bits_8_11, mask_vec);
193 let shifted_12_15 = vmulq_u32(bits_12_15, mask_vec);
194
195 let cur_8_11 = vld1q_u32(output[base_int + 8..].as_ptr());
196 let cur_12_15 = vld1q_u32(output[base_int + 12..].as_ptr());
197
198 vst1q_u32(
199 output[base_int + 8..].as_mut_ptr(),
200 vorrq_u32(cur_8_11, shifted_8_11),
201 );
202 vst1q_u32(
203 output[base_int + 12..].as_mut_ptr(),
204 vorrq_u32(cur_12_15, shifted_12_15),
205 );
206
207 let lut2 = &BIT_EXPAND_LUT[b2];
209 let bits_16_19 = vld1q_u32(lut2.as_ptr());
210 let bits_20_23 = vld1q_u32(lut2[4..].as_ptr());
211
212 let shifted_16_19 = vmulq_u32(bits_16_19, mask_vec);
213 let shifted_20_23 = vmulq_u32(bits_20_23, mask_vec);
214
215 let cur_16_19 = vld1q_u32(output[base_int + 16..].as_ptr());
216 let cur_20_23 = vld1q_u32(output[base_int + 20..].as_ptr());
217
218 vst1q_u32(
219 output[base_int + 16..].as_mut_ptr(),
220 vorrq_u32(cur_16_19, shifted_16_19),
221 );
222 vst1q_u32(
223 output[base_int + 20..].as_mut_ptr(),
224 vorrq_u32(cur_20_23, shifted_20_23),
225 );
226
227 let lut3 = &BIT_EXPAND_LUT[b3];
229 let bits_24_27 = vld1q_u32(lut3.as_ptr());
230 let bits_28_31 = vld1q_u32(lut3[4..].as_ptr());
231
232 let shifted_24_27 = vmulq_u32(bits_24_27, mask_vec);
233 let shifted_28_31 = vmulq_u32(bits_28_31, mask_vec);
234
235 let cur_24_27 = vld1q_u32(output[base_int + 24..].as_ptr());
236 let cur_28_31 = vld1q_u32(output[base_int + 28..].as_ptr());
237
238 vst1q_u32(
239 output[base_int + 24..].as_mut_ptr(),
240 vorrq_u32(cur_24_27, shifted_24_27),
241 );
242 vst1q_u32(
243 output[base_int + 28..].as_mut_ptr(),
244 vorrq_u32(cur_28_31, shifted_28_31),
245 );
246 }
247 }
248 }
249 }
250
251 #[target_feature(enable = "neon")]
253 pub unsafe fn prefix_sum_block_neon(
254 deltas: &mut [u32; VERTICAL_BP128_BLOCK_SIZE],
255 first_val: u32,
256 ) {
257 let mut carry = first_val;
258
259 for group in 0..GROUPS_PER_BLOCK {
260 let start = group * SIMD_LANES;
261 let mut group_vals = [
262 deltas[start],
263 deltas[start + 1],
264 deltas[start + 2],
265 deltas[start + 3],
266 ];
267
268 group_vals[0] = group_vals[0].wrapping_add(carry);
270
271 unsafe { prefix_sum_4_neon(&mut group_vals) };
273
274 deltas[start..start + 4].copy_from_slice(&group_vals);
276
277 carry = group_vals[3];
279 }
280 }
281}
282
283#[allow(dead_code)]
288mod scalar {
289 use super::*;
290
291 #[inline]
293 pub fn pack_4_scalar(values: &[u32; 4], bit_width: u8, output: &mut [u8]) {
294 if bit_width == 0 {
295 return;
296 }
297
298 let bytes_needed = ((bit_width as usize) * 4).div_ceil(8);
299 let mut packed = 0u128;
300 for (i, &val) in values.iter().enumerate() {
301 packed |= (val as u128) << (i * bit_width as usize);
302 }
303
304 let packed_bytes = packed.to_le_bytes();
305 output[..bytes_needed].copy_from_slice(&packed_bytes[..bytes_needed]);
306 }
307
308 #[inline]
310 pub fn unpack_4_scalar(input: &[u8], bit_width: u8, output: &mut [u32; 4]) {
311 if bit_width == 0 {
312 *output = [0; 4];
313 return;
314 }
315
316 let mask = (1u32 << bit_width) - 1;
317 let mut packed_bytes = [0u8; 16];
318 let bytes_needed = ((bit_width as usize) * 4).div_ceil(8);
319 packed_bytes[..bytes_needed.min(16)].copy_from_slice(&input[..bytes_needed.min(16)]);
320 let packed = u128::from_le_bytes(packed_bytes);
321
322 output[0] = (packed & mask as u128) as u32;
323 output[1] = ((packed >> bit_width) & mask as u128) as u32;
324 output[2] = ((packed >> (bit_width * 2)) & mask as u128) as u32;
325 output[3] = ((packed >> (bit_width * 3)) & mask as u128) as u32;
326 }
327
328 #[inline]
330 pub fn prefix_sum_4_scalar(vals: &mut [u32; 4]) {
331 vals[1] = vals[1].wrapping_add(vals[0]);
332 vals[2] = vals[2].wrapping_add(vals[1]);
333 vals[3] = vals[3].wrapping_add(vals[2]);
334 }
335
336 pub fn unpack_block_scalar(
338 input: &[u8],
339 bit_width: u8,
340 output: &mut [u32; VERTICAL_BP128_BLOCK_SIZE],
341 ) {
342 if bit_width == 0 {
343 output.fill(0);
344 return;
345 }
346
347 output.fill(0);
349
350 for bit_pos in 0..bit_width as usize {
352 let byte_offset = bit_pos * 16; for byte_idx in 0..16 {
355 let byte_val = input[byte_offset + byte_idx];
356 let base_int = byte_idx * 8;
357
358 output[base_int] |= (byte_val & 1) as u32 * (1 << bit_pos);
360 output[base_int + 1] |= ((byte_val >> 1) & 1) as u32 * (1 << bit_pos);
361 output[base_int + 2] |= ((byte_val >> 2) & 1) as u32 * (1 << bit_pos);
362 output[base_int + 3] |= ((byte_val >> 3) & 1) as u32 * (1 << bit_pos);
363 output[base_int + 4] |= ((byte_val >> 4) & 1) as u32 * (1 << bit_pos);
364 output[base_int + 5] |= ((byte_val >> 5) & 1) as u32 * (1 << bit_pos);
365 output[base_int + 6] |= ((byte_val >> 6) & 1) as u32 * (1 << bit_pos);
366 output[base_int + 7] |= ((byte_val >> 7) & 1) as u32 * (1 << bit_pos);
367 }
368 }
369 }
370
371 pub fn prefix_sum_block_scalar(deltas: &mut [u32; VERTICAL_BP128_BLOCK_SIZE], first_val: u32) {
373 let mut carry = first_val;
374
375 for group in 0..GROUPS_PER_BLOCK {
376 let start = group * SIMD_LANES;
377 let mut group_vals = [
378 deltas[start],
379 deltas[start + 1],
380 deltas[start + 2],
381 deltas[start + 3],
382 ];
383
384 group_vals[0] = group_vals[0].wrapping_add(carry);
385 prefix_sum_4_scalar(&mut group_vals);
386 deltas[start..start + 4].copy_from_slice(&group_vals);
387 carry = group_vals[3];
388 }
389 }
390}
391
392pub fn pack_vertical(
421 values: &[u32; VERTICAL_BP128_BLOCK_SIZE],
422 bit_width: u8,
423 output: &mut Vec<u8>,
424) {
425 if bit_width == 0 {
426 return;
427 }
428
429 let total_bytes = 16 * bit_width as usize;
431 let start = output.len();
432 output.resize(start + total_bytes, 0);
433
434 for bit_pos in 0..bit_width as usize {
436 let byte_offset = start + bit_pos * 16;
437
438 for byte_idx in 0..16 {
440 let base_int = byte_idx * 8;
441 let mut byte_val = 0u8;
442
443 byte_val |= ((values[base_int] >> bit_pos) & 1) as u8;
445 byte_val |= (((values[base_int + 1] >> bit_pos) & 1) as u8) << 1;
446 byte_val |= (((values[base_int + 2] >> bit_pos) & 1) as u8) << 2;
447 byte_val |= (((values[base_int + 3] >> bit_pos) & 1) as u8) << 3;
448 byte_val |= (((values[base_int + 4] >> bit_pos) & 1) as u8) << 4;
449 byte_val |= (((values[base_int + 5] >> bit_pos) & 1) as u8) << 5;
450 byte_val |= (((values[base_int + 6] >> bit_pos) & 1) as u8) << 6;
451 byte_val |= (((values[base_int + 7] >> bit_pos) & 1) as u8) << 7;
452
453 output[byte_offset + byte_idx] = byte_val;
454 }
455 }
456}
457
458pub fn unpack_vertical(input: &[u8], bit_width: u8, output: &mut [u32; VERTICAL_BP128_BLOCK_SIZE]) {
463 if bit_width == 0 {
464 output.fill(0);
465 return;
466 }
467
468 #[cfg(target_arch = "aarch64")]
469 {
470 unsafe { unpack_vertical_neon(input, bit_width, output) }
471 }
472
473 #[cfg(target_arch = "x86_64")]
474 {
475 if is_x86_feature_detected!("sse2") {
476 unsafe { unpack_vertical_sse(input, bit_width, output) }
477 } else {
478 unpack_vertical_scalar(input, bit_width, output)
479 }
480 }
481
482 #[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
483 {
484 unpack_vertical_scalar(input, bit_width, output)
485 }
486}
487
488#[inline]
490fn unpack_vertical_scalar(
491 input: &[u8],
492 bit_width: u8,
493 output: &mut [u32; VERTICAL_BP128_BLOCK_SIZE],
494) {
495 output.fill(0);
496
497 for bit_pos in 0..bit_width as usize {
499 let byte_offset = bit_pos * 16;
500 let bit_mask = 1u32 << bit_pos;
501
502 for byte_idx in 0..16 {
504 let byte_val = input[byte_offset + byte_idx];
505 let base_int = byte_idx * 8;
506
507 if byte_val & 0x01 != 0 {
509 output[base_int] |= bit_mask;
510 }
511 if byte_val & 0x02 != 0 {
512 output[base_int + 1] |= bit_mask;
513 }
514 if byte_val & 0x04 != 0 {
515 output[base_int + 2] |= bit_mask;
516 }
517 if byte_val & 0x08 != 0 {
518 output[base_int + 3] |= bit_mask;
519 }
520 if byte_val & 0x10 != 0 {
521 output[base_int + 4] |= bit_mask;
522 }
523 if byte_val & 0x20 != 0 {
524 output[base_int + 5] |= bit_mask;
525 }
526 if byte_val & 0x40 != 0 {
527 output[base_int + 6] |= bit_mask;
528 }
529 if byte_val & 0x80 != 0 {
530 output[base_int + 7] |= bit_mask;
531 }
532 }
533 }
534}
535
536#[cfg(target_arch = "aarch64")]
538#[target_feature(enable = "neon")]
539unsafe fn unpack_vertical_neon(
540 input: &[u8],
541 bit_width: u8,
542 output: &mut [u32; VERTICAL_BP128_BLOCK_SIZE],
543) {
544 use std::arch::aarch64::*;
545
546 unsafe {
547 let zero = vdupq_n_u32(0);
549 for i in (0..VERTICAL_BP128_BLOCK_SIZE).step_by(4) {
550 vst1q_u32(output[i..].as_mut_ptr(), zero);
551 }
552
553 for bit_pos in 0..bit_width as usize {
555 let byte_offset = bit_pos * 16;
556 let bit_mask = 1u32 << bit_pos;
557
558 let bytes = vld1q_u8(input.as_ptr().add(byte_offset));
560
561 let mut byte_array = [0u8; 16];
563 vst1q_u8(byte_array.as_mut_ptr(), bytes);
564
565 for (byte_idx, &byte_val) in byte_array.iter().enumerate() {
567 let base_int = byte_idx * 8;
568
569 output[base_int] |= ((byte_val & 0x01) as u32) * bit_mask;
571 output[base_int + 1] |= (((byte_val >> 1) & 0x01) as u32) * bit_mask;
572 output[base_int + 2] |= (((byte_val >> 2) & 0x01) as u32) * bit_mask;
573 output[base_int + 3] |= (((byte_val >> 3) & 0x01) as u32) * bit_mask;
574 output[base_int + 4] |= (((byte_val >> 4) & 0x01) as u32) * bit_mask;
575 output[base_int + 5] |= (((byte_val >> 5) & 0x01) as u32) * bit_mask;
576 output[base_int + 6] |= (((byte_val >> 6) & 0x01) as u32) * bit_mask;
577 output[base_int + 7] |= (((byte_val >> 7) & 0x01) as u32) * bit_mask;
578 }
579 }
580 }
581}
582
583#[cfg(target_arch = "x86_64")]
585#[target_feature(enable = "sse2")]
586unsafe fn unpack_vertical_sse(
587 input: &[u8],
588 bit_width: u8,
589 output: &mut [u32; VERTICAL_BP128_BLOCK_SIZE],
590) {
591 use std::arch::x86_64::*;
592
593 unsafe {
594 let zero = _mm_setzero_si128();
596 for i in (0..VERTICAL_BP128_BLOCK_SIZE).step_by(4) {
597 _mm_storeu_si128(output[i..].as_mut_ptr() as *mut __m128i, zero);
598 }
599
600 for bit_pos in 0..bit_width as usize {
602 let byte_offset = bit_pos * 16;
603
604 let bytes = _mm_loadu_si128(input.as_ptr().add(byte_offset) as *const __m128i);
606
607 let mut byte_array = [0u8; 16];
609 _mm_storeu_si128(byte_array.as_mut_ptr() as *mut __m128i, bytes);
610
611 for byte_idx in 0..16 {
612 let byte_val = byte_array[byte_idx];
613 let base_int = byte_idx * 8;
614
615 if byte_val & 0x01 != 0 {
617 output[base_int] |= 1u32 << bit_pos;
618 }
619 if byte_val & 0x02 != 0 {
620 output[base_int + 1] |= 1u32 << bit_pos;
621 }
622 if byte_val & 0x04 != 0 {
623 output[base_int + 2] |= 1u32 << bit_pos;
624 }
625 if byte_val & 0x08 != 0 {
626 output[base_int + 3] |= 1u32 << bit_pos;
627 }
628 if byte_val & 0x10 != 0 {
629 output[base_int + 4] |= 1u32 << bit_pos;
630 }
631 if byte_val & 0x20 != 0 {
632 output[base_int + 5] |= 1u32 << bit_pos;
633 }
634 if byte_val & 0x40 != 0 {
635 output[base_int + 6] |= 1u32 << bit_pos;
636 }
637 if byte_val & 0x80 != 0 {
638 output[base_int + 7] |= 1u32 << bit_pos;
639 }
640 }
641 }
642 }
643}
644
645#[allow(dead_code)]
647pub fn pack_horizontal(
648 values: &[u32; VERTICAL_BP128_BLOCK_SIZE],
649 bit_width: u8,
650 output: &mut Vec<u8>,
651) {
652 if bit_width == 0 {
653 return;
654 }
655
656 let bytes_needed = (VERTICAL_BP128_BLOCK_SIZE * bit_width as usize).div_ceil(8);
657 let start = output.len();
658 output.resize(start + bytes_needed, 0);
659
660 let mut bit_pos = 0usize;
661 for &value in values {
662 let byte_idx = start + bit_pos / 8;
663 let bit_offset = bit_pos % 8;
664
665 let mut remaining_bits = bit_width as usize;
666 let mut val = value;
667 let mut current_byte_idx = byte_idx;
668 let mut current_bit_offset = bit_offset;
669
670 while remaining_bits > 0 {
671 let bits_in_byte = (8 - current_bit_offset).min(remaining_bits);
672 let mask = ((1u32 << bits_in_byte) - 1) as u8;
673 output[current_byte_idx] |= ((val as u8) & mask) << current_bit_offset;
674 val >>= bits_in_byte;
675 remaining_bits -= bits_in_byte;
676 current_byte_idx += 1;
677 current_bit_offset = 0;
678 }
679
680 bit_pos += bit_width as usize;
681 }
682}
683
684#[allow(dead_code)]
685pub fn unpack_horizontal(
686 input: &[u8],
687 bit_width: u8,
688 output: &mut [u32; VERTICAL_BP128_BLOCK_SIZE],
689) {
690 if bit_width == 0 {
691 output.fill(0);
692 return;
693 }
694
695 let mask = (1u64 << bit_width) - 1;
696 let bit_width_usize = bit_width as usize;
697 let mut bit_pos = 0usize;
698 let input_ptr = input.as_ptr();
699
700 for out in output.iter_mut() {
701 let byte_idx = bit_pos >> 3;
702 let bit_offset = bit_pos & 7;
703
704 let word = unsafe { (input_ptr.add(byte_idx) as *const u64).read_unaligned() };
705
706 *out = ((word >> bit_offset) & mask) as u32;
707 bit_pos += bit_width_usize;
708 }
709}
710
711#[allow(dead_code)]
713pub fn prefix_sum_128(deltas: &mut [u32; VERTICAL_BP128_BLOCK_SIZE], first_val: u32) {
714 #[cfg(target_arch = "aarch64")]
715 {
716 unsafe { neon::prefix_sum_block_neon(deltas, first_val) }
717 }
718
719 #[cfg(not(target_arch = "aarch64"))]
720 {
721 scalar::prefix_sum_block_scalar(deltas, first_val)
722 }
723}
724
725pub fn unpack_vertical_d1(
734 input: &[u8],
735 bit_width: u8,
736 first_doc_id: u32,
737 output: &mut [u32; VERTICAL_BP128_BLOCK_SIZE],
738 count: usize,
739) {
740 if count == 0 {
741 return;
742 }
743
744 if bit_width == 0 {
745 let mut current = first_doc_id;
747 output[0] = current;
748 for out_val in output.iter_mut().take(count).skip(1) {
749 current = current.wrapping_add(1);
750 *out_val = current;
751 }
752 return;
753 }
754
755 let mut deltas = [0u32; VERTICAL_BP128_BLOCK_SIZE];
757 unpack_vertical(input, bit_width, &mut deltas);
758
759 output[0] = first_doc_id;
761 let mut current = first_doc_id;
762
763 for i in 1..count {
764 current = current.wrapping_add(deltas[i - 1]).wrapping_add(1);
766 output[i] = current;
767 }
768}
769
770#[derive(Debug, Clone)]
772pub struct VerticalBP128Block {
773 pub doc_data: Vec<u8>,
775 pub doc_bit_width: u8,
777 pub tf_data: Vec<u8>,
779 pub tf_bit_width: u8,
781 pub first_doc_id: u32,
783 pub last_doc_id: u32,
785 pub num_docs: u16,
787 pub max_tf: u32,
789 pub max_block_score: f32,
791}
792
793impl VerticalBP128Block {
794 pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
796 writer.write_u32::<LittleEndian>(self.first_doc_id)?;
797 writer.write_u32::<LittleEndian>(self.last_doc_id)?;
798 writer.write_u16::<LittleEndian>(self.num_docs)?;
799 writer.write_u8(self.doc_bit_width)?;
800 writer.write_u8(self.tf_bit_width)?;
801 writer.write_u32::<LittleEndian>(self.max_tf)?;
802 writer.write_f32::<LittleEndian>(self.max_block_score)?;
803
804 writer.write_u16::<LittleEndian>(self.doc_data.len() as u16)?;
805 writer.write_all(&self.doc_data)?;
806
807 writer.write_u16::<LittleEndian>(self.tf_data.len() as u16)?;
808 writer.write_all(&self.tf_data)?;
809
810 Ok(())
811 }
812
813 pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
815 let first_doc_id = reader.read_u32::<LittleEndian>()?;
816 let last_doc_id = reader.read_u32::<LittleEndian>()?;
817 let num_docs = reader.read_u16::<LittleEndian>()?;
818 let doc_bit_width = reader.read_u8()?;
819 let tf_bit_width = reader.read_u8()?;
820 let max_tf = reader.read_u32::<LittleEndian>()?;
821 let max_block_score = reader.read_f32::<LittleEndian>()?;
822
823 let doc_len = reader.read_u16::<LittleEndian>()? as usize;
824 let mut doc_data = vec![0u8; doc_len];
825 reader.read_exact(&mut doc_data)?;
826
827 let tf_len = reader.read_u16::<LittleEndian>()? as usize;
828 let mut tf_data = vec![0u8; tf_len];
829 reader.read_exact(&mut tf_data)?;
830
831 Ok(Self {
832 doc_data,
833 doc_bit_width,
834 tf_data,
835 tf_bit_width,
836 first_doc_id,
837 last_doc_id,
838 num_docs,
839 max_tf,
840 max_block_score,
841 })
842 }
843
844 pub fn decode_doc_ids(&self) -> Vec<u32> {
846 if self.num_docs == 0 {
847 return Vec::new();
848 }
849
850 let mut output = [0u32; VERTICAL_BP128_BLOCK_SIZE];
851 unpack_vertical_d1(
852 &self.doc_data,
853 self.doc_bit_width,
854 self.first_doc_id,
855 &mut output,
856 self.num_docs as usize,
857 );
858
859 output[..self.num_docs as usize].to_vec()
860 }
861
862 pub fn decode_term_freqs(&self) -> Vec<u32> {
864 if self.num_docs == 0 {
865 return Vec::new();
866 }
867
868 let mut output = [0u32; VERTICAL_BP128_BLOCK_SIZE];
869 unpack_vertical(&self.tf_data, self.tf_bit_width, &mut output);
870
871 output[..self.num_docs as usize]
873 .iter()
874 .map(|&tf| tf + 1)
875 .collect()
876 }
877}
878
879#[derive(Debug, Clone)]
881pub struct VerticalBP128PostingList {
882 pub blocks: Vec<VerticalBP128Block>,
884 pub doc_count: u32,
886 pub max_score: f32,
888}
889
890impl VerticalBP128PostingList {
891 const K1: f32 = 1.2;
893 const B: f32 = 0.75;
894
895 #[inline]
897 pub fn compute_bm25_upper_bound(max_tf: u32, idf: f32) -> f32 {
898 let tf = max_tf as f32;
899 let min_length_norm = 1.0 - Self::B;
900 let tf_norm = (tf * (Self::K1 + 1.0)) / (tf + Self::K1 * min_length_norm);
901 idf * tf_norm
902 }
903
904 pub fn from_postings(doc_ids: &[u32], term_freqs: &[u32], idf: f32) -> Self {
906 assert_eq!(doc_ids.len(), term_freqs.len());
907
908 if doc_ids.is_empty() {
909 return Self {
910 blocks: Vec::new(),
911 doc_count: 0,
912 max_score: 0.0,
913 };
914 }
915
916 let mut blocks = Vec::new();
917 let mut max_score = 0.0f32;
918 let mut i = 0;
919
920 while i < doc_ids.len() {
921 let block_end = (i + VERTICAL_BP128_BLOCK_SIZE).min(doc_ids.len());
922 let block_docs = &doc_ids[i..block_end];
923 let block_tfs = &term_freqs[i..block_end];
924
925 let block = Self::create_block(block_docs, block_tfs, idf);
926 max_score = max_score.max(block.max_block_score);
927 blocks.push(block);
928
929 i = block_end;
930 }
931
932 Self {
933 blocks,
934 doc_count: doc_ids.len() as u32,
935 max_score,
936 }
937 }
938
939 fn create_block(doc_ids: &[u32], term_freqs: &[u32], idf: f32) -> VerticalBP128Block {
940 let num_docs = doc_ids.len();
941 let first_doc_id = doc_ids[0];
942 let last_doc_id = *doc_ids.last().unwrap();
943
944 let mut deltas = [0u32; VERTICAL_BP128_BLOCK_SIZE];
946 let mut max_delta = 0u32;
947 for j in 1..num_docs {
948 let delta = doc_ids[j] - doc_ids[j - 1] - 1;
949 deltas[j - 1] = delta;
950 max_delta = max_delta.max(delta);
951 }
952
953 let mut tfs = [0u32; VERTICAL_BP128_BLOCK_SIZE];
955 let mut max_tf = 0u32;
956 for (j, &tf) in term_freqs.iter().enumerate() {
957 tfs[j] = tf.saturating_sub(1);
958 max_tf = max_tf.max(tf);
959 }
960
961 let doc_bit_width = simd::bits_needed(max_delta);
962 let tf_bit_width = simd::bits_needed(max_tf.saturating_sub(1));
963
964 let mut doc_data = Vec::new();
965 pack_vertical(&deltas, doc_bit_width, &mut doc_data);
966
967 let mut tf_data = Vec::new();
968 pack_vertical(&tfs, tf_bit_width, &mut tf_data);
969
970 let max_block_score = Self::compute_bm25_upper_bound(max_tf, idf);
971
972 VerticalBP128Block {
973 doc_data,
974 doc_bit_width,
975 tf_data,
976 tf_bit_width,
977 first_doc_id,
978 last_doc_id,
979 num_docs: num_docs as u16,
980 max_tf,
981 max_block_score,
982 }
983 }
984
985 pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
987 writer.write_u32::<LittleEndian>(self.doc_count)?;
988 writer.write_f32::<LittleEndian>(self.max_score)?;
989 writer.write_u32::<LittleEndian>(self.blocks.len() as u32)?;
990
991 for block in &self.blocks {
992 block.serialize(writer)?;
993 }
994
995 Ok(())
996 }
997
998 pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
1000 let doc_count = reader.read_u32::<LittleEndian>()?;
1001 let max_score = reader.read_f32::<LittleEndian>()?;
1002 let num_blocks = reader.read_u32::<LittleEndian>()? as usize;
1003
1004 let mut blocks = Vec::with_capacity(num_blocks);
1005 for _ in 0..num_blocks {
1006 blocks.push(VerticalBP128Block::deserialize(reader)?);
1007 }
1008
1009 Ok(Self {
1010 blocks,
1011 doc_count,
1012 max_score,
1013 })
1014 }
1015
1016 pub fn iterator(&self) -> VerticalBP128Iterator<'_> {
1018 VerticalBP128Iterator::new(self)
1019 }
1020
1021 pub fn size_bytes(&self) -> usize {
1023 let mut size = 12; for block in &self.blocks {
1025 size += 22 + block.doc_data.len() + block.tf_data.len();
1026 }
1027 size
1028 }
1029}
1030
1031pub struct VerticalBP128Iterator<'a> {
1033 list: &'a VerticalBP128PostingList,
1034 current_block: usize,
1035 block_doc_ids: Vec<u32>,
1036 block_term_freqs: Vec<u32>,
1037 pos_in_block: usize,
1038 exhausted: bool,
1039}
1040
1041impl<'a> VerticalBP128Iterator<'a> {
1042 pub fn new(list: &'a VerticalBP128PostingList) -> Self {
1043 let mut iter = Self {
1044 list,
1045 current_block: 0,
1046 block_doc_ids: Vec::new(),
1047 block_term_freqs: Vec::new(),
1048 pos_in_block: 0,
1049 exhausted: list.blocks.is_empty(),
1050 };
1051
1052 if !iter.exhausted {
1053 iter.decode_current_block();
1054 }
1055
1056 iter
1057 }
1058
1059 fn decode_current_block(&mut self) {
1060 let block = &self.list.blocks[self.current_block];
1061 self.block_doc_ids = block.decode_doc_ids();
1062 self.block_term_freqs = block.decode_term_freqs();
1063 self.pos_in_block = 0;
1064 }
1065
1066 pub fn doc(&self) -> u32 {
1068 if self.exhausted {
1069 u32::MAX
1070 } else {
1071 self.block_doc_ids[self.pos_in_block]
1072 }
1073 }
1074
1075 pub fn term_freq(&self) -> u32 {
1077 if self.exhausted {
1078 0
1079 } else {
1080 self.block_term_freqs[self.pos_in_block]
1081 }
1082 }
1083
1084 pub fn advance(&mut self) -> u32 {
1086 if self.exhausted {
1087 return u32::MAX;
1088 }
1089
1090 self.pos_in_block += 1;
1091
1092 if self.pos_in_block >= self.block_doc_ids.len() {
1093 self.current_block += 1;
1094 if self.current_block >= self.list.blocks.len() {
1095 self.exhausted = true;
1096 return u32::MAX;
1097 }
1098 self.decode_current_block();
1099 }
1100
1101 self.doc()
1102 }
1103
1104 pub fn seek(&mut self, target: u32) -> u32 {
1106 if self.exhausted {
1107 return u32::MAX;
1108 }
1109
1110 let block_idx = self.list.blocks[self.current_block..].binary_search_by(|block| {
1112 if block.last_doc_id < target {
1113 std::cmp::Ordering::Less
1114 } else if block.first_doc_id > target {
1115 std::cmp::Ordering::Greater
1116 } else {
1117 std::cmp::Ordering::Equal
1118 }
1119 });
1120
1121 let target_block = match block_idx {
1122 Ok(idx) => self.current_block + idx,
1123 Err(idx) => {
1124 if self.current_block + idx >= self.list.blocks.len() {
1125 self.exhausted = true;
1126 return u32::MAX;
1127 }
1128 self.current_block + idx
1129 }
1130 };
1131
1132 if target_block != self.current_block {
1133 self.current_block = target_block;
1134 self.decode_current_block();
1135 }
1136
1137 let pos = self.block_doc_ids[self.pos_in_block..]
1139 .binary_search(&target)
1140 .unwrap_or_else(|x| x);
1141 self.pos_in_block += pos;
1142
1143 if self.pos_in_block >= self.block_doc_ids.len() {
1144 self.current_block += 1;
1145 if self.current_block >= self.list.blocks.len() {
1146 self.exhausted = true;
1147 return u32::MAX;
1148 }
1149 self.decode_current_block();
1150 }
1151
1152 self.doc()
1153 }
1154
1155 pub fn max_remaining_score(&self) -> f32 {
1157 if self.exhausted {
1158 return 0.0;
1159 }
1160 self.list.blocks[self.current_block..]
1161 .iter()
1162 .map(|b| b.max_block_score)
1163 .fold(0.0f32, |a, b| a.max(b))
1164 }
1165
1166 pub fn current_block_max_score(&self) -> f32 {
1168 if self.exhausted {
1169 0.0
1170 } else {
1171 self.list.blocks[self.current_block].max_block_score
1172 }
1173 }
1174
1175 pub fn current_block_max_tf(&self) -> u32 {
1177 if self.exhausted {
1178 0
1179 } else {
1180 self.list.blocks[self.current_block].max_tf
1181 }
1182 }
1183
1184 pub fn skip_to_block_with_doc(&mut self, target: u32) -> Option<(u32, f32)> {
1187 while self.current_block < self.list.blocks.len() {
1188 let block = &self.list.blocks[self.current_block];
1189 if block.last_doc_id >= target {
1190 self.decode_current_block();
1192 return Some((block.first_doc_id, block.max_block_score));
1193 }
1194 self.current_block += 1;
1195 }
1196 self.exhausted = true;
1197 None
1198 }
1199
1200 pub fn is_exhausted(&self) -> bool {
1202 self.exhausted
1203 }
1204}
1205
1206#[cfg(test)]
1207mod tests {
1208 use super::*;
1209
1210 #[test]
1211 fn test_pack_unpack_vertical() {
1212 let mut values = [0u32; VERTICAL_BP128_BLOCK_SIZE];
1213 for (i, v) in values.iter_mut().enumerate() {
1214 *v = (i * 3) as u32;
1215 }
1216
1217 let max_val = values.iter().max().copied().unwrap();
1218 let bit_width = simd::bits_needed(max_val);
1219
1220 let mut packed = Vec::new();
1221 pack_vertical(&values, bit_width, &mut packed);
1222
1223 let mut unpacked = [0u32; VERTICAL_BP128_BLOCK_SIZE];
1224 unpack_vertical(&packed, bit_width, &mut unpacked);
1225
1226 assert_eq!(values, unpacked);
1227 }
1228
1229 #[test]
1230 fn test_pack_unpack_vertical_various_widths() {
1231 for bit_width in 1..=20 {
1232 let mut values = [0u32; VERTICAL_BP128_BLOCK_SIZE];
1233 let max_val = (1u32 << bit_width) - 1;
1234 for (i, v) in values.iter_mut().enumerate() {
1235 *v = (i as u32) % (max_val + 1);
1236 }
1237
1238 let mut packed = Vec::new();
1239 pack_vertical(&values, bit_width, &mut packed);
1240
1241 let mut unpacked = [0u32; VERTICAL_BP128_BLOCK_SIZE];
1242 unpack_vertical(&packed, bit_width, &mut unpacked);
1243
1244 assert_eq!(values, unpacked, "Failed for bit_width={}", bit_width);
1245 }
1246 }
1247
1248 #[test]
1249 fn test_simd_bp128_posting_list() {
1250 let doc_ids: Vec<u32> = (0..200).map(|i| i * 2).collect();
1251 let term_freqs: Vec<u32> = (0..200).map(|i| (i % 10) + 1).collect();
1252
1253 let list = VerticalBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
1254
1255 assert_eq!(list.doc_count, 200);
1256 assert_eq!(list.blocks.len(), 2); let mut iter = list.iterator();
1259 for (i, &expected_doc) in doc_ids.iter().enumerate() {
1260 assert_eq!(iter.doc(), expected_doc, "Doc mismatch at {}", i);
1261 assert_eq!(iter.term_freq(), term_freqs[i], "TF mismatch at {}", i);
1262 if i < doc_ids.len() - 1 {
1263 iter.advance();
1264 }
1265 }
1266 }
1267
1268 #[test]
1269 fn test_simd_bp128_seek() {
1270 let doc_ids: Vec<u32> = vec![10, 20, 30, 100, 200, 300, 1000, 2000];
1271 let term_freqs: Vec<u32> = vec![1, 2, 3, 4, 5, 6, 7, 8];
1272
1273 let list = VerticalBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
1274 let mut iter = list.iterator();
1275
1276 assert_eq!(iter.seek(25), 30);
1277 assert_eq!(iter.seek(100), 100);
1278 assert_eq!(iter.seek(500), 1000);
1279 assert_eq!(iter.seek(3000), u32::MAX);
1280 }
1281
1282 #[test]
1283 fn test_simd_bp128_serialization() {
1284 let doc_ids: Vec<u32> = (0..300).map(|i| i * 3).collect();
1285 let term_freqs: Vec<u32> = (0..300).map(|i| (i % 5) + 1).collect();
1286
1287 let list = VerticalBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.5);
1288
1289 let mut buffer = Vec::new();
1290 list.serialize(&mut buffer).unwrap();
1291
1292 let restored = VerticalBP128PostingList::deserialize(&mut &buffer[..]).unwrap();
1293
1294 assert_eq!(restored.doc_count, list.doc_count);
1295 assert_eq!(restored.blocks.len(), list.blocks.len());
1296
1297 let mut iter1 = list.iterator();
1298 let mut iter2 = restored.iterator();
1299
1300 while iter1.doc() != u32::MAX {
1301 assert_eq!(iter1.doc(), iter2.doc());
1302 assert_eq!(iter1.term_freq(), iter2.term_freq());
1303 iter1.advance();
1304 iter2.advance();
1305 }
1306 }
1307
1308 #[test]
1309 fn test_vertical_layout_size() {
1310 let mut values = [0u32; VERTICAL_BP128_BLOCK_SIZE];
1312 for (i, v) in values.iter_mut().enumerate() {
1313 *v = i as u32;
1314 }
1315
1316 let bit_width = simd::bits_needed(127); assert_eq!(bit_width, 7);
1318
1319 let mut packed = Vec::new();
1320 pack_horizontal(&values, bit_width, &mut packed);
1321
1322 let expected_bytes = (VERTICAL_BP128_BLOCK_SIZE * bit_width as usize) / 8;
1324 assert_eq!(expected_bytes, 112);
1325 assert_eq!(packed.len(), expected_bytes);
1326 }
1327
1328 #[test]
1329 fn test_simd_bp128_block_max() {
1330 let doc_ids: Vec<u32> = (0..500).map(|i| i * 2).collect();
1332 let term_freqs: Vec<u32> = (0..500)
1334 .map(|i| {
1335 if i < 128 {
1336 1 } else if i < 256 {
1338 5 } else if i < 384 {
1340 10 } else {
1342 3 }
1344 })
1345 .collect();
1346
1347 let list = VerticalBP128PostingList::from_postings(&doc_ids, &term_freqs, 2.0);
1348
1349 assert_eq!(list.blocks.len(), 4);
1351 assert_eq!(list.blocks[0].max_tf, 1);
1352 assert_eq!(list.blocks[1].max_tf, 5);
1353 assert_eq!(list.blocks[2].max_tf, 10);
1354 assert_eq!(list.blocks[3].max_tf, 3);
1355
1356 assert!(list.blocks[2].max_block_score > list.blocks[0].max_block_score);
1358 assert!(list.blocks[2].max_block_score > list.blocks[1].max_block_score);
1359 assert!(list.blocks[2].max_block_score > list.blocks[3].max_block_score);
1360
1361 assert_eq!(list.max_score, list.blocks[2].max_block_score);
1363
1364 let mut iter = list.iterator();
1366 assert_eq!(iter.current_block_max_tf(), 1); iter.seek(256); assert_eq!(iter.current_block_max_tf(), 5);
1371
1372 iter.seek(512); assert_eq!(iter.current_block_max_tf(), 10);
1375
1376 let mut iter2 = list.iterator();
1378 let result = iter2.skip_to_block_with_doc(300);
1379 assert!(result.is_some());
1380 let (first_doc, score) = result.unwrap();
1381 assert!(first_doc <= 300);
1382 assert!(score > 0.0);
1383 }
1384}