1use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
19use std::io::{self, Read, Write};
20
21#[cfg(target_arch = "aarch64")]
26#[allow(unsafe_op_in_unsafe_fn)]
27mod neon {
28 #[allow(unused_imports)]
29 use super::OPT_P4D_BLOCK_SIZE;
30 use std::arch::aarch64::*;
31
32 #[target_feature(enable = "neon")]
34 pub unsafe fn unpack_8bit_neon(input: &[u8], output: &mut [u32], count: usize) {
35 let chunks = count / 16;
36 let remainder = count % 16;
37
38 for chunk in 0..chunks {
39 let base = chunk * 16;
40 let in_ptr = input.as_ptr().add(base);
41
42 let bytes = vld1q_u8(in_ptr);
44
45 let low8 = vget_low_u8(bytes);
47 let high8 = vget_high_u8(bytes);
48
49 let low16 = vmovl_u8(low8);
50 let high16 = vmovl_u8(high8);
51
52 let v0 = vmovl_u16(vget_low_u16(low16));
53 let v1 = vmovl_u16(vget_high_u16(low16));
54 let v2 = vmovl_u16(vget_low_u16(high16));
55 let v3 = vmovl_u16(vget_high_u16(high16));
56
57 let out_ptr = output.as_mut_ptr().add(base);
58 vst1q_u32(out_ptr, v0);
59 vst1q_u32(out_ptr.add(4), v1);
60 vst1q_u32(out_ptr.add(8), v2);
61 vst1q_u32(out_ptr.add(12), v3);
62 }
63
64 let base = chunks * 16;
66 for i in 0..remainder {
67 output[base + i] = input[base + i] as u32;
68 }
69 }
70
71 #[target_feature(enable = "neon")]
73 pub unsafe fn unpack_16bit_neon(input: &[u8], output: &mut [u32], count: usize) {
74 let chunks = count / 8;
75 let remainder = count % 8;
76
77 for chunk in 0..chunks {
78 let base = chunk * 8;
79 let in_ptr = input.as_ptr().add(base * 2) as *const u16;
80
81 let vals = vld1q_u16(in_ptr);
82 let low = vmovl_u16(vget_low_u16(vals));
83 let high = vmovl_u16(vget_high_u16(vals));
84
85 let out_ptr = output.as_mut_ptr().add(base);
86 vst1q_u32(out_ptr, low);
87 vst1q_u32(out_ptr.add(4), high);
88 }
89
90 let base = chunks * 8;
92 for i in 0..remainder {
93 let idx = (base + i) * 2;
94 output[base + i] = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
95 }
96 }
97
98 #[target_feature(enable = "neon")]
101 pub unsafe fn delta_decode_neon(
102 deltas: &[u32],
103 output: &mut [u32],
104 first_doc_id: u32,
105 count: usize,
106 ) {
107 if count == 0 {
108 return;
109 }
110
111 output[0] = first_doc_id;
112 if count == 1 {
113 return;
114 }
115
116 let mut carry = first_doc_id;
117 let ones = vdupq_n_u32(1);
118
119 let full_groups = (count - 1) / 4;
120 let remainder = (count - 1) % 4;
121
122 for group in 0..full_groups {
123 let base = group * 4;
124
125 let d = vld1q_u32(deltas[base..].as_ptr());
127
128 let gaps = vaddq_u32(d, ones);
130
131 let g0 = vgetq_lane_u32(gaps, 0);
133 let g1 = vgetq_lane_u32(gaps, 1);
134 let g2 = vgetq_lane_u32(gaps, 2);
135 let g3 = vgetq_lane_u32(gaps, 3);
136
137 let v0 = carry.wrapping_add(g0);
138 let v1 = v0.wrapping_add(g1);
139 let v2 = v1.wrapping_add(g2);
140 let v3 = v2.wrapping_add(g3);
141
142 output[base + 1] = v0;
143 output[base + 2] = v1;
144 output[base + 3] = v2;
145 output[base + 4] = v3;
146
147 carry = v3;
148 }
149
150 let base = full_groups * 4;
152 for j in 0..remainder {
153 carry = carry.wrapping_add(deltas[base + j]).wrapping_add(1);
154 output[base + j + 1] = carry;
155 }
156 }
157
158 #[target_feature(enable = "neon")]
160 pub unsafe fn add_one_neon(values: &mut [u32], count: usize) {
161 let ones = vdupq_n_u32(1);
162 let chunks = count / 4;
163 let remainder = count % 4;
164
165 for chunk in 0..chunks {
166 let base = chunk * 4;
167 let ptr = values.as_mut_ptr().add(base);
168 let v = vld1q_u32(ptr);
169 let result = vaddq_u32(v, ones);
170 vst1q_u32(ptr, result);
171 }
172
173 let base = chunks * 4;
174 for i in 0..remainder {
175 values[base + i] += 1;
176 }
177 }
178
179 #[inline]
181 pub fn is_available() -> bool {
182 true
183 }
184}
185
186#[cfg(target_arch = "x86_64")]
191#[allow(unsafe_op_in_unsafe_fn)]
192mod sse {
193 use std::arch::x86_64::*;
194
195 #[target_feature(enable = "sse2", enable = "sse4.1")]
197 pub unsafe fn unpack_8bit_sse(input: &[u8], output: &mut [u32], count: usize) {
198 let chunks = count / 16;
199 let remainder = count % 16;
200
201 for chunk in 0..chunks {
202 let base = chunk * 16;
203 let in_ptr = input.as_ptr().add(base);
204
205 let bytes = _mm_loadu_si128(in_ptr as *const __m128i);
206
207 let v0 = _mm_cvtepu8_epi32(bytes);
209 let v1 = _mm_cvtepu8_epi32(_mm_srli_si128(bytes, 4));
210 let v2 = _mm_cvtepu8_epi32(_mm_srli_si128(bytes, 8));
211 let v3 = _mm_cvtepu8_epi32(_mm_srli_si128(bytes, 12));
212
213 let out_ptr = output.as_mut_ptr().add(base);
214 _mm_storeu_si128(out_ptr as *mut __m128i, v0);
215 _mm_storeu_si128(out_ptr.add(4) as *mut __m128i, v1);
216 _mm_storeu_si128(out_ptr.add(8) as *mut __m128i, v2);
217 _mm_storeu_si128(out_ptr.add(12) as *mut __m128i, v3);
218 }
219
220 let base = chunks * 16;
221 for i in 0..remainder {
222 output[base + i] = input[base + i] as u32;
223 }
224 }
225
226 #[target_feature(enable = "sse2", enable = "sse4.1")]
228 pub unsafe fn unpack_16bit_sse(input: &[u8], output: &mut [u32], count: usize) {
229 let chunks = count / 8;
230 let remainder = count % 8;
231
232 for chunk in 0..chunks {
233 let base = chunk * 8;
234 let in_ptr = input.as_ptr().add(base * 2);
235
236 let vals = _mm_loadu_si128(in_ptr as *const __m128i);
237 let low = _mm_cvtepu16_epi32(vals);
238 let high = _mm_cvtepu16_epi32(_mm_srli_si128(vals, 8));
239
240 let out_ptr = output.as_mut_ptr().add(base);
241 _mm_storeu_si128(out_ptr as *mut __m128i, low);
242 _mm_storeu_si128(out_ptr.add(4) as *mut __m128i, high);
243 }
244
245 let base = chunks * 8;
246 for i in 0..remainder {
247 let idx = (base + i) * 2;
248 output[base + i] = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
249 }
250 }
251
252 #[target_feature(enable = "sse2", enable = "sse4.1")]
254 pub unsafe fn delta_decode_sse(
255 deltas: &[u32],
256 output: &mut [u32],
257 first_doc_id: u32,
258 count: usize,
259 ) {
260 if count == 0 {
261 return;
262 }
263
264 output[0] = first_doc_id;
265 if count == 1 {
266 return;
267 }
268
269 let mut carry = first_doc_id;
270 let ones = _mm_set1_epi32(1);
271
272 let full_groups = (count - 1) / 4;
273 let remainder = (count - 1) % 4;
274
275 for group in 0..full_groups {
276 let base = group * 4;
277
278 let d = _mm_loadu_si128(deltas[base..].as_ptr() as *const __m128i);
279 let gaps = _mm_add_epi32(d, ones);
280
281 let g0 = _mm_extract_epi32(gaps, 0) as u32;
282 let g1 = _mm_extract_epi32(gaps, 1) as u32;
283 let g2 = _mm_extract_epi32(gaps, 2) as u32;
284 let g3 = _mm_extract_epi32(gaps, 3) as u32;
285
286 let v0 = carry.wrapping_add(g0);
287 let v1 = v0.wrapping_add(g1);
288 let v2 = v1.wrapping_add(g2);
289 let v3 = v2.wrapping_add(g3);
290
291 output[base + 1] = v0;
292 output[base + 2] = v1;
293 output[base + 3] = v2;
294 output[base + 4] = v3;
295
296 carry = v3;
297 }
298
299 let base = full_groups * 4;
300 for j in 0..remainder {
301 carry = carry.wrapping_add(deltas[base + j]).wrapping_add(1);
302 output[base + j + 1] = carry;
303 }
304 }
305
306 #[target_feature(enable = "sse2")]
308 pub unsafe fn add_one_sse(values: &mut [u32], count: usize) {
309 let ones = _mm_set1_epi32(1);
310 let chunks = count / 4;
311 let remainder = count % 4;
312
313 for chunk in 0..chunks {
314 let base = chunk * 4;
315 let ptr = values.as_mut_ptr().add(base) as *mut __m128i;
316 let v = _mm_loadu_si128(ptr);
317 let result = _mm_add_epi32(v, ones);
318 _mm_storeu_si128(ptr, result);
319 }
320
321 let base = chunks * 4;
322 for i in 0..remainder {
323 values[base + i] += 1;
324 }
325 }
326
327 #[inline]
329 pub fn is_available() -> bool {
330 is_x86_feature_detected!("sse4.1")
331 }
332}
333
334mod scalar {
339 #[inline]
341 pub fn unpack_8bit_scalar(input: &[u8], output: &mut [u32], count: usize) {
342 for i in 0..count {
343 output[i] = input[i] as u32;
344 }
345 }
346
347 #[inline]
349 pub fn unpack_16bit_scalar(input: &[u8], output: &mut [u32], count: usize) {
350 for (i, out) in output.iter_mut().enumerate().take(count) {
351 let idx = i * 2;
352 *out = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
353 }
354 }
355
356 #[inline]
358 pub fn delta_decode_scalar(
359 deltas: &[u32],
360 output: &mut [u32],
361 first_doc_id: u32,
362 count: usize,
363 ) {
364 if count == 0 {
365 return;
366 }
367
368 output[0] = first_doc_id;
369 let mut carry = first_doc_id;
370
371 for i in 0..count - 1 {
372 carry = carry.wrapping_add(deltas[i]).wrapping_add(1);
373 output[i + 1] = carry;
374 }
375 }
376
377 #[inline]
379 pub fn add_one_scalar(values: &mut [u32], count: usize) {
380 for val in values.iter_mut().take(count) {
381 *val += 1;
382 }
383 }
384}
385
386#[inline]
392fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
393 #[cfg(target_arch = "aarch64")]
394 {
395 if neon::is_available() {
396 unsafe {
397 neon::unpack_8bit_neon(input, output, count);
398 }
399 return;
400 }
401 }
402
403 #[cfg(target_arch = "x86_64")]
404 {
405 if sse::is_available() {
406 unsafe {
407 sse::unpack_8bit_sse(input, output, count);
408 }
409 return;
410 }
411 }
412
413 scalar::unpack_8bit_scalar(input, output, count);
414}
415
416#[inline]
418fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
419 #[cfg(target_arch = "aarch64")]
420 {
421 if neon::is_available() {
422 unsafe {
423 neon::unpack_16bit_neon(input, output, count);
424 }
425 return;
426 }
427 }
428
429 #[cfg(target_arch = "x86_64")]
430 {
431 if sse::is_available() {
432 unsafe {
433 sse::unpack_16bit_sse(input, output, count);
434 }
435 return;
436 }
437 }
438
439 scalar::unpack_16bit_scalar(input, output, count);
440}
441
442#[inline]
444fn delta_decode_simd(deltas: &[u32], output: &mut [u32], first_doc_id: u32, count: usize) {
445 #[cfg(target_arch = "aarch64")]
446 {
447 if neon::is_available() {
448 unsafe {
449 neon::delta_decode_neon(deltas, output, first_doc_id, count);
450 }
451 return;
452 }
453 }
454
455 #[cfg(target_arch = "x86_64")]
456 {
457 if sse::is_available() {
458 unsafe {
459 sse::delta_decode_sse(deltas, output, first_doc_id, count);
460 }
461 return;
462 }
463 }
464
465 scalar::delta_decode_scalar(deltas, output, first_doc_id, count);
466}
467
468#[inline]
470fn add_one_simd(values: &mut [u32], count: usize) {
471 #[cfg(target_arch = "aarch64")]
472 {
473 if neon::is_available() {
474 unsafe {
475 neon::add_one_neon(values, count);
476 }
477 return;
478 }
479 }
480
481 #[cfg(target_arch = "x86_64")]
482 {
483 if sse::is_available() {
484 unsafe {
485 sse::add_one_sse(values, count);
486 }
487 return;
488 }
489 }
490
491 scalar::add_one_scalar(values, count);
492}
493
494pub const OPT_P4D_BLOCK_SIZE: usize = 128;
496
497const MAX_EXCEPTIONS_RATIO: f32 = 0.10;
500
501#[inline]
503fn bits_needed(val: u32) -> u8 {
504 if val == 0 {
505 0
506 } else {
507 32 - val.leading_zeros() as u8
508 }
509}
510
511fn find_optimal_bit_width(values: &[u32]) -> (u8, usize, usize) {
514 if values.is_empty() {
515 return (0, 0, 0);
516 }
517
518 let n = values.len();
519 let max_exceptions = ((n as f32) * MAX_EXCEPTIONS_RATIO).ceil() as usize;
520
521 let mut bit_counts = [0usize; 33]; for &v in values {
524 let bits = bits_needed(v) as usize;
525 bit_counts[bits] += 1;
526 }
527
528 let mut cumulative = [0usize; 33];
530 cumulative[0] = bit_counts[0];
531 for b in 1..=32 {
532 cumulative[b] = cumulative[b - 1] + bit_counts[b];
533 }
534
535 let mut best_bits = 32u8;
536 let mut best_total = usize::MAX;
537 let mut best_exceptions = 0usize;
538
539 for b in 0..=32u8 {
541 let fitting = if b == 0 {
542 bit_counts[0]
543 } else {
544 cumulative[b as usize]
545 };
546 let exceptions = n - fitting;
547
548 if exceptions > max_exceptions && b < 32 {
550 continue;
551 }
552
553 let main_bits = n * (b as usize);
557 let exception_bits = if b < 32 {
558 exceptions * (7 + (32 - b as usize))
559 } else {
560 0
561 };
562 let total = main_bits + exception_bits;
563
564 if total < best_total {
565 best_total = total;
566 best_bits = b;
567 best_exceptions = exceptions;
568 }
569 }
570
571 (best_bits, best_exceptions, best_total)
572}
573
574fn pack_with_exceptions(values: &[u32], bit_width: u8) -> (Vec<u8>, Vec<(u8, u32)>) {
582 if bit_width == 0 {
583 let exceptions: Vec<(u8, u32)> = values
585 .iter()
586 .enumerate()
587 .filter(|&(_, &v)| v != 0)
588 .map(|(i, &v)| (i as u8, v)) .collect();
590 return (Vec::new(), exceptions);
591 }
592
593 if bit_width >= 32 {
594 let bytes_needed = values.len() * 4;
596 let mut packed = vec![0u8; bytes_needed];
597 for (i, &value) in values.iter().enumerate() {
598 let bytes = value.to_le_bytes();
599 packed[i * 4..i * 4 + 4].copy_from_slice(&bytes);
600 }
601 return (packed, Vec::new());
602 }
603
604 let mask = (1u64 << bit_width) - 1;
605 let bytes_needed = (values.len() * bit_width as usize).div_ceil(8);
606 let mut packed = vec![0u8; bytes_needed];
607 let mut exceptions = Vec::new();
608
609 let mut bit_pos = 0usize;
610 for (i, &value) in values.iter().enumerate() {
611 let low_bits = (value as u64) & mask;
613
614 let byte_idx = bit_pos / 8;
616 let bit_offset = bit_pos % 8;
617
618 let mut remaining_bits = bit_width as usize;
619 let mut val = low_bits;
620 let mut current_byte_idx = byte_idx;
621 let mut current_bit_offset = bit_offset;
622
623 while remaining_bits > 0 {
624 let bits_in_byte = (8 - current_bit_offset).min(remaining_bits);
625 let byte_mask = ((1u64 << bits_in_byte) - 1) as u8;
626 packed[current_byte_idx] |= ((val as u8) & byte_mask) << current_bit_offset;
627 val >>= bits_in_byte;
628 remaining_bits -= bits_in_byte;
629 current_byte_idx += 1;
630 current_bit_offset = 0;
631 }
632
633 bit_pos += bit_width as usize;
634
635 let fits = value <= mask as u32;
637 if !fits {
638 let high_bits = value >> bit_width;
639 exceptions.push((i as u8, high_bits));
640 }
641 }
642
643 (packed, exceptions)
644}
645
646fn unpack_with_exceptions(
655 packed: &[u8],
656 bit_width: u8,
657 exceptions: &[(u8, u32)],
658 count: usize,
659 output: &mut [u32],
660) {
661 if bit_width == 0 {
662 output[..count].fill(0);
663 } else if bit_width == 8 {
664 unpack_8bit(packed, output, count);
666 } else if bit_width == 16 {
667 unpack_16bit(packed, output, count);
669 } else if bit_width >= 32 {
670 for (i, out) in output.iter_mut().enumerate().take(count) {
672 let idx = i * 4;
673 *out = u32::from_le_bytes([
674 packed[idx],
675 packed[idx + 1],
676 packed[idx + 2],
677 packed[idx + 3],
678 ]);
679 }
680 return; } else {
682 let mask = (1u64 << bit_width) - 1;
684 let mut bit_pos = 0usize;
685 let input_ptr = packed.as_ptr();
686
687 for out in output[..count].iter_mut() {
688 let byte_idx = bit_pos >> 3;
689 let bit_offset = bit_pos & 7;
690
691 let word = if byte_idx + 8 <= packed.len() {
693 unsafe { (input_ptr.add(byte_idx) as *const u64).read_unaligned() }
694 } else {
695 let mut word = 0u64;
697 for (i, &b) in packed[byte_idx..].iter().enumerate() {
698 word |= (b as u64) << (i * 8);
699 }
700 word
701 };
702
703 *out = ((word >> bit_offset) & mask) as u32;
704 bit_pos += bit_width as usize;
705 }
706 }
707
708 for &(pos, high_bits) in exceptions {
711 if (pos as usize) < count {
712 let low_bits = output[pos as usize];
713 output[pos as usize] = (high_bits << bit_width) | low_bits;
714 }
715 }
716}
717
718#[derive(Debug, Clone)]
720pub struct OptP4DBlock {
721 pub first_doc_id: u32,
723 pub last_doc_id: u32,
725 pub num_docs: u16,
727 pub doc_bit_width: u8,
729 pub tf_bit_width: u8,
731 pub max_tf: u32,
733 pub max_block_score: f32,
735 pub doc_deltas: Vec<u8>,
737 pub doc_exceptions: Vec<(u8, u32)>,
739 pub term_freqs: Vec<u8>,
741 pub tf_exceptions: Vec<(u8, u32)>,
743}
744
745impl OptP4DBlock {
746 pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
748 writer.write_u32::<LittleEndian>(self.first_doc_id)?;
749 writer.write_u32::<LittleEndian>(self.last_doc_id)?;
750 writer.write_u16::<LittleEndian>(self.num_docs)?;
751 writer.write_u8(self.doc_bit_width)?;
752 writer.write_u8(self.tf_bit_width)?;
753 writer.write_u32::<LittleEndian>(self.max_tf)?;
754 writer.write_f32::<LittleEndian>(self.max_block_score)?;
755
756 writer.write_u16::<LittleEndian>(self.doc_deltas.len() as u16)?;
758 writer.write_all(&self.doc_deltas)?;
759
760 writer.write_u8(self.doc_exceptions.len() as u8)?;
762 for &(pos, val) in &self.doc_exceptions {
763 writer.write_u8(pos)?;
764 writer.write_u32::<LittleEndian>(val)?;
765 }
766
767 writer.write_u16::<LittleEndian>(self.term_freqs.len() as u16)?;
769 writer.write_all(&self.term_freqs)?;
770
771 writer.write_u8(self.tf_exceptions.len() as u8)?;
773 for &(pos, val) in &self.tf_exceptions {
774 writer.write_u8(pos)?;
775 writer.write_u32::<LittleEndian>(val)?;
776 }
777
778 Ok(())
779 }
780
781 pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
783 let first_doc_id = reader.read_u32::<LittleEndian>()?;
784 let last_doc_id = reader.read_u32::<LittleEndian>()?;
785 let num_docs = reader.read_u16::<LittleEndian>()?;
786 let doc_bit_width = reader.read_u8()?;
787 let tf_bit_width = reader.read_u8()?;
788 let max_tf = reader.read_u32::<LittleEndian>()?;
789 let max_block_score = reader.read_f32::<LittleEndian>()?;
790
791 let doc_deltas_len = reader.read_u16::<LittleEndian>()? as usize;
793 let mut doc_deltas = vec![0u8; doc_deltas_len];
794 reader.read_exact(&mut doc_deltas)?;
795
796 let num_doc_exceptions = reader.read_u8()? as usize;
798 let mut doc_exceptions = Vec::with_capacity(num_doc_exceptions);
799 for _ in 0..num_doc_exceptions {
800 let pos = reader.read_u8()?;
801 let val = reader.read_u32::<LittleEndian>()?;
802 doc_exceptions.push((pos, val));
803 }
804
805 let term_freqs_len = reader.read_u16::<LittleEndian>()? as usize;
807 let mut term_freqs = vec![0u8; term_freqs_len];
808 reader.read_exact(&mut term_freqs)?;
809
810 let num_tf_exceptions = reader.read_u8()? as usize;
812 let mut tf_exceptions = Vec::with_capacity(num_tf_exceptions);
813 for _ in 0..num_tf_exceptions {
814 let pos = reader.read_u8()?;
815 let val = reader.read_u32::<LittleEndian>()?;
816 tf_exceptions.push((pos, val));
817 }
818
819 Ok(Self {
820 first_doc_id,
821 last_doc_id,
822 num_docs,
823 doc_bit_width,
824 tf_bit_width,
825 max_tf,
826 max_block_score,
827 doc_deltas,
828 doc_exceptions,
829 term_freqs,
830 tf_exceptions,
831 })
832 }
833
834 pub fn decode_doc_ids(&self) -> Vec<u32> {
836 if self.num_docs == 0 {
837 return Vec::new();
838 }
839
840 let count = self.num_docs as usize;
841 let mut deltas = vec![0u32; count];
842
843 if count > 1 {
845 unpack_with_exceptions(
846 &self.doc_deltas,
847 self.doc_bit_width,
848 &self.doc_exceptions,
849 count - 1,
850 &mut deltas,
851 );
852 }
853
854 let mut doc_ids = vec![0u32; count];
856 delta_decode_simd(&deltas, &mut doc_ids, self.first_doc_id, count);
857
858 doc_ids
859 }
860
861 pub fn decode_term_freqs(&self) -> Vec<u32> {
863 if self.num_docs == 0 {
864 return Vec::new();
865 }
866
867 let count = self.num_docs as usize;
868 let mut tfs = vec![0u32; count];
869
870 unpack_with_exceptions(
872 &self.term_freqs,
873 self.tf_bit_width,
874 &self.tf_exceptions,
875 count,
876 &mut tfs,
877 );
878
879 add_one_simd(&mut tfs, count);
881
882 tfs
883 }
884}
885
886#[derive(Debug, Clone)]
888pub struct OptP4DPostingList {
889 pub blocks: Vec<OptP4DBlock>,
891 pub doc_count: u32,
893 pub max_score: f32,
895}
896
897impl OptP4DPostingList {
898 const K1: f32 = 1.2;
900 const B: f32 = 0.75;
901
902 #[inline]
904 fn compute_bm25f_upper_bound(max_tf: u32, idf: f32) -> f32 {
905 let tf = max_tf as f32;
906 let min_length_norm = 1.0 - Self::B;
907 let tf_norm = (tf * (Self::K1 + 1.0)) / (tf + Self::K1 * min_length_norm);
908 idf * tf_norm
909 }
910
911 pub fn from_postings(doc_ids: &[u32], term_freqs: &[u32], idf: f32) -> Self {
913 assert_eq!(doc_ids.len(), term_freqs.len());
914
915 if doc_ids.is_empty() {
916 return Self {
917 blocks: Vec::new(),
918 doc_count: 0,
919 max_score: 0.0,
920 };
921 }
922
923 let mut blocks = Vec::new();
924 let mut max_score = 0.0f32;
925 let mut i = 0;
926
927 while i < doc_ids.len() {
928 let block_end = (i + OPT_P4D_BLOCK_SIZE).min(doc_ids.len());
929 let block_docs = &doc_ids[i..block_end];
930 let block_tfs = &term_freqs[i..block_end];
931
932 let block = Self::create_block(block_docs, block_tfs, idf);
933 max_score = max_score.max(block.max_block_score);
934 blocks.push(block);
935
936 i = block_end;
937 }
938
939 Self {
940 blocks,
941 doc_count: doc_ids.len() as u32,
942 max_score,
943 }
944 }
945
946 fn create_block(doc_ids: &[u32], term_freqs: &[u32], idf: f32) -> OptP4DBlock {
947 let num_docs = doc_ids.len();
948 let first_doc_id = doc_ids[0];
949 let last_doc_id = *doc_ids.last().unwrap();
950
951 let mut deltas = Vec::with_capacity(num_docs.saturating_sub(1));
953 for j in 1..num_docs {
954 let delta = doc_ids[j] - doc_ids[j - 1] - 1;
955 deltas.push(delta);
956 }
957
958 let (doc_bit_width, _, _) = find_optimal_bit_width(&deltas);
960 let (doc_deltas, doc_exceptions) = pack_with_exceptions(&deltas, doc_bit_width);
961
962 let mut tfs = Vec::with_capacity(num_docs);
964 let mut max_tf = 0u32;
965
966 for &tf in term_freqs {
967 tfs.push(tf - 1); max_tf = max_tf.max(tf);
969 }
970
971 let (tf_bit_width, _, _) = find_optimal_bit_width(&tfs);
973 let (term_freqs_packed, tf_exceptions) = pack_with_exceptions(&tfs, tf_bit_width);
974
975 let max_block_score = Self::compute_bm25f_upper_bound(max_tf, idf);
977
978 OptP4DBlock {
979 first_doc_id,
980 last_doc_id,
981 num_docs: num_docs as u16,
982 doc_bit_width,
983 tf_bit_width,
984 max_tf,
985 max_block_score,
986 doc_deltas,
987 doc_exceptions,
988 term_freqs: term_freqs_packed,
989 tf_exceptions,
990 }
991 }
992
993 pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
995 writer.write_u32::<LittleEndian>(self.doc_count)?;
996 writer.write_f32::<LittleEndian>(self.max_score)?;
997 writer.write_u32::<LittleEndian>(self.blocks.len() as u32)?;
998
999 for block in &self.blocks {
1000 block.serialize(writer)?;
1001 }
1002
1003 Ok(())
1004 }
1005
1006 pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
1008 let doc_count = reader.read_u32::<LittleEndian>()?;
1009 let max_score = reader.read_f32::<LittleEndian>()?;
1010 let num_blocks = reader.read_u32::<LittleEndian>()? as usize;
1011
1012 let mut blocks = Vec::with_capacity(num_blocks);
1013 for _ in 0..num_blocks {
1014 blocks.push(OptP4DBlock::deserialize(reader)?);
1015 }
1016
1017 Ok(Self {
1018 blocks,
1019 doc_count,
1020 max_score,
1021 })
1022 }
1023
1024 pub fn len(&self) -> u32 {
1026 self.doc_count
1027 }
1028
1029 pub fn is_empty(&self) -> bool {
1031 self.doc_count == 0
1032 }
1033
1034 pub fn iterator(&self) -> OptP4DIterator<'_> {
1036 OptP4DIterator::new(self)
1037 }
1038}
1039
1040pub struct OptP4DIterator<'a> {
1042 posting_list: &'a OptP4DPostingList,
1043 current_block: usize,
1044 block_doc_ids: Vec<u32>,
1045 block_term_freqs: Vec<u32>,
1046 pos_in_block: usize,
1047 exhausted: bool,
1048}
1049
1050impl<'a> OptP4DIterator<'a> {
1051 pub fn new(posting_list: &'a OptP4DPostingList) -> Self {
1052 let mut iter = Self {
1053 posting_list,
1054 current_block: 0,
1055 block_doc_ids: Vec::new(),
1056 block_term_freqs: Vec::new(),
1057 pos_in_block: 0,
1058 exhausted: posting_list.blocks.is_empty(),
1059 };
1060
1061 if !iter.exhausted {
1062 iter.decode_current_block();
1063 }
1064
1065 iter
1066 }
1067
1068 fn decode_current_block(&mut self) {
1069 let block = &self.posting_list.blocks[self.current_block];
1070 self.block_doc_ids = block.decode_doc_ids();
1071 self.block_term_freqs = block.decode_term_freqs();
1072 self.pos_in_block = 0;
1073 }
1074
1075 pub fn doc(&self) -> u32 {
1077 if self.exhausted {
1078 u32::MAX
1079 } else {
1080 self.block_doc_ids[self.pos_in_block]
1081 }
1082 }
1083
1084 pub fn term_freq(&self) -> u32 {
1086 if self.exhausted {
1087 0
1088 } else {
1089 self.block_term_freqs[self.pos_in_block]
1090 }
1091 }
1092
1093 pub fn advance(&mut self) -> u32 {
1095 if self.exhausted {
1096 return u32::MAX;
1097 }
1098
1099 self.pos_in_block += 1;
1100
1101 if self.pos_in_block >= self.block_doc_ids.len() {
1102 self.current_block += 1;
1103 if self.current_block >= self.posting_list.blocks.len() {
1104 self.exhausted = true;
1105 return u32::MAX;
1106 }
1107 self.decode_current_block();
1108 }
1109
1110 self.doc()
1111 }
1112
1113 pub fn seek(&mut self, target: u32) -> u32 {
1115 if self.exhausted {
1116 return u32::MAX;
1117 }
1118
1119 while self.current_block < self.posting_list.blocks.len() {
1121 let block = &self.posting_list.blocks[self.current_block];
1122 if block.last_doc_id >= target {
1123 break;
1124 }
1125 self.current_block += 1;
1126 }
1127
1128 if self.current_block >= self.posting_list.blocks.len() {
1129 self.exhausted = true;
1130 return u32::MAX;
1131 }
1132
1133 if self.block_doc_ids.is_empty() || self.current_block != self.posting_list.blocks.len() - 1
1135 {
1136 self.decode_current_block();
1137 }
1138
1139 match self.block_doc_ids[self.pos_in_block..].binary_search(&target) {
1141 Ok(idx) => {
1142 self.pos_in_block += idx;
1143 }
1144 Err(idx) => {
1145 self.pos_in_block += idx;
1146 if self.pos_in_block >= self.block_doc_ids.len() {
1147 self.current_block += 1;
1149 if self.current_block >= self.posting_list.blocks.len() {
1150 self.exhausted = true;
1151 return u32::MAX;
1152 }
1153 self.decode_current_block();
1154 }
1155 }
1156 }
1157
1158 self.doc()
1159 }
1160}
1161
1162#[cfg(test)]
1163mod tests {
1164 use super::*;
1165
1166 #[test]
1167 fn test_bits_needed() {
1168 assert_eq!(bits_needed(0), 0);
1169 assert_eq!(bits_needed(1), 1);
1170 assert_eq!(bits_needed(2), 2);
1171 assert_eq!(bits_needed(3), 2);
1172 assert_eq!(bits_needed(4), 3);
1173 assert_eq!(bits_needed(255), 8);
1174 assert_eq!(bits_needed(256), 9);
1175 assert_eq!(bits_needed(u32::MAX), 32);
1176 }
1177
1178 #[test]
1179 fn test_find_optimal_bit_width() {
1180 let values = vec![0u32; 100];
1182 let (bits, exceptions, _) = find_optimal_bit_width(&values);
1183 assert_eq!(bits, 0);
1184 assert_eq!(exceptions, 0);
1185
1186 let values: Vec<u32> = (0..100).map(|i| i % 16).collect();
1188 let (bits, _, _) = find_optimal_bit_width(&values);
1189 assert!(bits <= 4);
1190
1191 let mut values: Vec<u32> = (0..100).map(|i| i % 16).collect();
1193 values[50] = 1_000_000; let (bits, exceptions, _) = find_optimal_bit_width(&values);
1195 assert!(bits < 20); assert!(exceptions >= 1);
1197 }
1198
1199 #[test]
1200 fn test_pack_unpack_with_exceptions() {
1201 let values = vec![1, 2, 3, 255, 4, 5, 1000, 6, 7, 8];
1202 let (packed, exceptions) = pack_with_exceptions(&values, 4);
1203
1204 let mut output = vec![0u32; values.len()];
1205 unpack_with_exceptions(&packed, 4, &exceptions, values.len(), &mut output);
1206
1207 assert_eq!(output, values);
1208 }
1209
1210 #[test]
1211 fn test_opt_p4d_posting_list_small() {
1212 let doc_ids: Vec<u32> = (0..100).map(|i| i * 2).collect();
1213 let term_freqs: Vec<u32> = vec![1; 100];
1214
1215 let list = OptP4DPostingList::from_postings(&doc_ids, &term_freqs, 1.0);
1216
1217 assert_eq!(list.len(), 100);
1218 assert_eq!(list.blocks.len(), 1);
1219
1220 let mut iter = list.iterator();
1222 for (i, &expected) in doc_ids.iter().enumerate() {
1223 assert_eq!(iter.doc(), expected, "Mismatch at {}", i);
1224 assert_eq!(iter.term_freq(), 1);
1225 iter.advance();
1226 }
1227 assert_eq!(iter.doc(), u32::MAX);
1228 }
1229
1230 #[test]
1231 fn test_opt_p4d_posting_list_large() {
1232 let doc_ids: Vec<u32> = (0..500).map(|i| i * 3).collect();
1233 let term_freqs: Vec<u32> = (0..500).map(|i| (i % 10) + 1).collect();
1234
1235 let list = OptP4DPostingList::from_postings(&doc_ids, &term_freqs, 1.0);
1236
1237 assert_eq!(list.len(), 500);
1238 assert_eq!(list.blocks.len(), 4); let mut iter = list.iterator();
1242 for (i, &expected) in doc_ids.iter().enumerate() {
1243 assert_eq!(iter.doc(), expected, "Mismatch at {}", i);
1244 assert_eq!(iter.term_freq(), term_freqs[i]);
1245 iter.advance();
1246 }
1247 }
1248
1249 #[test]
1250 fn test_opt_p4d_seek() {
1251 let doc_ids: Vec<u32> = vec![10, 20, 30, 100, 200, 300, 1000, 2000];
1252 let term_freqs: Vec<u32> = vec![1; 8];
1253
1254 let list = OptP4DPostingList::from_postings(&doc_ids, &term_freqs, 1.0);
1255 let mut iter = list.iterator();
1256
1257 assert_eq!(iter.seek(25), 30);
1258 assert_eq!(iter.seek(100), 100);
1259 assert_eq!(iter.seek(500), 1000);
1260 assert_eq!(iter.seek(3000), u32::MAX);
1261 }
1262
1263 #[test]
1264 fn test_opt_p4d_serialization() {
1265 let doc_ids: Vec<u32> = (0..200).map(|i| i * 5).collect();
1266 let term_freqs: Vec<u32> = (0..200).map(|i| (i % 5) + 1).collect();
1267
1268 let list = OptP4DPostingList::from_postings(&doc_ids, &term_freqs, 1.0);
1269
1270 let mut buffer = Vec::new();
1271 list.serialize(&mut buffer).unwrap();
1272
1273 let restored = OptP4DPostingList::deserialize(&mut &buffer[..]).unwrap();
1274
1275 assert_eq!(restored.len(), list.len());
1276 assert_eq!(restored.blocks.len(), list.blocks.len());
1277
1278 let mut iter1 = list.iterator();
1280 let mut iter2 = restored.iterator();
1281
1282 while iter1.doc() != u32::MAX {
1283 assert_eq!(iter1.doc(), iter2.doc());
1284 assert_eq!(iter1.term_freq(), iter2.term_freq());
1285 iter1.advance();
1286 iter2.advance();
1287 }
1288 }
1289
1290 #[test]
1291 fn test_opt_p4d_with_outliers() {
1292 let mut doc_ids: Vec<u32> = (0..128).map(|i| i * 2).collect();
1294 doc_ids[64] = 1_000_000; doc_ids.sort();
1298
1299 let term_freqs: Vec<u32> = vec![1; 128];
1300
1301 let list = OptP4DPostingList::from_postings(&doc_ids, &term_freqs, 1.0);
1302
1303 let mut iter = list.iterator();
1305 let mut found_outlier = false;
1306 while iter.doc() != u32::MAX {
1307 if iter.doc() == 1_000_000 {
1308 found_outlier = true;
1309 }
1310 iter.advance();
1311 }
1312 assert!(found_outlier, "Outlier value should be preserved");
1313 }
1314
1315 #[test]
1316 fn test_opt_p4d_simd_full_blocks() {
1317 let doc_ids: Vec<u32> = (0..1024).map(|i| i * 2).collect();
1319 let term_freqs: Vec<u32> = (0..1024).map(|i| (i % 20) + 1).collect();
1320
1321 let list = OptP4DPostingList::from_postings(&doc_ids, &term_freqs, 1.0);
1322
1323 assert_eq!(list.len(), 1024);
1324 assert_eq!(list.blocks.len(), 8); let mut iter = list.iterator();
1328 for (i, &expected_doc) in doc_ids.iter().enumerate() {
1329 assert_eq!(iter.doc(), expected_doc, "Doc mismatch at {}", i);
1330 assert_eq!(iter.term_freq(), term_freqs[i], "TF mismatch at {}", i);
1331 iter.advance();
1332 }
1333 assert_eq!(iter.doc(), u32::MAX);
1334 }
1335
1336 #[test]
1337 fn test_opt_p4d_simd_8bit_values() {
1338 let doc_ids: Vec<u32> = (0..256).collect();
1340 let term_freqs: Vec<u32> = (0..256).map(|i| (i % 100) + 1).collect();
1341
1342 let list = OptP4DPostingList::from_postings(&doc_ids, &term_freqs, 1.0);
1343
1344 let mut iter = list.iterator();
1346 for (i, &expected_doc) in doc_ids.iter().enumerate() {
1347 assert_eq!(iter.doc(), expected_doc, "Doc mismatch at {}", i);
1348 assert_eq!(iter.term_freq(), term_freqs[i], "TF mismatch at {}", i);
1349 iter.advance();
1350 }
1351 }
1352
1353 #[test]
1354 fn test_opt_p4d_simd_delta_decode() {
1355 let mut doc_ids = Vec::with_capacity(512);
1357 let mut current = 0u32;
1358 for i in 0..512 {
1359 current += (i % 10) + 1; doc_ids.push(current);
1361 }
1362 let term_freqs: Vec<u32> = vec![1; 512];
1363
1364 let list = OptP4DPostingList::from_postings(&doc_ids, &term_freqs, 1.0);
1365
1366 let mut iter = list.iterator();
1368 for (i, &expected_doc) in doc_ids.iter().enumerate() {
1369 assert_eq!(
1370 iter.doc(),
1371 expected_doc,
1372 "Doc mismatch at {} (expected {}, got {})",
1373 i,
1374 expected_doc,
1375 iter.doc()
1376 );
1377 iter.advance();
1378 }
1379 }
1380}