1use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
9use std::io::{self, Cursor, Read, Write};
10
11use super::config::WeightQuantization;
12use crate::DocId;
13use crate::directories::OwnedBytes;
14use crate::structures::postings::TERMINATED;
15use crate::structures::simd;
16
17pub const BLOCK_SIZE: usize = 128;
18pub const MAX_BLOCK_SIZE: usize = 256;
19
20#[derive(Debug, Clone, Copy)]
21pub struct BlockHeader {
22 pub count: u16,
23 pub doc_id_bits: u8,
24 pub ordinal_bits: u8,
25 pub weight_quant: WeightQuantization,
26 pub first_doc_id: DocId,
27 pub max_weight: f32,
28}
29
30impl BlockHeader {
31 pub const SIZE: usize = 16;
32
33 pub fn write<W: Write>(&self, w: &mut W) -> io::Result<()> {
34 w.write_u16::<LittleEndian>(self.count)?;
35 w.write_u8(self.doc_id_bits)?;
36 w.write_u8(self.ordinal_bits)?;
37 w.write_u8(self.weight_quant as u8)?;
38 w.write_u8(0)?;
39 w.write_u16::<LittleEndian>(0)?;
40 w.write_u32::<LittleEndian>(self.first_doc_id)?;
41 w.write_f32::<LittleEndian>(self.max_weight)?;
42 Ok(())
43 }
44
45 pub fn read<R: Read>(r: &mut R) -> io::Result<Self> {
46 let count = r.read_u16::<LittleEndian>()?;
47 let doc_id_bits = r.read_u8()?;
48 let ordinal_bits = r.read_u8()?;
49 let weight_quant_byte = r.read_u8()?;
50 let _ = r.read_u8()?;
51 let _ = r.read_u16::<LittleEndian>()?;
52 let first_doc_id = r.read_u32::<LittleEndian>()?;
53 let max_weight = r.read_f32::<LittleEndian>()?;
54
55 let weight_quant = WeightQuantization::from_u8(weight_quant_byte)
56 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Invalid weight quant"))?;
57
58 Ok(Self {
59 count,
60 doc_id_bits,
61 ordinal_bits,
62 weight_quant,
63 first_doc_id,
64 max_weight,
65 })
66 }
67}
68
69#[derive(Debug, Clone)]
70pub struct SparseBlock {
71 pub header: BlockHeader,
72 pub doc_ids_data: OwnedBytes,
74 pub ordinals_data: OwnedBytes,
76 pub weights_data: OwnedBytes,
78}
79
80impl SparseBlock {
81 pub fn from_postings(
82 postings: &[(DocId, u16, f32)],
83 weight_quant: WeightQuantization,
84 ) -> io::Result<Self> {
85 assert!(!postings.is_empty() && postings.len() <= MAX_BLOCK_SIZE);
86
87 let count = postings.len();
88 let first_doc_id = postings[0].0;
89
90 let mut deltas = Vec::with_capacity(count);
92 let mut prev = first_doc_id;
93 for &(doc_id, _, _) in postings {
94 deltas.push(doc_id.saturating_sub(prev));
95 prev = doc_id;
96 }
97 deltas[0] = 0;
98
99 let doc_id_bits = simd::round_bit_width(find_optimal_bit_width(&deltas[1..]));
100 let ordinals: Vec<u16> = postings.iter().map(|(_, o, _)| *o).collect();
101 let max_ordinal = ordinals.iter().copied().max().unwrap_or(0);
102 let ordinal_bits = if max_ordinal == 0 {
103 0
104 } else {
105 simd::round_bit_width(bits_needed_u16(max_ordinal))
106 };
107
108 let weights: Vec<f32> = postings.iter().map(|(_, _, w)| *w).collect();
109 let max_weight = weights
110 .iter()
111 .copied()
112 .fold(0.0f32, |acc, w| acc.max(w.abs()));
113
114 let doc_ids_data = OwnedBytes::new({
115 let rounded = simd::RoundedBitWidth::from_u8(doc_id_bits);
116 let num_deltas = count - 1;
117 let byte_count = num_deltas * rounded.bytes_per_value();
118 let mut data = vec![0u8; byte_count];
119 simd::pack_rounded(&deltas[1..], rounded, &mut data);
120 data
121 });
122 let ordinals_data = OwnedBytes::new(if ordinal_bits > 0 {
123 let rounded = simd::RoundedBitWidth::from_u8(ordinal_bits);
124 let byte_count = count * rounded.bytes_per_value();
125 let mut data = vec![0u8; byte_count];
126 let ord_u32: Vec<u32> = ordinals.iter().map(|&o| o as u32).collect();
127 simd::pack_rounded(&ord_u32, rounded, &mut data);
128 data
129 } else {
130 Vec::new()
131 });
132 let weights_data = OwnedBytes::new(encode_weights(&weights, weight_quant)?);
133
134 Ok(Self {
135 header: BlockHeader {
136 count: count as u16,
137 doc_id_bits,
138 ordinal_bits,
139 weight_quant,
140 first_doc_id,
141 max_weight,
142 },
143 doc_ids_data,
144 ordinals_data,
145 weights_data,
146 })
147 }
148
149 pub fn decode_doc_ids(&self) -> Vec<DocId> {
150 let mut out = Vec::with_capacity(self.header.count as usize);
151 self.decode_doc_ids_into(&mut out);
152 out
153 }
154
155 pub fn decode_doc_ids_into(&self, out: &mut Vec<DocId>) {
159 let count = self.header.count as usize;
160 out.clear();
161 out.resize(count, 0);
162 out[0] = self.header.first_doc_id;
163
164 if count > 1 {
165 let bits = self.header.doc_id_bits;
166 if bits == 0 {
167 out[1..].fill(self.header.first_doc_id);
169 } else {
170 simd::unpack_rounded(
172 &self.doc_ids_data,
173 simd::RoundedBitWidth::from_u8(bits),
174 &mut out[1..],
175 count - 1,
176 );
177 for i in 1..count {
179 out[i] += out[i - 1];
180 }
181 }
182 }
183 }
184
185 pub fn decode_ordinals(&self) -> Vec<u16> {
186 let mut out = Vec::with_capacity(self.header.count as usize);
187 self.decode_ordinals_into(&mut out);
188 out
189 }
190
191 pub fn decode_ordinals_into(&self, out: &mut Vec<u16>) {
195 let count = self.header.count as usize;
196 out.clear();
197 if self.header.ordinal_bits == 0 {
198 out.resize(count, 0u16);
199 } else {
200 let mut temp = [0u32; BLOCK_SIZE];
202 simd::unpack_rounded(
203 &self.ordinals_data,
204 simd::RoundedBitWidth::from_u8(self.header.ordinal_bits),
205 &mut temp[..count],
206 count,
207 );
208 out.reserve(count);
209 for &v in &temp[..count] {
210 out.push(v as u16);
211 }
212 }
213 }
214
215 pub fn decode_weights(&self) -> Vec<f32> {
216 let mut out = Vec::with_capacity(self.header.count as usize);
217 self.decode_weights_into(&mut out);
218 out
219 }
220
221 pub fn decode_weights_into(&self, out: &mut Vec<f32>) {
223 out.clear();
224 decode_weights_into(
225 &self.weights_data,
226 self.header.weight_quant,
227 self.header.count as usize,
228 out,
229 );
230 }
231
232 pub fn decode_scored_weights_into(&self, query_weight: f32, out: &mut Vec<f32>) {
240 out.clear();
241 let count = self.header.count as usize;
242 match self.header.weight_quant {
243 WeightQuantization::UInt8 if self.weights_data.len() >= 8 => {
244 let scale = f32::from_le_bytes([
246 self.weights_data[0],
247 self.weights_data[1],
248 self.weights_data[2],
249 self.weights_data[3],
250 ]);
251 let min_val = f32::from_le_bytes([
252 self.weights_data[4],
253 self.weights_data[5],
254 self.weights_data[6],
255 self.weights_data[7],
256 ]);
257 let eff_scale = query_weight * scale;
259 let eff_bias = query_weight * min_val;
260 out.resize(count, 0.0);
261 simd::dequantize_uint8(&self.weights_data[8..], out, eff_scale, eff_bias, count);
262 }
263 _ => {
264 decode_weights_into(&self.weights_data, self.header.weight_quant, count, out);
266 for w in out.iter_mut() {
267 *w *= query_weight;
268 }
269 }
270 }
271 }
272
273 #[inline]
285 pub fn accumulate_scored_weights(
286 &self,
287 query_weight: f32,
288 doc_ids: &[u32],
289 flat_scores: &mut [f32],
290 base_doc: u32,
291 dirty: &mut Vec<u32>,
292 ) -> usize {
293 let count = self.header.count as usize;
294 match self.header.weight_quant {
295 WeightQuantization::UInt8 if self.weights_data.len() >= 8 => {
296 let scale = f32::from_le_bytes([
298 self.weights_data[0],
299 self.weights_data[1],
300 self.weights_data[2],
301 self.weights_data[3],
302 ]);
303 let min_val = f32::from_le_bytes([
304 self.weights_data[4],
305 self.weights_data[5],
306 self.weights_data[6],
307 self.weights_data[7],
308 ]);
309 let eff_scale = query_weight * scale;
310 let eff_bias = query_weight * min_val;
311 let quant_data = &self.weights_data[8..];
312
313 for i in 0..count.min(quant_data.len()).min(doc_ids.len()) {
314 let w = quant_data[i] as f32 * eff_scale + eff_bias;
315 let off = (doc_ids[i] - base_doc) as usize;
316 if off >= flat_scores.len() {
317 continue;
318 }
319 if flat_scores[off] == 0.0 {
320 dirty.push(doc_ids[i]);
321 }
322 flat_scores[off] += w;
323 }
324 count
325 }
326 _ => {
327 let mut weights_buf = Vec::with_capacity(count);
329 decode_weights_into(
330 &self.weights_data,
331 self.header.weight_quant,
332 count,
333 &mut weights_buf,
334 );
335 for i in 0..count.min(weights_buf.len()).min(doc_ids.len()) {
336 let w = weights_buf[i] * query_weight;
337 let off = (doc_ids[i] - base_doc) as usize;
338 if off >= flat_scores.len() {
339 continue;
340 }
341 if flat_scores[off] == 0.0 {
342 dirty.push(doc_ids[i]);
343 }
344 flat_scores[off] += w;
345 }
346 count
347 }
348 }
349 }
350
351 pub fn write<W: Write>(&self, w: &mut W) -> io::Result<()> {
352 self.header.write(w)?;
353 if self.doc_ids_data.len() > u16::MAX as usize
354 || self.ordinals_data.len() > u16::MAX as usize
355 || self.weights_data.len() > u16::MAX as usize
356 {
357 return Err(io::Error::new(
358 io::ErrorKind::InvalidData,
359 format!(
360 "sparse sub-block too large for u16 length: doc_ids={}B ords={}B wts={}B",
361 self.doc_ids_data.len(),
362 self.ordinals_data.len(),
363 self.weights_data.len()
364 ),
365 ));
366 }
367 w.write_u16::<LittleEndian>(self.doc_ids_data.len() as u16)?;
368 w.write_u16::<LittleEndian>(self.ordinals_data.len() as u16)?;
369 w.write_u16::<LittleEndian>(self.weights_data.len() as u16)?;
370 w.write_u16::<LittleEndian>(0)?;
371 w.write_all(&self.doc_ids_data)?;
372 w.write_all(&self.ordinals_data)?;
373 w.write_all(&self.weights_data)?;
374 Ok(())
375 }
376
377 pub fn read<R: Read>(r: &mut R) -> io::Result<Self> {
378 let header = BlockHeader::read(r)?;
379 let doc_ids_len = r.read_u16::<LittleEndian>()? as usize;
380 let ordinals_len = r.read_u16::<LittleEndian>()? as usize;
381 let weights_len = r.read_u16::<LittleEndian>()? as usize;
382 let _ = r.read_u16::<LittleEndian>()?;
383
384 let mut doc_ids_vec = vec![0u8; doc_ids_len];
385 r.read_exact(&mut doc_ids_vec)?;
386 let mut ordinals_vec = vec![0u8; ordinals_len];
387 r.read_exact(&mut ordinals_vec)?;
388 let mut weights_vec = vec![0u8; weights_len];
389 r.read_exact(&mut weights_vec)?;
390
391 Ok(Self {
392 header,
393 doc_ids_data: OwnedBytes::new(doc_ids_vec),
394 ordinals_data: OwnedBytes::new(ordinals_vec),
395 weights_data: OwnedBytes::new(weights_vec),
396 })
397 }
398
399 pub fn from_owned_bytes(data: crate::directories::OwnedBytes) -> crate::Result<Self> {
405 let b = data.as_slice();
406 if b.len() < BlockHeader::SIZE + 8 {
407 return Err(crate::Error::Corruption(
408 "sparse block too small".to_string(),
409 ));
410 }
411 let mut cursor = Cursor::new(&b[..BlockHeader::SIZE]);
412 let header =
413 BlockHeader::read(&mut cursor).map_err(|e| crate::Error::Corruption(e.to_string()))?;
414
415 if header.count == 0 {
416 let hex: String = b
417 .iter()
418 .take(32)
419 .map(|x| format!("{x:02x}"))
420 .collect::<Vec<_>>()
421 .join(" ");
422 return Err(crate::Error::Corruption(format!(
423 "sparse block has count=0 (data_len={}, first_32_bytes=[{}])",
424 b.len(),
425 hex
426 )));
427 }
428
429 let p = BlockHeader::SIZE;
430 let doc_ids_len = u16::from_le_bytes([b[p], b[p + 1]]) as usize;
431 let ordinals_len = u16::from_le_bytes([b[p + 2], b[p + 3]]) as usize;
432 let weights_len = u16::from_le_bytes([b[p + 4], b[p + 5]]) as usize;
433 let data_start = p + 8;
436 let ord_start = data_start + doc_ids_len;
437 let wt_start = ord_start + ordinals_len;
438 let expected_end = wt_start + weights_len;
439
440 if expected_end > b.len() {
441 let hex: String = b
442 .iter()
443 .take(32)
444 .map(|x| format!("{x:02x}"))
445 .collect::<Vec<_>>()
446 .join(" ");
447 return Err(crate::Error::Corruption(format!(
448 "sparse block sub-block overflow: count={} doc_ids={}B ords={}B wts={}B need={}B have={}B (first_32=[{}])",
449 header.count,
450 doc_ids_len,
451 ordinals_len,
452 weights_len,
453 expected_end,
454 b.len(),
455 hex
456 )));
457 }
458
459 Ok(Self {
460 header,
461 doc_ids_data: data.slice(data_start..ord_start),
462 ordinals_data: data.slice(ord_start..wt_start),
463 weights_data: data.slice(wt_start..wt_start + weights_len),
464 })
465 }
466
467 pub fn with_doc_offset(&self, doc_offset: u32) -> Self {
473 Self {
474 header: BlockHeader {
475 first_doc_id: self.header.first_doc_id + doc_offset,
476 ..self.header
477 },
478 doc_ids_data: self.doc_ids_data.clone(),
479 ordinals_data: self.ordinals_data.clone(),
480 weights_data: self.weights_data.clone(),
481 }
482 }
483}
484
485#[derive(Debug, Clone)]
490pub struct BlockSparsePostingList {
491 pub doc_count: u32,
492 pub blocks: Vec<SparseBlock>,
493}
494
495impl BlockSparsePostingList {
496 pub fn from_postings_with_block_size(
498 postings: &[(DocId, u16, f32)],
499 weight_quant: WeightQuantization,
500 block_size: usize,
501 ) -> io::Result<Self> {
502 if postings.is_empty() {
503 return Ok(Self {
504 doc_count: 0,
505 blocks: Vec::new(),
506 });
507 }
508
509 let block_size = block_size.max(16); let mut blocks = Vec::new();
511 for chunk in postings.chunks(block_size) {
512 blocks.push(SparseBlock::from_postings(chunk, weight_quant)?);
513 }
514
515 let mut unique_docs = 1u32;
520 for i in 1..postings.len() {
521 if postings[i].0 != postings[i - 1].0 {
522 unique_docs += 1;
523 }
524 }
525
526 Ok(Self {
527 doc_count: unique_docs,
528 blocks,
529 })
530 }
531
532 pub fn from_postings(
534 postings: &[(DocId, u16, f32)],
535 weight_quant: WeightQuantization,
536 ) -> io::Result<Self> {
537 Self::from_postings_with_block_size(postings, weight_quant, BLOCK_SIZE)
538 }
539
540 pub fn from_postings_with_partition(
546 postings: &[(DocId, u16, f32)],
547 weight_quant: WeightQuantization,
548 partition: &[usize],
549 ) -> io::Result<Self> {
550 if postings.is_empty() {
551 return Ok(Self {
552 doc_count: 0,
553 blocks: Vec::new(),
554 });
555 }
556
557 let mut blocks = Vec::with_capacity(partition.len());
558 let mut offset = 0;
559 for &block_size in partition {
560 let end = (offset + block_size).min(postings.len());
561 blocks.push(SparseBlock::from_postings(
562 &postings[offset..end],
563 weight_quant,
564 )?);
565 offset = end;
566 }
567
568 let mut unique_docs = 1u32;
569 for i in 1..postings.len() {
570 if postings[i].0 != postings[i - 1].0 {
571 unique_docs += 1;
572 }
573 }
574
575 Ok(Self {
576 doc_count: unique_docs,
577 blocks,
578 })
579 }
580
581 pub fn doc_count(&self) -> u32 {
582 self.doc_count
583 }
584
585 pub fn num_blocks(&self) -> usize {
586 self.blocks.len()
587 }
588
589 pub fn global_max_weight(&self) -> f32 {
590 self.blocks
591 .iter()
592 .map(|b| b.header.max_weight)
593 .fold(0.0f32, f32::max)
594 }
595
596 pub fn block_max_weight(&self, block_idx: usize) -> Option<f32> {
597 self.blocks.get(block_idx).map(|b| b.header.max_weight)
598 }
599
600 pub fn size_bytes(&self) -> usize {
602 use std::mem::size_of;
603
604 let header_size = size_of::<u32>() * 2; let blocks_size: usize = self
606 .blocks
607 .iter()
608 .map(|b| {
609 size_of::<BlockHeader>()
610 + b.doc_ids_data.len()
611 + b.ordinals_data.len()
612 + b.weights_data.len()
613 })
614 .sum();
615 header_size + blocks_size
616 }
617
618 pub fn iterator(&self) -> BlockSparsePostingIterator<'_> {
619 BlockSparsePostingIterator::new(self)
620 }
621
622 pub fn serialize(&self) -> io::Result<(Vec<u8>, Vec<super::SparseSkipEntry>)> {
628 let mut block_data = Vec::new();
630 let mut skip_entries = Vec::with_capacity(self.blocks.len());
631 let mut offset = 0u64;
632
633 for block in &self.blocks {
634 let mut buf = Vec::new();
635 block.write(&mut buf)?;
636 let length = buf.len() as u32;
637
638 let first_doc = block.header.first_doc_id;
639 let doc_ids = block.decode_doc_ids();
640 let last_doc = doc_ids.last().copied().unwrap_or(first_doc);
641
642 skip_entries.push(super::SparseSkipEntry::new(
643 first_doc,
644 last_doc,
645 offset,
646 length,
647 block.header.max_weight,
648 ));
649
650 block_data.extend_from_slice(&buf);
651 offset += length as u64;
652 }
653
654 Ok((block_data, skip_entries))
655 }
656
657 #[cfg(test)]
662 pub fn from_parts(
663 doc_count: u32,
664 block_data: &[u8],
665 skip_entries: &[super::SparseSkipEntry],
666 ) -> io::Result<Self> {
667 let mut blocks = Vec::with_capacity(skip_entries.len());
668 for entry in skip_entries {
669 let start = entry.offset as usize;
670 let end = start + entry.length as usize;
671 blocks.push(SparseBlock::read(&mut std::io::Cursor::new(
672 &block_data[start..end],
673 ))?);
674 }
675 Ok(Self { doc_count, blocks })
676 }
677
678 pub fn decode_all(&self) -> Vec<(DocId, u16, f32)> {
679 let total_postings: usize = self.blocks.iter().map(|b| b.header.count as usize).sum();
680 let mut result = Vec::with_capacity(total_postings);
681 for block in &self.blocks {
682 let doc_ids = block.decode_doc_ids();
683 let ordinals = block.decode_ordinals();
684 let weights = block.decode_weights();
685 for i in 0..block.header.count as usize {
686 result.push((doc_ids[i], ordinals[i], weights[i]));
687 }
688 }
689 result
690 }
691
692 pub fn merge_with_offsets(lists: &[(&BlockSparsePostingList, u32)]) -> Self {
703 if lists.is_empty() {
704 return Self {
705 doc_count: 0,
706 blocks: Vec::new(),
707 };
708 }
709
710 let total_blocks: usize = lists.iter().map(|(pl, _)| pl.blocks.len()).sum();
712 let total_docs: u32 = lists.iter().map(|(pl, _)| pl.doc_count).sum();
713
714 let mut merged_blocks = Vec::with_capacity(total_blocks);
715
716 for (posting_list, doc_offset) in lists {
718 for block in &posting_list.blocks {
719 merged_blocks.push(block.with_doc_offset(*doc_offset));
720 }
721 }
722
723 Self {
724 doc_count: total_docs,
725 blocks: merged_blocks,
726 }
727 }
728
729 fn find_block(&self, target: DocId) -> Option<usize> {
730 if self.blocks.is_empty() {
731 return None;
732 }
733 let idx = self
736 .blocks
737 .partition_point(|b| b.header.first_doc_id <= target);
738 if idx == 0 {
739 Some(0)
741 } else {
742 Some(idx - 1)
743 }
744 }
745}
746
747pub struct BlockSparsePostingIterator<'a> {
752 posting_list: &'a BlockSparsePostingList,
753 block_idx: usize,
754 in_block_idx: usize,
755 current_doc_ids: Vec<DocId>,
756 current_ordinals: Vec<u16>,
757 current_weights: Vec<f32>,
758 ordinals_decoded: bool,
760 exhausted: bool,
761}
762
763impl<'a> BlockSparsePostingIterator<'a> {
764 fn new(posting_list: &'a BlockSparsePostingList) -> Self {
765 let mut iter = Self {
766 posting_list,
767 block_idx: 0,
768 in_block_idx: 0,
769 current_doc_ids: Vec::with_capacity(128),
770 current_ordinals: Vec::with_capacity(128),
771 current_weights: Vec::with_capacity(128),
772 ordinals_decoded: false,
773 exhausted: posting_list.blocks.is_empty(),
774 };
775 if !iter.exhausted {
776 iter.load_block(0);
777 }
778 iter
779 }
780
781 fn load_block(&mut self, block_idx: usize) {
782 if let Some(block) = self.posting_list.blocks.get(block_idx) {
783 block.decode_doc_ids_into(&mut self.current_doc_ids);
784 block.decode_weights_into(&mut self.current_weights);
785 self.ordinals_decoded = false;
787 self.block_idx = block_idx;
788 self.in_block_idx = 0;
789 }
790 }
791
792 #[inline]
794 fn ensure_ordinals_decoded(&mut self) {
795 if !self.ordinals_decoded {
796 if let Some(block) = self.posting_list.blocks.get(self.block_idx) {
797 block.decode_ordinals_into(&mut self.current_ordinals);
798 }
799 self.ordinals_decoded = true;
800 }
801 }
802
803 #[inline]
804 pub fn doc(&self) -> DocId {
805 if self.exhausted {
806 TERMINATED
807 } else {
808 self.current_doc_ids[self.in_block_idx]
810 }
811 }
812
813 #[inline]
814 pub fn weight(&self) -> f32 {
815 if self.exhausted {
816 return 0.0;
817 }
818 self.current_weights[self.in_block_idx]
820 }
821
822 #[inline]
823 pub fn ordinal(&mut self) -> u16 {
824 if self.exhausted {
825 return 0;
826 }
827 self.ensure_ordinals_decoded();
828 self.current_ordinals[self.in_block_idx]
829 }
830
831 pub fn advance(&mut self) -> DocId {
832 if self.exhausted {
833 return TERMINATED;
834 }
835 self.in_block_idx += 1;
836 if self.in_block_idx >= self.current_doc_ids.len() {
837 self.block_idx += 1;
838 if self.block_idx >= self.posting_list.blocks.len() {
839 self.exhausted = true;
840 } else {
841 self.load_block(self.block_idx);
842 }
843 }
844 self.doc()
845 }
846
847 pub fn seek(&mut self, target: DocId) -> DocId {
848 if self.exhausted {
849 return TERMINATED;
850 }
851 if self.doc() >= target {
852 return self.doc();
853 }
854
855 if let Some(&last_doc) = self.current_doc_ids.last()
857 && last_doc >= target
858 {
859 let remaining = &self.current_doc_ids[self.in_block_idx..];
860 let pos = crate::structures::simd::find_first_ge_u32(remaining, target);
861 self.in_block_idx += pos;
862 if self.in_block_idx >= self.current_doc_ids.len() {
863 self.block_idx += 1;
864 if self.block_idx >= self.posting_list.blocks.len() {
865 self.exhausted = true;
866 } else {
867 self.load_block(self.block_idx);
868 }
869 }
870 return self.doc();
871 }
872
873 if let Some(block_idx) = self.posting_list.find_block(target) {
875 self.load_block(block_idx);
876 let pos = crate::structures::simd::find_first_ge_u32(&self.current_doc_ids, target);
877 self.in_block_idx = pos;
878 if self.in_block_idx >= self.current_doc_ids.len() {
879 self.block_idx += 1;
880 if self.block_idx >= self.posting_list.blocks.len() {
881 self.exhausted = true;
882 } else {
883 self.load_block(self.block_idx);
884 }
885 }
886 } else {
887 self.exhausted = true;
888 }
889 self.doc()
890 }
891
892 pub fn skip_to_next_block(&mut self) -> DocId {
895 if self.exhausted {
896 return TERMINATED;
897 }
898 let next = self.block_idx + 1;
899 if next >= self.posting_list.blocks.len() {
900 self.exhausted = true;
901 return TERMINATED;
902 }
903 self.load_block(next);
904 self.doc()
905 }
906
907 pub fn is_exhausted(&self) -> bool {
908 self.exhausted
909 }
910
911 pub fn current_block_max_weight(&self) -> f32 {
912 self.posting_list
913 .blocks
914 .get(self.block_idx)
915 .map(|b| b.header.max_weight)
916 .unwrap_or(0.0)
917 }
918
919 pub fn current_block_max_contribution(&self, query_weight: f32) -> f32 {
920 query_weight * self.current_block_max_weight()
921 }
922}
923
924fn find_optimal_bit_width(values: &[u32]) -> u8 {
929 if values.is_empty() {
930 return 0;
931 }
932 let max_val = values.iter().copied().max().unwrap_or(0);
933 simd::bits_needed(max_val)
934}
935
936fn bits_needed_u16(val: u16) -> u8 {
937 if val == 0 {
938 0
939 } else {
940 16 - val.leading_zeros() as u8
941 }
942}
943
944fn encode_weights(weights: &[f32], quant: WeightQuantization) -> io::Result<Vec<u8>> {
949 let mut data = Vec::new();
950 match quant {
951 WeightQuantization::Float32 => {
952 for &w in weights {
953 data.write_f32::<LittleEndian>(w)?;
954 }
955 }
956 WeightQuantization::Float16 => {
957 use half::f16;
958 for &w in weights {
959 data.write_u16::<LittleEndian>(f16::from_f32(w).to_bits())?;
960 }
961 }
962 WeightQuantization::UInt8 => {
963 let min = weights.iter().copied().fold(f32::INFINITY, f32::min);
964 let max = weights.iter().copied().fold(f32::NEG_INFINITY, f32::max);
965 let range = max - min;
966 let scale = if range < f32::EPSILON {
967 1.0
968 } else {
969 range / 255.0
970 };
971 data.write_f32::<LittleEndian>(scale)?;
972 data.write_f32::<LittleEndian>(min)?;
973 for &w in weights {
974 data.write_u8(((w - min) / scale).round() as u8)?;
975 }
976 }
977 WeightQuantization::UInt4 => {
978 let min = weights.iter().copied().fold(f32::INFINITY, f32::min);
979 let max = weights.iter().copied().fold(f32::NEG_INFINITY, f32::max);
980 let range = max - min;
981 let scale = if range < f32::EPSILON {
982 1.0
983 } else {
984 range / 15.0
985 };
986 data.write_f32::<LittleEndian>(scale)?;
987 data.write_f32::<LittleEndian>(min)?;
988 let mut i = 0;
989 while i < weights.len() {
990 let q1 = ((weights[i] - min) / scale).round() as u8 & 0x0F;
991 let q2 = if i + 1 < weights.len() {
992 ((weights[i + 1] - min) / scale).round() as u8 & 0x0F
993 } else {
994 0
995 };
996 data.write_u8((q2 << 4) | q1)?;
997 i += 2;
998 }
999 }
1000 }
1001 Ok(data)
1002}
1003
1004fn decode_weights_into(data: &[u8], quant: WeightQuantization, count: usize, out: &mut Vec<f32>) {
1005 let mut cursor = Cursor::new(data);
1006 match quant {
1007 WeightQuantization::Float32 => {
1008 for _ in 0..count {
1009 out.push(cursor.read_f32::<LittleEndian>().unwrap_or(0.0));
1010 }
1011 }
1012 WeightQuantization::Float16 => {
1013 use half::f16;
1014 for _ in 0..count {
1015 let bits = cursor.read_u16::<LittleEndian>().unwrap_or(0);
1016 out.push(f16::from_bits(bits).to_f32());
1017 }
1018 }
1019 WeightQuantization::UInt8 => {
1020 let scale = cursor.read_f32::<LittleEndian>().unwrap_or(1.0);
1021 let min_val = cursor.read_f32::<LittleEndian>().unwrap_or(0.0);
1022 let offset = cursor.position() as usize;
1023 out.resize(count, 0.0);
1024 simd::dequantize_uint8(&data[offset..], out, scale, min_val, count);
1025 }
1026 WeightQuantization::UInt4 => {
1027 let scale = cursor.read_f32::<LittleEndian>().unwrap_or(1.0);
1028 let min = cursor.read_f32::<LittleEndian>().unwrap_or(0.0);
1029 let mut i = 0;
1030 while i < count {
1031 let byte = cursor.read_u8().unwrap_or(0);
1032 out.push((byte & 0x0F) as f32 * scale + min);
1033 i += 1;
1034 if i < count {
1035 out.push((byte >> 4) as f32 * scale + min);
1036 i += 1;
1037 }
1038 }
1039 }
1040 }
1041}
1042
1043#[cfg(test)]
1044mod tests {
1045 use super::*;
1046
1047 #[test]
1048 fn test_block_roundtrip() {
1049 let postings = vec![
1050 (10u32, 0u16, 1.5f32),
1051 (15, 0, 2.0),
1052 (20, 1, 0.5),
1053 (100, 0, 3.0),
1054 ];
1055 let block = SparseBlock::from_postings(&postings, WeightQuantization::Float32).unwrap();
1056
1057 assert_eq!(block.decode_doc_ids(), vec![10, 15, 20, 100]);
1058 assert_eq!(block.decode_ordinals(), vec![0, 0, 1, 0]);
1059 let weights = block.decode_weights();
1060 assert!((weights[0] - 1.5).abs() < 0.01);
1061 }
1062
1063 #[test]
1064 fn test_posting_list() {
1065 let postings: Vec<(DocId, u16, f32)> =
1066 (0..300).map(|i| (i * 2, 0, i as f32 * 0.1)).collect();
1067 let list =
1068 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1069
1070 assert_eq!(list.doc_count(), 300);
1071 assert_eq!(list.num_blocks(), 3);
1072
1073 let mut iter = list.iterator();
1074 assert_eq!(iter.doc(), 0);
1075 iter.advance();
1076 assert_eq!(iter.doc(), 2);
1077 }
1078
1079 #[test]
1080 fn test_serialization() {
1081 let postings = vec![(1u32, 0u16, 0.5f32), (10, 1, 1.5), (100, 0, 2.5)];
1082 let list =
1083 BlockSparsePostingList::from_postings(&postings, WeightQuantization::UInt8).unwrap();
1084
1085 let (block_data, skip_entries) = list.serialize().unwrap();
1086 let list2 =
1087 BlockSparsePostingList::from_parts(list.doc_count(), &block_data, &skip_entries)
1088 .unwrap();
1089
1090 assert_eq!(list.doc_count(), list2.doc_count());
1091 }
1092
1093 #[test]
1094 fn test_seek() {
1095 let postings: Vec<(DocId, u16, f32)> = (0..500).map(|i| (i * 3, 0, i as f32)).collect();
1096 let list =
1097 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1098
1099 let mut iter = list.iterator();
1100 assert_eq!(iter.seek(300), 300);
1101 assert_eq!(iter.seek(301), 303);
1102 assert_eq!(iter.seek(2000), TERMINATED);
1103 }
1104
1105 #[test]
1106 fn test_merge_with_offsets() {
1107 let postings1: Vec<(DocId, u16, f32)> = vec![(0, 0, 1.0), (5, 0, 2.0), (10, 1, 3.0)];
1109 let list1 =
1110 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1111
1112 let postings2: Vec<(DocId, u16, f32)> = vec![(0, 0, 4.0), (3, 1, 5.0), (7, 0, 6.0)];
1114 let list2 =
1115 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1116
1117 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
1119
1120 assert_eq!(merged.doc_count(), 6);
1121
1122 let decoded = merged.decode_all();
1124 assert_eq!(decoded.len(), 6);
1125
1126 assert_eq!(decoded[0].0, 0);
1128 assert_eq!(decoded[1].0, 5);
1129 assert_eq!(decoded[2].0, 10);
1130
1131 assert_eq!(decoded[3].0, 100); assert_eq!(decoded[4].0, 103); assert_eq!(decoded[5].0, 107); assert!((decoded[0].2 - 1.0).abs() < 0.01);
1138 assert!((decoded[3].2 - 4.0).abs() < 0.01);
1139
1140 assert_eq!(decoded[2].1, 1); assert_eq!(decoded[4].1, 1); }
1144
1145 #[test]
1146 fn test_merge_with_offsets_multi_block() {
1147 let postings1: Vec<(DocId, u16, f32)> = (0..200).map(|i| (i * 2, 0, i as f32)).collect();
1149 let list1 =
1150 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1151 assert!(list1.num_blocks() > 1, "Should have multiple blocks");
1152
1153 let postings2: Vec<(DocId, u16, f32)> = (0..150).map(|i| (i * 3, 1, i as f32)).collect();
1154 let list2 =
1155 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1156
1157 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 1000)]);
1159
1160 assert_eq!(merged.doc_count(), 350);
1161 assert_eq!(merged.num_blocks(), list1.num_blocks() + list2.num_blocks());
1162
1163 let mut iter = merged.iterator();
1165
1166 assert_eq!(iter.doc(), 0);
1168
1169 let doc = iter.seek(1000);
1171 assert_eq!(doc, 1000); iter.advance();
1175 assert_eq!(iter.doc(), 1003); }
1177
1178 #[test]
1179 fn test_merge_with_offsets_serialize_roundtrip() {
1180 let postings1: Vec<(DocId, u16, f32)> = vec![(0, 0, 1.0), (5, 0, 2.0), (10, 1, 3.0)];
1182 let list1 =
1183 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1184
1185 let postings2: Vec<(DocId, u16, f32)> = vec![(0, 0, 4.0), (3, 1, 5.0), (7, 0, 6.0)];
1186 let list2 =
1187 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1188
1189 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
1191
1192 let (block_data, skip_entries) = merged.serialize().unwrap();
1194 let loaded =
1195 BlockSparsePostingList::from_parts(merged.doc_count(), &block_data, &skip_entries)
1196 .unwrap();
1197
1198 let decoded = loaded.decode_all();
1200 assert_eq!(decoded.len(), 6);
1201
1202 assert_eq!(decoded[0].0, 0);
1204 assert_eq!(decoded[1].0, 5);
1205 assert_eq!(decoded[2].0, 10);
1206
1207 assert_eq!(decoded[3].0, 100, "First doc of seg2 should be 0+100=100");
1209 assert_eq!(decoded[4].0, 103, "Second doc of seg2 should be 3+100=103");
1210 assert_eq!(decoded[5].0, 107, "Third doc of seg2 should be 7+100=107");
1211
1212 let mut iter = loaded.iterator();
1214 assert_eq!(iter.doc(), 0);
1215 iter.advance();
1216 assert_eq!(iter.doc(), 5);
1217 iter.advance();
1218 assert_eq!(iter.doc(), 10);
1219 iter.advance();
1220 assert_eq!(iter.doc(), 100);
1221 iter.advance();
1222 assert_eq!(iter.doc(), 103);
1223 iter.advance();
1224 assert_eq!(iter.doc(), 107);
1225 }
1226
1227 #[test]
1228 fn test_merge_seek_after_roundtrip() {
1229 let postings1: Vec<(DocId, u16, f32)> = (0..200).map(|i| (i * 2, 0, 1.0)).collect();
1231 let list1 =
1232 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1233
1234 let postings2: Vec<(DocId, u16, f32)> = (0..150).map(|i| (i * 3, 0, 2.0)).collect();
1235 let list2 =
1236 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1237
1238 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 1000)]);
1240
1241 let (block_data, skip_entries) = merged.serialize().unwrap();
1243 let loaded =
1244 BlockSparsePostingList::from_parts(merged.doc_count(), &block_data, &skip_entries)
1245 .unwrap();
1246
1247 let mut iter = loaded.iterator();
1249
1250 let doc = iter.seek(100);
1252 assert_eq!(doc, 100, "Seek to 100 in segment 1");
1253
1254 let doc = iter.seek(1000);
1256 assert_eq!(doc, 1000, "Seek to 1000 (first doc of segment 2)");
1257
1258 let doc = iter.seek(1050);
1260 assert!(
1261 doc >= 1050,
1262 "Seek to 1050 should find doc >= 1050, got {}",
1263 doc
1264 );
1265
1266 let doc = iter.seek(500);
1268 assert!(
1269 doc >= 1050,
1270 "Seek backwards should not go back, got {}",
1271 doc
1272 );
1273
1274 let mut iter2 = loaded.iterator();
1276
1277 let mut count = 0;
1279 let mut prev_doc = 0;
1280 while iter2.doc() != super::TERMINATED {
1281 let current = iter2.doc();
1282 if count > 0 {
1283 assert!(
1284 current > prev_doc,
1285 "Docs should be monotonically increasing: {} vs {}",
1286 prev_doc,
1287 current
1288 );
1289 }
1290 prev_doc = current;
1291 iter2.advance();
1292 count += 1;
1293 }
1294 assert_eq!(count, 350, "Should have 350 total docs");
1295 }
1296
1297 #[test]
1298 fn test_doc_count_multi_value() {
1299 let postings: Vec<(DocId, u16, f32)> = vec![
1302 (0, 0, 1.0),
1303 (0, 1, 1.5),
1304 (0, 2, 2.0),
1305 (5, 0, 3.0),
1306 (5, 1, 3.5),
1307 (10, 0, 4.0),
1308 ];
1309 let list =
1310 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1311
1312 assert_eq!(list.doc_count(), 3);
1314
1315 let decoded = list.decode_all();
1317 assert_eq!(decoded.len(), 6);
1318 }
1319
1320 #[test]
1324 fn test_zero_copy_merge_patches_first_doc_id() {
1325 use crate::structures::SparseSkipEntry;
1326
1327 let postings1: Vec<(DocId, u16, f32)> = (0..200).map(|i| (i * 2, 0, i as f32)).collect();
1329 let list1 =
1330 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1331 assert!(list1.num_blocks() > 1);
1332
1333 let postings2: Vec<(DocId, u16, f32)> = (0..150).map(|i| (i * 3, 1, i as f32)).collect();
1334 let list2 =
1335 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1336
1337 let (raw1, skip1) = list1.serialize().unwrap();
1339 let (raw2, skip2) = list2.serialize().unwrap();
1340
1341 let doc_offset: u32 = 1000; let total_docs = list1.doc_count() + list2.doc_count();
1344
1345 let mut merged_skip = Vec::new();
1347 let mut cumulative_offset = 0u64;
1348 for entry in &skip1 {
1349 merged_skip.push(SparseSkipEntry::new(
1350 entry.first_doc,
1351 entry.last_doc,
1352 cumulative_offset + entry.offset,
1353 entry.length,
1354 entry.max_weight,
1355 ));
1356 }
1357 if let Some(last) = skip1.last() {
1358 cumulative_offset += last.offset + last.length as u64;
1359 }
1360 for entry in &skip2 {
1361 merged_skip.push(SparseSkipEntry::new(
1362 entry.first_doc + doc_offset,
1363 entry.last_doc + doc_offset,
1364 cumulative_offset + entry.offset,
1365 entry.length,
1366 entry.max_weight,
1367 ));
1368 }
1369
1370 let mut merged_block_data = Vec::new();
1372 merged_block_data.extend_from_slice(&raw1);
1373
1374 const FIRST_DOC_ID_OFFSET: usize = 8;
1375 let mut buf2 = raw2.to_vec();
1376 for entry in &skip2 {
1377 let off = entry.offset as usize + FIRST_DOC_ID_OFFSET;
1378 if off + 4 <= buf2.len() {
1379 let old = u32::from_le_bytes(buf2[off..off + 4].try_into().unwrap());
1380 let patched = (old + doc_offset).to_le_bytes();
1381 buf2[off..off + 4].copy_from_slice(&patched);
1382 }
1383 }
1384 merged_block_data.extend_from_slice(&buf2);
1385
1386 let loaded =
1388 BlockSparsePostingList::from_parts(total_docs, &merged_block_data, &merged_skip)
1389 .unwrap();
1390 assert_eq!(loaded.doc_count(), 350);
1391
1392 let mut iter = loaded.iterator();
1393
1394 assert_eq!(iter.doc(), 0);
1396 let doc = iter.seek(100);
1397 assert_eq!(doc, 100);
1398 let doc = iter.seek(398);
1399 assert_eq!(doc, 398);
1400
1401 let doc = iter.seek(1000);
1403 assert_eq!(doc, 1000, "First doc of segment 2 should be 1000");
1404 iter.advance();
1405 assert_eq!(iter.doc(), 1003, "Second doc of segment 2 should be 1003");
1406 let doc = iter.seek(1447);
1407 assert_eq!(doc, 1447, "Last doc of segment 2 should be 1447");
1408
1409 iter.advance();
1411 assert_eq!(iter.doc(), super::TERMINATED);
1412
1413 let reference =
1415 BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, doc_offset)]);
1416 let mut ref_iter = reference.iterator();
1417 let mut zc_iter = loaded.iterator();
1418 while ref_iter.doc() != super::TERMINATED {
1419 assert_eq!(
1420 ref_iter.doc(),
1421 zc_iter.doc(),
1422 "Zero-copy and reference merge should produce identical doc_ids"
1423 );
1424 assert!(
1425 (ref_iter.weight() - zc_iter.weight()).abs() < 0.01,
1426 "Weights should match: {} vs {}",
1427 ref_iter.weight(),
1428 zc_iter.weight()
1429 );
1430 ref_iter.advance();
1431 zc_iter.advance();
1432 }
1433 assert_eq!(zc_iter.doc(), super::TERMINATED);
1434 }
1435
1436 #[test]
1437 fn test_doc_count_single_value() {
1438 let postings: Vec<(DocId, u16, f32)> =
1440 vec![(0, 0, 1.0), (5, 0, 2.0), (10, 0, 3.0), (15, 0, 4.0)];
1441 let list =
1442 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1443
1444 assert_eq!(list.doc_count(), 4);
1446 }
1447
1448 #[test]
1449 fn test_doc_count_multi_value_serialization_roundtrip() {
1450 let postings: Vec<(DocId, u16, f32)> =
1452 vec![(0, 0, 1.0), (0, 1, 1.5), (5, 0, 2.0), (5, 1, 2.5)];
1453 let list =
1454 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1455 assert_eq!(list.doc_count(), 2);
1456
1457 let (block_data, skip_entries) = list.serialize().unwrap();
1458 let loaded =
1459 BlockSparsePostingList::from_parts(list.doc_count(), &block_data, &skip_entries)
1460 .unwrap();
1461 assert_eq!(loaded.doc_count(), 2);
1462 }
1463
1464 #[test]
1465 fn test_merge_preserves_weights_and_ordinals() {
1466 let postings1: Vec<(DocId, u16, f32)> = vec![(0, 0, 1.5), (5, 1, 2.5), (10, 2, 3.5)];
1468 let list1 =
1469 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1470
1471 let postings2: Vec<(DocId, u16, f32)> = vec![(0, 0, 4.5), (3, 1, 5.5), (7, 3, 6.5)];
1472 let list2 =
1473 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1474
1475 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
1477
1478 let (block_data, skip_entries) = merged.serialize().unwrap();
1480 let loaded =
1481 BlockSparsePostingList::from_parts(merged.doc_count(), &block_data, &skip_entries)
1482 .unwrap();
1483
1484 let mut iter = loaded.iterator();
1486
1487 assert_eq!(iter.doc(), 0);
1489 assert!(
1490 (iter.weight() - 1.5).abs() < 0.01,
1491 "Weight should be 1.5, got {}",
1492 iter.weight()
1493 );
1494 assert_eq!(iter.ordinal(), 0);
1495
1496 iter.advance();
1497 assert_eq!(iter.doc(), 5);
1498 assert!(
1499 (iter.weight() - 2.5).abs() < 0.01,
1500 "Weight should be 2.5, got {}",
1501 iter.weight()
1502 );
1503 assert_eq!(iter.ordinal(), 1);
1504
1505 iter.advance();
1506 assert_eq!(iter.doc(), 10);
1507 assert!(
1508 (iter.weight() - 3.5).abs() < 0.01,
1509 "Weight should be 3.5, got {}",
1510 iter.weight()
1511 );
1512 assert_eq!(iter.ordinal(), 2);
1513
1514 iter.advance();
1516 assert_eq!(iter.doc(), 100);
1517 assert!(
1518 (iter.weight() - 4.5).abs() < 0.01,
1519 "Weight should be 4.5, got {}",
1520 iter.weight()
1521 );
1522 assert_eq!(iter.ordinal(), 0);
1523
1524 iter.advance();
1525 assert_eq!(iter.doc(), 103);
1526 assert!(
1527 (iter.weight() - 5.5).abs() < 0.01,
1528 "Weight should be 5.5, got {}",
1529 iter.weight()
1530 );
1531 assert_eq!(iter.ordinal(), 1);
1532
1533 iter.advance();
1534 assert_eq!(iter.doc(), 107);
1535 assert!(
1536 (iter.weight() - 6.5).abs() < 0.01,
1537 "Weight should be 6.5, got {}",
1538 iter.weight()
1539 );
1540 assert_eq!(iter.ordinal(), 3);
1541
1542 iter.advance();
1544 assert_eq!(iter.doc(), super::TERMINATED);
1545 }
1546
1547 #[test]
1548 fn test_merge_global_max_weight() {
1549 let postings1: Vec<(DocId, u16, f32)> = vec![
1551 (0, 0, 3.0),
1552 (1, 0, 7.0), (2, 0, 2.0),
1554 ];
1555 let list1 =
1556 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1557
1558 let postings2: Vec<(DocId, u16, f32)> = vec![
1559 (0, 0, 5.0),
1560 (1, 0, 4.0),
1561 (2, 0, 6.0), ];
1563 let list2 =
1564 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1565
1566 assert!((list1.global_max_weight() - 7.0).abs() < 0.01);
1568 assert!((list2.global_max_weight() - 6.0).abs() < 0.01);
1569
1570 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
1572
1573 assert!(
1575 (merged.global_max_weight() - 7.0).abs() < 0.01,
1576 "Global max should be 7.0, got {}",
1577 merged.global_max_weight()
1578 );
1579
1580 let (block_data, skip_entries) = merged.serialize().unwrap();
1582 let loaded =
1583 BlockSparsePostingList::from_parts(merged.doc_count(), &block_data, &skip_entries)
1584 .unwrap();
1585
1586 assert!(
1587 (loaded.global_max_weight() - 7.0).abs() < 0.01,
1588 "After roundtrip, global max should still be 7.0, got {}",
1589 loaded.global_max_weight()
1590 );
1591 }
1592
1593 #[test]
1594 fn test_scoring_simulation_after_merge() {
1595 let postings1: Vec<(DocId, u16, f32)> = vec![
1597 (0, 0, 0.5), (5, 0, 0.8), ];
1600 let list1 =
1601 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1602
1603 let postings2: Vec<(DocId, u16, f32)> = vec![
1604 (0, 0, 0.6), (3, 0, 0.9), ];
1607 let list2 =
1608 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1609
1610 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
1612
1613 let (block_data, skip_entries) = merged.serialize().unwrap();
1615 let loaded =
1616 BlockSparsePostingList::from_parts(merged.doc_count(), &block_data, &skip_entries)
1617 .unwrap();
1618
1619 let query_weight = 2.0f32;
1621 let mut iter = loaded.iterator();
1622
1623 assert_eq!(iter.doc(), 0);
1626 let score = query_weight * iter.weight();
1627 assert!(
1628 (score - 1.0).abs() < 0.01,
1629 "Doc 0 score should be 1.0, got {}",
1630 score
1631 );
1632
1633 iter.advance();
1634 assert_eq!(iter.doc(), 5);
1636 let score = query_weight * iter.weight();
1637 assert!(
1638 (score - 1.6).abs() < 0.01,
1639 "Doc 5 score should be 1.6, got {}",
1640 score
1641 );
1642
1643 iter.advance();
1644 assert_eq!(iter.doc(), 100);
1646 let score = query_weight * iter.weight();
1647 assert!(
1648 (score - 1.2).abs() < 0.01,
1649 "Doc 100 score should be 1.2, got {}",
1650 score
1651 );
1652
1653 iter.advance();
1654 assert_eq!(iter.doc(), 103);
1656 let score = query_weight * iter.weight();
1657 assert!(
1658 (score - 1.8).abs() < 0.01,
1659 "Doc 103 score should be 1.8, got {}",
1660 score
1661 );
1662 }
1663}