1use crate::structures::simd;
13use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
14use std::io::{self, Read, Write};
15
16pub const HORIZONTAL_BP128_BLOCK_SIZE: usize = 128;
18
19pub const SMALL_BLOCK_SIZE: usize = 32;
21
22pub const SMALL_BLOCK_THRESHOLD: usize = 256;
24
25pub fn pack_block(
27 values: &[u32; HORIZONTAL_BP128_BLOCK_SIZE],
28 bit_width: u8,
29 output: &mut Vec<u8>,
30) {
31 if bit_width == 0 {
32 return;
33 }
34
35 let bytes_needed = (HORIZONTAL_BP128_BLOCK_SIZE * bit_width as usize).div_ceil(8);
36 let start = output.len();
37 output.resize(start + bytes_needed, 0);
38
39 let mut bit_pos = 0usize;
40 for &value in values {
41 let byte_idx = start + bit_pos / 8;
42 let bit_offset = bit_pos % 8;
43
44 let mut remaining_bits = bit_width as usize;
46 let mut val = value;
47 let mut current_byte_idx = byte_idx;
48 let mut current_bit_offset = bit_offset;
49
50 while remaining_bits > 0 {
51 let bits_in_byte = (8 - current_bit_offset).min(remaining_bits);
52 let mask = ((1u32 << bits_in_byte) - 1) as u8;
53 output[current_byte_idx] |= ((val as u8) & mask) << current_bit_offset;
54 val >>= bits_in_byte;
55 remaining_bits -= bits_in_byte;
56 current_byte_idx += 1;
57 current_bit_offset = 0;
58 }
59
60 bit_pos += bit_width as usize;
61 }
62}
63
64pub fn unpack_block(input: &[u8], bit_width: u8, output: &mut [u32; HORIZONTAL_BP128_BLOCK_SIZE]) {
67 if bit_width == 0 {
68 output.fill(0);
69 return;
70 }
71
72 match bit_width {
74 8 => simd::unpack_8bit(input, output, HORIZONTAL_BP128_BLOCK_SIZE),
75 16 => simd::unpack_16bit(input, output, HORIZONTAL_BP128_BLOCK_SIZE),
76 32 => simd::unpack_32bit(input, output, HORIZONTAL_BP128_BLOCK_SIZE),
77 _ => unpack_block_generic(input, bit_width, output),
78 }
79}
80
81#[inline]
84fn unpack_block_generic(
85 input: &[u8],
86 bit_width: u8,
87 output: &mut [u32; HORIZONTAL_BP128_BLOCK_SIZE],
88) {
89 let mask = (1u64 << bit_width) - 1;
90 let bit_width_usize = bit_width as usize;
91 let mut bit_pos = 0usize;
92
93 let input_ptr = input.as_ptr();
97
98 for out in output.iter_mut() {
99 let byte_idx = bit_pos >> 3; let bit_offset = bit_pos & 7; let word = unsafe { (input_ptr.add(byte_idx) as *const u64).read_unaligned() };
105
106 *out = ((word >> bit_offset) & mask) as u32;
107 bit_pos += bit_width_usize;
108 }
109}
110
111#[inline]
114pub fn unpack_block_n(input: &[u8], bit_width: u8, output: &mut [u32], n: usize) {
115 if bit_width == 0 {
116 output[..n].fill(0);
117 return;
118 }
119
120 let mask = (1u64 << bit_width) - 1;
121 let bit_width_usize = bit_width as usize;
122 let mut bit_pos = 0usize;
123 let input_ptr = input.as_ptr();
124
125 for out in output[..n].iter_mut() {
126 let byte_idx = bit_pos >> 3;
127 let bit_offset = bit_pos & 7;
128
129 let word = unsafe { (input_ptr.add(byte_idx) as *const u64).read_unaligned() };
131
132 *out = ((word >> bit_offset) & mask) as u32;
133 bit_pos += bit_width_usize;
134 }
135}
136
137#[inline]
140pub fn binary_search_block(block: &[u32], target: u32) -> usize {
141 match block.binary_search(&target) {
142 Ok(idx) => idx,
143 Err(idx) => idx,
144 }
145}
146
147#[allow(dead_code)]
151#[inline]
152fn prefix_sum_8(deltas: &mut [u32; 8]) {
153 for i in (1..8).rev() {
155 deltas[i] = deltas[i].wrapping_add(deltas[i - 1]);
156 }
157 for i in (2..8).rev() {
159 deltas[i] = deltas[i].wrapping_add(deltas[i - 2]);
160 }
161 for i in (4..8).rev() {
163 deltas[i] = deltas[i].wrapping_add(deltas[i - 4]);
164 }
165}
166
167#[derive(Debug, Clone)]
169pub struct HorizontalBP128Block {
170 pub doc_deltas: Vec<u8>,
172 pub doc_bit_width: u8,
174 pub term_freqs: Vec<u8>,
176 pub tf_bit_width: u8,
178 pub first_doc_id: u32,
180 pub last_doc_id: u32,
182 pub num_docs: u16,
184 pub max_tf: u32,
186 pub max_block_score: f32,
189}
190
191impl HorizontalBP128Block {
192 pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
194 writer.write_u32::<LittleEndian>(self.first_doc_id)?;
195 writer.write_u32::<LittleEndian>(self.last_doc_id)?;
196 writer.write_u16::<LittleEndian>(self.num_docs)?;
197 writer.write_u8(self.doc_bit_width)?;
198 writer.write_u8(self.tf_bit_width)?;
199 writer.write_u32::<LittleEndian>(self.max_tf)?;
200 writer.write_f32::<LittleEndian>(self.max_block_score)?;
201
202 writer.write_u16::<LittleEndian>(self.doc_deltas.len() as u16)?;
204 writer.write_all(&self.doc_deltas)?;
205
206 writer.write_u16::<LittleEndian>(self.term_freqs.len() as u16)?;
208 writer.write_all(&self.term_freqs)?;
209
210 Ok(())
211 }
212
213 pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
215 let first_doc_id = reader.read_u32::<LittleEndian>()?;
216 let last_doc_id = reader.read_u32::<LittleEndian>()?;
217 let num_docs = reader.read_u16::<LittleEndian>()?;
218 let doc_bit_width = reader.read_u8()?;
219 let tf_bit_width = reader.read_u8()?;
220 let max_tf = reader.read_u32::<LittleEndian>()?;
221 let max_block_score = reader.read_f32::<LittleEndian>()?;
222
223 let doc_deltas_len = reader.read_u16::<LittleEndian>()? as usize;
224 let mut doc_deltas = vec![0u8; doc_deltas_len];
225 reader.read_exact(&mut doc_deltas)?;
226
227 let term_freqs_len = reader.read_u16::<LittleEndian>()? as usize;
228 let mut term_freqs = vec![0u8; term_freqs_len];
229 reader.read_exact(&mut term_freqs)?;
230
231 Ok(Self {
232 doc_deltas,
233 doc_bit_width,
234 term_freqs,
235 tf_bit_width,
236 first_doc_id,
237 last_doc_id,
238 num_docs,
239 max_tf,
240 max_block_score,
241 })
242 }
243
244 pub fn decode_doc_ids(&self) -> Vec<u32> {
246 let mut output = vec![0u32; self.num_docs as usize];
247 self.decode_doc_ids_into(&mut output);
248 output
249 }
250
251 #[inline]
253 pub fn decode_doc_ids_into(&self, output: &mut [u32]) -> usize {
254 let count = self.num_docs as usize;
255 if count == 0 {
256 return 0;
257 }
258
259 simd::unpack_delta_decode(
261 &self.doc_deltas,
262 self.doc_bit_width,
263 output,
264 self.first_doc_id,
265 count,
266 );
267
268 count
269 }
270
271 pub fn decode_term_freqs(&self) -> Vec<u32> {
273 let mut output = vec![0u32; self.num_docs as usize];
274 self.decode_term_freqs_into(&mut output);
275 output
276 }
277
278 #[inline]
280 pub fn decode_term_freqs_into(&self, output: &mut [u32]) -> usize {
281 let count = self.num_docs as usize;
282 if count == 0 {
283 return 0;
284 }
285
286 unpack_block_n(&self.term_freqs, self.tf_bit_width, output, count);
288
289 simd::add_one(output, count);
291
292 count
293 }
294}
295
296#[derive(Debug, Clone)]
298pub struct HorizontalBP128PostingList {
299 pub blocks: Vec<HorizontalBP128Block>,
301 pub doc_count: u32,
303 pub max_score: f32,
305}
306
307impl HorizontalBP128PostingList {
308 pub fn from_postings(doc_ids: &[u32], term_freqs: &[u32], idf: f32) -> Self {
310 assert_eq!(doc_ids.len(), term_freqs.len());
311
312 if doc_ids.is_empty() {
313 return Self {
314 blocks: Vec::new(),
315 doc_count: 0,
316 max_score: 0.0,
317 };
318 }
319
320 let mut blocks = Vec::new();
321 let mut max_score = 0.0f32;
322 let mut i = 0;
323
324 while i < doc_ids.len() {
325 let block_end = (i + HORIZONTAL_BP128_BLOCK_SIZE).min(doc_ids.len());
326 let block_docs = &doc_ids[i..block_end];
327 let block_tfs = &term_freqs[i..block_end];
328
329 let block = Self::create_block(block_docs, block_tfs, idf);
330 max_score = max_score.max(block.max_block_score);
331 blocks.push(block);
332
333 i = block_end;
334 }
335
336 Self {
337 blocks,
338 doc_count: doc_ids.len() as u32,
339 max_score,
340 }
341 }
342
343 const K1: f32 = 1.2;
345 const B: f32 = 0.75;
346
347 #[inline]
350 pub fn compute_bm25f_upper_bound(max_tf: u32, idf: f32, field_boost: f32) -> f32 {
351 let tf = max_tf as f32;
352 let min_length_norm = 1.0 - Self::B;
355 let tf_norm =
356 (tf * field_boost * (Self::K1 + 1.0)) / (tf * field_boost + Self::K1 * min_length_norm);
357 idf * tf_norm
358 }
359
360 fn create_block(doc_ids: &[u32], term_freqs: &[u32], idf: f32) -> HorizontalBP128Block {
361 let num_docs = doc_ids.len();
362 let first_doc_id = doc_ids[0];
363 let last_doc_id = *doc_ids.last().unwrap();
364
365 let mut deltas = [0u32; HORIZONTAL_BP128_BLOCK_SIZE];
367 let mut max_delta = 0u32;
368 for j in 1..num_docs {
369 let delta = doc_ids[j] - doc_ids[j - 1] - 1;
370 deltas[j - 1] = delta;
371 max_delta = max_delta.max(delta);
372 }
373
374 let mut tfs = [0u32; HORIZONTAL_BP128_BLOCK_SIZE];
376 let mut max_tf = 0u32;
377
378 for (j, &tf) in term_freqs.iter().enumerate() {
379 tfs[j] = tf - 1; max_tf = max_tf.max(tf);
381 }
382
383 let max_block_score = Self::compute_bm25f_upper_bound(max_tf, idf, 1.0);
386
387 let doc_bit_width = simd::bits_needed(max_delta);
388 let tf_bit_width = simd::bits_needed(max_tf.saturating_sub(1)); let mut doc_deltas = Vec::new();
391 pack_block(&deltas, doc_bit_width, &mut doc_deltas);
392
393 let mut term_freqs_packed = Vec::new();
394 pack_block(&tfs, tf_bit_width, &mut term_freqs_packed);
395
396 HorizontalBP128Block {
397 doc_deltas,
398 doc_bit_width,
399 term_freqs: term_freqs_packed,
400 tf_bit_width,
401 first_doc_id,
402 last_doc_id,
403 num_docs: num_docs as u16,
404 max_tf,
405 max_block_score,
406 }
407 }
408
409 pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
411 writer.write_u32::<LittleEndian>(self.doc_count)?;
412 writer.write_f32::<LittleEndian>(self.max_score)?;
413 writer.write_u32::<LittleEndian>(self.blocks.len() as u32)?;
414
415 for block in &self.blocks {
416 block.serialize(writer)?;
417 }
418
419 Ok(())
420 }
421
422 pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
424 let doc_count = reader.read_u32::<LittleEndian>()?;
425 let max_score = reader.read_f32::<LittleEndian>()?;
426 let num_blocks = reader.read_u32::<LittleEndian>()? as usize;
427
428 let mut blocks = Vec::with_capacity(num_blocks);
429 for _ in 0..num_blocks {
430 blocks.push(HorizontalBP128Block::deserialize(reader)?);
431 }
432
433 Ok(Self {
434 blocks,
435 doc_count,
436 max_score,
437 })
438 }
439
440 pub fn iterator(&self) -> HorizontalBP128Iterator<'_> {
442 HorizontalBP128Iterator::new(self)
443 }
444}
445
446pub struct HorizontalBP128Iterator<'a> {
448 posting_list: &'a HorizontalBP128PostingList,
449 current_block: usize,
451 current_block_len: usize,
453 block_doc_ids: Vec<u32>,
455 block_term_freqs: Vec<u32>,
457 pos_in_block: usize,
459 exhausted: bool,
461}
462
463impl<'a> HorizontalBP128Iterator<'a> {
464 pub fn new(posting_list: &'a HorizontalBP128PostingList) -> Self {
465 let mut iter = Self {
467 posting_list,
468 current_block: 0,
469 current_block_len: 0,
470 block_doc_ids: vec![0u32; HORIZONTAL_BP128_BLOCK_SIZE],
471 block_term_freqs: vec![0u32; HORIZONTAL_BP128_BLOCK_SIZE],
472 pos_in_block: 0,
473 exhausted: posting_list.blocks.is_empty(),
474 };
475
476 if !iter.exhausted {
477 iter.decode_current_block();
478 }
479
480 iter
481 }
482
483 #[inline]
484 fn decode_current_block(&mut self) {
485 let block = &self.posting_list.blocks[self.current_block];
486 self.current_block_len = block.decode_doc_ids_into(&mut self.block_doc_ids);
488 block.decode_term_freqs_into(&mut self.block_term_freqs);
489 self.pos_in_block = 0;
490 }
491
492 #[inline]
494 pub fn doc(&self) -> u32 {
495 if self.exhausted {
496 u32::MAX
497 } else {
498 self.block_doc_ids[self.pos_in_block]
499 }
500 }
501
502 #[inline]
504 pub fn term_freq(&self) -> u32 {
505 if self.exhausted {
506 0
507 } else {
508 self.block_term_freqs[self.pos_in_block]
509 }
510 }
511
512 #[inline]
514 pub fn advance(&mut self) -> u32 {
515 if self.exhausted {
516 return u32::MAX;
517 }
518
519 self.pos_in_block += 1;
520
521 if self.pos_in_block >= self.current_block_len {
522 self.current_block += 1;
523 if self.current_block >= self.posting_list.blocks.len() {
524 self.exhausted = true;
525 return u32::MAX;
526 }
527 self.decode_current_block();
528 }
529
530 self.doc()
531 }
532
533 pub fn seek(&mut self, target: u32) -> u32 {
535 if self.exhausted {
536 return u32::MAX;
537 }
538
539 let block_idx = self.posting_list.blocks[self.current_block..].binary_search_by(|block| {
541 if block.last_doc_id < target {
542 std::cmp::Ordering::Less
543 } else if block.first_doc_id > target {
544 std::cmp::Ordering::Greater
545 } else {
546 std::cmp::Ordering::Equal
547 }
548 });
549
550 let target_block = match block_idx {
551 Ok(idx) => self.current_block + idx,
552 Err(idx) => {
553 if self.current_block + idx >= self.posting_list.blocks.len() {
554 self.exhausted = true;
555 return u32::MAX;
556 }
557 self.current_block + idx
558 }
559 };
560
561 if target_block != self.current_block {
563 self.current_block = target_block;
564 self.decode_current_block();
565 } else if self.current_block_len == 0 {
566 self.decode_current_block();
567 }
568
569 let pos = binary_search_block(
571 &self.block_doc_ids[self.pos_in_block..self.current_block_len],
572 target,
573 );
574 self.pos_in_block += pos;
575
576 if self.pos_in_block >= self.current_block_len {
577 self.current_block += 1;
579 if self.current_block >= self.posting_list.blocks.len() {
580 self.exhausted = true;
581 return u32::MAX;
582 }
583 self.decode_current_block();
584 }
585
586 self.doc()
587 }
588
589 pub fn max_remaining_score(&self) -> f32 {
591 if self.exhausted {
592 return 0.0;
593 }
594
595 self.posting_list.blocks[self.current_block..]
596 .iter()
597 .map(|b| b.max_block_score)
598 .fold(0.0f32, |a, b| a.max(b))
599 }
600
601 pub fn skip_to_block_with_doc(&mut self, target: u32) -> Option<(u32, f32)> {
603 while self.current_block < self.posting_list.blocks.len() {
604 let block = &self.posting_list.blocks[self.current_block];
605 if block.last_doc_id >= target {
606 return Some((block.first_doc_id, block.max_block_score));
607 }
608 self.current_block += 1;
609 }
610 self.exhausted = true;
611 None
612 }
613
614 pub fn current_block_max_score(&self) -> f32 {
616 if self.exhausted {
617 0.0
618 } else {
619 self.posting_list.blocks[self.current_block].max_block_score
620 }
621 }
622
623 pub fn current_block_max_tf(&self) -> u32 {
625 if self.exhausted {
626 0
627 } else {
628 self.posting_list.blocks[self.current_block].max_tf
629 }
630 }
631}
632
633#[cfg(test)]
634mod tests {
635 use super::*;
636
637 #[test]
638 fn test_bits_needed() {
639 assert_eq!(simd::bits_needed(0), 0);
640 assert_eq!(simd::bits_needed(1), 1);
641 assert_eq!(simd::bits_needed(2), 2);
642 assert_eq!(simd::bits_needed(3), 2);
643 assert_eq!(simd::bits_needed(255), 8);
644 assert_eq!(simd::bits_needed(256), 9);
645 }
646
647 #[test]
648 fn test_pack_unpack() {
649 let mut values = [0u32; HORIZONTAL_BP128_BLOCK_SIZE];
650 for (i, value) in values.iter_mut().enumerate() {
651 *value = (i * 3) as u32;
652 }
653
654 let max_val = values.iter().max().copied().unwrap();
655 let bit_width = simd::bits_needed(max_val);
656
657 let mut packed = Vec::new();
658 pack_block(&values, bit_width, &mut packed);
659
660 let mut unpacked = [0u32; HORIZONTAL_BP128_BLOCK_SIZE];
661 unpack_block(&packed, bit_width, &mut unpacked);
662
663 assert_eq!(values, unpacked);
664 }
665
666 #[test]
667 fn test_bitpacked_posting_list() {
668 let doc_ids: Vec<u32> = (0..200).map(|i| i * 2).collect();
669 let term_freqs: Vec<u32> = (0..200).map(|i| (i % 10) + 1).collect();
670
671 let posting_list = HorizontalBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
672
673 assert_eq!(posting_list.doc_count, 200);
674 assert_eq!(posting_list.blocks.len(), 2); let mut iter = posting_list.iterator();
678 for (i, &expected_doc) in doc_ids.iter().enumerate() {
679 assert_eq!(iter.doc(), expected_doc, "Mismatch at position {}", i);
680 assert_eq!(iter.term_freq(), term_freqs[i]);
681 if i < doc_ids.len() - 1 {
682 iter.advance();
683 }
684 }
685 }
686
687 #[test]
688 fn test_bitpacked_seek() {
689 let doc_ids: Vec<u32> = vec![10, 20, 30, 100, 200, 300, 1000, 2000];
690 let term_freqs: Vec<u32> = vec![1, 2, 3, 4, 5, 6, 7, 8];
691
692 let posting_list = HorizontalBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
693 let mut iter = posting_list.iterator();
694
695 assert_eq!(iter.seek(25), 30);
696 assert_eq!(iter.seek(100), 100);
697 assert_eq!(iter.seek(500), 1000);
698 assert_eq!(iter.seek(3000), u32::MAX);
699 }
700
701 #[test]
702 fn test_serialization() {
703 let doc_ids: Vec<u32> = (0..50).map(|i| i * 3).collect();
704 let term_freqs: Vec<u32> = (0..50).map(|_| 1).collect();
705
706 let posting_list = HorizontalBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.5);
707
708 let mut buffer = Vec::new();
709 posting_list.serialize(&mut buffer).unwrap();
710
711 let restored = HorizontalBP128PostingList::deserialize(&mut &buffer[..]).unwrap();
712
713 assert_eq!(restored.doc_count, posting_list.doc_count);
714 assert_eq!(restored.blocks.len(), posting_list.blocks.len());
715
716 let mut iter1 = posting_list.iterator();
718 let mut iter2 = restored.iterator();
719
720 while iter1.doc() != u32::MAX {
721 assert_eq!(iter1.doc(), iter2.doc());
722 assert_eq!(iter1.term_freq(), iter2.term_freq());
723 iter1.advance();
724 iter2.advance();
725 }
726 }
727
728 #[test]
729 fn test_hillis_steele_prefix_sum() {
730 let mut deltas = [1u32, 2, 3, 4, 5, 6, 7, 8];
732 prefix_sum_8(&mut deltas);
733 assert_eq!(deltas, [1, 3, 6, 10, 15, 21, 28, 36]);
735
736 let deltas2 = [0u32; 16]; let mut output2 = [0u32; 16];
739 simd::delta_decode(&mut output2, &deltas2, 100, 8);
740 assert_eq!(&output2[..8], &[100, 101, 102, 103, 104, 105, 106, 107]);
742
743 let deltas3 = [1u32, 0, 2, 0, 4, 0, 0, 0];
746 let mut output3 = [0u32; 8];
747 simd::delta_decode(&mut output3, &deltas3, 10, 8);
748 assert_eq!(&output3[..8], &[10, 12, 13, 16, 17, 22, 23, 24]);
750 }
751
752 #[test]
753 fn test_delta_decode_large_block() {
754 let doc_ids: Vec<u32> = (0..128).map(|i| i * 5 + 100).collect();
756 let term_freqs: Vec<u32> = vec![1; 128];
757
758 let posting_list = HorizontalBP128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
759 let decoded = posting_list.blocks[0].decode_doc_ids();
760
761 assert_eq!(decoded.len(), 128);
762 for (i, (&expected, &actual)) in doc_ids.iter().zip(decoded.iter()).enumerate() {
763 assert_eq!(expected, actual, "Mismatch at position {}", i);
764 }
765 }
766}