1use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
13use std::io::{self, Read, Write};
14
15#[cfg(target_arch = "aarch64")]
20mod neon {
21 use super::HORIZONTAL_BP128_BLOCK_SIZE;
22 use std::arch::aarch64::*;
23
24 #[target_feature(enable = "neon")]
27 pub unsafe fn unpack_block_8_neon(
28 input: &[u8],
29 output: &mut [u32; HORIZONTAL_BP128_BLOCK_SIZE],
30 ) {
31 unsafe {
32 for chunk in 0..8 {
34 let base = chunk * 16;
35 let in_ptr = input.as_ptr().add(base);
36
37 let bytes = vld1q_u8(in_ptr);
39
40 let low8 = vget_low_u8(bytes);
43 let high8 = vget_high_u8(bytes);
44
45 let low16 = vmovl_u8(low8);
47 let high16 = vmovl_u8(high8);
48
49 let v0 = vmovl_u16(vget_low_u16(low16));
51 let v1 = vmovl_u16(vget_high_u16(low16));
52 let v2 = vmovl_u16(vget_low_u16(high16));
53 let v3 = vmovl_u16(vget_high_u16(high16));
54
55 let out_ptr = output.as_mut_ptr().add(base);
57 vst1q_u32(out_ptr, v0);
58 vst1q_u32(out_ptr.add(4), v1);
59 vst1q_u32(out_ptr.add(8), v2);
60 vst1q_u32(out_ptr.add(12), v3);
61 }
62 }
63 }
64
65 #[target_feature(enable = "neon")]
67 pub unsafe fn unpack_block_16_neon(
68 input: &[u8],
69 output: &mut [u32; HORIZONTAL_BP128_BLOCK_SIZE],
70 ) {
71 unsafe {
72 for chunk in 0..16 {
74 let base = chunk * 8;
75 let in_ptr = input.as_ptr().add(base * 2) as *const u16;
76
77 let vals = vld1q_u16(in_ptr);
79
80 let low = vmovl_u16(vget_low_u16(vals));
82 let high = vmovl_u16(vget_high_u16(vals));
83
84 let out_ptr = output.as_mut_ptr().add(base);
86 vst1q_u32(out_ptr, low);
87 vst1q_u32(out_ptr.add(4), high);
88 }
89 }
90 }
91
92 #[target_feature(enable = "neon")]
94 pub unsafe fn unpack_block_32_neon(
95 input: &[u8],
96 output: &mut [u32; HORIZONTAL_BP128_BLOCK_SIZE],
97 ) {
98 unsafe {
99 let in_ptr = input.as_ptr() as *const u32;
100 let out_ptr = output.as_mut_ptr();
101
102 for i in 0..32 {
104 let vals = vld1q_u32(in_ptr.add(i * 4));
105 vst1q_u32(out_ptr.add(i * 4), vals);
106 }
107 }
108 }
109
110 #[target_feature(enable = "neon")]
113 #[allow(dead_code)]
114 pub unsafe fn delta_decode_block_neon(
115 output: &mut [u32],
116 deltas: &[u32],
117 first_doc_id: u32,
118 count: usize,
119 ) {
120 if count == 0 {
121 return;
122 }
123
124 let mut carry = first_doc_id;
126 output[0] = carry;
127
128 let full_groups = (count - 1) / 4;
129 let remainder = (count - 1) % 4;
130
131 for group in 0..full_groups {
132 let base = group * 4;
133
134 unsafe {
135 let d = vld1q_u32(deltas[base..].as_ptr());
137
138 let ones = vdupq_n_u32(1);
140 let gaps = vaddq_u32(d, ones);
141
142 let g0 = vgetq_lane_u32(gaps, 0);
144 let g1 = vgetq_lane_u32(gaps, 1);
145 let g2 = vgetq_lane_u32(gaps, 2);
146 let g3 = vgetq_lane_u32(gaps, 3);
147
148 let v0 = carry.wrapping_add(g0);
149 let v1 = v0.wrapping_add(g1);
150 let v2 = v1.wrapping_add(g2);
151 let v3 = v2.wrapping_add(g3);
152
153 output[base + 1] = v0;
155 output[base + 2] = v1;
156 output[base + 3] = v2;
157 output[base + 4] = v3;
158
159 carry = v3;
160 }
161 }
162
163 let base = full_groups * 4;
165 for j in 0..remainder {
166 carry = carry.wrapping_add(deltas[base + j]).wrapping_add(1);
167 output[base + j + 1] = carry;
168 }
169 }
170}
171
172#[cfg(target_arch = "x86_64")]
177#[allow(dead_code)]
178mod sse {
179 use super::HORIZONTAL_BP128_BLOCK_SIZE;
180 use std::arch::x86_64::*;
181
182 #[target_feature(enable = "sse2", enable = "sse4.1")]
185 pub unsafe fn unpack_block_8_sse(
186 input: &[u8],
187 output: &mut [u32; HORIZONTAL_BP128_BLOCK_SIZE],
188 ) {
189 for chunk in 0..8 {
191 let base = chunk * 16;
192 let in_ptr = input.as_ptr().add(base);
193
194 let bytes = _mm_loadu_si128(in_ptr as *const __m128i);
196
197 let v0 = _mm_cvtepu8_epi32(bytes);
200 let v1 = _mm_cvtepu8_epi32(_mm_srli_si128(bytes, 4));
201 let v2 = _mm_cvtepu8_epi32(_mm_srli_si128(bytes, 8));
202 let v3 = _mm_cvtepu8_epi32(_mm_srli_si128(bytes, 12));
203
204 let out_ptr = output.as_mut_ptr().add(base);
206 _mm_storeu_si128(out_ptr as *mut __m128i, v0);
207 _mm_storeu_si128(out_ptr.add(4) as *mut __m128i, v1);
208 _mm_storeu_si128(out_ptr.add(8) as *mut __m128i, v2);
209 _mm_storeu_si128(out_ptr.add(12) as *mut __m128i, v3);
210 }
211 }
212
213 #[target_feature(enable = "sse2", enable = "sse4.1")]
215 pub unsafe fn unpack_block_16_sse(
216 input: &[u8],
217 output: &mut [u32; HORIZONTAL_BP128_BLOCK_SIZE],
218 ) {
219 for chunk in 0..16 {
221 let base = chunk * 8;
222 let in_ptr = input.as_ptr().add(base * 2);
223
224 let vals = _mm_loadu_si128(in_ptr as *const __m128i);
226
227 let low = _mm_cvtepu16_epi32(vals);
229 let high = _mm_cvtepu16_epi32(_mm_srli_si128(vals, 8));
230
231 let out_ptr = output.as_mut_ptr().add(base);
233 _mm_storeu_si128(out_ptr as *mut __m128i, low);
234 _mm_storeu_si128(out_ptr.add(4) as *mut __m128i, high);
235 }
236 }
237
238 #[target_feature(enable = "sse2")]
240 pub unsafe fn unpack_block_32_sse(
241 input: &[u8],
242 output: &mut [u32; HORIZONTAL_BP128_BLOCK_SIZE],
243 ) {
244 let in_ptr = input.as_ptr() as *const __m128i;
245 let out_ptr = output.as_mut_ptr() as *mut __m128i;
246
247 for i in 0..32 {
249 let vals = _mm_loadu_si128(in_ptr.add(i));
250 _mm_storeu_si128(out_ptr.add(i), vals);
251 }
252 }
253
254 #[target_feature(enable = "sse2")]
256 pub unsafe fn delta_decode_block_sse(
257 output: &mut [u32],
258 deltas: &[u32],
259 first_doc_id: u32,
260 count: usize,
261 ) {
262 if count == 0 {
263 return;
264 }
265
266 let mut carry = first_doc_id;
268 output[0] = carry;
269
270 let full_groups = (count - 1) / 4;
271 let remainder = (count - 1) % 4;
272
273 let ones = _mm_set1_epi32(1);
274
275 for group in 0..full_groups {
276 let base = group * 4;
277
278 let d = _mm_loadu_si128(deltas[base..].as_ptr() as *const __m128i);
280
281 let gaps = _mm_add_epi32(d, ones);
283
284 let g0 = _mm_extract_epi32(gaps, 0) as u32;
286 let g1 = _mm_extract_epi32(gaps, 1) as u32;
287 let g2 = _mm_extract_epi32(gaps, 2) as u32;
288 let g3 = _mm_extract_epi32(gaps, 3) as u32;
289
290 let v0 = carry.wrapping_add(g0);
291 let v1 = v0.wrapping_add(g1);
292 let v2 = v1.wrapping_add(g2);
293 let v3 = v2.wrapping_add(g3);
294
295 output[base + 1] = v0;
297 output[base + 2] = v1;
298 output[base + 3] = v2;
299 output[base + 4] = v3;
300
301 carry = v3;
302 }
303
304 let base = full_groups * 4;
306 for j in 0..remainder {
307 carry = carry.wrapping_add(deltas[base + j]).wrapping_add(1);
308 output[base + j + 1] = carry;
309 }
310 }
311}
312
313#[allow(dead_code)]
315mod scalar {
316 use super::HORIZONTAL_BP128_BLOCK_SIZE;
317
318 #[inline]
319 pub fn unpack_block_8_scalar(input: &[u8], output: &mut [u32; HORIZONTAL_BP128_BLOCK_SIZE]) {
320 for (i, out) in output.iter_mut().enumerate() {
321 *out = input[i] as u32;
322 }
323 }
324
325 #[inline]
326 pub fn unpack_block_16_scalar(input: &[u8], output: &mut [u32; HORIZONTAL_BP128_BLOCK_SIZE]) {
327 for (i, out) in output.iter_mut().enumerate() {
328 let idx = i * 2;
329 *out = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
330 }
331 }
332
333 #[inline]
334 pub fn unpack_block_32_scalar(input: &[u8], output: &mut [u32; HORIZONTAL_BP128_BLOCK_SIZE]) {
335 for (i, out) in output.iter_mut().enumerate() {
336 let idx = i * 4;
337 *out = u32::from_le_bytes([input[idx], input[idx + 1], input[idx + 2], input[idx + 3]]);
338 }
339 }
340}
341
342pub const HORIZONTAL_BP128_BLOCK_SIZE: usize = 128;
344
345pub const SMALL_BLOCK_SIZE: usize = 32;
347
348pub const SMALL_BLOCK_THRESHOLD: usize = 256;
350
351#[inline]
353pub fn bits_needed(max_val: u32) -> u8 {
354 if max_val == 0 {
355 0
356 } else {
357 32 - max_val.leading_zeros() as u8
358 }
359}
360
361pub fn pack_block(
363 values: &[u32; HORIZONTAL_BP128_BLOCK_SIZE],
364 bit_width: u8,
365 output: &mut Vec<u8>,
366) {
367 if bit_width == 0 {
368 return;
369 }
370
371 let bytes_needed = (HORIZONTAL_BP128_BLOCK_SIZE * bit_width as usize).div_ceil(8);
372 let start = output.len();
373 output.resize(start + bytes_needed, 0);
374
375 let mut bit_pos = 0usize;
376 for &value in values {
377 let byte_idx = start + bit_pos / 8;
378 let bit_offset = bit_pos % 8;
379
380 let mut remaining_bits = bit_width as usize;
382 let mut val = value;
383 let mut current_byte_idx = byte_idx;
384 let mut current_bit_offset = bit_offset;
385
386 while remaining_bits > 0 {
387 let bits_in_byte = (8 - current_bit_offset).min(remaining_bits);
388 let mask = ((1u32 << bits_in_byte) - 1) as u8;
389 output[current_byte_idx] |= ((val as u8) & mask) << current_bit_offset;
390 val >>= bits_in_byte;
391 remaining_bits -= bits_in_byte;
392 current_byte_idx += 1;
393 current_bit_offset = 0;
394 }
395
396 bit_pos += bit_width as usize;
397 }
398}
399
400pub fn unpack_block(input: &[u8], bit_width: u8, output: &mut [u32; HORIZONTAL_BP128_BLOCK_SIZE]) {
403 if bit_width == 0 {
404 output.fill(0);
405 return;
406 }
407
408 match bit_width {
410 8 => unpack_block_8(input, output),
411 16 => unpack_block_16(input, output),
412 32 => unpack_block_32(input, output),
413 _ => unpack_block_generic(input, bit_width, output),
414 }
415}
416
417#[inline]
419fn unpack_block_8(input: &[u8], output: &mut [u32; HORIZONTAL_BP128_BLOCK_SIZE]) {
420 #[cfg(target_arch = "aarch64")]
421 {
422 unsafe { neon::unpack_block_8_neon(input, output) }
424 }
425
426 #[cfg(target_arch = "x86_64")]
427 {
428 if is_x86_feature_detected!("sse4.1") {
431 unsafe { sse::unpack_block_8_sse(input, output) }
432 } else {
433 scalar::unpack_block_8_scalar(input, output)
434 }
435 }
436
437 #[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
438 {
439 scalar::unpack_block_8_scalar(input, output)
440 }
441}
442
443#[inline]
445fn unpack_block_16(input: &[u8], output: &mut [u32; HORIZONTAL_BP128_BLOCK_SIZE]) {
446 #[cfg(target_arch = "aarch64")]
447 {
448 unsafe { neon::unpack_block_16_neon(input, output) }
450 }
451
452 #[cfg(target_arch = "x86_64")]
453 {
454 if is_x86_feature_detected!("sse4.1") {
456 unsafe { sse::unpack_block_16_sse(input, output) }
457 } else {
458 scalar::unpack_block_16_scalar(input, output)
459 }
460 }
461
462 #[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
463 {
464 scalar::unpack_block_16_scalar(input, output)
465 }
466}
467
468#[inline]
470fn unpack_block_32(input: &[u8], output: &mut [u32; HORIZONTAL_BP128_BLOCK_SIZE]) {
471 #[cfg(target_arch = "aarch64")]
472 {
473 unsafe { neon::unpack_block_32_neon(input, output) }
475 }
476
477 #[cfg(target_arch = "x86_64")]
478 {
479 unsafe { sse::unpack_block_32_sse(input, output) }
481 }
482
483 #[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
484 {
485 scalar::unpack_block_32_scalar(input, output)
486 }
487}
488
489#[inline]
492fn unpack_block_generic(
493 input: &[u8],
494 bit_width: u8,
495 output: &mut [u32; HORIZONTAL_BP128_BLOCK_SIZE],
496) {
497 let mask = (1u64 << bit_width) - 1;
498 let bit_width_usize = bit_width as usize;
499 let mut bit_pos = 0usize;
500
501 let input_ptr = input.as_ptr();
505
506 for out in output.iter_mut() {
507 let byte_idx = bit_pos >> 3; let bit_offset = bit_pos & 7; let word = unsafe { (input_ptr.add(byte_idx) as *const u64).read_unaligned() };
513
514 *out = ((word >> bit_offset) & mask) as u32;
515 bit_pos += bit_width_usize;
516 }
517}
518
519#[inline]
522pub fn unpack_block_n(input: &[u8], bit_width: u8, output: &mut [u32], n: usize) {
523 if bit_width == 0 {
524 output[..n].fill(0);
525 return;
526 }
527
528 let mask = (1u64 << bit_width) - 1;
529 let bit_width_usize = bit_width as usize;
530 let mut bit_pos = 0usize;
531 let input_ptr = input.as_ptr();
532
533 for out in output[..n].iter_mut() {
534 let byte_idx = bit_pos >> 3;
535 let bit_offset = bit_pos & 7;
536
537 let word = unsafe { (input_ptr.add(byte_idx) as *const u64).read_unaligned() };
539
540 *out = ((word >> bit_offset) & mask) as u32;
541 bit_pos += bit_width_usize;
542 }
543}
544
545#[inline]
548pub fn binary_search_block(block: &[u32], target: u32) -> usize {
549 match block.binary_search(&target) {
550 Ok(idx) => idx,
551 Err(idx) => idx,
552 }
553}
554
555#[allow(dead_code)]
559#[inline]
560fn prefix_sum_8(deltas: &mut [u32; 8]) {
561 for i in (1..8).rev() {
563 deltas[i] = deltas[i].wrapping_add(deltas[i - 1]);
564 }
565 for i in (2..8).rev() {
567 deltas[i] = deltas[i].wrapping_add(deltas[i - 2]);
568 }
569 for i in (4..8).rev() {
571 deltas[i] = deltas[i].wrapping_add(deltas[i - 4]);
572 }
573}
574
575#[inline]
584pub fn delta_decode_block(output: &mut [u32], deltas: &[u32], first_doc_id: u32, count: usize) {
585 if count == 0 {
586 return;
587 }
588
589 let mut doc_id = first_doc_id;
590 output[0] = doc_id;
591
592 for i in 1..count {
593 doc_id = doc_id.wrapping_add(deltas[i - 1]).wrapping_add(1);
595 output[i] = doc_id;
596 }
597}
598
599#[derive(Debug, Clone)]
601pub struct HorizontalBP128Block {
602 pub doc_deltas: Vec<u8>,
604 pub doc_bit_width: u8,
606 pub term_freqs: Vec<u8>,
608 pub tf_bit_width: u8,
610 pub first_doc_id: u32,
612 pub last_doc_id: u32,
614 pub num_docs: u16,
616 pub max_tf: u32,
618 pub max_block_score: f32,
621}
622
623impl HorizontalBP128Block {
624 pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
626 writer.write_u32::<LittleEndian>(self.first_doc_id)?;
627 writer.write_u32::<LittleEndian>(self.last_doc_id)?;
628 writer.write_u16::<LittleEndian>(self.num_docs)?;
629 writer.write_u8(self.doc_bit_width)?;
630 writer.write_u8(self.tf_bit_width)?;
631 writer.write_u32::<LittleEndian>(self.max_tf)?;
632 writer.write_f32::<LittleEndian>(self.max_block_score)?;
633
634 writer.write_u16::<LittleEndian>(self.doc_deltas.len() as u16)?;
636 writer.write_all(&self.doc_deltas)?;
637
638 writer.write_u16::<LittleEndian>(self.term_freqs.len() as u16)?;
640 writer.write_all(&self.term_freqs)?;
641
642 Ok(())
643 }
644
645 pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
647 let first_doc_id = reader.read_u32::<LittleEndian>()?;
648 let last_doc_id = reader.read_u32::<LittleEndian>()?;
649 let num_docs = reader.read_u16::<LittleEndian>()?;
650 let doc_bit_width = reader.read_u8()?;
651 let tf_bit_width = reader.read_u8()?;
652 let max_tf = reader.read_u32::<LittleEndian>()?;
653 let max_block_score = reader.read_f32::<LittleEndian>()?;
654
655 let doc_deltas_len = reader.read_u16::<LittleEndian>()? as usize;
656 let mut doc_deltas = vec![0u8; doc_deltas_len];
657 reader.read_exact(&mut doc_deltas)?;
658
659 let term_freqs_len = reader.read_u16::<LittleEndian>()? as usize;
660 let mut term_freqs = vec![0u8; term_freqs_len];
661 reader.read_exact(&mut term_freqs)?;
662
663 Ok(Self {
664 doc_deltas,
665 doc_bit_width,
666 term_freqs,
667 tf_bit_width,
668 first_doc_id,
669 last_doc_id,
670 num_docs,
671 max_tf,
672 max_block_score,
673 })
674 }
675
676 pub fn decode_doc_ids(&self) -> Vec<u32> {
678 if self.num_docs == 0 {
679 return Vec::new();
680 }
681
682 let count = self.num_docs as usize;
683 let mut deltas = [0u32; HORIZONTAL_BP128_BLOCK_SIZE];
684 unpack_block(&self.doc_deltas, self.doc_bit_width, &mut deltas);
685
686 let mut output = [0u32; HORIZONTAL_BP128_BLOCK_SIZE];
687 delta_decode_block(&mut output, &deltas, self.first_doc_id, count);
688
689 output[..count].to_vec()
690 }
691
692 pub fn decode_term_freqs(&self) -> Vec<u32> {
694 if self.num_docs == 0 {
695 return Vec::new();
696 }
697
698 let mut tfs = [0u32; HORIZONTAL_BP128_BLOCK_SIZE];
699 unpack_block(&self.term_freqs, self.tf_bit_width, &mut tfs);
700
701 tfs[..self.num_docs as usize]
703 .iter()
704 .map(|&tf| tf + 1)
705 .collect()
706 }
707}
708
709#[derive(Debug, Clone)]
711pub struct HorizontalBP128PostingList {
712 pub blocks: Vec<HorizontalBP128Block>,
714 pub doc_count: u32,
716 pub max_score: f32,
718}
719
720impl HorizontalBP128PostingList {
721 pub fn from_postings(doc_ids: &[u32], term_freqs: &[u32], idf: f32) -> Self {
723 assert_eq!(doc_ids.len(), term_freqs.len());
724
725 if doc_ids.is_empty() {
726 return Self {
727 blocks: Vec::new(),
728 doc_count: 0,
729 max_score: 0.0,
730 };
731 }
732
733 let mut blocks = Vec::new();
734 let mut max_score = 0.0f32;
735 let mut i = 0;
736
737 while i < doc_ids.len() {
738 let block_end = (i + HORIZONTAL_BP128_BLOCK_SIZE).min(doc_ids.len());
739 let block_docs = &doc_ids[i..block_end];
740 let block_tfs = &term_freqs[i..block_end];
741
742 let block = Self::create_block(block_docs, block_tfs, idf);
743 max_score = max_score.max(block.max_block_score);
744 blocks.push(block);
745
746 i = block_end;
747 }
748
749 Self {
750 blocks,
751 doc_count: doc_ids.len() as u32,
752 max_score,
753 }
754 }
755
756 const K1: f32 = 1.2;
758 const B: f32 = 0.75;
759
760 #[inline]
763 pub fn compute_bm25f_upper_bound(max_tf: u32, idf: f32, field_boost: f32) -> f32 {
764 let tf = max_tf as f32;
765 let min_length_norm = 1.0 - Self::B;
768 let tf_norm =
769 (tf * field_boost * (Self::K1 + 1.0)) / (tf * field_boost + Self::K1 * min_length_norm);
770 idf * tf_norm
771 }
772
773 fn create_block(doc_ids: &[u32], term_freqs: &[u32], idf: f32) -> HorizontalBP128Block {
774 let num_docs = doc_ids.len();
775 let first_doc_id = doc_ids[0];
776 let last_doc_id = *doc_ids.last().unwrap();
777
778 let mut deltas = [0u32; HORIZONTAL_BP128_BLOCK_SIZE];
780 let mut max_delta = 0u32;
781 for j in 1..num_docs {
782 let delta = doc_ids[j] - doc_ids[j - 1] - 1;
783 deltas[j - 1] = delta;
784 max_delta = max_delta.max(delta);
785 }
786
787 let mut tfs = [0u32; HORIZONTAL_BP128_BLOCK_SIZE];
789 let mut max_tf = 0u32;
790
791 for (j, &tf) in term_freqs.iter().enumerate() {
792 tfs[j] = tf - 1; max_tf = max_tf.max(tf);
794 }
795
796 let max_block_score = Self::compute_bm25f_upper_bound(max_tf, idf, 1.0);
799
800 let doc_bit_width = bits_needed(max_delta);
801 let tf_bit_width = bits_needed(max_tf.saturating_sub(1)); let mut doc_deltas = Vec::new();
804 pack_block(&deltas, doc_bit_width, &mut doc_deltas);
805
806 let mut term_freqs_packed = Vec::new();
807 pack_block(&tfs, tf_bit_width, &mut term_freqs_packed);
808
809 HorizontalBP128Block {
810 doc_deltas,
811 doc_bit_width,
812 term_freqs: term_freqs_packed,
813 tf_bit_width,
814 first_doc_id,
815 last_doc_id,
816 num_docs: num_docs as u16,
817 max_tf,
818 max_block_score,
819 }
820 }
821
822 pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
824 writer.write_u32::<LittleEndian>(self.doc_count)?;
825 writer.write_f32::<LittleEndian>(self.max_score)?;
826 writer.write_u32::<LittleEndian>(self.blocks.len() as u32)?;
827
828 for block in &self.blocks {
829 block.serialize(writer)?;
830 }
831
832 Ok(())
833 }
834
835 pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
837 let doc_count = reader.read_u32::<LittleEndian>()?;
838 let max_score = reader.read_f32::<LittleEndian>()?;
839 let num_blocks = reader.read_u32::<LittleEndian>()? as usize;
840
841 let mut blocks = Vec::with_capacity(num_blocks);
842 for _ in 0..num_blocks {
843 blocks.push(HorizontalBP128Block::deserialize(reader)?);
844 }
845
846 Ok(Self {
847 blocks,
848 doc_count,
849 max_score,
850 })
851 }
852
853 pub fn iterator(&self) -> HorizontalBP128Iterator<'_> {
855 HorizontalBP128Iterator::new(self)
856 }
857}
858
859pub struct HorizontalBP128Iterator<'a> {
861 posting_list: &'a HorizontalBP128PostingList,
862 current_block: usize,
864 block_doc_ids: Vec<u32>,
866 block_term_freqs: Vec<u32>,
868 pos_in_block: usize,
870 exhausted: bool,
872}
873
874impl<'a> HorizontalBP128Iterator<'a> {
875 pub fn new(posting_list: &'a HorizontalBP128PostingList) -> Self {
876 let mut iter = Self {
877 posting_list,
878 current_block: 0,
879 block_doc_ids: Vec::new(),
880 block_term_freqs: Vec::new(),
881 pos_in_block: 0,
882 exhausted: posting_list.blocks.is_empty(),
883 };
884
885 if !iter.exhausted {
886 iter.decode_current_block();
887 }
888
889 iter
890 }
891
892 fn decode_current_block(&mut self) {
893 let block = &self.posting_list.blocks[self.current_block];
894 self.block_doc_ids = block.decode_doc_ids();
895 self.block_term_freqs = block.decode_term_freqs();
896 self.pos_in_block = 0;
897 }
898
899 pub fn doc(&self) -> u32 {
901 if self.exhausted {
902 u32::MAX
903 } else {
904 self.block_doc_ids[self.pos_in_block]
905 }
906 }
907
908 pub fn term_freq(&self) -> u32 {
910 if self.exhausted {
911 0
912 } else {
913 self.block_term_freqs[self.pos_in_block]
914 }
915 }
916
917 pub fn advance(&mut self) -> u32 {
919 if self.exhausted {
920 return u32::MAX;
921 }
922
923 self.pos_in_block += 1;
924
925 if self.pos_in_block >= self.block_doc_ids.len() {
926 self.current_block += 1;
927 if self.current_block >= self.posting_list.blocks.len() {
928 self.exhausted = true;
929 return u32::MAX;
930 }
931 self.decode_current_block();
932 }
933
934 self.doc()
935 }
936
937 pub fn seek(&mut self, target: u32) -> u32 {
939 if self.exhausted {
940 return u32::MAX;
941 }
942
943 let block_idx = self.posting_list.blocks[self.current_block..].binary_search_by(|block| {
945 if block.last_doc_id < target {
946 std::cmp::Ordering::Less
947 } else if block.first_doc_id > target {
948 std::cmp::Ordering::Greater
949 } else {
950 std::cmp::Ordering::Equal
951 }
952 });
953
954 let target_block = match block_idx {
955 Ok(idx) => self.current_block + idx,
956 Err(idx) => {
957 if self.current_block + idx >= self.posting_list.blocks.len() {
958 self.exhausted = true;
959 return u32::MAX;
960 }
961 self.current_block + idx
962 }
963 };
964
965 if target_block != self.current_block {
967 self.current_block = target_block;
968 self.decode_current_block();
969 } else if self.block_doc_ids.is_empty() {
970 self.decode_current_block();
971 }
972
973 let pos = binary_search_block(&self.block_doc_ids[self.pos_in_block..], target);
975 self.pos_in_block += pos;
976
977 if self.pos_in_block >= self.block_doc_ids.len() {
978 self.current_block += 1;
980 if self.current_block >= self.posting_list.blocks.len() {
981 self.exhausted = true;
982 return u32::MAX;
983 }
984 self.decode_current_block();
985 }
986
987 self.doc()
988 }
989
990 pub fn max_remaining_score(&self) -> f32 {
992 if self.exhausted {
993 return 0.0;
994 }
995
996 self.posting_list.blocks[self.current_block..]
997 .iter()
998 .map(|b| b.max_block_score)
999 .fold(0.0f32, |a, b| a.max(b))
1000 }
1001
1002 pub fn skip_to_block_with_doc(&mut self, target: u32) -> Option<(u32, f32)> {
1004 while self.current_block < self.posting_list.blocks.len() {
1005 let block = &self.posting_list.blocks[self.current_block];
1006 if block.last_doc_id >= target {
1007 return Some((block.first_doc_id, block.max_block_score));
1008 }
1009 self.current_block += 1;
1010 }
1011 self.exhausted = true;
1012 None
1013 }
1014
1015 pub fn current_block_max_score(&self) -> f32 {
1017 if self.exhausted {
1018 0.0
1019 } else {
1020 self.posting_list.blocks[self.current_block].max_block_score
1021 }
1022 }
1023
1024 pub fn current_block_max_tf(&self) -> u32 {
1026 if self.exhausted {
1027 0
1028 } else {
1029 self.posting_list.blocks[self.current_block].max_tf
1030 }
1031 }
1032}
1033
1034#[cfg(test)]
1035mod tests {
1036 use super::*;
1037
1038 #[test]
1039 fn test_bits_needed() {
1040 assert_eq!(bits_needed(0), 0);
1041 assert_eq!(bits_needed(1), 1);
1042 assert_eq!(bits_needed(2), 2);
1043 assert_eq!(bits_needed(3), 2);
1044 assert_eq!(bits_needed(255), 8);
1045 assert_eq!(bits_needed(256), 9);
1046 }
1047
1048 #[test]
1049 fn test_pack_unpack() {
1050 let mut values = [0u32; HORIZONTAL_BP128_BLOCK_SIZE];
1051 for (i, value) in values.iter_mut().enumerate() {
1052 *value = (i * 3) as u32;
1053 }
1054
1055 let max_val = values.iter().max().copied().unwrap();
1056 let bit_width = bits_needed(max_val);
1057
1058 let mut packed = Vec::new();
1059 pack_block(&values, bit_width, &mut packed);
1060
1061 let mut unpacked = [0u32; HORIZONTAL_BP128_BLOCK_SIZE];
1062 unpack_block(&packed, bit_width, &mut unpacked);
1063
1064 assert_eq!(values, unpacked);
1065 }
1066
1067 #[test]
1068 fn test_bitpacked_posting_list() {
1069 let doc_ids: Vec<u32> = (0..200).map(|i| i * 2).collect();
1070 let term_freqs: Vec<u32> = (0..200).map(|i| (i % 10) + 1).collect();
1071
1072 let posting_list = HorizontalBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
1073
1074 assert_eq!(posting_list.doc_count, 200);
1075 assert_eq!(posting_list.blocks.len(), 2); let mut iter = posting_list.iterator();
1079 for (i, &expected_doc) in doc_ids.iter().enumerate() {
1080 assert_eq!(iter.doc(), expected_doc, "Mismatch at position {}", i);
1081 assert_eq!(iter.term_freq(), term_freqs[i]);
1082 if i < doc_ids.len() - 1 {
1083 iter.advance();
1084 }
1085 }
1086 }
1087
1088 #[test]
1089 fn test_bitpacked_seek() {
1090 let doc_ids: Vec<u32> = vec![10, 20, 30, 100, 200, 300, 1000, 2000];
1091 let term_freqs: Vec<u32> = vec![1, 2, 3, 4, 5, 6, 7, 8];
1092
1093 let posting_list = HorizontalBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
1094 let mut iter = posting_list.iterator();
1095
1096 assert_eq!(iter.seek(25), 30);
1097 assert_eq!(iter.seek(100), 100);
1098 assert_eq!(iter.seek(500), 1000);
1099 assert_eq!(iter.seek(3000), u32::MAX);
1100 }
1101
1102 #[test]
1103 fn test_serialization() {
1104 let doc_ids: Vec<u32> = (0..50).map(|i| i * 3).collect();
1105 let term_freqs: Vec<u32> = (0..50).map(|_| 1).collect();
1106
1107 let posting_list = HorizontalBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.5);
1108
1109 let mut buffer = Vec::new();
1110 posting_list.serialize(&mut buffer).unwrap();
1111
1112 let restored = HorizontalBP128PostingList::deserialize(&mut &buffer[..]).unwrap();
1113
1114 assert_eq!(restored.doc_count, posting_list.doc_count);
1115 assert_eq!(restored.blocks.len(), posting_list.blocks.len());
1116
1117 let mut iter1 = posting_list.iterator();
1119 let mut iter2 = restored.iterator();
1120
1121 while iter1.doc() != u32::MAX {
1122 assert_eq!(iter1.doc(), iter2.doc());
1123 assert_eq!(iter1.term_freq(), iter2.term_freq());
1124 iter1.advance();
1125 iter2.advance();
1126 }
1127 }
1128
1129 #[test]
1130 fn test_hillis_steele_prefix_sum() {
1131 let mut deltas = [1u32, 2, 3, 4, 5, 6, 7, 8];
1133 prefix_sum_8(&mut deltas);
1134 assert_eq!(deltas, [1, 3, 6, 10, 15, 21, 28, 36]);
1136
1137 let deltas2 = [0u32; 16]; let mut output2 = [0u32; 16];
1140 delta_decode_block(&mut output2, &deltas2, 100, 8);
1141 assert_eq!(&output2[..8], &[100, 101, 102, 103, 104, 105, 106, 107]);
1143
1144 let deltas3 = [1u32, 0, 2, 0, 4, 0, 0, 0];
1147 let mut output3 = [0u32; 8];
1148 delta_decode_block(&mut output3, &deltas3, 10, 8);
1149 assert_eq!(&output3[..8], &[10, 12, 13, 16, 17, 22, 23, 24]);
1151 }
1152
1153 #[test]
1154 fn test_delta_decode_large_block() {
1155 let doc_ids: Vec<u32> = (0..128).map(|i| i * 5 + 100).collect();
1157 let term_freqs: Vec<u32> = vec![1; 128];
1158
1159 let posting_list = HorizontalBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
1160 let decoded = posting_list.blocks[0].decode_doc_ids();
1161
1162 assert_eq!(decoded.len(), 128);
1163 for (i, (&expected, &actual)) in doc_ids.iter().zip(decoded.iter()).enumerate() {
1164 assert_eq!(expected, actual, "Mismatch at position {}", i);
1165 }
1166 }
1167}