1use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
9use std::io::{self, Cursor, Read, Write};
10
11use super::config::WeightQuantization;
12use crate::DocId;
13use crate::directories::OwnedBytes;
14use crate::structures::postings::TERMINATED;
15use crate::structures::simd;
16
17pub const BLOCK_SIZE: usize = 128;
18
19#[derive(Debug, Clone, Copy)]
20pub struct BlockHeader {
21 pub count: u16,
22 pub doc_id_bits: u8,
23 pub ordinal_bits: u8,
24 pub weight_quant: WeightQuantization,
25 pub first_doc_id: DocId,
26 pub max_weight: f32,
27}
28
29impl BlockHeader {
30 pub const SIZE: usize = 16;
31
32 pub fn write<W: Write>(&self, w: &mut W) -> io::Result<()> {
33 w.write_u16::<LittleEndian>(self.count)?;
34 w.write_u8(self.doc_id_bits)?;
35 w.write_u8(self.ordinal_bits)?;
36 w.write_u8(self.weight_quant as u8)?;
37 w.write_u8(0)?;
38 w.write_u16::<LittleEndian>(0)?;
39 w.write_u32::<LittleEndian>(self.first_doc_id)?;
40 w.write_f32::<LittleEndian>(self.max_weight)?;
41 Ok(())
42 }
43
44 pub fn read<R: Read>(r: &mut R) -> io::Result<Self> {
45 let count = r.read_u16::<LittleEndian>()?;
46 let doc_id_bits = r.read_u8()?;
47 let ordinal_bits = r.read_u8()?;
48 let weight_quant_byte = r.read_u8()?;
49 let _ = r.read_u8()?;
50 let _ = r.read_u16::<LittleEndian>()?;
51 let first_doc_id = r.read_u32::<LittleEndian>()?;
52 let max_weight = r.read_f32::<LittleEndian>()?;
53
54 let weight_quant = WeightQuantization::from_u8(weight_quant_byte)
55 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Invalid weight quant"))?;
56
57 Ok(Self {
58 count,
59 doc_id_bits,
60 ordinal_bits,
61 weight_quant,
62 first_doc_id,
63 max_weight,
64 })
65 }
66}
67
68#[derive(Debug, Clone)]
69pub struct SparseBlock {
70 pub header: BlockHeader,
71 pub doc_ids_data: OwnedBytes,
73 pub ordinals_data: OwnedBytes,
75 pub weights_data: OwnedBytes,
77}
78
79impl SparseBlock {
80 pub fn from_postings(
81 postings: &[(DocId, u16, f32)],
82 weight_quant: WeightQuantization,
83 ) -> io::Result<Self> {
84 assert!(!postings.is_empty() && postings.len() <= BLOCK_SIZE);
85
86 let count = postings.len();
87 let first_doc_id = postings[0].0;
88
89 let mut deltas = Vec::with_capacity(count);
91 let mut prev = first_doc_id;
92 for &(doc_id, _, _) in postings {
93 deltas.push(doc_id.saturating_sub(prev));
94 prev = doc_id;
95 }
96 deltas[0] = 0;
97
98 let doc_id_bits = simd::round_bit_width(find_optimal_bit_width(&deltas[1..]));
99 let ordinals: Vec<u16> = postings.iter().map(|(_, o, _)| *o).collect();
100 let max_ordinal = ordinals.iter().copied().max().unwrap_or(0);
101 let ordinal_bits = if max_ordinal == 0 {
102 0
103 } else {
104 simd::round_bit_width(bits_needed_u16(max_ordinal))
105 };
106
107 let weights: Vec<f32> = postings.iter().map(|(_, _, w)| *w).collect();
108 let max_weight = weights.iter().copied().fold(0.0f32, f32::max);
109
110 let doc_ids_data = OwnedBytes::new({
111 let rounded = simd::RoundedBitWidth::from_u8(doc_id_bits);
112 let num_deltas = count - 1;
113 let byte_count = num_deltas * rounded.bytes_per_value();
114 let mut data = vec![0u8; byte_count];
115 simd::pack_rounded(&deltas[1..], rounded, &mut data);
116 data
117 });
118 let ordinals_data = OwnedBytes::new(if ordinal_bits > 0 {
119 let rounded = simd::RoundedBitWidth::from_u8(ordinal_bits);
120 let byte_count = count * rounded.bytes_per_value();
121 let mut data = vec![0u8; byte_count];
122 let ord_u32: Vec<u32> = ordinals.iter().map(|&o| o as u32).collect();
123 simd::pack_rounded(&ord_u32, rounded, &mut data);
124 data
125 } else {
126 Vec::new()
127 });
128 let weights_data = OwnedBytes::new(encode_weights(&weights, weight_quant)?);
129
130 Ok(Self {
131 header: BlockHeader {
132 count: count as u16,
133 doc_id_bits,
134 ordinal_bits,
135 weight_quant,
136 first_doc_id,
137 max_weight,
138 },
139 doc_ids_data,
140 ordinals_data,
141 weights_data,
142 })
143 }
144
145 pub fn decode_doc_ids(&self) -> Vec<DocId> {
146 let mut out = Vec::with_capacity(self.header.count as usize);
147 self.decode_doc_ids_into(&mut out);
148 out
149 }
150
151 pub fn decode_doc_ids_into(&self, out: &mut Vec<DocId>) {
155 let count = self.header.count as usize;
156 out.clear();
157 out.resize(count, 0);
158 out[0] = self.header.first_doc_id;
159
160 if count > 1 {
161 let bits = self.header.doc_id_bits;
162 if bits == 0 {
163 out[1..].fill(self.header.first_doc_id);
165 } else {
166 simd::unpack_rounded(
168 &self.doc_ids_data,
169 simd::RoundedBitWidth::from_u8(bits),
170 &mut out[1..],
171 count - 1,
172 );
173 for i in 1..count {
175 out[i] += out[i - 1];
176 }
177 }
178 }
179 }
180
181 pub fn decode_ordinals(&self) -> Vec<u16> {
182 let mut out = Vec::with_capacity(self.header.count as usize);
183 self.decode_ordinals_into(&mut out);
184 out
185 }
186
187 pub fn decode_ordinals_into(&self, out: &mut Vec<u16>) {
191 let count = self.header.count as usize;
192 out.clear();
193 if self.header.ordinal_bits == 0 {
194 out.resize(count, 0u16);
195 } else {
196 let mut temp = [0u32; BLOCK_SIZE];
198 simd::unpack_rounded(
199 &self.ordinals_data,
200 simd::RoundedBitWidth::from_u8(self.header.ordinal_bits),
201 &mut temp[..count],
202 count,
203 );
204 out.reserve(count);
205 for &v in &temp[..count] {
206 out.push(v as u16);
207 }
208 }
209 }
210
211 pub fn decode_weights(&self) -> Vec<f32> {
212 let mut out = Vec::with_capacity(self.header.count as usize);
213 self.decode_weights_into(&mut out);
214 out
215 }
216
217 pub fn decode_weights_into(&self, out: &mut Vec<f32>) {
219 out.clear();
220 decode_weights_into(
221 &self.weights_data,
222 self.header.weight_quant,
223 self.header.count as usize,
224 out,
225 );
226 }
227
228 pub fn decode_scored_weights_into(&self, query_weight: f32, out: &mut Vec<f32>) {
236 out.clear();
237 let count = self.header.count as usize;
238 match self.header.weight_quant {
239 WeightQuantization::UInt8 if self.weights_data.len() >= 8 => {
240 let scale = f32::from_le_bytes([
242 self.weights_data[0],
243 self.weights_data[1],
244 self.weights_data[2],
245 self.weights_data[3],
246 ]);
247 let min_val = f32::from_le_bytes([
248 self.weights_data[4],
249 self.weights_data[5],
250 self.weights_data[6],
251 self.weights_data[7],
252 ]);
253 let eff_scale = query_weight * scale;
255 let eff_bias = query_weight * min_val;
256 out.resize(count, 0.0);
257 simd::dequantize_uint8(&self.weights_data[8..], out, eff_scale, eff_bias, count);
258 }
259 _ => {
260 decode_weights_into(&self.weights_data, self.header.weight_quant, count, out);
262 for w in out.iter_mut() {
263 *w *= query_weight;
264 }
265 }
266 }
267 }
268
269 pub fn write<W: Write>(&self, w: &mut W) -> io::Result<()> {
270 self.header.write(w)?;
271 w.write_u16::<LittleEndian>(self.doc_ids_data.len() as u16)?;
272 w.write_u16::<LittleEndian>(self.ordinals_data.len() as u16)?;
273 w.write_u16::<LittleEndian>(self.weights_data.len() as u16)?;
274 w.write_u16::<LittleEndian>(0)?;
275 w.write_all(&self.doc_ids_data)?;
276 w.write_all(&self.ordinals_data)?;
277 w.write_all(&self.weights_data)?;
278 Ok(())
279 }
280
281 pub fn read<R: Read>(r: &mut R) -> io::Result<Self> {
282 let header = BlockHeader::read(r)?;
283 let doc_ids_len = r.read_u16::<LittleEndian>()? as usize;
284 let ordinals_len = r.read_u16::<LittleEndian>()? as usize;
285 let weights_len = r.read_u16::<LittleEndian>()? as usize;
286 let _ = r.read_u16::<LittleEndian>()?;
287
288 let mut doc_ids_vec = vec![0u8; doc_ids_len];
289 r.read_exact(&mut doc_ids_vec)?;
290 let mut ordinals_vec = vec![0u8; ordinals_len];
291 r.read_exact(&mut ordinals_vec)?;
292 let mut weights_vec = vec![0u8; weights_len];
293 r.read_exact(&mut weights_vec)?;
294
295 Ok(Self {
296 header,
297 doc_ids_data: OwnedBytes::new(doc_ids_vec),
298 ordinals_data: OwnedBytes::new(ordinals_vec),
299 weights_data: OwnedBytes::new(weights_vec),
300 })
301 }
302
303 pub fn from_owned_bytes(data: crate::directories::OwnedBytes) -> crate::Result<Self> {
309 let b = data.as_slice();
310 if b.len() < BlockHeader::SIZE + 8 {
311 return Err(crate::Error::Corruption(
312 "sparse block too small".to_string(),
313 ));
314 }
315 let mut cursor = Cursor::new(&b[..BlockHeader::SIZE]);
316 let header =
317 BlockHeader::read(&mut cursor).map_err(|e| crate::Error::Corruption(e.to_string()))?;
318
319 if header.count == 0 {
320 let hex: String = b
321 .iter()
322 .take(32)
323 .map(|x| format!("{x:02x}"))
324 .collect::<Vec<_>>()
325 .join(" ");
326 return Err(crate::Error::Corruption(format!(
327 "sparse block has count=0 (data_len={}, first_32_bytes=[{}])",
328 b.len(),
329 hex
330 )));
331 }
332
333 let p = BlockHeader::SIZE;
334 let doc_ids_len = u16::from_le_bytes([b[p], b[p + 1]]) as usize;
335 let ordinals_len = u16::from_le_bytes([b[p + 2], b[p + 3]]) as usize;
336 let weights_len = u16::from_le_bytes([b[p + 4], b[p + 5]]) as usize;
337 let data_start = p + 8;
340 let ord_start = data_start + doc_ids_len;
341 let wt_start = ord_start + ordinals_len;
342 let expected_end = wt_start + weights_len;
343
344 if expected_end > b.len() {
345 let hex: String = b
346 .iter()
347 .take(32)
348 .map(|x| format!("{x:02x}"))
349 .collect::<Vec<_>>()
350 .join(" ");
351 return Err(crate::Error::Corruption(format!(
352 "sparse block sub-block overflow: count={} doc_ids={}B ords={}B wts={}B need={}B have={}B (first_32=[{}])",
353 header.count,
354 doc_ids_len,
355 ordinals_len,
356 weights_len,
357 expected_end,
358 b.len(),
359 hex
360 )));
361 }
362
363 Ok(Self {
364 header,
365 doc_ids_data: data.slice(data_start..ord_start),
366 ordinals_data: data.slice(ord_start..wt_start),
367 weights_data: data.slice(wt_start..wt_start + weights_len),
368 })
369 }
370
371 pub fn with_doc_offset(&self, doc_offset: u32) -> Self {
377 Self {
378 header: BlockHeader {
379 first_doc_id: self.header.first_doc_id + doc_offset,
380 ..self.header
381 },
382 doc_ids_data: self.doc_ids_data.clone(),
383 ordinals_data: self.ordinals_data.clone(),
384 weights_data: self.weights_data.clone(),
385 }
386 }
387}
388
389#[derive(Debug, Clone)]
394pub struct BlockSparsePostingList {
395 pub doc_count: u32,
396 pub blocks: Vec<SparseBlock>,
397}
398
399impl BlockSparsePostingList {
400 pub fn from_postings_with_block_size(
402 postings: &[(DocId, u16, f32)],
403 weight_quant: WeightQuantization,
404 block_size: usize,
405 ) -> io::Result<Self> {
406 if postings.is_empty() {
407 return Ok(Self {
408 doc_count: 0,
409 blocks: Vec::new(),
410 });
411 }
412
413 let block_size = block_size.max(16); let mut blocks = Vec::new();
415 for chunk in postings.chunks(block_size) {
416 blocks.push(SparseBlock::from_postings(chunk, weight_quant)?);
417 }
418
419 let mut unique_docs = 1u32;
424 for i in 1..postings.len() {
425 if postings[i].0 != postings[i - 1].0 {
426 unique_docs += 1;
427 }
428 }
429
430 Ok(Self {
431 doc_count: unique_docs,
432 blocks,
433 })
434 }
435
436 pub fn from_postings(
438 postings: &[(DocId, u16, f32)],
439 weight_quant: WeightQuantization,
440 ) -> io::Result<Self> {
441 Self::from_postings_with_block_size(postings, weight_quant, BLOCK_SIZE)
442 }
443
444 pub fn doc_count(&self) -> u32 {
445 self.doc_count
446 }
447
448 pub fn num_blocks(&self) -> usize {
449 self.blocks.len()
450 }
451
452 pub fn global_max_weight(&self) -> f32 {
453 self.blocks
454 .iter()
455 .map(|b| b.header.max_weight)
456 .fold(0.0f32, f32::max)
457 }
458
459 pub fn block_max_weight(&self, block_idx: usize) -> Option<f32> {
460 self.blocks.get(block_idx).map(|b| b.header.max_weight)
461 }
462
463 pub fn size_bytes(&self) -> usize {
465 use std::mem::size_of;
466
467 let header_size = size_of::<u32>() * 2; let blocks_size: usize = self
469 .blocks
470 .iter()
471 .map(|b| {
472 size_of::<BlockHeader>()
473 + b.doc_ids_data.len()
474 + b.ordinals_data.len()
475 + b.weights_data.len()
476 })
477 .sum();
478 header_size + blocks_size
479 }
480
481 pub fn iterator(&self) -> BlockSparsePostingIterator<'_> {
482 BlockSparsePostingIterator::new(self)
483 }
484
485 pub fn serialize<W: Write>(&self, w: &mut W) -> io::Result<()> {
494 use super::SparseSkipEntry;
495
496 w.write_u32::<LittleEndian>(self.doc_count)?;
497 w.write_f32::<LittleEndian>(self.global_max_weight())?;
498 w.write_u32::<LittleEndian>(self.blocks.len() as u32)?;
499
500 let mut block_bytes: Vec<Vec<u8>> = Vec::with_capacity(self.blocks.len());
502 for block in &self.blocks {
503 let mut buf = Vec::new();
504 block.write(&mut buf)?;
505 block_bytes.push(buf);
506 }
507
508 let mut offset = 0u32;
510 for (block, bytes) in self.blocks.iter().zip(block_bytes.iter()) {
511 let first_doc = block.header.first_doc_id;
512 let doc_ids = block.decode_doc_ids();
513 let last_doc = doc_ids.last().copied().unwrap_or(first_doc);
514 let length = bytes.len() as u32;
515
516 let entry =
517 SparseSkipEntry::new(first_doc, last_doc, offset, length, block.header.max_weight);
518 entry.write(w)?;
519 offset += length;
520 }
521
522 for bytes in block_bytes {
524 w.write_all(&bytes)?;
525 }
526
527 Ok(())
528 }
529
530 pub fn serialize_v3(&self) -> io::Result<(Vec<u8>, Vec<super::SparseSkipEntry>)> {
536 let mut block_data = Vec::new();
538 let mut skip_entries = Vec::with_capacity(self.blocks.len());
539 let mut offset = 0u32;
540
541 for block in &self.blocks {
542 let mut buf = Vec::new();
543 block.write(&mut buf)?;
544 let length = buf.len() as u32;
545
546 let first_doc = block.header.first_doc_id;
547 let doc_ids = block.decode_doc_ids();
548 let last_doc = doc_ids.last().copied().unwrap_or(first_doc);
549
550 skip_entries.push(super::SparseSkipEntry::new(
551 first_doc,
552 last_doc,
553 offset,
554 length,
555 block.header.max_weight,
556 ));
557
558 block_data.extend_from_slice(&buf);
559 offset += length;
560 }
561
562 Ok((block_data, skip_entries))
563 }
564
565 pub fn deserialize<R: Read>(r: &mut R) -> io::Result<Self> {
568 use super::SparseSkipEntry;
569
570 let doc_count = r.read_u32::<LittleEndian>()?;
571 let _global_max_weight = r.read_f32::<LittleEndian>()?;
572 let num_blocks = r.read_u32::<LittleEndian>()? as usize;
573
574 for _ in 0..num_blocks {
576 let _ = SparseSkipEntry::read(r)?;
577 }
578
579 let mut blocks = Vec::with_capacity(num_blocks);
581 for _ in 0..num_blocks {
582 blocks.push(SparseBlock::read(r)?);
583 }
584 Ok(Self { doc_count, blocks })
585 }
586
587 pub fn deserialize_header<R: Read>(
590 r: &mut R,
591 ) -> io::Result<(u32, f32, Vec<super::SparseSkipEntry>, usize)> {
592 use super::SparseSkipEntry;
593
594 let doc_count = r.read_u32::<LittleEndian>()?;
595 let global_max_weight = r.read_f32::<LittleEndian>()?;
596 let num_blocks = r.read_u32::<LittleEndian>()? as usize;
597
598 let mut entries = Vec::with_capacity(num_blocks);
599 for _ in 0..num_blocks {
600 entries.push(SparseSkipEntry::read(r)?);
601 }
602
603 let header_size = 4 + 4 + 4 + num_blocks * SparseSkipEntry::SIZE;
605
606 Ok((doc_count, global_max_weight, entries, header_size))
607 }
608
609 pub fn decode_all(&self) -> Vec<(DocId, u16, f32)> {
610 let total_postings: usize = self.blocks.iter().map(|b| b.header.count as usize).sum();
611 let mut result = Vec::with_capacity(total_postings);
612 for block in &self.blocks {
613 let doc_ids = block.decode_doc_ids();
614 let ordinals = block.decode_ordinals();
615 let weights = block.decode_weights();
616 for i in 0..block.header.count as usize {
617 result.push((doc_ids[i], ordinals[i], weights[i]));
618 }
619 }
620 result
621 }
622
623 pub fn merge_with_offsets(lists: &[(&BlockSparsePostingList, u32)]) -> Self {
634 if lists.is_empty() {
635 return Self {
636 doc_count: 0,
637 blocks: Vec::new(),
638 };
639 }
640
641 let total_blocks: usize = lists.iter().map(|(pl, _)| pl.blocks.len()).sum();
643 let total_docs: u32 = lists.iter().map(|(pl, _)| pl.doc_count).sum();
644
645 let mut merged_blocks = Vec::with_capacity(total_blocks);
646
647 for (posting_list, doc_offset) in lists {
649 for block in &posting_list.blocks {
650 merged_blocks.push(block.with_doc_offset(*doc_offset));
651 }
652 }
653
654 Self {
655 doc_count: total_docs,
656 blocks: merged_blocks,
657 }
658 }
659
660 fn find_block(&self, target: DocId) -> Option<usize> {
661 if self.blocks.is_empty() {
662 return None;
663 }
664 let idx = self
667 .blocks
668 .partition_point(|b| b.header.first_doc_id <= target);
669 if idx == 0 {
670 Some(0)
672 } else {
673 Some(idx - 1)
674 }
675 }
676}
677
678pub struct BlockSparsePostingIterator<'a> {
683 posting_list: &'a BlockSparsePostingList,
684 block_idx: usize,
685 in_block_idx: usize,
686 current_doc_ids: Vec<DocId>,
687 current_ordinals: Vec<u16>,
688 current_weights: Vec<f32>,
689 ordinals_decoded: bool,
691 exhausted: bool,
692}
693
694impl<'a> BlockSparsePostingIterator<'a> {
695 fn new(posting_list: &'a BlockSparsePostingList) -> Self {
696 let mut iter = Self {
697 posting_list,
698 block_idx: 0,
699 in_block_idx: 0,
700 current_doc_ids: Vec::with_capacity(128),
701 current_ordinals: Vec::with_capacity(128),
702 current_weights: Vec::with_capacity(128),
703 ordinals_decoded: false,
704 exhausted: posting_list.blocks.is_empty(),
705 };
706 if !iter.exhausted {
707 iter.load_block(0);
708 }
709 iter
710 }
711
712 fn load_block(&mut self, block_idx: usize) {
713 if let Some(block) = self.posting_list.blocks.get(block_idx) {
714 block.decode_doc_ids_into(&mut self.current_doc_ids);
715 block.decode_weights_into(&mut self.current_weights);
716 self.ordinals_decoded = false;
718 self.block_idx = block_idx;
719 self.in_block_idx = 0;
720 }
721 }
722
723 #[inline]
725 fn ensure_ordinals_decoded(&mut self) {
726 if !self.ordinals_decoded {
727 if let Some(block) = self.posting_list.blocks.get(self.block_idx) {
728 block.decode_ordinals_into(&mut self.current_ordinals);
729 }
730 self.ordinals_decoded = true;
731 }
732 }
733
734 #[inline]
735 pub fn doc(&self) -> DocId {
736 if self.exhausted {
737 TERMINATED
738 } else {
739 self.current_doc_ids[self.in_block_idx]
741 }
742 }
743
744 #[inline]
745 pub fn weight(&self) -> f32 {
746 if self.exhausted {
747 return 0.0;
748 }
749 self.current_weights[self.in_block_idx]
751 }
752
753 pub fn ordinal(&mut self) -> u16 {
754 self.ensure_ordinals_decoded();
755 self.current_ordinals
756 .get(self.in_block_idx)
757 .copied()
758 .unwrap_or(0)
759 }
760
761 pub fn advance(&mut self) -> DocId {
762 if self.exhausted {
763 return TERMINATED;
764 }
765 self.in_block_idx += 1;
766 if self.in_block_idx >= self.current_doc_ids.len() {
767 self.block_idx += 1;
768 if self.block_idx >= self.posting_list.blocks.len() {
769 self.exhausted = true;
770 } else {
771 self.load_block(self.block_idx);
772 }
773 }
774 self.doc()
775 }
776
777 pub fn seek(&mut self, target: DocId) -> DocId {
778 if self.exhausted {
779 return TERMINATED;
780 }
781 if self.doc() >= target {
782 return self.doc();
783 }
784
785 if let Some(&last_doc) = self.current_doc_ids.last()
787 && last_doc >= target
788 {
789 let remaining = &self.current_doc_ids[self.in_block_idx..];
790 let pos = crate::structures::simd::find_first_ge_u32(remaining, target);
791 self.in_block_idx += pos;
792 if self.in_block_idx >= self.current_doc_ids.len() {
793 self.block_idx += 1;
794 if self.block_idx >= self.posting_list.blocks.len() {
795 self.exhausted = true;
796 } else {
797 self.load_block(self.block_idx);
798 }
799 }
800 return self.doc();
801 }
802
803 if let Some(block_idx) = self.posting_list.find_block(target) {
805 self.load_block(block_idx);
806 let pos = crate::structures::simd::find_first_ge_u32(&self.current_doc_ids, target);
807 self.in_block_idx = pos;
808 if self.in_block_idx >= self.current_doc_ids.len() {
809 self.block_idx += 1;
810 if self.block_idx >= self.posting_list.blocks.len() {
811 self.exhausted = true;
812 } else {
813 self.load_block(self.block_idx);
814 }
815 }
816 } else {
817 self.exhausted = true;
818 }
819 self.doc()
820 }
821
822 pub fn skip_to_next_block(&mut self) -> DocId {
825 if self.exhausted {
826 return TERMINATED;
827 }
828 let next = self.block_idx + 1;
829 if next >= self.posting_list.blocks.len() {
830 self.exhausted = true;
831 return TERMINATED;
832 }
833 self.load_block(next);
834 self.doc()
835 }
836
837 pub fn is_exhausted(&self) -> bool {
838 self.exhausted
839 }
840
841 pub fn current_block_max_weight(&self) -> f32 {
842 self.posting_list
843 .blocks
844 .get(self.block_idx)
845 .map(|b| b.header.max_weight)
846 .unwrap_or(0.0)
847 }
848
849 pub fn current_block_max_contribution(&self, query_weight: f32) -> f32 {
850 query_weight * self.current_block_max_weight()
851 }
852}
853
854fn find_optimal_bit_width(values: &[u32]) -> u8 {
859 if values.is_empty() {
860 return 0;
861 }
862 let max_val = values.iter().copied().max().unwrap_or(0);
863 simd::bits_needed(max_val)
864}
865
866fn bits_needed_u16(val: u16) -> u8 {
867 if val == 0 {
868 0
869 } else {
870 16 - val.leading_zeros() as u8
871 }
872}
873
874fn encode_weights(weights: &[f32], quant: WeightQuantization) -> io::Result<Vec<u8>> {
879 let mut data = Vec::new();
880 match quant {
881 WeightQuantization::Float32 => {
882 for &w in weights {
883 data.write_f32::<LittleEndian>(w)?;
884 }
885 }
886 WeightQuantization::Float16 => {
887 use half::f16;
888 for &w in weights {
889 data.write_u16::<LittleEndian>(f16::from_f32(w).to_bits())?;
890 }
891 }
892 WeightQuantization::UInt8 => {
893 let min = weights.iter().copied().fold(f32::INFINITY, f32::min);
894 let max = weights.iter().copied().fold(f32::NEG_INFINITY, f32::max);
895 let range = max - min;
896 let scale = if range < f32::EPSILON {
897 1.0
898 } else {
899 range / 255.0
900 };
901 data.write_f32::<LittleEndian>(scale)?;
902 data.write_f32::<LittleEndian>(min)?;
903 for &w in weights {
904 data.write_u8(((w - min) / scale).round() as u8)?;
905 }
906 }
907 WeightQuantization::UInt4 => {
908 let min = weights.iter().copied().fold(f32::INFINITY, f32::min);
909 let max = weights.iter().copied().fold(f32::NEG_INFINITY, f32::max);
910 let range = max - min;
911 let scale = if range < f32::EPSILON {
912 1.0
913 } else {
914 range / 15.0
915 };
916 data.write_f32::<LittleEndian>(scale)?;
917 data.write_f32::<LittleEndian>(min)?;
918 let mut i = 0;
919 while i < weights.len() {
920 let q1 = ((weights[i] - min) / scale).round() as u8 & 0x0F;
921 let q2 = if i + 1 < weights.len() {
922 ((weights[i + 1] - min) / scale).round() as u8 & 0x0F
923 } else {
924 0
925 };
926 data.write_u8((q2 << 4) | q1)?;
927 i += 2;
928 }
929 }
930 }
931 Ok(data)
932}
933
934fn decode_weights_into(data: &[u8], quant: WeightQuantization, count: usize, out: &mut Vec<f32>) {
935 let mut cursor = Cursor::new(data);
936 match quant {
937 WeightQuantization::Float32 => {
938 for _ in 0..count {
939 out.push(cursor.read_f32::<LittleEndian>().unwrap_or(0.0));
940 }
941 }
942 WeightQuantization::Float16 => {
943 use half::f16;
944 for _ in 0..count {
945 let bits = cursor.read_u16::<LittleEndian>().unwrap_or(0);
946 out.push(f16::from_bits(bits).to_f32());
947 }
948 }
949 WeightQuantization::UInt8 => {
950 let scale = cursor.read_f32::<LittleEndian>().unwrap_or(1.0);
951 let min_val = cursor.read_f32::<LittleEndian>().unwrap_or(0.0);
952 let offset = cursor.position() as usize;
953 out.resize(count, 0.0);
954 simd::dequantize_uint8(&data[offset..], out, scale, min_val, count);
955 }
956 WeightQuantization::UInt4 => {
957 let scale = cursor.read_f32::<LittleEndian>().unwrap_or(1.0);
958 let min = cursor.read_f32::<LittleEndian>().unwrap_or(0.0);
959 let mut i = 0;
960 while i < count {
961 let byte = cursor.read_u8().unwrap_or(0);
962 out.push((byte & 0x0F) as f32 * scale + min);
963 i += 1;
964 if i < count {
965 out.push((byte >> 4) as f32 * scale + min);
966 i += 1;
967 }
968 }
969 }
970 }
971}
972
973#[cfg(test)]
974mod tests {
975 use super::*;
976
977 #[test]
978 fn test_block_roundtrip() {
979 let postings = vec![
980 (10u32, 0u16, 1.5f32),
981 (15, 0, 2.0),
982 (20, 1, 0.5),
983 (100, 0, 3.0),
984 ];
985 let block = SparseBlock::from_postings(&postings, WeightQuantization::Float32).unwrap();
986
987 assert_eq!(block.decode_doc_ids(), vec![10, 15, 20, 100]);
988 assert_eq!(block.decode_ordinals(), vec![0, 0, 1, 0]);
989 let weights = block.decode_weights();
990 assert!((weights[0] - 1.5).abs() < 0.01);
991 }
992
993 #[test]
994 fn test_posting_list() {
995 let postings: Vec<(DocId, u16, f32)> =
996 (0..300).map(|i| (i * 2, 0, i as f32 * 0.1)).collect();
997 let list =
998 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
999
1000 assert_eq!(list.doc_count(), 300);
1001 assert_eq!(list.num_blocks(), 3);
1002
1003 let mut iter = list.iterator();
1004 assert_eq!(iter.doc(), 0);
1005 iter.advance();
1006 assert_eq!(iter.doc(), 2);
1007 }
1008
1009 #[test]
1010 fn test_serialization() {
1011 let postings = vec![(1u32, 0u16, 0.5f32), (10, 1, 1.5), (100, 0, 2.5)];
1012 let list =
1013 BlockSparsePostingList::from_postings(&postings, WeightQuantization::UInt8).unwrap();
1014
1015 let mut buf = Vec::new();
1016 list.serialize(&mut buf).unwrap();
1017 let list2 = BlockSparsePostingList::deserialize(&mut Cursor::new(&buf)).unwrap();
1018
1019 assert_eq!(list.doc_count(), list2.doc_count());
1020 }
1021
1022 #[test]
1023 fn test_seek() {
1024 let postings: Vec<(DocId, u16, f32)> = (0..500).map(|i| (i * 3, 0, i as f32)).collect();
1025 let list =
1026 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1027
1028 let mut iter = list.iterator();
1029 assert_eq!(iter.seek(300), 300);
1030 assert_eq!(iter.seek(301), 303);
1031 assert_eq!(iter.seek(2000), TERMINATED);
1032 }
1033
1034 #[test]
1035 fn test_merge_with_offsets() {
1036 let postings1: Vec<(DocId, u16, f32)> = vec![(0, 0, 1.0), (5, 0, 2.0), (10, 1, 3.0)];
1038 let list1 =
1039 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1040
1041 let postings2: Vec<(DocId, u16, f32)> = vec![(0, 0, 4.0), (3, 1, 5.0), (7, 0, 6.0)];
1043 let list2 =
1044 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1045
1046 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
1048
1049 assert_eq!(merged.doc_count(), 6);
1050
1051 let decoded = merged.decode_all();
1053 assert_eq!(decoded.len(), 6);
1054
1055 assert_eq!(decoded[0].0, 0);
1057 assert_eq!(decoded[1].0, 5);
1058 assert_eq!(decoded[2].0, 10);
1059
1060 assert_eq!(decoded[3].0, 100); assert_eq!(decoded[4].0, 103); assert_eq!(decoded[5].0, 107); assert!((decoded[0].2 - 1.0).abs() < 0.01);
1067 assert!((decoded[3].2 - 4.0).abs() < 0.01);
1068
1069 assert_eq!(decoded[2].1, 1); assert_eq!(decoded[4].1, 1); }
1073
1074 #[test]
1075 fn test_merge_with_offsets_multi_block() {
1076 let postings1: Vec<(DocId, u16, f32)> = (0..200).map(|i| (i * 2, 0, i as f32)).collect();
1078 let list1 =
1079 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1080 assert!(list1.num_blocks() > 1, "Should have multiple blocks");
1081
1082 let postings2: Vec<(DocId, u16, f32)> = (0..150).map(|i| (i * 3, 1, i as f32)).collect();
1083 let list2 =
1084 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1085
1086 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 1000)]);
1088
1089 assert_eq!(merged.doc_count(), 350);
1090 assert_eq!(merged.num_blocks(), list1.num_blocks() + list2.num_blocks());
1091
1092 let mut iter = merged.iterator();
1094
1095 assert_eq!(iter.doc(), 0);
1097
1098 let doc = iter.seek(1000);
1100 assert_eq!(doc, 1000); iter.advance();
1104 assert_eq!(iter.doc(), 1003); }
1106
1107 #[test]
1108 fn test_merge_with_offsets_serialize_roundtrip() {
1109 let postings1: Vec<(DocId, u16, f32)> = vec![(0, 0, 1.0), (5, 0, 2.0), (10, 1, 3.0)];
1111 let list1 =
1112 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1113
1114 let postings2: Vec<(DocId, u16, f32)> = vec![(0, 0, 4.0), (3, 1, 5.0), (7, 0, 6.0)];
1115 let list2 =
1116 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1117
1118 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
1120
1121 let mut bytes = Vec::new();
1123 merged.serialize(&mut bytes).unwrap();
1124
1125 let mut cursor = std::io::Cursor::new(&bytes);
1127 let loaded = BlockSparsePostingList::deserialize(&mut cursor).unwrap();
1128
1129 let decoded = loaded.decode_all();
1131 assert_eq!(decoded.len(), 6);
1132
1133 assert_eq!(decoded[0].0, 0);
1135 assert_eq!(decoded[1].0, 5);
1136 assert_eq!(decoded[2].0, 10);
1137
1138 assert_eq!(decoded[3].0, 100, "First doc of seg2 should be 0+100=100");
1140 assert_eq!(decoded[4].0, 103, "Second doc of seg2 should be 3+100=103");
1141 assert_eq!(decoded[5].0, 107, "Third doc of seg2 should be 7+100=107");
1142
1143 let mut iter = loaded.iterator();
1145 assert_eq!(iter.doc(), 0);
1146 iter.advance();
1147 assert_eq!(iter.doc(), 5);
1148 iter.advance();
1149 assert_eq!(iter.doc(), 10);
1150 iter.advance();
1151 assert_eq!(iter.doc(), 100);
1152 iter.advance();
1153 assert_eq!(iter.doc(), 103);
1154 iter.advance();
1155 assert_eq!(iter.doc(), 107);
1156 }
1157
1158 #[test]
1159 fn test_merge_seek_after_roundtrip() {
1160 let postings1: Vec<(DocId, u16, f32)> = (0..200).map(|i| (i * 2, 0, 1.0)).collect();
1162 let list1 =
1163 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1164
1165 let postings2: Vec<(DocId, u16, f32)> = (0..150).map(|i| (i * 3, 0, 2.0)).collect();
1166 let list2 =
1167 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1168
1169 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 1000)]);
1171
1172 let mut bytes = Vec::new();
1174 merged.serialize(&mut bytes).unwrap();
1175 let loaded =
1176 BlockSparsePostingList::deserialize(&mut std::io::Cursor::new(&bytes)).unwrap();
1177
1178 let mut iter = loaded.iterator();
1180
1181 let doc = iter.seek(100);
1183 assert_eq!(doc, 100, "Seek to 100 in segment 1");
1184
1185 let doc = iter.seek(1000);
1187 assert_eq!(doc, 1000, "Seek to 1000 (first doc of segment 2)");
1188
1189 let doc = iter.seek(1050);
1191 assert!(
1192 doc >= 1050,
1193 "Seek to 1050 should find doc >= 1050, got {}",
1194 doc
1195 );
1196
1197 let doc = iter.seek(500);
1199 assert!(
1200 doc >= 1050,
1201 "Seek backwards should not go back, got {}",
1202 doc
1203 );
1204
1205 let mut iter2 = loaded.iterator();
1207
1208 let mut count = 0;
1210 let mut prev_doc = 0;
1211 while iter2.doc() != super::TERMINATED {
1212 let current = iter2.doc();
1213 if count > 0 {
1214 assert!(
1215 current > prev_doc,
1216 "Docs should be monotonically increasing: {} vs {}",
1217 prev_doc,
1218 current
1219 );
1220 }
1221 prev_doc = current;
1222 iter2.advance();
1223 count += 1;
1224 }
1225 assert_eq!(count, 350, "Should have 350 total docs");
1226 }
1227
1228 #[test]
1229 fn test_doc_count_multi_value() {
1230 let postings: Vec<(DocId, u16, f32)> = vec![
1233 (0, 0, 1.0),
1234 (0, 1, 1.5),
1235 (0, 2, 2.0),
1236 (5, 0, 3.0),
1237 (5, 1, 3.5),
1238 (10, 0, 4.0),
1239 ];
1240 let list =
1241 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1242
1243 assert_eq!(list.doc_count(), 3);
1245
1246 let decoded = list.decode_all();
1248 assert_eq!(decoded.len(), 6);
1249 }
1250
1251 #[test]
1255 fn test_zero_copy_merge_patches_first_doc_id() {
1256 use crate::structures::SparseSkipEntry;
1257
1258 let postings1: Vec<(DocId, u16, f32)> = (0..200).map(|i| (i * 2, 0, i as f32)).collect();
1260 let list1 =
1261 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1262 assert!(list1.num_blocks() > 1);
1263
1264 let postings2: Vec<(DocId, u16, f32)> = (0..150).map(|i| (i * 3, 1, i as f32)).collect();
1265 let list2 =
1266 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1267
1268 let mut bytes1 = Vec::new();
1270 list1.serialize(&mut bytes1).unwrap();
1271 let mut bytes2 = Vec::new();
1272 list2.serialize(&mut bytes2).unwrap();
1273
1274 fn parse_raw(data: &[u8]) -> (u32, f32, Vec<SparseSkipEntry>, &[u8]) {
1276 let doc_count = u32::from_le_bytes(data[0..4].try_into().unwrap());
1277 let global_max = f32::from_le_bytes(data[4..8].try_into().unwrap());
1278 let num_blocks = u32::from_le_bytes(data[8..12].try_into().unwrap()) as usize;
1279 let mut pos = 12;
1280 let mut skip = Vec::new();
1281 for _ in 0..num_blocks {
1282 let first_doc = u32::from_le_bytes(data[pos..pos + 4].try_into().unwrap());
1283 let last_doc = u32::from_le_bytes(data[pos + 4..pos + 8].try_into().unwrap());
1284 let offset = u32::from_le_bytes(data[pos + 8..pos + 12].try_into().unwrap());
1285 let length = u32::from_le_bytes(data[pos + 12..pos + 16].try_into().unwrap());
1286 let max_w = f32::from_le_bytes(data[pos + 16..pos + 20].try_into().unwrap());
1287 skip.push(SparseSkipEntry::new(
1288 first_doc, last_doc, offset, length, max_w,
1289 ));
1290 pos += 20;
1291 }
1292 (doc_count, global_max, skip, &data[pos..])
1293 }
1294
1295 let (dc1, gm1, skip1, raw1) = parse_raw(&bytes1);
1296 let (dc2, gm2, skip2, raw2) = parse_raw(&bytes2);
1297
1298 let doc_offset: u32 = 1000; let total_docs = dc1 + dc2;
1301 let global_max = gm1.max(gm2);
1302 let total_blocks = (skip1.len() + skip2.len()) as u32;
1303
1304 let mut output = Vec::new();
1305 output.extend_from_slice(&total_docs.to_le_bytes());
1307 output.extend_from_slice(&global_max.to_le_bytes());
1308 output.extend_from_slice(&total_blocks.to_le_bytes());
1309
1310 let mut block_data_offset = 0u32;
1312 for entry in &skip1 {
1313 let adjusted = SparseSkipEntry::new(
1314 entry.first_doc,
1315 entry.last_doc,
1316 block_data_offset + entry.offset,
1317 entry.length,
1318 entry.max_weight,
1319 );
1320 adjusted.write(&mut output).unwrap();
1321 }
1322 if let Some(last) = skip1.last() {
1323 block_data_offset += last.offset + last.length;
1324 }
1325 for entry in &skip2 {
1326 let adjusted = SparseSkipEntry::new(
1327 entry.first_doc + doc_offset,
1328 entry.last_doc + doc_offset,
1329 block_data_offset + entry.offset,
1330 entry.length,
1331 entry.max_weight,
1332 );
1333 adjusted.write(&mut output).unwrap();
1334 }
1335
1336 output.extend_from_slice(raw1);
1338
1339 const FIRST_DOC_ID_OFFSET: usize = 8;
1340 let mut buf2 = raw2.to_vec();
1341 for entry in &skip2 {
1342 let off = entry.offset as usize + FIRST_DOC_ID_OFFSET;
1343 if off + 4 <= buf2.len() {
1344 let old = u32::from_le_bytes(buf2[off..off + 4].try_into().unwrap());
1345 let patched = (old + doc_offset).to_le_bytes();
1346 buf2[off..off + 4].copy_from_slice(&patched);
1347 }
1348 }
1349 output.extend_from_slice(&buf2);
1350
1351 let loaded = BlockSparsePostingList::deserialize(&mut Cursor::new(&output)).unwrap();
1353 assert_eq!(loaded.doc_count(), 350);
1354
1355 let mut iter = loaded.iterator();
1356
1357 assert_eq!(iter.doc(), 0);
1359 let doc = iter.seek(100);
1360 assert_eq!(doc, 100);
1361 let doc = iter.seek(398);
1362 assert_eq!(doc, 398);
1363
1364 let doc = iter.seek(1000);
1366 assert_eq!(doc, 1000, "First doc of segment 2 should be 1000");
1367 iter.advance();
1368 assert_eq!(iter.doc(), 1003, "Second doc of segment 2 should be 1003");
1369 let doc = iter.seek(1447);
1370 assert_eq!(doc, 1447, "Last doc of segment 2 should be 1447");
1371
1372 iter.advance();
1374 assert_eq!(iter.doc(), super::TERMINATED);
1375
1376 let reference =
1378 BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, doc_offset)]);
1379 let mut ref_iter = reference.iterator();
1380 let mut zc_iter = loaded.iterator();
1381 while ref_iter.doc() != super::TERMINATED {
1382 assert_eq!(
1383 ref_iter.doc(),
1384 zc_iter.doc(),
1385 "Zero-copy and reference merge should produce identical doc_ids"
1386 );
1387 assert!(
1388 (ref_iter.weight() - zc_iter.weight()).abs() < 0.01,
1389 "Weights should match: {} vs {}",
1390 ref_iter.weight(),
1391 zc_iter.weight()
1392 );
1393 ref_iter.advance();
1394 zc_iter.advance();
1395 }
1396 assert_eq!(zc_iter.doc(), super::TERMINATED);
1397 }
1398
1399 #[test]
1400 fn test_doc_count_single_value() {
1401 let postings: Vec<(DocId, u16, f32)> =
1403 vec![(0, 0, 1.0), (5, 0, 2.0), (10, 0, 3.0), (15, 0, 4.0)];
1404 let list =
1405 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1406
1407 assert_eq!(list.doc_count(), 4);
1409 }
1410
1411 #[test]
1412 fn test_doc_count_multi_value_serialization_roundtrip() {
1413 let postings: Vec<(DocId, u16, f32)> =
1415 vec![(0, 0, 1.0), (0, 1, 1.5), (5, 0, 2.0), (5, 1, 2.5)];
1416 let list =
1417 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1418 assert_eq!(list.doc_count(), 2);
1419
1420 let mut buf = Vec::new();
1421 list.serialize(&mut buf).unwrap();
1422 let loaded = BlockSparsePostingList::deserialize(&mut Cursor::new(&buf)).unwrap();
1423 assert_eq!(loaded.doc_count(), 2);
1424 }
1425
1426 #[test]
1427 fn test_merge_preserves_weights_and_ordinals() {
1428 let postings1: Vec<(DocId, u16, f32)> = vec![(0, 0, 1.5), (5, 1, 2.5), (10, 2, 3.5)];
1430 let list1 =
1431 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1432
1433 let postings2: Vec<(DocId, u16, f32)> = vec![(0, 0, 4.5), (3, 1, 5.5), (7, 3, 6.5)];
1434 let list2 =
1435 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1436
1437 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
1439
1440 let mut bytes = Vec::new();
1442 merged.serialize(&mut bytes).unwrap();
1443 let loaded =
1444 BlockSparsePostingList::deserialize(&mut std::io::Cursor::new(&bytes)).unwrap();
1445
1446 let mut iter = loaded.iterator();
1448
1449 assert_eq!(iter.doc(), 0);
1451 assert!(
1452 (iter.weight() - 1.5).abs() < 0.01,
1453 "Weight should be 1.5, got {}",
1454 iter.weight()
1455 );
1456 assert_eq!(iter.ordinal(), 0);
1457
1458 iter.advance();
1459 assert_eq!(iter.doc(), 5);
1460 assert!(
1461 (iter.weight() - 2.5).abs() < 0.01,
1462 "Weight should be 2.5, got {}",
1463 iter.weight()
1464 );
1465 assert_eq!(iter.ordinal(), 1);
1466
1467 iter.advance();
1468 assert_eq!(iter.doc(), 10);
1469 assert!(
1470 (iter.weight() - 3.5).abs() < 0.01,
1471 "Weight should be 3.5, got {}",
1472 iter.weight()
1473 );
1474 assert_eq!(iter.ordinal(), 2);
1475
1476 iter.advance();
1478 assert_eq!(iter.doc(), 100);
1479 assert!(
1480 (iter.weight() - 4.5).abs() < 0.01,
1481 "Weight should be 4.5, got {}",
1482 iter.weight()
1483 );
1484 assert_eq!(iter.ordinal(), 0);
1485
1486 iter.advance();
1487 assert_eq!(iter.doc(), 103);
1488 assert!(
1489 (iter.weight() - 5.5).abs() < 0.01,
1490 "Weight should be 5.5, got {}",
1491 iter.weight()
1492 );
1493 assert_eq!(iter.ordinal(), 1);
1494
1495 iter.advance();
1496 assert_eq!(iter.doc(), 107);
1497 assert!(
1498 (iter.weight() - 6.5).abs() < 0.01,
1499 "Weight should be 6.5, got {}",
1500 iter.weight()
1501 );
1502 assert_eq!(iter.ordinal(), 3);
1503
1504 iter.advance();
1506 assert_eq!(iter.doc(), super::TERMINATED);
1507 }
1508
1509 #[test]
1510 fn test_merge_global_max_weight() {
1511 let postings1: Vec<(DocId, u16, f32)> = vec![
1513 (0, 0, 3.0),
1514 (1, 0, 7.0), (2, 0, 2.0),
1516 ];
1517 let list1 =
1518 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1519
1520 let postings2: Vec<(DocId, u16, f32)> = vec![
1521 (0, 0, 5.0),
1522 (1, 0, 4.0),
1523 (2, 0, 6.0), ];
1525 let list2 =
1526 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1527
1528 assert!((list1.global_max_weight() - 7.0).abs() < 0.01);
1530 assert!((list2.global_max_weight() - 6.0).abs() < 0.01);
1531
1532 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
1534
1535 assert!(
1537 (merged.global_max_weight() - 7.0).abs() < 0.01,
1538 "Global max should be 7.0, got {}",
1539 merged.global_max_weight()
1540 );
1541
1542 let mut bytes = Vec::new();
1544 merged.serialize(&mut bytes).unwrap();
1545 let loaded =
1546 BlockSparsePostingList::deserialize(&mut std::io::Cursor::new(&bytes)).unwrap();
1547
1548 assert!(
1549 (loaded.global_max_weight() - 7.0).abs() < 0.01,
1550 "After roundtrip, global max should still be 7.0, got {}",
1551 loaded.global_max_weight()
1552 );
1553 }
1554
1555 #[test]
1556 fn test_scoring_simulation_after_merge() {
1557 let postings1: Vec<(DocId, u16, f32)> = vec![
1559 (0, 0, 0.5), (5, 0, 0.8), ];
1562 let list1 =
1563 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1564
1565 let postings2: Vec<(DocId, u16, f32)> = vec![
1566 (0, 0, 0.6), (3, 0, 0.9), ];
1569 let list2 =
1570 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1571
1572 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
1574
1575 let mut bytes = Vec::new();
1577 merged.serialize(&mut bytes).unwrap();
1578 let loaded =
1579 BlockSparsePostingList::deserialize(&mut std::io::Cursor::new(&bytes)).unwrap();
1580
1581 let query_weight = 2.0f32;
1583 let mut iter = loaded.iterator();
1584
1585 assert_eq!(iter.doc(), 0);
1588 let score = query_weight * iter.weight();
1589 assert!(
1590 (score - 1.0).abs() < 0.01,
1591 "Doc 0 score should be 1.0, got {}",
1592 score
1593 );
1594
1595 iter.advance();
1596 assert_eq!(iter.doc(), 5);
1598 let score = query_weight * iter.weight();
1599 assert!(
1600 (score - 1.6).abs() < 0.01,
1601 "Doc 5 score should be 1.6, got {}",
1602 score
1603 );
1604
1605 iter.advance();
1606 assert_eq!(iter.doc(), 100);
1608 let score = query_weight * iter.weight();
1609 assert!(
1610 (score - 1.2).abs() < 0.01,
1611 "Doc 100 score should be 1.2, got {}",
1612 score
1613 );
1614
1615 iter.advance();
1616 assert_eq!(iter.doc(), 103);
1618 let score = query_weight * iter.weight();
1619 assert!(
1620 (score - 1.8).abs() < 0.01,
1621 "Doc 103 score should be 1.8, got {}",
1622 score
1623 );
1624 }
1625}