1use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
14use std::io::{self, Read, Write};
15
16pub const SIMD_BLOCK_SIZE: usize = 128;
18
19#[allow(dead_code)]
21const SIMD_LANES: usize = 4;
22
23#[allow(dead_code)]
25const GROUPS_PER_BLOCK: usize = SIMD_BLOCK_SIZE / SIMD_LANES;
26
27#[inline]
29pub fn bits_needed(max_val: u32) -> u8 {
30 if max_val == 0 {
31 0
32 } else {
33 32 - max_val.leading_zeros() as u8
34 }
35}
36
37#[cfg(target_arch = "aarch64")]
42#[allow(dead_code)]
43mod neon {
44 use super::*;
45 use std::arch::aarch64::*;
46
47 static BIT_EXPAND_LUT: [[u32; 8]; 256] = {
51 let mut lut = [[0u32; 8]; 256];
52 let mut byte = 0usize;
53 while byte < 256 {
54 let mut bit = 0;
55 while bit < 8 {
56 lut[byte][bit] = ((byte >> bit) & 1) as u32;
57 bit += 1;
58 }
59 byte += 1;
60 }
61 lut
62 };
63
64 #[inline]
66 #[target_feature(enable = "neon")]
67 pub unsafe fn unpack_4_neon(input: &[u8], bit_width: u8, output: &mut [u32; 4]) {
68 if bit_width == 0 {
69 *output = [0; 4];
70 return;
71 }
72
73 let mask = (1u32 << bit_width) - 1;
74
75 let mut packed_bytes = [0u8; 16];
77 let bytes_needed = ((bit_width as usize) * 4).div_ceil(8);
78 packed_bytes[..bytes_needed.min(16)].copy_from_slice(&input[..bytes_needed.min(16)]);
79 let packed = u128::from_le_bytes(packed_bytes);
80
81 let v0 = (packed & mask as u128) as u32;
83 let v1 = ((packed >> bit_width) & mask as u128) as u32;
84 let v2 = ((packed >> (bit_width * 2)) & mask as u128) as u32;
85 let v3 = ((packed >> (bit_width * 3)) & mask as u128) as u32;
86
87 unsafe {
89 let result = vld1q_u32([v0, v1, v2, v3].as_ptr());
90 vst1q_u32(output.as_mut_ptr(), result);
91 }
92 }
93
94 #[inline]
96 #[target_feature(enable = "neon")]
97 pub unsafe fn prefix_sum_4_neon(values: &mut [u32; 4]) {
98 unsafe {
99 let mut v = vld1q_u32(values.as_ptr());
101
102 let shifted1 = vextq_u32(vdupq_n_u32(0), v, 3); v = vaddq_u32(v, shifted1);
107 let shifted2 = vextq_u32(vdupq_n_u32(0), v, 2); v = vaddq_u32(v, shifted2);
112
113 vst1q_u32(values.as_mut_ptr(), v);
115 }
116 }
117
118 #[target_feature(enable = "neon")]
125 pub unsafe fn unpack_block_neon(
126 input: &[u8],
127 bit_width: u8,
128 output: &mut [u32; SIMD_BLOCK_SIZE],
129 ) {
130 if bit_width == 0 {
131 output.fill(0);
132 return;
133 }
134
135 unsafe {
137 let zero = vdupq_n_u32(0);
138 for i in (0..SIMD_BLOCK_SIZE).step_by(4) {
139 vst1q_u32(output[i..].as_mut_ptr(), zero);
140 }
141 }
142
143 for bit_pos in 0..bit_width as usize {
145 let byte_offset = bit_pos * 16;
146 let bit_mask = 1u32 << bit_pos;
147
148 if bit_pos + 1 < bit_width as usize {
150 let next_offset = (bit_pos + 1) * 16;
151 unsafe {
152 std::arch::asm!(
154 "prfm pldl1keep, [{0}]",
155 in(reg) input.as_ptr().add(next_offset),
156 options(nostack, preserves_flags)
157 );
158 }
159 }
160
161 for chunk in 0..4 {
163 let chunk_offset = byte_offset + chunk * 4;
164
165 let b0 = input[chunk_offset] as usize;
167 let b1 = input[chunk_offset + 1] as usize;
168 let b2 = input[chunk_offset + 2] as usize;
169 let b3 = input[chunk_offset + 3] as usize;
170
171 let base_int = chunk * 32;
172
173 unsafe {
174 let mask_vec = vdupq_n_u32(bit_mask);
175
176 let lut0 = &BIT_EXPAND_LUT[b0];
178 let bits_0_3 = vld1q_u32(lut0.as_ptr());
179 let bits_4_7 = vld1q_u32(lut0[4..].as_ptr());
180
181 let shifted_0_3 = vmulq_u32(bits_0_3, mask_vec);
182 let shifted_4_7 = vmulq_u32(bits_4_7, mask_vec);
183
184 let cur_0_3 = vld1q_u32(output[base_int..].as_ptr());
185 let cur_4_7 = vld1q_u32(output[base_int + 4..].as_ptr());
186
187 vst1q_u32(
188 output[base_int..].as_mut_ptr(),
189 vorrq_u32(cur_0_3, shifted_0_3),
190 );
191 vst1q_u32(
192 output[base_int + 4..].as_mut_ptr(),
193 vorrq_u32(cur_4_7, shifted_4_7),
194 );
195
196 let lut1 = &BIT_EXPAND_LUT[b1];
198 let bits_8_11 = vld1q_u32(lut1.as_ptr());
199 let bits_12_15 = vld1q_u32(lut1[4..].as_ptr());
200
201 let shifted_8_11 = vmulq_u32(bits_8_11, mask_vec);
202 let shifted_12_15 = vmulq_u32(bits_12_15, mask_vec);
203
204 let cur_8_11 = vld1q_u32(output[base_int + 8..].as_ptr());
205 let cur_12_15 = vld1q_u32(output[base_int + 12..].as_ptr());
206
207 vst1q_u32(
208 output[base_int + 8..].as_mut_ptr(),
209 vorrq_u32(cur_8_11, shifted_8_11),
210 );
211 vst1q_u32(
212 output[base_int + 12..].as_mut_ptr(),
213 vorrq_u32(cur_12_15, shifted_12_15),
214 );
215
216 let lut2 = &BIT_EXPAND_LUT[b2];
218 let bits_16_19 = vld1q_u32(lut2.as_ptr());
219 let bits_20_23 = vld1q_u32(lut2[4..].as_ptr());
220
221 let shifted_16_19 = vmulq_u32(bits_16_19, mask_vec);
222 let shifted_20_23 = vmulq_u32(bits_20_23, mask_vec);
223
224 let cur_16_19 = vld1q_u32(output[base_int + 16..].as_ptr());
225 let cur_20_23 = vld1q_u32(output[base_int + 20..].as_ptr());
226
227 vst1q_u32(
228 output[base_int + 16..].as_mut_ptr(),
229 vorrq_u32(cur_16_19, shifted_16_19),
230 );
231 vst1q_u32(
232 output[base_int + 20..].as_mut_ptr(),
233 vorrq_u32(cur_20_23, shifted_20_23),
234 );
235
236 let lut3 = &BIT_EXPAND_LUT[b3];
238 let bits_24_27 = vld1q_u32(lut3.as_ptr());
239 let bits_28_31 = vld1q_u32(lut3[4..].as_ptr());
240
241 let shifted_24_27 = vmulq_u32(bits_24_27, mask_vec);
242 let shifted_28_31 = vmulq_u32(bits_28_31, mask_vec);
243
244 let cur_24_27 = vld1q_u32(output[base_int + 24..].as_ptr());
245 let cur_28_31 = vld1q_u32(output[base_int + 28..].as_ptr());
246
247 vst1q_u32(
248 output[base_int + 24..].as_mut_ptr(),
249 vorrq_u32(cur_24_27, shifted_24_27),
250 );
251 vst1q_u32(
252 output[base_int + 28..].as_mut_ptr(),
253 vorrq_u32(cur_28_31, shifted_28_31),
254 );
255 }
256 }
257 }
258 }
259
260 #[target_feature(enable = "neon")]
262 pub unsafe fn prefix_sum_block_neon(deltas: &mut [u32; SIMD_BLOCK_SIZE], first_val: u32) {
263 let mut carry = first_val;
264
265 for group in 0..GROUPS_PER_BLOCK {
266 let start = group * SIMD_LANES;
267 let mut group_vals = [
268 deltas[start],
269 deltas[start + 1],
270 deltas[start + 2],
271 deltas[start + 3],
272 ];
273
274 group_vals[0] = group_vals[0].wrapping_add(carry);
276
277 unsafe { prefix_sum_4_neon(&mut group_vals) };
279
280 deltas[start..start + 4].copy_from_slice(&group_vals);
282
283 carry = group_vals[3];
285 }
286 }
287}
288
289#[allow(dead_code)]
294mod scalar {
295 use super::*;
296
297 #[inline]
299 pub fn pack_4_scalar(values: &[u32; 4], bit_width: u8, output: &mut [u8]) {
300 if bit_width == 0 {
301 return;
302 }
303
304 let bytes_needed = ((bit_width as usize) * 4).div_ceil(8);
305 let mut packed = 0u128;
306 for (i, &val) in values.iter().enumerate() {
307 packed |= (val as u128) << (i * bit_width as usize);
308 }
309
310 let packed_bytes = packed.to_le_bytes();
311 output[..bytes_needed].copy_from_slice(&packed_bytes[..bytes_needed]);
312 }
313
314 #[inline]
316 pub fn unpack_4_scalar(input: &[u8], bit_width: u8, output: &mut [u32; 4]) {
317 if bit_width == 0 {
318 *output = [0; 4];
319 return;
320 }
321
322 let mask = (1u32 << bit_width) - 1;
323 let mut packed_bytes = [0u8; 16];
324 let bytes_needed = ((bit_width as usize) * 4).div_ceil(8);
325 packed_bytes[..bytes_needed.min(16)].copy_from_slice(&input[..bytes_needed.min(16)]);
326 let packed = u128::from_le_bytes(packed_bytes);
327
328 output[0] = (packed & mask as u128) as u32;
329 output[1] = ((packed >> bit_width) & mask as u128) as u32;
330 output[2] = ((packed >> (bit_width * 2)) & mask as u128) as u32;
331 output[3] = ((packed >> (bit_width * 3)) & mask as u128) as u32;
332 }
333
334 #[inline]
336 pub fn prefix_sum_4_scalar(vals: &mut [u32; 4]) {
337 vals[1] = vals[1].wrapping_add(vals[0]);
338 vals[2] = vals[2].wrapping_add(vals[1]);
339 vals[3] = vals[3].wrapping_add(vals[2]);
340 }
341
342 pub fn unpack_block_scalar(input: &[u8], bit_width: u8, output: &mut [u32; SIMD_BLOCK_SIZE]) {
344 if bit_width == 0 {
345 output.fill(0);
346 return;
347 }
348
349 output.fill(0);
351
352 for bit_pos in 0..bit_width as usize {
354 let byte_offset = bit_pos * 16; for byte_idx in 0..16 {
357 let byte_val = input[byte_offset + byte_idx];
358 let base_int = byte_idx * 8;
359
360 output[base_int] |= (byte_val & 1) as u32 * (1 << bit_pos);
362 output[base_int + 1] |= ((byte_val >> 1) & 1) as u32 * (1 << bit_pos);
363 output[base_int + 2] |= ((byte_val >> 2) & 1) as u32 * (1 << bit_pos);
364 output[base_int + 3] |= ((byte_val >> 3) & 1) as u32 * (1 << bit_pos);
365 output[base_int + 4] |= ((byte_val >> 4) & 1) as u32 * (1 << bit_pos);
366 output[base_int + 5] |= ((byte_val >> 5) & 1) as u32 * (1 << bit_pos);
367 output[base_int + 6] |= ((byte_val >> 6) & 1) as u32 * (1 << bit_pos);
368 output[base_int + 7] |= ((byte_val >> 7) & 1) as u32 * (1 << bit_pos);
369 }
370 }
371 }
372
373 pub fn prefix_sum_block_scalar(deltas: &mut [u32; SIMD_BLOCK_SIZE], first_val: u32) {
375 let mut carry = first_val;
376
377 for group in 0..GROUPS_PER_BLOCK {
378 let start = group * SIMD_LANES;
379 let mut group_vals = [
380 deltas[start],
381 deltas[start + 1],
382 deltas[start + 2],
383 deltas[start + 3],
384 ];
385
386 group_vals[0] = group_vals[0].wrapping_add(carry);
387 prefix_sum_4_scalar(&mut group_vals);
388 deltas[start..start + 4].copy_from_slice(&group_vals);
389 carry = group_vals[3];
390 }
391 }
392}
393
394pub fn pack_horizontal(values: &[u32; SIMD_BLOCK_SIZE], bit_width: u8, output: &mut Vec<u8>) {
403 if bit_width == 0 {
404 return;
405 }
406
407 let total_bytes = (SIMD_BLOCK_SIZE * bit_width as usize) / 8;
409 let start = output.len();
410 output.resize(start + total_bytes, 0);
411
412 for bit_pos in 0..bit_width as usize {
415 let byte_offset = start + bit_pos * (SIMD_BLOCK_SIZE / 8);
416 for (int_idx, &val) in values.iter().enumerate() {
417 let bit = (val >> bit_pos) & 1;
418 let byte_idx = byte_offset + int_idx / 8;
419 let bit_in_byte = int_idx % 8;
420 output[byte_idx] |= (bit as u8) << bit_in_byte;
421 }
422 }
423}
424
425pub fn unpack_horizontal(input: &[u8], bit_width: u8, output: &mut [u32; SIMD_BLOCK_SIZE]) {
430 if bit_width == 0 {
431 output.fill(0);
432 return;
433 }
434
435 output.fill(0);
437
438 for bit_pos in 0..bit_width as usize {
441 let byte_offset = bit_pos * 16; for byte_idx in 0..16 {
445 let byte_val = input[byte_offset + byte_idx];
446 let base_int = byte_idx * 8;
447
448 output[base_int] |= (byte_val & 1) as u32 * (1 << bit_pos);
450 output[base_int + 1] |= ((byte_val >> 1) & 1) as u32 * (1 << bit_pos);
451 output[base_int + 2] |= ((byte_val >> 2) & 1) as u32 * (1 << bit_pos);
452 output[base_int + 3] |= ((byte_val >> 3) & 1) as u32 * (1 << bit_pos);
453 output[base_int + 4] |= ((byte_val >> 4) & 1) as u32 * (1 << bit_pos);
454 output[base_int + 5] |= ((byte_val >> 5) & 1) as u32 * (1 << bit_pos);
455 output[base_int + 6] |= ((byte_val >> 6) & 1) as u32 * (1 << bit_pos);
456 output[base_int + 7] |= ((byte_val >> 7) & 1) as u32 * (1 << bit_pos);
457 }
458 }
459}
460
461#[allow(dead_code)]
463pub fn prefix_sum_128(deltas: &mut [u32; SIMD_BLOCK_SIZE], first_val: u32) {
464 #[cfg(target_arch = "aarch64")]
465 {
466 unsafe { neon::prefix_sum_block_neon(deltas, first_val) }
467 }
468
469 #[cfg(not(target_arch = "aarch64"))]
470 {
471 scalar::prefix_sum_block_scalar(deltas, first_val)
472 }
473}
474
475pub fn pack_vertical(values: &[u32; SIMD_BLOCK_SIZE], bit_width: u8, output: &mut Vec<u8>) {
477 pack_horizontal(values, bit_width, output)
478}
479
480pub fn unpack_vertical(input: &[u8], bit_width: u8, output: &mut [u32; SIMD_BLOCK_SIZE]) {
481 unpack_horizontal(input, bit_width, output)
482}
483
484pub fn unpack_vertical_d1(
493 input: &[u8],
494 bit_width: u8,
495 first_doc_id: u32,
496 output: &mut [u32; SIMD_BLOCK_SIZE],
497 count: usize,
498) {
499 if count == 0 {
500 return;
501 }
502
503 if bit_width == 0 {
504 let mut current = first_doc_id;
506 output[0] = current;
507 for out_val in output.iter_mut().take(count).skip(1) {
508 current = current.wrapping_add(1);
509 *out_val = current;
510 }
511 return;
512 }
513
514 output[0] = first_doc_id;
516 let mut current = first_doc_id;
517
518 let full_groups = (count - 1) / 4;
520 let remainder = (count - 1) % 4;
521
522 for group in 0..full_groups {
523 let base_idx = group * 4;
524
525 let mut deltas = [0u32; 4];
527 for bit_pos in 0..bit_width as usize {
528 let byte_offset = bit_pos * (SIMD_BLOCK_SIZE / 8);
529 for (j, delta) in deltas.iter_mut().enumerate() {
530 let int_idx = base_idx + j;
531 let byte_idx = byte_offset + int_idx / 8;
532 let bit_in_byte = int_idx % 8;
533 let bit = ((input[byte_idx] >> bit_in_byte) & 1) as u32;
534 *delta |= bit << bit_pos;
535 }
536 }
537
538 for j in 0..4 {
540 current = current.wrapping_add(deltas[j]).wrapping_add(1);
541 output[base_idx + j + 1] = current;
542 }
543 }
544
545 let base_idx = full_groups * 4;
547 for j in 0..remainder {
548 let int_idx = base_idx + j;
549 let mut delta = 0u32;
550 for bit_pos in 0..bit_width as usize {
551 let byte_offset = bit_pos * (SIMD_BLOCK_SIZE / 8);
552 let byte_idx = byte_offset + int_idx / 8;
553 let bit_in_byte = int_idx % 8;
554 let bit = ((input[byte_idx] >> bit_in_byte) & 1) as u32;
555 delta |= bit << bit_pos;
556 }
557 current = current.wrapping_add(delta).wrapping_add(1);
558 output[base_idx + j + 1] = current;
559 }
560}
561
562#[derive(Debug, Clone)]
564pub struct SimdBp128Block {
565 pub doc_data: Vec<u8>,
567 pub doc_bit_width: u8,
569 pub tf_data: Vec<u8>,
571 pub tf_bit_width: u8,
573 pub first_doc_id: u32,
575 pub last_doc_id: u32,
577 pub num_docs: u16,
579 pub max_tf: u32,
581 pub max_block_score: f32,
583}
584
585impl SimdBp128Block {
586 pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
588 writer.write_u32::<LittleEndian>(self.first_doc_id)?;
589 writer.write_u32::<LittleEndian>(self.last_doc_id)?;
590 writer.write_u16::<LittleEndian>(self.num_docs)?;
591 writer.write_u8(self.doc_bit_width)?;
592 writer.write_u8(self.tf_bit_width)?;
593 writer.write_u32::<LittleEndian>(self.max_tf)?;
594 writer.write_f32::<LittleEndian>(self.max_block_score)?;
595
596 writer.write_u16::<LittleEndian>(self.doc_data.len() as u16)?;
597 writer.write_all(&self.doc_data)?;
598
599 writer.write_u16::<LittleEndian>(self.tf_data.len() as u16)?;
600 writer.write_all(&self.tf_data)?;
601
602 Ok(())
603 }
604
605 pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
607 let first_doc_id = reader.read_u32::<LittleEndian>()?;
608 let last_doc_id = reader.read_u32::<LittleEndian>()?;
609 let num_docs = reader.read_u16::<LittleEndian>()?;
610 let doc_bit_width = reader.read_u8()?;
611 let tf_bit_width = reader.read_u8()?;
612 let max_tf = reader.read_u32::<LittleEndian>()?;
613 let max_block_score = reader.read_f32::<LittleEndian>()?;
614
615 let doc_len = reader.read_u16::<LittleEndian>()? as usize;
616 let mut doc_data = vec![0u8; doc_len];
617 reader.read_exact(&mut doc_data)?;
618
619 let tf_len = reader.read_u16::<LittleEndian>()? as usize;
620 let mut tf_data = vec![0u8; tf_len];
621 reader.read_exact(&mut tf_data)?;
622
623 Ok(Self {
624 doc_data,
625 doc_bit_width,
626 tf_data,
627 tf_bit_width,
628 first_doc_id,
629 last_doc_id,
630 num_docs,
631 max_tf,
632 max_block_score,
633 })
634 }
635
636 pub fn decode_doc_ids(&self) -> Vec<u32> {
638 if self.num_docs == 0 {
639 return Vec::new();
640 }
641
642 let mut output = [0u32; SIMD_BLOCK_SIZE];
643 unpack_vertical_d1(
644 &self.doc_data,
645 self.doc_bit_width,
646 self.first_doc_id,
647 &mut output,
648 self.num_docs as usize,
649 );
650
651 output[..self.num_docs as usize].to_vec()
652 }
653
654 pub fn decode_term_freqs(&self) -> Vec<u32> {
656 if self.num_docs == 0 {
657 return Vec::new();
658 }
659
660 let mut output = [0u32; SIMD_BLOCK_SIZE];
661 unpack_vertical(&self.tf_data, self.tf_bit_width, &mut output);
662
663 output[..self.num_docs as usize]
665 .iter()
666 .map(|&tf| tf + 1)
667 .collect()
668 }
669}
670
671#[derive(Debug, Clone)]
673pub struct SimdBp128PostingList {
674 pub blocks: Vec<SimdBp128Block>,
676 pub doc_count: u32,
678 pub max_score: f32,
680}
681
682impl SimdBp128PostingList {
683 const K1: f32 = 1.2;
685 const B: f32 = 0.75;
686
687 #[inline]
689 pub fn compute_bm25_upper_bound(max_tf: u32, idf: f32) -> f32 {
690 let tf = max_tf as f32;
691 let min_length_norm = 1.0 - Self::B;
692 let tf_norm = (tf * (Self::K1 + 1.0)) / (tf + Self::K1 * min_length_norm);
693 idf * tf_norm
694 }
695
696 pub fn from_postings(doc_ids: &[u32], term_freqs: &[u32], idf: f32) -> Self {
698 assert_eq!(doc_ids.len(), term_freqs.len());
699
700 if doc_ids.is_empty() {
701 return Self {
702 blocks: Vec::new(),
703 doc_count: 0,
704 max_score: 0.0,
705 };
706 }
707
708 let mut blocks = Vec::new();
709 let mut max_score = 0.0f32;
710 let mut i = 0;
711
712 while i < doc_ids.len() {
713 let block_end = (i + SIMD_BLOCK_SIZE).min(doc_ids.len());
714 let block_docs = &doc_ids[i..block_end];
715 let block_tfs = &term_freqs[i..block_end];
716
717 let block = Self::create_block(block_docs, block_tfs, idf);
718 max_score = max_score.max(block.max_block_score);
719 blocks.push(block);
720
721 i = block_end;
722 }
723
724 Self {
725 blocks,
726 doc_count: doc_ids.len() as u32,
727 max_score,
728 }
729 }
730
731 fn create_block(doc_ids: &[u32], term_freqs: &[u32], idf: f32) -> SimdBp128Block {
732 let num_docs = doc_ids.len();
733 let first_doc_id = doc_ids[0];
734 let last_doc_id = *doc_ids.last().unwrap();
735
736 let mut deltas = [0u32; SIMD_BLOCK_SIZE];
738 let mut max_delta = 0u32;
739 for j in 1..num_docs {
740 let delta = doc_ids[j] - doc_ids[j - 1] - 1;
741 deltas[j - 1] = delta;
742 max_delta = max_delta.max(delta);
743 }
744
745 let mut tfs = [0u32; SIMD_BLOCK_SIZE];
747 let mut max_tf = 0u32;
748 for (j, &tf) in term_freqs.iter().enumerate() {
749 tfs[j] = tf.saturating_sub(1);
750 max_tf = max_tf.max(tf);
751 }
752
753 let doc_bit_width = bits_needed(max_delta);
754 let tf_bit_width = bits_needed(max_tf.saturating_sub(1));
755
756 let mut doc_data = Vec::new();
757 pack_vertical(&deltas, doc_bit_width, &mut doc_data);
758
759 let mut tf_data = Vec::new();
760 pack_vertical(&tfs, tf_bit_width, &mut tf_data);
761
762 let max_block_score = Self::compute_bm25_upper_bound(max_tf, idf);
763
764 SimdBp128Block {
765 doc_data,
766 doc_bit_width,
767 tf_data,
768 tf_bit_width,
769 first_doc_id,
770 last_doc_id,
771 num_docs: num_docs as u16,
772 max_tf,
773 max_block_score,
774 }
775 }
776
777 pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
779 writer.write_u32::<LittleEndian>(self.doc_count)?;
780 writer.write_f32::<LittleEndian>(self.max_score)?;
781 writer.write_u32::<LittleEndian>(self.blocks.len() as u32)?;
782
783 for block in &self.blocks {
784 block.serialize(writer)?;
785 }
786
787 Ok(())
788 }
789
790 pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
792 let doc_count = reader.read_u32::<LittleEndian>()?;
793 let max_score = reader.read_f32::<LittleEndian>()?;
794 let num_blocks = reader.read_u32::<LittleEndian>()? as usize;
795
796 let mut blocks = Vec::with_capacity(num_blocks);
797 for _ in 0..num_blocks {
798 blocks.push(SimdBp128Block::deserialize(reader)?);
799 }
800
801 Ok(Self {
802 blocks,
803 doc_count,
804 max_score,
805 })
806 }
807
808 pub fn iterator(&self) -> SimdBp128Iterator<'_> {
810 SimdBp128Iterator::new(self)
811 }
812
813 pub fn size_bytes(&self) -> usize {
815 let mut size = 12; for block in &self.blocks {
817 size += 22 + block.doc_data.len() + block.tf_data.len();
818 }
819 size
820 }
821}
822
823pub struct SimdBp128Iterator<'a> {
825 list: &'a SimdBp128PostingList,
826 current_block: usize,
827 block_doc_ids: Vec<u32>,
828 block_term_freqs: Vec<u32>,
829 pos_in_block: usize,
830 exhausted: bool,
831}
832
833impl<'a> SimdBp128Iterator<'a> {
834 pub fn new(list: &'a SimdBp128PostingList) -> Self {
835 let mut iter = Self {
836 list,
837 current_block: 0,
838 block_doc_ids: Vec::new(),
839 block_term_freqs: Vec::new(),
840 pos_in_block: 0,
841 exhausted: list.blocks.is_empty(),
842 };
843
844 if !iter.exhausted {
845 iter.decode_current_block();
846 }
847
848 iter
849 }
850
851 fn decode_current_block(&mut self) {
852 let block = &self.list.blocks[self.current_block];
853 self.block_doc_ids = block.decode_doc_ids();
854 self.block_term_freqs = block.decode_term_freqs();
855 self.pos_in_block = 0;
856 }
857
858 pub fn doc(&self) -> u32 {
860 if self.exhausted {
861 u32::MAX
862 } else {
863 self.block_doc_ids[self.pos_in_block]
864 }
865 }
866
867 pub fn term_freq(&self) -> u32 {
869 if self.exhausted {
870 0
871 } else {
872 self.block_term_freqs[self.pos_in_block]
873 }
874 }
875
876 pub fn advance(&mut self) -> u32 {
878 if self.exhausted {
879 return u32::MAX;
880 }
881
882 self.pos_in_block += 1;
883
884 if self.pos_in_block >= self.block_doc_ids.len() {
885 self.current_block += 1;
886 if self.current_block >= self.list.blocks.len() {
887 self.exhausted = true;
888 return u32::MAX;
889 }
890 self.decode_current_block();
891 }
892
893 self.doc()
894 }
895
896 pub fn seek(&mut self, target: u32) -> u32 {
898 if self.exhausted {
899 return u32::MAX;
900 }
901
902 let block_idx = self.list.blocks[self.current_block..].binary_search_by(|block| {
904 if block.last_doc_id < target {
905 std::cmp::Ordering::Less
906 } else if block.first_doc_id > target {
907 std::cmp::Ordering::Greater
908 } else {
909 std::cmp::Ordering::Equal
910 }
911 });
912
913 let target_block = match block_idx {
914 Ok(idx) => self.current_block + idx,
915 Err(idx) => {
916 if self.current_block + idx >= self.list.blocks.len() {
917 self.exhausted = true;
918 return u32::MAX;
919 }
920 self.current_block + idx
921 }
922 };
923
924 if target_block != self.current_block {
925 self.current_block = target_block;
926 self.decode_current_block();
927 }
928
929 let pos = self.block_doc_ids[self.pos_in_block..]
931 .binary_search(&target)
932 .unwrap_or_else(|x| x);
933 self.pos_in_block += pos;
934
935 if self.pos_in_block >= self.block_doc_ids.len() {
936 self.current_block += 1;
937 if self.current_block >= self.list.blocks.len() {
938 self.exhausted = true;
939 return u32::MAX;
940 }
941 self.decode_current_block();
942 }
943
944 self.doc()
945 }
946
947 pub fn max_remaining_score(&self) -> f32 {
949 if self.exhausted {
950 return 0.0;
951 }
952 self.list.blocks[self.current_block..]
953 .iter()
954 .map(|b| b.max_block_score)
955 .fold(0.0f32, |a, b| a.max(b))
956 }
957
958 pub fn current_block_max_score(&self) -> f32 {
960 if self.exhausted {
961 0.0
962 } else {
963 self.list.blocks[self.current_block].max_block_score
964 }
965 }
966
967 pub fn current_block_max_tf(&self) -> u32 {
969 if self.exhausted {
970 0
971 } else {
972 self.list.blocks[self.current_block].max_tf
973 }
974 }
975
976 pub fn skip_to_block_with_doc(&mut self, target: u32) -> Option<(u32, f32)> {
979 while self.current_block < self.list.blocks.len() {
980 let block = &self.list.blocks[self.current_block];
981 if block.last_doc_id >= target {
982 self.decode_current_block();
984 return Some((block.first_doc_id, block.max_block_score));
985 }
986 self.current_block += 1;
987 }
988 self.exhausted = true;
989 None
990 }
991
992 pub fn is_exhausted(&self) -> bool {
994 self.exhausted
995 }
996}
997
998#[cfg(test)]
999mod tests {
1000 use super::*;
1001
1002 #[test]
1003 fn test_pack_unpack_vertical() {
1004 let mut values = [0u32; SIMD_BLOCK_SIZE];
1005 for (i, v) in values.iter_mut().enumerate() {
1006 *v = (i * 3) as u32;
1007 }
1008
1009 let max_val = values.iter().max().copied().unwrap();
1010 let bit_width = bits_needed(max_val);
1011
1012 let mut packed = Vec::new();
1013 pack_vertical(&values, bit_width, &mut packed);
1014
1015 let mut unpacked = [0u32; SIMD_BLOCK_SIZE];
1016 unpack_vertical(&packed, bit_width, &mut unpacked);
1017
1018 assert_eq!(values, unpacked);
1019 }
1020
1021 #[test]
1022 fn test_pack_unpack_vertical_various_widths() {
1023 for bit_width in 1..=20 {
1024 let mut values = [0u32; SIMD_BLOCK_SIZE];
1025 let max_val = (1u32 << bit_width) - 1;
1026 for (i, v) in values.iter_mut().enumerate() {
1027 *v = (i as u32) % (max_val + 1);
1028 }
1029
1030 let mut packed = Vec::new();
1031 pack_vertical(&values, bit_width, &mut packed);
1032
1033 let mut unpacked = [0u32; SIMD_BLOCK_SIZE];
1034 unpack_vertical(&packed, bit_width, &mut unpacked);
1035
1036 assert_eq!(values, unpacked, "Failed for bit_width={}", bit_width);
1037 }
1038 }
1039
1040 #[test]
1041 fn test_simd_bp128_posting_list() {
1042 let doc_ids: Vec<u32> = (0..200).map(|i| i * 2).collect();
1043 let term_freqs: Vec<u32> = (0..200).map(|i| (i % 10) + 1).collect();
1044
1045 let list = SimdBp128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
1046
1047 assert_eq!(list.doc_count, 200);
1048 assert_eq!(list.blocks.len(), 2); let mut iter = list.iterator();
1051 for (i, &expected_doc) in doc_ids.iter().enumerate() {
1052 assert_eq!(iter.doc(), expected_doc, "Doc mismatch at {}", i);
1053 assert_eq!(iter.term_freq(), term_freqs[i], "TF mismatch at {}", i);
1054 if i < doc_ids.len() - 1 {
1055 iter.advance();
1056 }
1057 }
1058 }
1059
1060 #[test]
1061 fn test_simd_bp128_seek() {
1062 let doc_ids: Vec<u32> = vec![10, 20, 30, 100, 200, 300, 1000, 2000];
1063 let term_freqs: Vec<u32> = vec![1, 2, 3, 4, 5, 6, 7, 8];
1064
1065 let list = SimdBp128PostingList::from_postings(&doc_ids, &term_freqs, 1.0);
1066 let mut iter = list.iterator();
1067
1068 assert_eq!(iter.seek(25), 30);
1069 assert_eq!(iter.seek(100), 100);
1070 assert_eq!(iter.seek(500), 1000);
1071 assert_eq!(iter.seek(3000), u32::MAX);
1072 }
1073
1074 #[test]
1075 fn test_simd_bp128_serialization() {
1076 let doc_ids: Vec<u32> = (0..300).map(|i| i * 3).collect();
1077 let term_freqs: Vec<u32> = (0..300).map(|i| (i % 5) + 1).collect();
1078
1079 let list = SimdBp128PostingList::from_postings(&doc_ids, &term_freqs, 1.5);
1080
1081 let mut buffer = Vec::new();
1082 list.serialize(&mut buffer).unwrap();
1083
1084 let restored = SimdBp128PostingList::deserialize(&mut &buffer[..]).unwrap();
1085
1086 assert_eq!(restored.doc_count, list.doc_count);
1087 assert_eq!(restored.blocks.len(), list.blocks.len());
1088
1089 let mut iter1 = list.iterator();
1090 let mut iter2 = restored.iterator();
1091
1092 while iter1.doc() != u32::MAX {
1093 assert_eq!(iter1.doc(), iter2.doc());
1094 assert_eq!(iter1.term_freq(), iter2.term_freq());
1095 iter1.advance();
1096 iter2.advance();
1097 }
1098 }
1099
1100 #[test]
1101 fn test_vertical_layout_size() {
1102 let mut values = [0u32; SIMD_BLOCK_SIZE];
1104 for (i, v) in values.iter_mut().enumerate() {
1105 *v = i as u32;
1106 }
1107
1108 let bit_width = bits_needed(127); assert_eq!(bit_width, 7);
1110
1111 let mut packed = Vec::new();
1112 pack_horizontal(&values, bit_width, &mut packed);
1113
1114 let expected_bytes = (SIMD_BLOCK_SIZE * bit_width as usize) / 8;
1116 assert_eq!(expected_bytes, 112);
1117 assert_eq!(packed.len(), expected_bytes);
1118 }
1119
1120 #[test]
1121 fn test_simd_bp128_block_max() {
1122 let doc_ids: Vec<u32> = (0..500).map(|i| i * 2).collect();
1124 let term_freqs: Vec<u32> = (0..500)
1126 .map(|i| {
1127 if i < 128 {
1128 1 } else if i < 256 {
1130 5 } else if i < 384 {
1132 10 } else {
1134 3 }
1136 })
1137 .collect();
1138
1139 let list = SimdBp128PostingList::from_postings(&doc_ids, &term_freqs, 2.0);
1140
1141 assert_eq!(list.blocks.len(), 4);
1143 assert_eq!(list.blocks[0].max_tf, 1);
1144 assert_eq!(list.blocks[1].max_tf, 5);
1145 assert_eq!(list.blocks[2].max_tf, 10);
1146 assert_eq!(list.blocks[3].max_tf, 3);
1147
1148 assert!(list.blocks[2].max_block_score > list.blocks[0].max_block_score);
1150 assert!(list.blocks[2].max_block_score > list.blocks[1].max_block_score);
1151 assert!(list.blocks[2].max_block_score > list.blocks[3].max_block_score);
1152
1153 assert_eq!(list.max_score, list.blocks[2].max_block_score);
1155
1156 let mut iter = list.iterator();
1158 assert_eq!(iter.current_block_max_tf(), 1); iter.seek(256); assert_eq!(iter.current_block_max_tf(), 5);
1163
1164 iter.seek(512); assert_eq!(iter.current_block_max_tf(), 10);
1167
1168 let mut iter2 = list.iterator();
1170 let result = iter2.skip_to_block_with_doc(300);
1171 assert!(result.is_some());
1172 let (first_doc, score) = result.unwrap();
1173 assert!(first_doc <= 300);
1174 assert!(score > 0.0);
1175 }
1176}