1use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
13use std::io::{self, Read, Write};
14
15pub const BITPACK_BLOCK_SIZE: usize = 128;
17
18pub const SMALL_BLOCK_SIZE: usize = 32;
20
21pub const SMALL_BLOCK_THRESHOLD: usize = 256;
23
24#[inline]
26pub fn bits_needed(max_val: u32) -> u8 {
27 if max_val == 0 {
28 0
29 } else {
30 32 - max_val.leading_zeros() as u8
31 }
32}
33
34pub fn pack_block(values: &[u32; BITPACK_BLOCK_SIZE], bit_width: u8, output: &mut Vec<u8>) {
36 if bit_width == 0 {
37 return;
38 }
39
40 let bytes_needed = (BITPACK_BLOCK_SIZE * bit_width as usize).div_ceil(8);
41 let start = output.len();
42 output.resize(start + bytes_needed, 0);
43
44 let mut bit_pos = 0usize;
45 for &value in values {
46 let byte_idx = start + bit_pos / 8;
47 let bit_offset = bit_pos % 8;
48
49 let mut remaining_bits = bit_width as usize;
51 let mut val = value;
52 let mut current_byte_idx = byte_idx;
53 let mut current_bit_offset = bit_offset;
54
55 while remaining_bits > 0 {
56 let bits_in_byte = (8 - current_bit_offset).min(remaining_bits);
57 let mask = ((1u32 << bits_in_byte) - 1) as u8;
58 output[current_byte_idx] |= ((val as u8) & mask) << current_bit_offset;
59 val >>= bits_in_byte;
60 remaining_bits -= bits_in_byte;
61 current_byte_idx += 1;
62 current_bit_offset = 0;
63 }
64
65 bit_pos += bit_width as usize;
66 }
67}
68
69pub fn unpack_block(input: &[u8], bit_width: u8, output: &mut [u32; BITPACK_BLOCK_SIZE]) {
72 if bit_width == 0 {
73 output.fill(0);
74 return;
75 }
76
77 match bit_width {
79 8 => unpack_block_8(input, output),
80 16 => unpack_block_16(input, output),
81 _ => unpack_block_generic(input, bit_width, output),
82 }
83}
84
85#[inline]
87fn unpack_block_8(input: &[u8], output: &mut [u32; BITPACK_BLOCK_SIZE]) {
88 for (i, out) in output.iter_mut().enumerate() {
89 *out = input[i] as u32;
90 }
91}
92
93#[inline]
95fn unpack_block_16(input: &[u8], output: &mut [u32; BITPACK_BLOCK_SIZE]) {
96 for (i, out) in output.iter_mut().enumerate() {
97 let idx = i * 2;
98 *out = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
99 }
100}
101
102fn unpack_block_generic(input: &[u8], bit_width: u8, output: &mut [u32; BITPACK_BLOCK_SIZE]) {
104 let mut bit_pos = 0usize;
105 for out in output.iter_mut() {
106 let byte_idx = bit_pos / 8;
107 let bit_offset = bit_pos % 8;
108
109 let mut value = 0u32;
110 let mut remaining_bits = bit_width as usize;
111 let mut current_byte_idx = byte_idx;
112 let mut current_bit_offset = bit_offset;
113 let mut shift = 0;
114
115 while remaining_bits > 0 {
116 let bits_in_byte = (8 - current_bit_offset).min(remaining_bits);
117 let mask = ((1u32 << bits_in_byte) - 1) as u8;
118 let byte_val = (input[current_byte_idx] >> current_bit_offset) & mask;
119 value |= (byte_val as u32) << shift;
120 shift += bits_in_byte;
121 remaining_bits -= bits_in_byte;
122 current_byte_idx += 1;
123 current_bit_offset = 0;
124 }
125
126 *out = value;
127 bit_pos += bit_width as usize;
128 }
129}
130
131pub fn unpack_block_n(input: &[u8], bit_width: u8, output: &mut [u32], n: usize) {
133 if bit_width == 0 {
134 output[..n].fill(0);
135 return;
136 }
137
138 let mut bit_pos = 0usize;
139 for out in output[..n].iter_mut() {
140 let byte_idx = bit_pos / 8;
141 let bit_offset = bit_pos % 8;
142
143 let mut value = 0u32;
144 let mut remaining_bits = bit_width as usize;
145 let mut current_byte_idx = byte_idx;
146 let mut current_bit_offset = bit_offset;
147 let mut shift = 0;
148
149 while remaining_bits > 0 {
150 let bits_in_byte = (8 - current_bit_offset).min(remaining_bits);
151 let mask = ((1u32 << bits_in_byte) - 1) as u8;
152 let byte_val = (input[current_byte_idx] >> current_bit_offset) & mask;
153 value |= (byte_val as u32) << shift;
154 shift += bits_in_byte;
155 remaining_bits -= bits_in_byte;
156 current_byte_idx += 1;
157 current_bit_offset = 0;
158 }
159
160 *out = value;
161 bit_pos += bit_width as usize;
162 }
163}
164
165#[inline]
168pub fn binary_search_block(block: &[u32], target: u32) -> usize {
169 match block.binary_search(&target) {
170 Ok(idx) => idx,
171 Err(idx) => idx,
172 }
173}
174
175#[allow(dead_code)]
179#[inline]
180fn prefix_sum_8(deltas: &mut [u32; 8]) {
181 for i in (1..8).rev() {
183 deltas[i] = deltas[i].wrapping_add(deltas[i - 1]);
184 }
185 for i in (2..8).rev() {
187 deltas[i] = deltas[i].wrapping_add(deltas[i - 2]);
188 }
189 for i in (4..8).rev() {
191 deltas[i] = deltas[i].wrapping_add(deltas[i - 4]);
192 }
193}
194
195#[inline]
204pub fn delta_decode_block(output: &mut [u32], deltas: &[u32], first_doc_id: u32, count: usize) {
205 if count == 0 {
206 return;
207 }
208
209 let mut doc_id = first_doc_id;
210 output[0] = doc_id;
211
212 for i in 1..count {
213 doc_id = doc_id.wrapping_add(deltas[i - 1]).wrapping_add(1);
215 output[i] = doc_id;
216 }
217}
218
219#[derive(Debug, Clone)]
221pub struct BitpackedBlock {
222 pub doc_deltas: Vec<u8>,
224 pub doc_bit_width: u8,
226 pub term_freqs: Vec<u8>,
228 pub tf_bit_width: u8,
230 pub first_doc_id: u32,
232 pub last_doc_id: u32,
234 pub num_docs: u16,
236 pub max_tf: u32,
238 pub max_block_score: f32,
241}
242
243impl BitpackedBlock {
244 pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
246 writer.write_u32::<LittleEndian>(self.first_doc_id)?;
247 writer.write_u32::<LittleEndian>(self.last_doc_id)?;
248 writer.write_u16::<LittleEndian>(self.num_docs)?;
249 writer.write_u8(self.doc_bit_width)?;
250 writer.write_u8(self.tf_bit_width)?;
251 writer.write_u32::<LittleEndian>(self.max_tf)?;
252 writer.write_f32::<LittleEndian>(self.max_block_score)?;
253
254 writer.write_u16::<LittleEndian>(self.doc_deltas.len() as u16)?;
256 writer.write_all(&self.doc_deltas)?;
257
258 writer.write_u16::<LittleEndian>(self.term_freqs.len() as u16)?;
260 writer.write_all(&self.term_freqs)?;
261
262 Ok(())
263 }
264
265 pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
267 let first_doc_id = reader.read_u32::<LittleEndian>()?;
268 let last_doc_id = reader.read_u32::<LittleEndian>()?;
269 let num_docs = reader.read_u16::<LittleEndian>()?;
270 let doc_bit_width = reader.read_u8()?;
271 let tf_bit_width = reader.read_u8()?;
272 let max_tf = reader.read_u32::<LittleEndian>()?;
273 let max_block_score = reader.read_f32::<LittleEndian>()?;
274
275 let doc_deltas_len = reader.read_u16::<LittleEndian>()? as usize;
276 let mut doc_deltas = vec![0u8; doc_deltas_len];
277 reader.read_exact(&mut doc_deltas)?;
278
279 let term_freqs_len = reader.read_u16::<LittleEndian>()? as usize;
280 let mut term_freqs = vec![0u8; term_freqs_len];
281 reader.read_exact(&mut term_freqs)?;
282
283 Ok(Self {
284 doc_deltas,
285 doc_bit_width,
286 term_freqs,
287 tf_bit_width,
288 first_doc_id,
289 last_doc_id,
290 num_docs,
291 max_tf,
292 max_block_score,
293 })
294 }
295
296 pub fn decode_doc_ids(&self) -> Vec<u32> {
298 if self.num_docs == 0 {
299 return Vec::new();
300 }
301
302 let count = self.num_docs as usize;
303 let mut deltas = [0u32; BITPACK_BLOCK_SIZE];
304 unpack_block(&self.doc_deltas, self.doc_bit_width, &mut deltas);
305
306 let mut output = [0u32; BITPACK_BLOCK_SIZE];
307 delta_decode_block(&mut output, &deltas, self.first_doc_id, count);
308
309 output[..count].to_vec()
310 }
311
312 pub fn decode_term_freqs(&self) -> Vec<u32> {
314 if self.num_docs == 0 {
315 return Vec::new();
316 }
317
318 let mut tfs = [0u32; BITPACK_BLOCK_SIZE];
319 unpack_block(&self.term_freqs, self.tf_bit_width, &mut tfs);
320
321 tfs[..self.num_docs as usize]
323 .iter()
324 .map(|&tf| tf + 1)
325 .collect()
326 }
327}
328
329#[derive(Debug, Clone)]
331pub struct BitpackedPostingList {
332 pub blocks: Vec<BitpackedBlock>,
334 pub doc_count: u32,
336 pub max_score: f32,
338}
339
340impl BitpackedPostingList {
341 pub fn from_postings(doc_ids: &[u32], term_freqs: &[u32], idf: f32) -> Self {
343 assert_eq!(doc_ids.len(), term_freqs.len());
344
345 if doc_ids.is_empty() {
346 return Self {
347 blocks: Vec::new(),
348 doc_count: 0,
349 max_score: 0.0,
350 };
351 }
352
353 let mut blocks = Vec::new();
354 let mut max_score = 0.0f32;
355 let mut i = 0;
356
357 while i < doc_ids.len() {
358 let block_end = (i + BITPACK_BLOCK_SIZE).min(doc_ids.len());
359 let block_docs = &doc_ids[i..block_end];
360 let block_tfs = &term_freqs[i..block_end];
361
362 let block = Self::create_block(block_docs, block_tfs, idf);
363 max_score = max_score.max(block.max_block_score);
364 blocks.push(block);
365
366 i = block_end;
367 }
368
369 Self {
370 blocks,
371 doc_count: doc_ids.len() as u32,
372 max_score,
373 }
374 }
375
376 const K1: f32 = 1.2;
378 const B: f32 = 0.75;
379
380 #[inline]
383 pub fn compute_bm25f_upper_bound(max_tf: u32, idf: f32, field_boost: f32) -> f32 {
384 let tf = max_tf as f32;
385 let min_length_norm = 1.0 - Self::B;
388 let tf_norm =
389 (tf * field_boost * (Self::K1 + 1.0)) / (tf * field_boost + Self::K1 * min_length_norm);
390 idf * tf_norm
391 }
392
393 fn create_block(doc_ids: &[u32], term_freqs: &[u32], idf: f32) -> BitpackedBlock {
394 let num_docs = doc_ids.len();
395 let first_doc_id = doc_ids[0];
396 let last_doc_id = *doc_ids.last().unwrap();
397
398 let mut deltas = [0u32; BITPACK_BLOCK_SIZE];
400 let mut max_delta = 0u32;
401 for j in 1..num_docs {
402 let delta = doc_ids[j] - doc_ids[j - 1] - 1;
403 deltas[j - 1] = delta;
404 max_delta = max_delta.max(delta);
405 }
406
407 let mut tfs = [0u32; BITPACK_BLOCK_SIZE];
409 let mut max_tf = 0u32;
410
411 for (j, &tf) in term_freqs.iter().enumerate() {
412 tfs[j] = tf - 1; max_tf = max_tf.max(tf);
414 }
415
416 let max_block_score = Self::compute_bm25f_upper_bound(max_tf, idf, 1.0);
419
420 let doc_bit_width = bits_needed(max_delta);
421 let tf_bit_width = bits_needed(max_tf.saturating_sub(1)); let mut doc_deltas = Vec::new();
424 pack_block(&deltas, doc_bit_width, &mut doc_deltas);
425
426 let mut term_freqs_packed = Vec::new();
427 pack_block(&tfs, tf_bit_width, &mut term_freqs_packed);
428
429 BitpackedBlock {
430 doc_deltas,
431 doc_bit_width,
432 term_freqs: term_freqs_packed,
433 tf_bit_width,
434 first_doc_id,
435 last_doc_id,
436 num_docs: num_docs as u16,
437 max_tf,
438 max_block_score,
439 }
440 }
441
442 pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
444 writer.write_u32::<LittleEndian>(self.doc_count)?;
445 writer.write_f32::<LittleEndian>(self.max_score)?;
446 writer.write_u32::<LittleEndian>(self.blocks.len() as u32)?;
447
448 for block in &self.blocks {
449 block.serialize(writer)?;
450 }
451
452 Ok(())
453 }
454
455 pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
457 let doc_count = reader.read_u32::<LittleEndian>()?;
458 let max_score = reader.read_f32::<LittleEndian>()?;
459 let num_blocks = reader.read_u32::<LittleEndian>()? as usize;
460
461 let mut blocks = Vec::with_capacity(num_blocks);
462 for _ in 0..num_blocks {
463 blocks.push(BitpackedBlock::deserialize(reader)?);
464 }
465
466 Ok(Self {
467 blocks,
468 doc_count,
469 max_score,
470 })
471 }
472
473 pub fn iterator(&self) -> BitpackedPostingIterator<'_> {
475 BitpackedPostingIterator::new(self)
476 }
477}
478
479pub struct BitpackedPostingIterator<'a> {
481 posting_list: &'a BitpackedPostingList,
482 current_block: usize,
484 block_doc_ids: Vec<u32>,
486 block_term_freqs: Vec<u32>,
488 pos_in_block: usize,
490 exhausted: bool,
492}
493
494impl<'a> BitpackedPostingIterator<'a> {
495 pub fn new(posting_list: &'a BitpackedPostingList) -> Self {
496 let mut iter = Self {
497 posting_list,
498 current_block: 0,
499 block_doc_ids: Vec::new(),
500 block_term_freqs: Vec::new(),
501 pos_in_block: 0,
502 exhausted: posting_list.blocks.is_empty(),
503 };
504
505 if !iter.exhausted {
506 iter.decode_current_block();
507 }
508
509 iter
510 }
511
512 fn decode_current_block(&mut self) {
513 let block = &self.posting_list.blocks[self.current_block];
514 self.block_doc_ids = block.decode_doc_ids();
515 self.block_term_freqs = block.decode_term_freqs();
516 self.pos_in_block = 0;
517 }
518
519 pub fn doc(&self) -> u32 {
521 if self.exhausted {
522 u32::MAX
523 } else {
524 self.block_doc_ids[self.pos_in_block]
525 }
526 }
527
528 pub fn term_freq(&self) -> u32 {
530 if self.exhausted {
531 0
532 } else {
533 self.block_term_freqs[self.pos_in_block]
534 }
535 }
536
537 pub fn advance(&mut self) -> u32 {
539 if self.exhausted {
540 return u32::MAX;
541 }
542
543 self.pos_in_block += 1;
544
545 if self.pos_in_block >= self.block_doc_ids.len() {
546 self.current_block += 1;
547 if self.current_block >= self.posting_list.blocks.len() {
548 self.exhausted = true;
549 return u32::MAX;
550 }
551 self.decode_current_block();
552 }
553
554 self.doc()
555 }
556
557 pub fn seek(&mut self, target: u32) -> u32 {
559 if self.exhausted {
560 return u32::MAX;
561 }
562
563 let block_idx = self.posting_list.blocks[self.current_block..].binary_search_by(|block| {
565 if block.last_doc_id < target {
566 std::cmp::Ordering::Less
567 } else if block.first_doc_id > target {
568 std::cmp::Ordering::Greater
569 } else {
570 std::cmp::Ordering::Equal
571 }
572 });
573
574 let target_block = match block_idx {
575 Ok(idx) => self.current_block + idx,
576 Err(idx) => {
577 if self.current_block + idx >= self.posting_list.blocks.len() {
578 self.exhausted = true;
579 return u32::MAX;
580 }
581 self.current_block + idx
582 }
583 };
584
585 if target_block != self.current_block {
587 self.current_block = target_block;
588 self.decode_current_block();
589 } else if self.block_doc_ids.is_empty() {
590 self.decode_current_block();
591 }
592
593 let pos = binary_search_block(&self.block_doc_ids[self.pos_in_block..], target);
595 self.pos_in_block += pos;
596
597 if self.pos_in_block >= self.block_doc_ids.len() {
598 self.current_block += 1;
600 if self.current_block >= self.posting_list.blocks.len() {
601 self.exhausted = true;
602 return u32::MAX;
603 }
604 self.decode_current_block();
605 }
606
607 self.doc()
608 }
609
610 pub fn max_remaining_score(&self) -> f32 {
612 if self.exhausted {
613 return 0.0;
614 }
615
616 self.posting_list.blocks[self.current_block..]
617 .iter()
618 .map(|b| b.max_block_score)
619 .fold(0.0f32, |a, b| a.max(b))
620 }
621
622 pub fn skip_to_block_with_doc(&mut self, target: u32) -> Option<(u32, f32)> {
624 while self.current_block < self.posting_list.blocks.len() {
625 let block = &self.posting_list.blocks[self.current_block];
626 if block.last_doc_id >= target {
627 return Some((block.first_doc_id, block.max_block_score));
628 }
629 self.current_block += 1;
630 }
631 self.exhausted = true;
632 None
633 }
634
635 pub fn current_block_max_score(&self) -> f32 {
637 if self.exhausted {
638 0.0
639 } else {
640 self.posting_list.blocks[self.current_block].max_block_score
641 }
642 }
643
644 pub fn current_block_max_tf(&self) -> u32 {
646 if self.exhausted {
647 0
648 } else {
649 self.posting_list.blocks[self.current_block].max_tf
650 }
651 }
652}
653
654#[cfg(test)]
655mod tests {
656 use super::*;
657
658 #[test]
659 fn test_bits_needed() {
660 assert_eq!(bits_needed(0), 0);
661 assert_eq!(bits_needed(1), 1);
662 assert_eq!(bits_needed(2), 2);
663 assert_eq!(bits_needed(3), 2);
664 assert_eq!(bits_needed(255), 8);
665 assert_eq!(bits_needed(256), 9);
666 }
667
668 #[test]
669 fn test_pack_unpack() {
670 let mut values = [0u32; BITPACK_BLOCK_SIZE];
671 for i in 0..BITPACK_BLOCK_SIZE {
672 values[i] = (i * 3) as u32;
673 }
674
675 let max_val = values.iter().max().copied().unwrap();
676 let bit_width = bits_needed(max_val);
677
678 let mut packed = Vec::new();
679 pack_block(&values, bit_width, &mut packed);
680
681 let mut unpacked = [0u32; BITPACK_BLOCK_SIZE];
682 unpack_block(&packed, bit_width, &mut unpacked);
683
684 assert_eq!(values, unpacked);
685 }
686
687 #[test]
688 fn test_bitpacked_posting_list() {
689 let doc_ids: Vec<u32> = (0..200).map(|i| i * 2).collect();
690 let term_freqs: Vec<u32> = (0..200).map(|i| (i % 10) + 1).collect();
691
692 let posting_list = BitpackedPostingList::from_postings(&doc_ids, &term_freqs, 1.0);
693
694 assert_eq!(posting_list.doc_count, 200);
695 assert_eq!(posting_list.blocks.len(), 2); let mut iter = posting_list.iterator();
699 for (i, &expected_doc) in doc_ids.iter().enumerate() {
700 assert_eq!(iter.doc(), expected_doc, "Mismatch at position {}", i);
701 assert_eq!(iter.term_freq(), term_freqs[i]);
702 if i < doc_ids.len() - 1 {
703 iter.advance();
704 }
705 }
706 }
707
708 #[test]
709 fn test_bitpacked_seek() {
710 let doc_ids: Vec<u32> = vec![10, 20, 30, 100, 200, 300, 1000, 2000];
711 let term_freqs: Vec<u32> = vec![1, 2, 3, 4, 5, 6, 7, 8];
712
713 let posting_list = BitpackedPostingList::from_postings(&doc_ids, &term_freqs, 1.0);
714 let mut iter = posting_list.iterator();
715
716 assert_eq!(iter.seek(25), 30);
717 assert_eq!(iter.seek(100), 100);
718 assert_eq!(iter.seek(500), 1000);
719 assert_eq!(iter.seek(3000), u32::MAX);
720 }
721
722 #[test]
723 fn test_serialization() {
724 let doc_ids: Vec<u32> = (0..50).map(|i| i * 3).collect();
725 let term_freqs: Vec<u32> = (0..50).map(|_| 1).collect();
726
727 let posting_list = BitpackedPostingList::from_postings(&doc_ids, &term_freqs, 1.5);
728
729 let mut buffer = Vec::new();
730 posting_list.serialize(&mut buffer).unwrap();
731
732 let restored = BitpackedPostingList::deserialize(&mut &buffer[..]).unwrap();
733
734 assert_eq!(restored.doc_count, posting_list.doc_count);
735 assert_eq!(restored.blocks.len(), posting_list.blocks.len());
736
737 let mut iter1 = posting_list.iterator();
739 let mut iter2 = restored.iterator();
740
741 while iter1.doc() != u32::MAX {
742 assert_eq!(iter1.doc(), iter2.doc());
743 assert_eq!(iter1.term_freq(), iter2.term_freq());
744 iter1.advance();
745 iter2.advance();
746 }
747 }
748
749 #[test]
750 fn test_hillis_steele_prefix_sum() {
751 let mut deltas = [1u32, 2, 3, 4, 5, 6, 7, 8];
753 prefix_sum_8(&mut deltas);
754 assert_eq!(deltas, [1, 3, 6, 10, 15, 21, 28, 36]);
756
757 let deltas2 = [0u32; 16]; let mut output2 = [0u32; 16];
760 delta_decode_block(&mut output2, &deltas2, 100, 8);
761 assert_eq!(&output2[..8], &[100, 101, 102, 103, 104, 105, 106, 107]);
763
764 let deltas3 = [1u32, 0, 2, 0, 4, 0, 0, 0];
767 let mut output3 = [0u32; 8];
768 delta_decode_block(&mut output3, &deltas3, 10, 8);
769 assert_eq!(&output3[..8], &[10, 12, 13, 16, 17, 22, 23, 24]);
771 }
772
773 #[test]
774 fn test_delta_decode_large_block() {
775 let doc_ids: Vec<u32> = (0..128).map(|i| i * 5 + 100).collect();
777 let term_freqs: Vec<u32> = vec![1; 128];
778
779 let posting_list = BitpackedPostingList::from_postings(&doc_ids, &term_freqs, 1.0);
780 let decoded = posting_list.blocks[0].decode_doc_ids();
781
782 assert_eq!(decoded.len(), 128);
783 for (i, (&expected, &actual)) in doc_ids.iter().zip(decoded.iter()).enumerate() {
784 assert_eq!(expected, actual, "Mismatch at position {}", i);
785 }
786 }
787}