1use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
9use std::io::{self, Cursor, Read, Write};
10
11use super::config::WeightQuantization;
12use crate::DocId;
13use crate::directories::OwnedBytes;
14use crate::structures::postings::TERMINATED;
15use crate::structures::simd;
16
17pub const BLOCK_SIZE: usize = 128;
18
19#[derive(Debug, Clone, Copy)]
20pub struct BlockHeader {
21 pub count: u16,
22 pub doc_id_bits: u8,
23 pub ordinal_bits: u8,
24 pub weight_quant: WeightQuantization,
25 pub first_doc_id: DocId,
26 pub max_weight: f32,
27}
28
29impl BlockHeader {
30 pub const SIZE: usize = 16;
31
32 pub fn write<W: Write>(&self, w: &mut W) -> io::Result<()> {
33 w.write_u16::<LittleEndian>(self.count)?;
34 w.write_u8(self.doc_id_bits)?;
35 w.write_u8(self.ordinal_bits)?;
36 w.write_u8(self.weight_quant as u8)?;
37 w.write_u8(0)?;
38 w.write_u16::<LittleEndian>(0)?;
39 w.write_u32::<LittleEndian>(self.first_doc_id)?;
40 w.write_f32::<LittleEndian>(self.max_weight)?;
41 Ok(())
42 }
43
44 pub fn read<R: Read>(r: &mut R) -> io::Result<Self> {
45 let count = r.read_u16::<LittleEndian>()?;
46 let doc_id_bits = r.read_u8()?;
47 let ordinal_bits = r.read_u8()?;
48 let weight_quant_byte = r.read_u8()?;
49 let _ = r.read_u8()?;
50 let _ = r.read_u16::<LittleEndian>()?;
51 let first_doc_id = r.read_u32::<LittleEndian>()?;
52 let max_weight = r.read_f32::<LittleEndian>()?;
53
54 let weight_quant = WeightQuantization::from_u8(weight_quant_byte)
55 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Invalid weight quant"))?;
56
57 Ok(Self {
58 count,
59 doc_id_bits,
60 ordinal_bits,
61 weight_quant,
62 first_doc_id,
63 max_weight,
64 })
65 }
66}
67
68#[derive(Debug, Clone)]
69pub struct SparseBlock {
70 pub header: BlockHeader,
71 pub doc_ids_data: OwnedBytes,
73 pub ordinals_data: OwnedBytes,
75 pub weights_data: OwnedBytes,
77}
78
79impl SparseBlock {
80 pub fn from_postings(
81 postings: &[(DocId, u16, f32)],
82 weight_quant: WeightQuantization,
83 ) -> io::Result<Self> {
84 assert!(!postings.is_empty() && postings.len() <= BLOCK_SIZE);
85
86 let count = postings.len();
87 let first_doc_id = postings[0].0;
88
89 let mut deltas = Vec::with_capacity(count);
91 let mut prev = first_doc_id;
92 for &(doc_id, _, _) in postings {
93 deltas.push(doc_id.saturating_sub(prev));
94 prev = doc_id;
95 }
96 deltas[0] = 0;
97
98 let doc_id_bits = simd::round_bit_width(find_optimal_bit_width(&deltas[1..]));
99 let ordinals: Vec<u16> = postings.iter().map(|(_, o, _)| *o).collect();
100 let max_ordinal = ordinals.iter().copied().max().unwrap_or(0);
101 let ordinal_bits = if max_ordinal == 0 {
102 0
103 } else {
104 simd::round_bit_width(bits_needed_u16(max_ordinal))
105 };
106
107 let weights: Vec<f32> = postings.iter().map(|(_, _, w)| *w).collect();
108 let max_weight = weights.iter().copied().fold(0.0f32, f32::max);
109
110 let doc_ids_data = OwnedBytes::new({
111 let rounded = simd::RoundedBitWidth::from_u8(doc_id_bits);
112 let num_deltas = count - 1;
113 let byte_count = num_deltas * rounded.bytes_per_value();
114 let mut data = vec![0u8; byte_count];
115 simd::pack_rounded(&deltas[1..], rounded, &mut data);
116 data
117 });
118 let ordinals_data = OwnedBytes::new(if ordinal_bits > 0 {
119 let rounded = simd::RoundedBitWidth::from_u8(ordinal_bits);
120 let byte_count = count * rounded.bytes_per_value();
121 let mut data = vec![0u8; byte_count];
122 let ord_u32: Vec<u32> = ordinals.iter().map(|&o| o as u32).collect();
123 simd::pack_rounded(&ord_u32, rounded, &mut data);
124 data
125 } else {
126 Vec::new()
127 });
128 let weights_data = OwnedBytes::new(encode_weights(&weights, weight_quant)?);
129
130 Ok(Self {
131 header: BlockHeader {
132 count: count as u16,
133 doc_id_bits,
134 ordinal_bits,
135 weight_quant,
136 first_doc_id,
137 max_weight,
138 },
139 doc_ids_data,
140 ordinals_data,
141 weights_data,
142 })
143 }
144
145 pub fn decode_doc_ids(&self) -> Vec<DocId> {
146 let mut out = Vec::with_capacity(self.header.count as usize);
147 self.decode_doc_ids_into(&mut out);
148 out
149 }
150
151 pub fn decode_doc_ids_into(&self, out: &mut Vec<DocId>) {
155 let count = self.header.count as usize;
156 out.clear();
157 out.resize(count, 0);
158 out[0] = self.header.first_doc_id;
159
160 if count > 1 {
161 let bits = self.header.doc_id_bits;
162 if bits == 0 {
163 out[1..].fill(self.header.first_doc_id);
165 } else {
166 simd::unpack_rounded(
168 &self.doc_ids_data,
169 simd::RoundedBitWidth::from_u8(bits),
170 &mut out[1..],
171 count - 1,
172 );
173 for i in 1..count {
175 out[i] += out[i - 1];
176 }
177 }
178 }
179 }
180
181 pub fn decode_ordinals(&self) -> Vec<u16> {
182 let mut out = Vec::with_capacity(self.header.count as usize);
183 self.decode_ordinals_into(&mut out);
184 out
185 }
186
187 pub fn decode_ordinals_into(&self, out: &mut Vec<u16>) {
191 let count = self.header.count as usize;
192 out.clear();
193 if self.header.ordinal_bits == 0 {
194 out.resize(count, 0u16);
195 } else {
196 let mut temp = [0u32; BLOCK_SIZE];
198 simd::unpack_rounded(
199 &self.ordinals_data,
200 simd::RoundedBitWidth::from_u8(self.header.ordinal_bits),
201 &mut temp[..count],
202 count,
203 );
204 out.reserve(count);
205 for &v in &temp[..count] {
206 out.push(v as u16);
207 }
208 }
209 }
210
211 pub fn decode_weights(&self) -> Vec<f32> {
212 let mut out = Vec::with_capacity(self.header.count as usize);
213 self.decode_weights_into(&mut out);
214 out
215 }
216
217 pub fn decode_weights_into(&self, out: &mut Vec<f32>) {
219 out.clear();
220 decode_weights_into(
221 &self.weights_data,
222 self.header.weight_quant,
223 self.header.count as usize,
224 out,
225 );
226 }
227
228 pub fn decode_scored_weights_into(&self, query_weight: f32, out: &mut Vec<f32>) {
236 out.clear();
237 let count = self.header.count as usize;
238 match self.header.weight_quant {
239 WeightQuantization::UInt8 if self.weights_data.len() >= 8 => {
240 let scale = f32::from_le_bytes([
242 self.weights_data[0],
243 self.weights_data[1],
244 self.weights_data[2],
245 self.weights_data[3],
246 ]);
247 let min_val = f32::from_le_bytes([
248 self.weights_data[4],
249 self.weights_data[5],
250 self.weights_data[6],
251 self.weights_data[7],
252 ]);
253 let eff_scale = query_weight * scale;
255 let eff_bias = query_weight * min_val;
256 out.resize(count, 0.0);
257 simd::dequantize_uint8(&self.weights_data[8..], out, eff_scale, eff_bias, count);
258 }
259 _ => {
260 decode_weights_into(&self.weights_data, self.header.weight_quant, count, out);
262 for w in out.iter_mut() {
263 *w *= query_weight;
264 }
265 }
266 }
267 }
268
269 pub fn write<W: Write>(&self, w: &mut W) -> io::Result<()> {
270 self.header.write(w)?;
271 w.write_u16::<LittleEndian>(self.doc_ids_data.len() as u16)?;
272 w.write_u16::<LittleEndian>(self.ordinals_data.len() as u16)?;
273 w.write_u16::<LittleEndian>(self.weights_data.len() as u16)?;
274 w.write_u16::<LittleEndian>(0)?;
275 w.write_all(&self.doc_ids_data)?;
276 w.write_all(&self.ordinals_data)?;
277 w.write_all(&self.weights_data)?;
278 Ok(())
279 }
280
281 pub fn read<R: Read>(r: &mut R) -> io::Result<Self> {
282 let header = BlockHeader::read(r)?;
283 let doc_ids_len = r.read_u16::<LittleEndian>()? as usize;
284 let ordinals_len = r.read_u16::<LittleEndian>()? as usize;
285 let weights_len = r.read_u16::<LittleEndian>()? as usize;
286 let _ = r.read_u16::<LittleEndian>()?;
287
288 let mut doc_ids_vec = vec![0u8; doc_ids_len];
289 r.read_exact(&mut doc_ids_vec)?;
290 let mut ordinals_vec = vec![0u8; ordinals_len];
291 r.read_exact(&mut ordinals_vec)?;
292 let mut weights_vec = vec![0u8; weights_len];
293 r.read_exact(&mut weights_vec)?;
294
295 Ok(Self {
296 header,
297 doc_ids_data: OwnedBytes::new(doc_ids_vec),
298 ordinals_data: OwnedBytes::new(ordinals_vec),
299 weights_data: OwnedBytes::new(weights_vec),
300 })
301 }
302
303 pub fn from_owned_bytes(data: crate::directories::OwnedBytes) -> crate::Result<Self> {
309 let b = data.as_slice();
310 if b.len() < BlockHeader::SIZE + 8 {
311 return Err(crate::Error::Corruption(
312 "sparse block too small".to_string(),
313 ));
314 }
315 let mut cursor = Cursor::new(&b[..BlockHeader::SIZE]);
316 let header =
317 BlockHeader::read(&mut cursor).map_err(|e| crate::Error::Corruption(e.to_string()))?;
318
319 if header.count == 0 {
320 let hex: String = b
321 .iter()
322 .take(32)
323 .map(|x| format!("{x:02x}"))
324 .collect::<Vec<_>>()
325 .join(" ");
326 return Err(crate::Error::Corruption(format!(
327 "sparse block has count=0 (data_len={}, first_32_bytes=[{}])",
328 b.len(),
329 hex
330 )));
331 }
332
333 let p = BlockHeader::SIZE;
334 let doc_ids_len = u16::from_le_bytes([b[p], b[p + 1]]) as usize;
335 let ordinals_len = u16::from_le_bytes([b[p + 2], b[p + 3]]) as usize;
336 let weights_len = u16::from_le_bytes([b[p + 4], b[p + 5]]) as usize;
337 let data_start = p + 8;
340 let ord_start = data_start + doc_ids_len;
341 let wt_start = ord_start + ordinals_len;
342 let expected_end = wt_start + weights_len;
343
344 if expected_end > b.len() {
345 let hex: String = b
346 .iter()
347 .take(32)
348 .map(|x| format!("{x:02x}"))
349 .collect::<Vec<_>>()
350 .join(" ");
351 return Err(crate::Error::Corruption(format!(
352 "sparse block sub-block overflow: count={} doc_ids={}B ords={}B wts={}B need={}B have={}B (first_32=[{}])",
353 header.count,
354 doc_ids_len,
355 ordinals_len,
356 weights_len,
357 expected_end,
358 b.len(),
359 hex
360 )));
361 }
362
363 Ok(Self {
364 header,
365 doc_ids_data: data.slice(data_start..ord_start),
366 ordinals_data: data.slice(ord_start..wt_start),
367 weights_data: data.slice(wt_start..wt_start + weights_len),
368 })
369 }
370
371 pub fn with_doc_offset(&self, doc_offset: u32) -> Self {
377 Self {
378 header: BlockHeader {
379 first_doc_id: self.header.first_doc_id + doc_offset,
380 ..self.header
381 },
382 doc_ids_data: self.doc_ids_data.clone(),
383 ordinals_data: self.ordinals_data.clone(),
384 weights_data: self.weights_data.clone(),
385 }
386 }
387}
388
389#[derive(Debug, Clone)]
394pub struct BlockSparsePostingList {
395 pub doc_count: u32,
396 pub blocks: Vec<SparseBlock>,
397}
398
399impl BlockSparsePostingList {
400 pub fn from_postings_with_block_size(
402 postings: &[(DocId, u16, f32)],
403 weight_quant: WeightQuantization,
404 block_size: usize,
405 ) -> io::Result<Self> {
406 if postings.is_empty() {
407 return Ok(Self {
408 doc_count: 0,
409 blocks: Vec::new(),
410 });
411 }
412
413 let block_size = block_size.max(16); let mut blocks = Vec::new();
415 for chunk in postings.chunks(block_size) {
416 blocks.push(SparseBlock::from_postings(chunk, weight_quant)?);
417 }
418
419 let mut unique_docs = 1u32;
424 for i in 1..postings.len() {
425 if postings[i].0 != postings[i - 1].0 {
426 unique_docs += 1;
427 }
428 }
429
430 Ok(Self {
431 doc_count: unique_docs,
432 blocks,
433 })
434 }
435
436 pub fn from_postings(
438 postings: &[(DocId, u16, f32)],
439 weight_quant: WeightQuantization,
440 ) -> io::Result<Self> {
441 Self::from_postings_with_block_size(postings, weight_quant, BLOCK_SIZE)
442 }
443
444 pub fn doc_count(&self) -> u32 {
445 self.doc_count
446 }
447
448 pub fn num_blocks(&self) -> usize {
449 self.blocks.len()
450 }
451
452 pub fn global_max_weight(&self) -> f32 {
453 self.blocks
454 .iter()
455 .map(|b| b.header.max_weight)
456 .fold(0.0f32, f32::max)
457 }
458
459 pub fn block_max_weight(&self, block_idx: usize) -> Option<f32> {
460 self.blocks.get(block_idx).map(|b| b.header.max_weight)
461 }
462
463 pub fn size_bytes(&self) -> usize {
465 use std::mem::size_of;
466
467 let header_size = size_of::<u32>() * 2; let blocks_size: usize = self
469 .blocks
470 .iter()
471 .map(|b| {
472 size_of::<BlockHeader>()
473 + b.doc_ids_data.len()
474 + b.ordinals_data.len()
475 + b.weights_data.len()
476 })
477 .sum();
478 header_size + blocks_size
479 }
480
481 pub fn iterator(&self) -> BlockSparsePostingIterator<'_> {
482 BlockSparsePostingIterator::new(self)
483 }
484
485 pub fn serialize<W: Write>(&self, w: &mut W) -> io::Result<()> {
494 use super::SparseSkipEntry;
495
496 w.write_u32::<LittleEndian>(self.doc_count)?;
497 w.write_f32::<LittleEndian>(self.global_max_weight())?;
498 w.write_u32::<LittleEndian>(self.blocks.len() as u32)?;
499
500 let mut block_bytes: Vec<Vec<u8>> = Vec::with_capacity(self.blocks.len());
502 for block in &self.blocks {
503 let mut buf = Vec::new();
504 block.write(&mut buf)?;
505 block_bytes.push(buf);
506 }
507
508 let mut offset = 0u32;
510 for (block, bytes) in self.blocks.iter().zip(block_bytes.iter()) {
511 let first_doc = block.header.first_doc_id;
512 let doc_ids = block.decode_doc_ids();
513 let last_doc = doc_ids.last().copied().unwrap_or(first_doc);
514 let length = bytes.len() as u32;
515
516 let entry =
517 SparseSkipEntry::new(first_doc, last_doc, offset, length, block.header.max_weight);
518 entry.write(w)?;
519 offset += length;
520 }
521
522 for bytes in block_bytes {
524 w.write_all(&bytes)?;
525 }
526
527 Ok(())
528 }
529
530 pub fn serialize_v3(&self) -> io::Result<(Vec<u8>, Vec<super::SparseSkipEntry>)> {
536 let mut block_data = Vec::new();
538 let mut skip_entries = Vec::with_capacity(self.blocks.len());
539 let mut offset = 0u32;
540
541 for block in &self.blocks {
542 let mut buf = Vec::new();
543 block.write(&mut buf)?;
544 let length = buf.len() as u32;
545
546 let first_doc = block.header.first_doc_id;
547 let doc_ids = block.decode_doc_ids();
548 let last_doc = doc_ids.last().copied().unwrap_or(first_doc);
549
550 skip_entries.push(super::SparseSkipEntry::new(
551 first_doc,
552 last_doc,
553 offset,
554 length,
555 block.header.max_weight,
556 ));
557
558 block_data.extend_from_slice(&buf);
559 offset += length;
560 }
561
562 Ok((block_data, skip_entries))
563 }
564
565 pub fn deserialize<R: Read>(r: &mut R) -> io::Result<Self> {
568 use super::SparseSkipEntry;
569
570 let doc_count = r.read_u32::<LittleEndian>()?;
571 let _global_max_weight = r.read_f32::<LittleEndian>()?;
572 let num_blocks = r.read_u32::<LittleEndian>()? as usize;
573
574 for _ in 0..num_blocks {
576 let _ = SparseSkipEntry::read(r)?;
577 }
578
579 let mut blocks = Vec::with_capacity(num_blocks);
581 for _ in 0..num_blocks {
582 blocks.push(SparseBlock::read(r)?);
583 }
584 Ok(Self { doc_count, blocks })
585 }
586
587 pub fn deserialize_header<R: Read>(
590 r: &mut R,
591 ) -> io::Result<(u32, f32, Vec<super::SparseSkipEntry>, usize)> {
592 use super::SparseSkipEntry;
593
594 let doc_count = r.read_u32::<LittleEndian>()?;
595 let global_max_weight = r.read_f32::<LittleEndian>()?;
596 let num_blocks = r.read_u32::<LittleEndian>()? as usize;
597
598 let mut entries = Vec::with_capacity(num_blocks);
599 for _ in 0..num_blocks {
600 entries.push(SparseSkipEntry::read(r)?);
601 }
602
603 let header_size = 4 + 4 + 4 + num_blocks * SparseSkipEntry::SIZE;
605
606 Ok((doc_count, global_max_weight, entries, header_size))
607 }
608
609 pub fn decode_all(&self) -> Vec<(DocId, u16, f32)> {
610 let total_postings: usize = self.blocks.iter().map(|b| b.header.count as usize).sum();
611 let mut result = Vec::with_capacity(total_postings);
612 for block in &self.blocks {
613 let doc_ids = block.decode_doc_ids();
614 let ordinals = block.decode_ordinals();
615 let weights = block.decode_weights();
616 for i in 0..block.header.count as usize {
617 result.push((doc_ids[i], ordinals[i], weights[i]));
618 }
619 }
620 result
621 }
622
623 pub fn merge_with_offsets(lists: &[(&BlockSparsePostingList, u32)]) -> Self {
634 if lists.is_empty() {
635 return Self {
636 doc_count: 0,
637 blocks: Vec::new(),
638 };
639 }
640
641 let total_blocks: usize = lists.iter().map(|(pl, _)| pl.blocks.len()).sum();
643 let total_docs: u32 = lists.iter().map(|(pl, _)| pl.doc_count).sum();
644
645 let mut merged_blocks = Vec::with_capacity(total_blocks);
646
647 for (posting_list, doc_offset) in lists {
649 for block in &posting_list.blocks {
650 merged_blocks.push(block.with_doc_offset(*doc_offset));
651 }
652 }
653
654 Self {
655 doc_count: total_docs,
656 blocks: merged_blocks,
657 }
658 }
659
660 fn find_block(&self, target: DocId) -> Option<usize> {
661 if self.blocks.is_empty() {
662 return None;
663 }
664 let idx = self
667 .blocks
668 .partition_point(|b| b.header.first_doc_id <= target);
669 if idx == 0 {
670 Some(0)
672 } else {
673 Some(idx - 1)
674 }
675 }
676}
677
678pub struct BlockSparsePostingIterator<'a> {
683 posting_list: &'a BlockSparsePostingList,
684 block_idx: usize,
685 in_block_idx: usize,
686 current_doc_ids: Vec<DocId>,
687 current_ordinals: Vec<u16>,
688 current_weights: Vec<f32>,
689 ordinals_decoded: bool,
691 exhausted: bool,
692}
693
694impl<'a> BlockSparsePostingIterator<'a> {
695 fn new(posting_list: &'a BlockSparsePostingList) -> Self {
696 let mut iter = Self {
697 posting_list,
698 block_idx: 0,
699 in_block_idx: 0,
700 current_doc_ids: Vec::with_capacity(128),
701 current_ordinals: Vec::with_capacity(128),
702 current_weights: Vec::with_capacity(128),
703 ordinals_decoded: false,
704 exhausted: posting_list.blocks.is_empty(),
705 };
706 if !iter.exhausted {
707 iter.load_block(0);
708 }
709 iter
710 }
711
712 fn load_block(&mut self, block_idx: usize) {
713 if let Some(block) = self.posting_list.blocks.get(block_idx) {
714 block.decode_doc_ids_into(&mut self.current_doc_ids);
715 block.decode_weights_into(&mut self.current_weights);
716 self.ordinals_decoded = false;
718 self.block_idx = block_idx;
719 self.in_block_idx = 0;
720 }
721 }
722
723 #[inline]
725 fn ensure_ordinals_decoded(&mut self) {
726 if !self.ordinals_decoded {
727 if let Some(block) = self.posting_list.blocks.get(self.block_idx) {
728 block.decode_ordinals_into(&mut self.current_ordinals);
729 }
730 self.ordinals_decoded = true;
731 }
732 }
733
734 #[inline]
735 pub fn doc(&self) -> DocId {
736 if self.exhausted {
737 TERMINATED
738 } else {
739 self.current_doc_ids[self.in_block_idx]
741 }
742 }
743
744 #[inline]
745 pub fn weight(&self) -> f32 {
746 if self.exhausted {
747 return 0.0;
748 }
749 self.current_weights[self.in_block_idx]
751 }
752
753 #[inline]
754 pub fn ordinal(&mut self) -> u16 {
755 if self.exhausted {
756 return 0;
757 }
758 self.ensure_ordinals_decoded();
759 self.current_ordinals[self.in_block_idx]
760 }
761
762 pub fn advance(&mut self) -> DocId {
763 if self.exhausted {
764 return TERMINATED;
765 }
766 self.in_block_idx += 1;
767 if self.in_block_idx >= self.current_doc_ids.len() {
768 self.block_idx += 1;
769 if self.block_idx >= self.posting_list.blocks.len() {
770 self.exhausted = true;
771 } else {
772 self.load_block(self.block_idx);
773 }
774 }
775 self.doc()
776 }
777
778 pub fn seek(&mut self, target: DocId) -> DocId {
779 if self.exhausted {
780 return TERMINATED;
781 }
782 if self.doc() >= target {
783 return self.doc();
784 }
785
786 if let Some(&last_doc) = self.current_doc_ids.last()
788 && last_doc >= target
789 {
790 let remaining = &self.current_doc_ids[self.in_block_idx..];
791 let pos = crate::structures::simd::find_first_ge_u32(remaining, target);
792 self.in_block_idx += pos;
793 if self.in_block_idx >= self.current_doc_ids.len() {
794 self.block_idx += 1;
795 if self.block_idx >= self.posting_list.blocks.len() {
796 self.exhausted = true;
797 } else {
798 self.load_block(self.block_idx);
799 }
800 }
801 return self.doc();
802 }
803
804 if let Some(block_idx) = self.posting_list.find_block(target) {
806 self.load_block(block_idx);
807 let pos = crate::structures::simd::find_first_ge_u32(&self.current_doc_ids, target);
808 self.in_block_idx = pos;
809 if self.in_block_idx >= self.current_doc_ids.len() {
810 self.block_idx += 1;
811 if self.block_idx >= self.posting_list.blocks.len() {
812 self.exhausted = true;
813 } else {
814 self.load_block(self.block_idx);
815 }
816 }
817 } else {
818 self.exhausted = true;
819 }
820 self.doc()
821 }
822
823 pub fn skip_to_next_block(&mut self) -> DocId {
826 if self.exhausted {
827 return TERMINATED;
828 }
829 let next = self.block_idx + 1;
830 if next >= self.posting_list.blocks.len() {
831 self.exhausted = true;
832 return TERMINATED;
833 }
834 self.load_block(next);
835 self.doc()
836 }
837
838 pub fn is_exhausted(&self) -> bool {
839 self.exhausted
840 }
841
842 pub fn current_block_max_weight(&self) -> f32 {
843 self.posting_list
844 .blocks
845 .get(self.block_idx)
846 .map(|b| b.header.max_weight)
847 .unwrap_or(0.0)
848 }
849
850 pub fn current_block_max_contribution(&self, query_weight: f32) -> f32 {
851 query_weight * self.current_block_max_weight()
852 }
853}
854
855fn find_optimal_bit_width(values: &[u32]) -> u8 {
860 if values.is_empty() {
861 return 0;
862 }
863 let max_val = values.iter().copied().max().unwrap_or(0);
864 simd::bits_needed(max_val)
865}
866
867fn bits_needed_u16(val: u16) -> u8 {
868 if val == 0 {
869 0
870 } else {
871 16 - val.leading_zeros() as u8
872 }
873}
874
875fn encode_weights(weights: &[f32], quant: WeightQuantization) -> io::Result<Vec<u8>> {
880 let mut data = Vec::new();
881 match quant {
882 WeightQuantization::Float32 => {
883 for &w in weights {
884 data.write_f32::<LittleEndian>(w)?;
885 }
886 }
887 WeightQuantization::Float16 => {
888 use half::f16;
889 for &w in weights {
890 data.write_u16::<LittleEndian>(f16::from_f32(w).to_bits())?;
891 }
892 }
893 WeightQuantization::UInt8 => {
894 let min = weights.iter().copied().fold(f32::INFINITY, f32::min);
895 let max = weights.iter().copied().fold(f32::NEG_INFINITY, f32::max);
896 let range = max - min;
897 let scale = if range < f32::EPSILON {
898 1.0
899 } else {
900 range / 255.0
901 };
902 data.write_f32::<LittleEndian>(scale)?;
903 data.write_f32::<LittleEndian>(min)?;
904 for &w in weights {
905 data.write_u8(((w - min) / scale).round() as u8)?;
906 }
907 }
908 WeightQuantization::UInt4 => {
909 let min = weights.iter().copied().fold(f32::INFINITY, f32::min);
910 let max = weights.iter().copied().fold(f32::NEG_INFINITY, f32::max);
911 let range = max - min;
912 let scale = if range < f32::EPSILON {
913 1.0
914 } else {
915 range / 15.0
916 };
917 data.write_f32::<LittleEndian>(scale)?;
918 data.write_f32::<LittleEndian>(min)?;
919 let mut i = 0;
920 while i < weights.len() {
921 let q1 = ((weights[i] - min) / scale).round() as u8 & 0x0F;
922 let q2 = if i + 1 < weights.len() {
923 ((weights[i + 1] - min) / scale).round() as u8 & 0x0F
924 } else {
925 0
926 };
927 data.write_u8((q2 << 4) | q1)?;
928 i += 2;
929 }
930 }
931 }
932 Ok(data)
933}
934
935fn decode_weights_into(data: &[u8], quant: WeightQuantization, count: usize, out: &mut Vec<f32>) {
936 let mut cursor = Cursor::new(data);
937 match quant {
938 WeightQuantization::Float32 => {
939 for _ in 0..count {
940 out.push(cursor.read_f32::<LittleEndian>().unwrap_or(0.0));
941 }
942 }
943 WeightQuantization::Float16 => {
944 use half::f16;
945 for _ in 0..count {
946 let bits = cursor.read_u16::<LittleEndian>().unwrap_or(0);
947 out.push(f16::from_bits(bits).to_f32());
948 }
949 }
950 WeightQuantization::UInt8 => {
951 let scale = cursor.read_f32::<LittleEndian>().unwrap_or(1.0);
952 let min_val = cursor.read_f32::<LittleEndian>().unwrap_or(0.0);
953 let offset = cursor.position() as usize;
954 out.resize(count, 0.0);
955 simd::dequantize_uint8(&data[offset..], out, scale, min_val, count);
956 }
957 WeightQuantization::UInt4 => {
958 let scale = cursor.read_f32::<LittleEndian>().unwrap_or(1.0);
959 let min = cursor.read_f32::<LittleEndian>().unwrap_or(0.0);
960 let mut i = 0;
961 while i < count {
962 let byte = cursor.read_u8().unwrap_or(0);
963 out.push((byte & 0x0F) as f32 * scale + min);
964 i += 1;
965 if i < count {
966 out.push((byte >> 4) as f32 * scale + min);
967 i += 1;
968 }
969 }
970 }
971 }
972}
973
974#[cfg(test)]
975mod tests {
976 use super::*;
977
978 #[test]
979 fn test_block_roundtrip() {
980 let postings = vec![
981 (10u32, 0u16, 1.5f32),
982 (15, 0, 2.0),
983 (20, 1, 0.5),
984 (100, 0, 3.0),
985 ];
986 let block = SparseBlock::from_postings(&postings, WeightQuantization::Float32).unwrap();
987
988 assert_eq!(block.decode_doc_ids(), vec![10, 15, 20, 100]);
989 assert_eq!(block.decode_ordinals(), vec![0, 0, 1, 0]);
990 let weights = block.decode_weights();
991 assert!((weights[0] - 1.5).abs() < 0.01);
992 }
993
994 #[test]
995 fn test_posting_list() {
996 let postings: Vec<(DocId, u16, f32)> =
997 (0..300).map(|i| (i * 2, 0, i as f32 * 0.1)).collect();
998 let list =
999 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1000
1001 assert_eq!(list.doc_count(), 300);
1002 assert_eq!(list.num_blocks(), 3);
1003
1004 let mut iter = list.iterator();
1005 assert_eq!(iter.doc(), 0);
1006 iter.advance();
1007 assert_eq!(iter.doc(), 2);
1008 }
1009
1010 #[test]
1011 fn test_serialization() {
1012 let postings = vec![(1u32, 0u16, 0.5f32), (10, 1, 1.5), (100, 0, 2.5)];
1013 let list =
1014 BlockSparsePostingList::from_postings(&postings, WeightQuantization::UInt8).unwrap();
1015
1016 let mut buf = Vec::new();
1017 list.serialize(&mut buf).unwrap();
1018 let list2 = BlockSparsePostingList::deserialize(&mut Cursor::new(&buf)).unwrap();
1019
1020 assert_eq!(list.doc_count(), list2.doc_count());
1021 }
1022
1023 #[test]
1024 fn test_seek() {
1025 let postings: Vec<(DocId, u16, f32)> = (0..500).map(|i| (i * 3, 0, i as f32)).collect();
1026 let list =
1027 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1028
1029 let mut iter = list.iterator();
1030 assert_eq!(iter.seek(300), 300);
1031 assert_eq!(iter.seek(301), 303);
1032 assert_eq!(iter.seek(2000), TERMINATED);
1033 }
1034
1035 #[test]
1036 fn test_merge_with_offsets() {
1037 let postings1: Vec<(DocId, u16, f32)> = vec![(0, 0, 1.0), (5, 0, 2.0), (10, 1, 3.0)];
1039 let list1 =
1040 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1041
1042 let postings2: Vec<(DocId, u16, f32)> = vec![(0, 0, 4.0), (3, 1, 5.0), (7, 0, 6.0)];
1044 let list2 =
1045 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1046
1047 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
1049
1050 assert_eq!(merged.doc_count(), 6);
1051
1052 let decoded = merged.decode_all();
1054 assert_eq!(decoded.len(), 6);
1055
1056 assert_eq!(decoded[0].0, 0);
1058 assert_eq!(decoded[1].0, 5);
1059 assert_eq!(decoded[2].0, 10);
1060
1061 assert_eq!(decoded[3].0, 100); assert_eq!(decoded[4].0, 103); assert_eq!(decoded[5].0, 107); assert!((decoded[0].2 - 1.0).abs() < 0.01);
1068 assert!((decoded[3].2 - 4.0).abs() < 0.01);
1069
1070 assert_eq!(decoded[2].1, 1); assert_eq!(decoded[4].1, 1); }
1074
1075 #[test]
1076 fn test_merge_with_offsets_multi_block() {
1077 let postings1: Vec<(DocId, u16, f32)> = (0..200).map(|i| (i * 2, 0, i as f32)).collect();
1079 let list1 =
1080 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1081 assert!(list1.num_blocks() > 1, "Should have multiple blocks");
1082
1083 let postings2: Vec<(DocId, u16, f32)> = (0..150).map(|i| (i * 3, 1, i as f32)).collect();
1084 let list2 =
1085 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1086
1087 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 1000)]);
1089
1090 assert_eq!(merged.doc_count(), 350);
1091 assert_eq!(merged.num_blocks(), list1.num_blocks() + list2.num_blocks());
1092
1093 let mut iter = merged.iterator();
1095
1096 assert_eq!(iter.doc(), 0);
1098
1099 let doc = iter.seek(1000);
1101 assert_eq!(doc, 1000); iter.advance();
1105 assert_eq!(iter.doc(), 1003); }
1107
1108 #[test]
1109 fn test_merge_with_offsets_serialize_roundtrip() {
1110 let postings1: Vec<(DocId, u16, f32)> = vec![(0, 0, 1.0), (5, 0, 2.0), (10, 1, 3.0)];
1112 let list1 =
1113 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1114
1115 let postings2: Vec<(DocId, u16, f32)> = vec![(0, 0, 4.0), (3, 1, 5.0), (7, 0, 6.0)];
1116 let list2 =
1117 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1118
1119 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
1121
1122 let mut bytes = Vec::new();
1124 merged.serialize(&mut bytes).unwrap();
1125
1126 let mut cursor = std::io::Cursor::new(&bytes);
1128 let loaded = BlockSparsePostingList::deserialize(&mut cursor).unwrap();
1129
1130 let decoded = loaded.decode_all();
1132 assert_eq!(decoded.len(), 6);
1133
1134 assert_eq!(decoded[0].0, 0);
1136 assert_eq!(decoded[1].0, 5);
1137 assert_eq!(decoded[2].0, 10);
1138
1139 assert_eq!(decoded[3].0, 100, "First doc of seg2 should be 0+100=100");
1141 assert_eq!(decoded[4].0, 103, "Second doc of seg2 should be 3+100=103");
1142 assert_eq!(decoded[5].0, 107, "Third doc of seg2 should be 7+100=107");
1143
1144 let mut iter = loaded.iterator();
1146 assert_eq!(iter.doc(), 0);
1147 iter.advance();
1148 assert_eq!(iter.doc(), 5);
1149 iter.advance();
1150 assert_eq!(iter.doc(), 10);
1151 iter.advance();
1152 assert_eq!(iter.doc(), 100);
1153 iter.advance();
1154 assert_eq!(iter.doc(), 103);
1155 iter.advance();
1156 assert_eq!(iter.doc(), 107);
1157 }
1158
1159 #[test]
1160 fn test_merge_seek_after_roundtrip() {
1161 let postings1: Vec<(DocId, u16, f32)> = (0..200).map(|i| (i * 2, 0, 1.0)).collect();
1163 let list1 =
1164 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1165
1166 let postings2: Vec<(DocId, u16, f32)> = (0..150).map(|i| (i * 3, 0, 2.0)).collect();
1167 let list2 =
1168 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1169
1170 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 1000)]);
1172
1173 let mut bytes = Vec::new();
1175 merged.serialize(&mut bytes).unwrap();
1176 let loaded =
1177 BlockSparsePostingList::deserialize(&mut std::io::Cursor::new(&bytes)).unwrap();
1178
1179 let mut iter = loaded.iterator();
1181
1182 let doc = iter.seek(100);
1184 assert_eq!(doc, 100, "Seek to 100 in segment 1");
1185
1186 let doc = iter.seek(1000);
1188 assert_eq!(doc, 1000, "Seek to 1000 (first doc of segment 2)");
1189
1190 let doc = iter.seek(1050);
1192 assert!(
1193 doc >= 1050,
1194 "Seek to 1050 should find doc >= 1050, got {}",
1195 doc
1196 );
1197
1198 let doc = iter.seek(500);
1200 assert!(
1201 doc >= 1050,
1202 "Seek backwards should not go back, got {}",
1203 doc
1204 );
1205
1206 let mut iter2 = loaded.iterator();
1208
1209 let mut count = 0;
1211 let mut prev_doc = 0;
1212 while iter2.doc() != super::TERMINATED {
1213 let current = iter2.doc();
1214 if count > 0 {
1215 assert!(
1216 current > prev_doc,
1217 "Docs should be monotonically increasing: {} vs {}",
1218 prev_doc,
1219 current
1220 );
1221 }
1222 prev_doc = current;
1223 iter2.advance();
1224 count += 1;
1225 }
1226 assert_eq!(count, 350, "Should have 350 total docs");
1227 }
1228
1229 #[test]
1230 fn test_doc_count_multi_value() {
1231 let postings: Vec<(DocId, u16, f32)> = vec![
1234 (0, 0, 1.0),
1235 (0, 1, 1.5),
1236 (0, 2, 2.0),
1237 (5, 0, 3.0),
1238 (5, 1, 3.5),
1239 (10, 0, 4.0),
1240 ];
1241 let list =
1242 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1243
1244 assert_eq!(list.doc_count(), 3);
1246
1247 let decoded = list.decode_all();
1249 assert_eq!(decoded.len(), 6);
1250 }
1251
1252 #[test]
1256 fn test_zero_copy_merge_patches_first_doc_id() {
1257 use crate::structures::SparseSkipEntry;
1258
1259 let postings1: Vec<(DocId, u16, f32)> = (0..200).map(|i| (i * 2, 0, i as f32)).collect();
1261 let list1 =
1262 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1263 assert!(list1.num_blocks() > 1);
1264
1265 let postings2: Vec<(DocId, u16, f32)> = (0..150).map(|i| (i * 3, 1, i as f32)).collect();
1266 let list2 =
1267 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1268
1269 let mut bytes1 = Vec::new();
1271 list1.serialize(&mut bytes1).unwrap();
1272 let mut bytes2 = Vec::new();
1273 list2.serialize(&mut bytes2).unwrap();
1274
1275 fn parse_raw(data: &[u8]) -> (u32, f32, Vec<SparseSkipEntry>, &[u8]) {
1277 let doc_count = u32::from_le_bytes(data[0..4].try_into().unwrap());
1278 let global_max = f32::from_le_bytes(data[4..8].try_into().unwrap());
1279 let num_blocks = u32::from_le_bytes(data[8..12].try_into().unwrap()) as usize;
1280 let mut pos = 12;
1281 let mut skip = Vec::new();
1282 for _ in 0..num_blocks {
1283 let first_doc = u32::from_le_bytes(data[pos..pos + 4].try_into().unwrap());
1284 let last_doc = u32::from_le_bytes(data[pos + 4..pos + 8].try_into().unwrap());
1285 let offset = u32::from_le_bytes(data[pos + 8..pos + 12].try_into().unwrap());
1286 let length = u32::from_le_bytes(data[pos + 12..pos + 16].try_into().unwrap());
1287 let max_w = f32::from_le_bytes(data[pos + 16..pos + 20].try_into().unwrap());
1288 skip.push(SparseSkipEntry::new(
1289 first_doc, last_doc, offset, length, max_w,
1290 ));
1291 pos += 20;
1292 }
1293 (doc_count, global_max, skip, &data[pos..])
1294 }
1295
1296 let (dc1, gm1, skip1, raw1) = parse_raw(&bytes1);
1297 let (dc2, gm2, skip2, raw2) = parse_raw(&bytes2);
1298
1299 let doc_offset: u32 = 1000; let total_docs = dc1 + dc2;
1302 let global_max = gm1.max(gm2);
1303 let total_blocks = (skip1.len() + skip2.len()) as u32;
1304
1305 let mut output = Vec::new();
1306 output.extend_from_slice(&total_docs.to_le_bytes());
1308 output.extend_from_slice(&global_max.to_le_bytes());
1309 output.extend_from_slice(&total_blocks.to_le_bytes());
1310
1311 let mut block_data_offset = 0u32;
1313 for entry in &skip1 {
1314 let adjusted = SparseSkipEntry::new(
1315 entry.first_doc,
1316 entry.last_doc,
1317 block_data_offset + entry.offset,
1318 entry.length,
1319 entry.max_weight,
1320 );
1321 adjusted.write(&mut output).unwrap();
1322 }
1323 if let Some(last) = skip1.last() {
1324 block_data_offset += last.offset + last.length;
1325 }
1326 for entry in &skip2 {
1327 let adjusted = SparseSkipEntry::new(
1328 entry.first_doc + doc_offset,
1329 entry.last_doc + doc_offset,
1330 block_data_offset + entry.offset,
1331 entry.length,
1332 entry.max_weight,
1333 );
1334 adjusted.write(&mut output).unwrap();
1335 }
1336
1337 output.extend_from_slice(raw1);
1339
1340 const FIRST_DOC_ID_OFFSET: usize = 8;
1341 let mut buf2 = raw2.to_vec();
1342 for entry in &skip2 {
1343 let off = entry.offset as usize + FIRST_DOC_ID_OFFSET;
1344 if off + 4 <= buf2.len() {
1345 let old = u32::from_le_bytes(buf2[off..off + 4].try_into().unwrap());
1346 let patched = (old + doc_offset).to_le_bytes();
1347 buf2[off..off + 4].copy_from_slice(&patched);
1348 }
1349 }
1350 output.extend_from_slice(&buf2);
1351
1352 let loaded = BlockSparsePostingList::deserialize(&mut Cursor::new(&output)).unwrap();
1354 assert_eq!(loaded.doc_count(), 350);
1355
1356 let mut iter = loaded.iterator();
1357
1358 assert_eq!(iter.doc(), 0);
1360 let doc = iter.seek(100);
1361 assert_eq!(doc, 100);
1362 let doc = iter.seek(398);
1363 assert_eq!(doc, 398);
1364
1365 let doc = iter.seek(1000);
1367 assert_eq!(doc, 1000, "First doc of segment 2 should be 1000");
1368 iter.advance();
1369 assert_eq!(iter.doc(), 1003, "Second doc of segment 2 should be 1003");
1370 let doc = iter.seek(1447);
1371 assert_eq!(doc, 1447, "Last doc of segment 2 should be 1447");
1372
1373 iter.advance();
1375 assert_eq!(iter.doc(), super::TERMINATED);
1376
1377 let reference =
1379 BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, doc_offset)]);
1380 let mut ref_iter = reference.iterator();
1381 let mut zc_iter = loaded.iterator();
1382 while ref_iter.doc() != super::TERMINATED {
1383 assert_eq!(
1384 ref_iter.doc(),
1385 zc_iter.doc(),
1386 "Zero-copy and reference merge should produce identical doc_ids"
1387 );
1388 assert!(
1389 (ref_iter.weight() - zc_iter.weight()).abs() < 0.01,
1390 "Weights should match: {} vs {}",
1391 ref_iter.weight(),
1392 zc_iter.weight()
1393 );
1394 ref_iter.advance();
1395 zc_iter.advance();
1396 }
1397 assert_eq!(zc_iter.doc(), super::TERMINATED);
1398 }
1399
1400 #[test]
1401 fn test_doc_count_single_value() {
1402 let postings: Vec<(DocId, u16, f32)> =
1404 vec![(0, 0, 1.0), (5, 0, 2.0), (10, 0, 3.0), (15, 0, 4.0)];
1405 let list =
1406 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1407
1408 assert_eq!(list.doc_count(), 4);
1410 }
1411
1412 #[test]
1413 fn test_doc_count_multi_value_serialization_roundtrip() {
1414 let postings: Vec<(DocId, u16, f32)> =
1416 vec![(0, 0, 1.0), (0, 1, 1.5), (5, 0, 2.0), (5, 1, 2.5)];
1417 let list =
1418 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1419 assert_eq!(list.doc_count(), 2);
1420
1421 let mut buf = Vec::new();
1422 list.serialize(&mut buf).unwrap();
1423 let loaded = BlockSparsePostingList::deserialize(&mut Cursor::new(&buf)).unwrap();
1424 assert_eq!(loaded.doc_count(), 2);
1425 }
1426
1427 #[test]
1428 fn test_merge_preserves_weights_and_ordinals() {
1429 let postings1: Vec<(DocId, u16, f32)> = vec![(0, 0, 1.5), (5, 1, 2.5), (10, 2, 3.5)];
1431 let list1 =
1432 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1433
1434 let postings2: Vec<(DocId, u16, f32)> = vec![(0, 0, 4.5), (3, 1, 5.5), (7, 3, 6.5)];
1435 let list2 =
1436 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1437
1438 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
1440
1441 let mut bytes = Vec::new();
1443 merged.serialize(&mut bytes).unwrap();
1444 let loaded =
1445 BlockSparsePostingList::deserialize(&mut std::io::Cursor::new(&bytes)).unwrap();
1446
1447 let mut iter = loaded.iterator();
1449
1450 assert_eq!(iter.doc(), 0);
1452 assert!(
1453 (iter.weight() - 1.5).abs() < 0.01,
1454 "Weight should be 1.5, got {}",
1455 iter.weight()
1456 );
1457 assert_eq!(iter.ordinal(), 0);
1458
1459 iter.advance();
1460 assert_eq!(iter.doc(), 5);
1461 assert!(
1462 (iter.weight() - 2.5).abs() < 0.01,
1463 "Weight should be 2.5, got {}",
1464 iter.weight()
1465 );
1466 assert_eq!(iter.ordinal(), 1);
1467
1468 iter.advance();
1469 assert_eq!(iter.doc(), 10);
1470 assert!(
1471 (iter.weight() - 3.5).abs() < 0.01,
1472 "Weight should be 3.5, got {}",
1473 iter.weight()
1474 );
1475 assert_eq!(iter.ordinal(), 2);
1476
1477 iter.advance();
1479 assert_eq!(iter.doc(), 100);
1480 assert!(
1481 (iter.weight() - 4.5).abs() < 0.01,
1482 "Weight should be 4.5, got {}",
1483 iter.weight()
1484 );
1485 assert_eq!(iter.ordinal(), 0);
1486
1487 iter.advance();
1488 assert_eq!(iter.doc(), 103);
1489 assert!(
1490 (iter.weight() - 5.5).abs() < 0.01,
1491 "Weight should be 5.5, got {}",
1492 iter.weight()
1493 );
1494 assert_eq!(iter.ordinal(), 1);
1495
1496 iter.advance();
1497 assert_eq!(iter.doc(), 107);
1498 assert!(
1499 (iter.weight() - 6.5).abs() < 0.01,
1500 "Weight should be 6.5, got {}",
1501 iter.weight()
1502 );
1503 assert_eq!(iter.ordinal(), 3);
1504
1505 iter.advance();
1507 assert_eq!(iter.doc(), super::TERMINATED);
1508 }
1509
1510 #[test]
1511 fn test_merge_global_max_weight() {
1512 let postings1: Vec<(DocId, u16, f32)> = vec![
1514 (0, 0, 3.0),
1515 (1, 0, 7.0), (2, 0, 2.0),
1517 ];
1518 let list1 =
1519 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1520
1521 let postings2: Vec<(DocId, u16, f32)> = vec![
1522 (0, 0, 5.0),
1523 (1, 0, 4.0),
1524 (2, 0, 6.0), ];
1526 let list2 =
1527 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1528
1529 assert!((list1.global_max_weight() - 7.0).abs() < 0.01);
1531 assert!((list2.global_max_weight() - 6.0).abs() < 0.01);
1532
1533 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
1535
1536 assert!(
1538 (merged.global_max_weight() - 7.0).abs() < 0.01,
1539 "Global max should be 7.0, got {}",
1540 merged.global_max_weight()
1541 );
1542
1543 let mut bytes = Vec::new();
1545 merged.serialize(&mut bytes).unwrap();
1546 let loaded =
1547 BlockSparsePostingList::deserialize(&mut std::io::Cursor::new(&bytes)).unwrap();
1548
1549 assert!(
1550 (loaded.global_max_weight() - 7.0).abs() < 0.01,
1551 "After roundtrip, global max should still be 7.0, got {}",
1552 loaded.global_max_weight()
1553 );
1554 }
1555
1556 #[test]
1557 fn test_scoring_simulation_after_merge() {
1558 let postings1: Vec<(DocId, u16, f32)> = vec![
1560 (0, 0, 0.5), (5, 0, 0.8), ];
1563 let list1 =
1564 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1565
1566 let postings2: Vec<(DocId, u16, f32)> = vec![
1567 (0, 0, 0.6), (3, 0, 0.9), ];
1570 let list2 =
1571 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1572
1573 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
1575
1576 let mut bytes = Vec::new();
1578 merged.serialize(&mut bytes).unwrap();
1579 let loaded =
1580 BlockSparsePostingList::deserialize(&mut std::io::Cursor::new(&bytes)).unwrap();
1581
1582 let query_weight = 2.0f32;
1584 let mut iter = loaded.iterator();
1585
1586 assert_eq!(iter.doc(), 0);
1589 let score = query_weight * iter.weight();
1590 assert!(
1591 (score - 1.0).abs() < 0.01,
1592 "Doc 0 score should be 1.0, got {}",
1593 score
1594 );
1595
1596 iter.advance();
1597 assert_eq!(iter.doc(), 5);
1599 let score = query_weight * iter.weight();
1600 assert!(
1601 (score - 1.6).abs() < 0.01,
1602 "Doc 5 score should be 1.6, got {}",
1603 score
1604 );
1605
1606 iter.advance();
1607 assert_eq!(iter.doc(), 100);
1609 let score = query_weight * iter.weight();
1610 assert!(
1611 (score - 1.2).abs() < 0.01,
1612 "Doc 100 score should be 1.2, got {}",
1613 score
1614 );
1615
1616 iter.advance();
1617 assert_eq!(iter.doc(), 103);
1619 let score = query_weight * iter.weight();
1620 assert!(
1621 (score - 1.8).abs() < 0.01,
1622 "Doc 103 score should be 1.8, got {}",
1623 score
1624 );
1625 }
1626}