1use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
9use std::io::{self, Cursor, Read, Write};
10
11use super::config::WeightQuantization;
12use crate::DocId;
13use crate::directories::OwnedBytes;
14use crate::structures::postings::TERMINATED;
15use crate::structures::simd;
16
17pub const BLOCK_SIZE: usize = 128;
18
19#[derive(Debug, Clone, Copy)]
20pub struct BlockHeader {
21 pub count: u16,
22 pub doc_id_bits: u8,
23 pub ordinal_bits: u8,
24 pub weight_quant: WeightQuantization,
25 pub first_doc_id: DocId,
26 pub max_weight: f32,
27}
28
29impl BlockHeader {
30 pub const SIZE: usize = 16;
31
32 pub fn write<W: Write>(&self, w: &mut W) -> io::Result<()> {
33 w.write_u16::<LittleEndian>(self.count)?;
34 w.write_u8(self.doc_id_bits)?;
35 w.write_u8(self.ordinal_bits)?;
36 w.write_u8(self.weight_quant as u8)?;
37 w.write_u8(0)?;
38 w.write_u16::<LittleEndian>(0)?;
39 w.write_u32::<LittleEndian>(self.first_doc_id)?;
40 w.write_f32::<LittleEndian>(self.max_weight)?;
41 Ok(())
42 }
43
44 pub fn read<R: Read>(r: &mut R) -> io::Result<Self> {
45 let count = r.read_u16::<LittleEndian>()?;
46 let doc_id_bits = r.read_u8()?;
47 let ordinal_bits = r.read_u8()?;
48 let weight_quant_byte = r.read_u8()?;
49 let _ = r.read_u8()?;
50 let _ = r.read_u16::<LittleEndian>()?;
51 let first_doc_id = r.read_u32::<LittleEndian>()?;
52 let max_weight = r.read_f32::<LittleEndian>()?;
53
54 let weight_quant = WeightQuantization::from_u8(weight_quant_byte)
55 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Invalid weight quant"))?;
56
57 Ok(Self {
58 count,
59 doc_id_bits,
60 ordinal_bits,
61 weight_quant,
62 first_doc_id,
63 max_weight,
64 })
65 }
66}
67
68#[derive(Debug, Clone)]
69pub struct SparseBlock {
70 pub header: BlockHeader,
71 pub doc_ids_data: OwnedBytes,
73 pub ordinals_data: OwnedBytes,
75 pub weights_data: OwnedBytes,
77}
78
79impl SparseBlock {
80 pub fn from_postings(
81 postings: &[(DocId, u16, f32)],
82 weight_quant: WeightQuantization,
83 ) -> io::Result<Self> {
84 assert!(!postings.is_empty() && postings.len() <= BLOCK_SIZE);
85
86 let count = postings.len();
87 let first_doc_id = postings[0].0;
88
89 let mut deltas = Vec::with_capacity(count);
91 let mut prev = first_doc_id;
92 for &(doc_id, _, _) in postings {
93 deltas.push(doc_id.saturating_sub(prev));
94 prev = doc_id;
95 }
96 deltas[0] = 0;
97
98 let doc_id_bits = simd::round_bit_width(find_optimal_bit_width(&deltas[1..]));
99 let ordinals: Vec<u16> = postings.iter().map(|(_, o, _)| *o).collect();
100 let max_ordinal = ordinals.iter().copied().max().unwrap_or(0);
101 let ordinal_bits = if max_ordinal == 0 {
102 0
103 } else {
104 simd::round_bit_width(bits_needed_u16(max_ordinal))
105 };
106
107 let weights: Vec<f32> = postings.iter().map(|(_, _, w)| *w).collect();
108 let max_weight = weights.iter().copied().fold(0.0f32, f32::max);
109
110 let doc_ids_data = OwnedBytes::new({
111 let rounded = simd::RoundedBitWidth::from_u8(doc_id_bits);
112 let num_deltas = count - 1;
113 let byte_count = num_deltas * rounded.bytes_per_value();
114 let mut data = vec![0u8; byte_count];
115 simd::pack_rounded(&deltas[1..], rounded, &mut data);
116 data
117 });
118 let ordinals_data = OwnedBytes::new(if ordinal_bits > 0 {
119 let rounded = simd::RoundedBitWidth::from_u8(ordinal_bits);
120 let byte_count = count * rounded.bytes_per_value();
121 let mut data = vec![0u8; byte_count];
122 let ord_u32: Vec<u32> = ordinals.iter().map(|&o| o as u32).collect();
123 simd::pack_rounded(&ord_u32, rounded, &mut data);
124 data
125 } else {
126 Vec::new()
127 });
128 let weights_data = OwnedBytes::new(encode_weights(&weights, weight_quant)?);
129
130 Ok(Self {
131 header: BlockHeader {
132 count: count as u16,
133 doc_id_bits,
134 ordinal_bits,
135 weight_quant,
136 first_doc_id,
137 max_weight,
138 },
139 doc_ids_data,
140 ordinals_data,
141 weights_data,
142 })
143 }
144
145 pub fn decode_doc_ids(&self) -> Vec<DocId> {
146 let mut out = Vec::with_capacity(self.header.count as usize);
147 self.decode_doc_ids_into(&mut out);
148 out
149 }
150
151 pub fn decode_doc_ids_into(&self, out: &mut Vec<DocId>) {
155 let count = self.header.count as usize;
156 out.clear();
157 out.resize(count, 0);
158 out[0] = self.header.first_doc_id;
159
160 if count > 1 {
161 let bits = self.header.doc_id_bits;
162 if bits == 0 {
163 out[1..].fill(self.header.first_doc_id);
165 } else {
166 simd::unpack_rounded(
168 &self.doc_ids_data,
169 simd::RoundedBitWidth::from_u8(bits),
170 &mut out[1..],
171 count - 1,
172 );
173 for i in 1..count {
175 out[i] += out[i - 1];
176 }
177 }
178 }
179 }
180
181 pub fn decode_ordinals(&self) -> Vec<u16> {
182 let mut out = Vec::with_capacity(self.header.count as usize);
183 self.decode_ordinals_into(&mut out);
184 out
185 }
186
187 pub fn decode_ordinals_into(&self, out: &mut Vec<u16>) {
191 let count = self.header.count as usize;
192 out.clear();
193 if self.header.ordinal_bits == 0 {
194 out.resize(count, 0u16);
195 } else {
196 let mut temp = [0u32; BLOCK_SIZE];
198 simd::unpack_rounded(
199 &self.ordinals_data,
200 simd::RoundedBitWidth::from_u8(self.header.ordinal_bits),
201 &mut temp[..count],
202 count,
203 );
204 out.reserve(count);
205 for &v in &temp[..count] {
206 out.push(v as u16);
207 }
208 }
209 }
210
211 pub fn decode_weights(&self) -> Vec<f32> {
212 let mut out = Vec::with_capacity(self.header.count as usize);
213 self.decode_weights_into(&mut out);
214 out
215 }
216
217 pub fn decode_weights_into(&self, out: &mut Vec<f32>) {
219 out.clear();
220 decode_weights_into(
221 &self.weights_data,
222 self.header.weight_quant,
223 self.header.count as usize,
224 out,
225 );
226 }
227
228 pub fn decode_scored_weights_into(&self, query_weight: f32, out: &mut Vec<f32>) {
236 out.clear();
237 let count = self.header.count as usize;
238 match self.header.weight_quant {
239 WeightQuantization::UInt8 if self.weights_data.len() >= 8 => {
240 let scale = f32::from_le_bytes([
242 self.weights_data[0],
243 self.weights_data[1],
244 self.weights_data[2],
245 self.weights_data[3],
246 ]);
247 let min_val = f32::from_le_bytes([
248 self.weights_data[4],
249 self.weights_data[5],
250 self.weights_data[6],
251 self.weights_data[7],
252 ]);
253 let eff_scale = query_weight * scale;
255 let eff_bias = query_weight * min_val;
256 out.resize(count, 0.0);
257 simd::dequantize_uint8(&self.weights_data[8..], out, eff_scale, eff_bias, count);
258 }
259 _ => {
260 decode_weights_into(&self.weights_data, self.header.weight_quant, count, out);
262 for w in out.iter_mut() {
263 *w *= query_weight;
264 }
265 }
266 }
267 }
268
269 pub fn write<W: Write>(&self, w: &mut W) -> io::Result<()> {
270 self.header.write(w)?;
271 w.write_u16::<LittleEndian>(self.doc_ids_data.len() as u16)?;
272 w.write_u16::<LittleEndian>(self.ordinals_data.len() as u16)?;
273 w.write_u16::<LittleEndian>(self.weights_data.len() as u16)?;
274 w.write_u16::<LittleEndian>(0)?;
275 w.write_all(&self.doc_ids_data)?;
276 w.write_all(&self.ordinals_data)?;
277 w.write_all(&self.weights_data)?;
278 Ok(())
279 }
280
281 pub fn read<R: Read>(r: &mut R) -> io::Result<Self> {
282 let header = BlockHeader::read(r)?;
283 let doc_ids_len = r.read_u16::<LittleEndian>()? as usize;
284 let ordinals_len = r.read_u16::<LittleEndian>()? as usize;
285 let weights_len = r.read_u16::<LittleEndian>()? as usize;
286 let _ = r.read_u16::<LittleEndian>()?;
287
288 let mut doc_ids_vec = vec![0u8; doc_ids_len];
289 r.read_exact(&mut doc_ids_vec)?;
290 let mut ordinals_vec = vec![0u8; ordinals_len];
291 r.read_exact(&mut ordinals_vec)?;
292 let mut weights_vec = vec![0u8; weights_len];
293 r.read_exact(&mut weights_vec)?;
294
295 Ok(Self {
296 header,
297 doc_ids_data: OwnedBytes::new(doc_ids_vec),
298 ordinals_data: OwnedBytes::new(ordinals_vec),
299 weights_data: OwnedBytes::new(weights_vec),
300 })
301 }
302
303 pub fn from_owned_bytes(data: crate::directories::OwnedBytes) -> crate::Result<Self> {
309 let b = data.as_slice();
310 if b.len() < BlockHeader::SIZE + 8 {
311 return Err(crate::Error::Corruption(
312 "sparse block too small".to_string(),
313 ));
314 }
315 let mut cursor = Cursor::new(&b[..BlockHeader::SIZE]);
316 let header =
317 BlockHeader::read(&mut cursor).map_err(|e| crate::Error::Corruption(e.to_string()))?;
318
319 if header.count == 0 {
320 let hex: String = b
321 .iter()
322 .take(32)
323 .map(|x| format!("{x:02x}"))
324 .collect::<Vec<_>>()
325 .join(" ");
326 return Err(crate::Error::Corruption(format!(
327 "sparse block has count=0 (data_len={}, first_32_bytes=[{}])",
328 b.len(),
329 hex
330 )));
331 }
332
333 let p = BlockHeader::SIZE;
334 let doc_ids_len = u16::from_le_bytes([b[p], b[p + 1]]) as usize;
335 let ordinals_len = u16::from_le_bytes([b[p + 2], b[p + 3]]) as usize;
336 let weights_len = u16::from_le_bytes([b[p + 4], b[p + 5]]) as usize;
337 let data_start = p + 8;
340 let ord_start = data_start + doc_ids_len;
341 let wt_start = ord_start + ordinals_len;
342 let expected_end = wt_start + weights_len;
343
344 if expected_end > b.len() {
345 let hex: String = b
346 .iter()
347 .take(32)
348 .map(|x| format!("{x:02x}"))
349 .collect::<Vec<_>>()
350 .join(" ");
351 return Err(crate::Error::Corruption(format!(
352 "sparse block sub-block overflow: count={} doc_ids={}B ords={}B wts={}B need={}B have={}B (first_32=[{}])",
353 header.count,
354 doc_ids_len,
355 ordinals_len,
356 weights_len,
357 expected_end,
358 b.len(),
359 hex
360 )));
361 }
362
363 Ok(Self {
364 header,
365 doc_ids_data: data.slice(data_start..ord_start),
366 ordinals_data: data.slice(ord_start..wt_start),
367 weights_data: data.slice(wt_start..wt_start + weights_len),
368 })
369 }
370
371 pub fn with_doc_offset(&self, doc_offset: u32) -> Self {
377 Self {
378 header: BlockHeader {
379 first_doc_id: self.header.first_doc_id + doc_offset,
380 ..self.header
381 },
382 doc_ids_data: self.doc_ids_data.clone(),
383 ordinals_data: self.ordinals_data.clone(),
384 weights_data: self.weights_data.clone(),
385 }
386 }
387}
388
389#[derive(Debug, Clone)]
394pub struct BlockSparsePostingList {
395 pub doc_count: u32,
396 pub blocks: Vec<SparseBlock>,
397}
398
399impl BlockSparsePostingList {
400 pub fn from_postings_with_block_size(
402 postings: &[(DocId, u16, f32)],
403 weight_quant: WeightQuantization,
404 block_size: usize,
405 ) -> io::Result<Self> {
406 if postings.is_empty() {
407 return Ok(Self {
408 doc_count: 0,
409 blocks: Vec::new(),
410 });
411 }
412
413 let block_size = block_size.max(16); let mut blocks = Vec::new();
415 for chunk in postings.chunks(block_size) {
416 blocks.push(SparseBlock::from_postings(chunk, weight_quant)?);
417 }
418
419 let mut unique_docs = 1u32;
424 for i in 1..postings.len() {
425 if postings[i].0 != postings[i - 1].0 {
426 unique_docs += 1;
427 }
428 }
429
430 Ok(Self {
431 doc_count: unique_docs,
432 blocks,
433 })
434 }
435
436 pub fn from_postings(
438 postings: &[(DocId, u16, f32)],
439 weight_quant: WeightQuantization,
440 ) -> io::Result<Self> {
441 Self::from_postings_with_block_size(postings, weight_quant, BLOCK_SIZE)
442 }
443
444 pub fn doc_count(&self) -> u32 {
445 self.doc_count
446 }
447
448 pub fn num_blocks(&self) -> usize {
449 self.blocks.len()
450 }
451
452 pub fn global_max_weight(&self) -> f32 {
453 self.blocks
454 .iter()
455 .map(|b| b.header.max_weight)
456 .fold(0.0f32, f32::max)
457 }
458
459 pub fn block_max_weight(&self, block_idx: usize) -> Option<f32> {
460 self.blocks.get(block_idx).map(|b| b.header.max_weight)
461 }
462
463 pub fn size_bytes(&self) -> usize {
465 use std::mem::size_of;
466
467 let header_size = size_of::<u32>() * 2; let blocks_size: usize = self
469 .blocks
470 .iter()
471 .map(|b| {
472 size_of::<BlockHeader>()
473 + b.doc_ids_data.len()
474 + b.ordinals_data.len()
475 + b.weights_data.len()
476 })
477 .sum();
478 header_size + blocks_size
479 }
480
481 pub fn iterator(&self) -> BlockSparsePostingIterator<'_> {
482 BlockSparsePostingIterator::new(self)
483 }
484
485 pub fn serialize<W: Write>(&self, w: &mut W) -> io::Result<()> {
494 use super::SparseSkipEntry;
495
496 w.write_u32::<LittleEndian>(self.doc_count)?;
497 w.write_f32::<LittleEndian>(self.global_max_weight())?;
498 w.write_u32::<LittleEndian>(self.blocks.len() as u32)?;
499
500 let mut block_bytes: Vec<Vec<u8>> = Vec::with_capacity(self.blocks.len());
502 for block in &self.blocks {
503 let mut buf = Vec::new();
504 block.write(&mut buf)?;
505 block_bytes.push(buf);
506 }
507
508 let mut offset = 0u32;
510 for (block, bytes) in self.blocks.iter().zip(block_bytes.iter()) {
511 let first_doc = block.header.first_doc_id;
512 let doc_ids = block.decode_doc_ids();
513 let last_doc = doc_ids.last().copied().unwrap_or(first_doc);
514 let length = bytes.len() as u32;
515
516 let entry =
517 SparseSkipEntry::new(first_doc, last_doc, offset, length, block.header.max_weight);
518 entry.write(w)?;
519 offset += length;
520 }
521
522 for bytes in block_bytes {
524 w.write_all(&bytes)?;
525 }
526
527 Ok(())
528 }
529
530 pub fn serialize_v3(&self) -> io::Result<(Vec<u8>, Vec<super::SparseSkipEntry>)> {
536 let mut block_data = Vec::new();
538 let mut skip_entries = Vec::with_capacity(self.blocks.len());
539 let mut offset = 0u32;
540
541 for block in &self.blocks {
542 let mut buf = Vec::new();
543 block.write(&mut buf)?;
544 let length = buf.len() as u32;
545
546 let first_doc = block.header.first_doc_id;
547 let doc_ids = block.decode_doc_ids();
548 let last_doc = doc_ids.last().copied().unwrap_or(first_doc);
549
550 skip_entries.push(super::SparseSkipEntry::new(
551 first_doc,
552 last_doc,
553 offset,
554 length,
555 block.header.max_weight,
556 ));
557
558 block_data.extend_from_slice(&buf);
559 offset += length;
560 }
561
562 Ok((block_data, skip_entries))
563 }
564
565 pub fn deserialize<R: Read>(r: &mut R) -> io::Result<Self> {
568 use super::SparseSkipEntry;
569
570 let doc_count = r.read_u32::<LittleEndian>()?;
571 let _global_max_weight = r.read_f32::<LittleEndian>()?;
572 let num_blocks = r.read_u32::<LittleEndian>()? as usize;
573
574 for _ in 0..num_blocks {
576 let _ = SparseSkipEntry::read(r)?;
577 }
578
579 let mut blocks = Vec::with_capacity(num_blocks);
581 for _ in 0..num_blocks {
582 blocks.push(SparseBlock::read(r)?);
583 }
584 Ok(Self { doc_count, blocks })
585 }
586
587 pub fn deserialize_header<R: Read>(
590 r: &mut R,
591 ) -> io::Result<(u32, f32, Vec<super::SparseSkipEntry>, usize)> {
592 use super::SparseSkipEntry;
593
594 let doc_count = r.read_u32::<LittleEndian>()?;
595 let global_max_weight = r.read_f32::<LittleEndian>()?;
596 let num_blocks = r.read_u32::<LittleEndian>()? as usize;
597
598 let mut entries = Vec::with_capacity(num_blocks);
599 for _ in 0..num_blocks {
600 entries.push(SparseSkipEntry::read(r)?);
601 }
602
603 let header_size = 4 + 4 + 4 + num_blocks * SparseSkipEntry::SIZE;
605
606 Ok((doc_count, global_max_weight, entries, header_size))
607 }
608
609 pub fn decode_all(&self) -> Vec<(DocId, u16, f32)> {
610 let total_postings: usize = self.blocks.iter().map(|b| b.header.count as usize).sum();
611 let mut result = Vec::with_capacity(total_postings);
612 for block in &self.blocks {
613 let doc_ids = block.decode_doc_ids();
614 let ordinals = block.decode_ordinals();
615 let weights = block.decode_weights();
616 for i in 0..block.header.count as usize {
617 result.push((doc_ids[i], ordinals[i], weights[i]));
618 }
619 }
620 result
621 }
622
623 pub fn merge_with_offsets(lists: &[(&BlockSparsePostingList, u32)]) -> Self {
634 if lists.is_empty() {
635 return Self {
636 doc_count: 0,
637 blocks: Vec::new(),
638 };
639 }
640
641 let total_blocks: usize = lists.iter().map(|(pl, _)| pl.blocks.len()).sum();
643 let total_docs: u32 = lists.iter().map(|(pl, _)| pl.doc_count).sum();
644
645 let mut merged_blocks = Vec::with_capacity(total_blocks);
646
647 for (posting_list, doc_offset) in lists {
649 for block in &posting_list.blocks {
650 merged_blocks.push(block.with_doc_offset(*doc_offset));
651 }
652 }
653
654 Self {
655 doc_count: total_docs,
656 blocks: merged_blocks,
657 }
658 }
659
660 fn find_block(&self, target: DocId) -> Option<usize> {
661 if self.blocks.is_empty() {
662 return None;
663 }
664 let idx = self
667 .blocks
668 .partition_point(|b| b.header.first_doc_id <= target);
669 if idx == 0 {
670 Some(0)
672 } else {
673 Some(idx - 1)
674 }
675 }
676}
677
678pub struct BlockSparsePostingIterator<'a> {
683 posting_list: &'a BlockSparsePostingList,
684 block_idx: usize,
685 in_block_idx: usize,
686 current_doc_ids: Vec<DocId>,
687 current_ordinals: Vec<u16>,
688 current_weights: Vec<f32>,
689 exhausted: bool,
690}
691
692impl<'a> BlockSparsePostingIterator<'a> {
693 fn new(posting_list: &'a BlockSparsePostingList) -> Self {
694 let mut iter = Self {
695 posting_list,
696 block_idx: 0,
697 in_block_idx: 0,
698 current_doc_ids: Vec::new(),
699 current_ordinals: Vec::new(),
700 current_weights: Vec::new(),
701 exhausted: posting_list.blocks.is_empty(),
702 };
703 if !iter.exhausted {
704 iter.load_block(0);
705 }
706 iter
707 }
708
709 fn load_block(&mut self, block_idx: usize) {
710 if let Some(block) = self.posting_list.blocks.get(block_idx) {
711 block.decode_doc_ids_into(&mut self.current_doc_ids);
712 block.decode_ordinals_into(&mut self.current_ordinals);
713 block.decode_weights_into(&mut self.current_weights);
714 self.block_idx = block_idx;
715 self.in_block_idx = 0;
716 }
717 }
718
719 pub fn doc(&self) -> DocId {
720 if self.exhausted {
721 TERMINATED
722 } else {
723 self.current_doc_ids
724 .get(self.in_block_idx)
725 .copied()
726 .unwrap_or(TERMINATED)
727 }
728 }
729
730 pub fn weight(&self) -> f32 {
731 self.current_weights
732 .get(self.in_block_idx)
733 .copied()
734 .unwrap_or(0.0)
735 }
736
737 pub fn ordinal(&self) -> u16 {
738 self.current_ordinals
739 .get(self.in_block_idx)
740 .copied()
741 .unwrap_or(0)
742 }
743
744 pub fn advance(&mut self) -> DocId {
745 if self.exhausted {
746 return TERMINATED;
747 }
748 self.in_block_idx += 1;
749 if self.in_block_idx >= self.current_doc_ids.len() {
750 self.block_idx += 1;
751 if self.block_idx >= self.posting_list.blocks.len() {
752 self.exhausted = true;
753 } else {
754 self.load_block(self.block_idx);
755 }
756 }
757 self.doc()
758 }
759
760 pub fn seek(&mut self, target: DocId) -> DocId {
761 if self.exhausted {
762 return TERMINATED;
763 }
764 if self.doc() >= target {
765 return self.doc();
766 }
767
768 if let Some(&last_doc) = self.current_doc_ids.last()
770 && last_doc >= target
771 {
772 let remaining = &self.current_doc_ids[self.in_block_idx..];
773 let pos = crate::structures::simd::find_first_ge_u32(remaining, target);
774 self.in_block_idx += pos;
775 if self.in_block_idx >= self.current_doc_ids.len() {
776 self.block_idx += 1;
777 if self.block_idx >= self.posting_list.blocks.len() {
778 self.exhausted = true;
779 } else {
780 self.load_block(self.block_idx);
781 }
782 }
783 return self.doc();
784 }
785
786 if let Some(block_idx) = self.posting_list.find_block(target) {
788 self.load_block(block_idx);
789 let pos = crate::structures::simd::find_first_ge_u32(&self.current_doc_ids, target);
790 self.in_block_idx = pos;
791 if self.in_block_idx >= self.current_doc_ids.len() {
792 self.block_idx += 1;
793 if self.block_idx >= self.posting_list.blocks.len() {
794 self.exhausted = true;
795 } else {
796 self.load_block(self.block_idx);
797 }
798 }
799 } else {
800 self.exhausted = true;
801 }
802 self.doc()
803 }
804
805 pub fn skip_to_next_block(&mut self) -> DocId {
808 if self.exhausted {
809 return TERMINATED;
810 }
811 let next = self.block_idx + 1;
812 if next >= self.posting_list.blocks.len() {
813 self.exhausted = true;
814 return TERMINATED;
815 }
816 self.load_block(next);
817 self.doc()
818 }
819
820 pub fn is_exhausted(&self) -> bool {
821 self.exhausted
822 }
823
824 pub fn current_block_max_weight(&self) -> f32 {
825 self.posting_list
826 .blocks
827 .get(self.block_idx)
828 .map(|b| b.header.max_weight)
829 .unwrap_or(0.0)
830 }
831
832 pub fn current_block_max_contribution(&self, query_weight: f32) -> f32 {
833 query_weight * self.current_block_max_weight()
834 }
835}
836
837fn find_optimal_bit_width(values: &[u32]) -> u8 {
842 if values.is_empty() {
843 return 0;
844 }
845 let max_val = values.iter().copied().max().unwrap_or(0);
846 simd::bits_needed(max_val)
847}
848
849fn bits_needed_u16(val: u16) -> u8 {
850 if val == 0 {
851 0
852 } else {
853 16 - val.leading_zeros() as u8
854 }
855}
856
857fn encode_weights(weights: &[f32], quant: WeightQuantization) -> io::Result<Vec<u8>> {
862 let mut data = Vec::new();
863 match quant {
864 WeightQuantization::Float32 => {
865 for &w in weights {
866 data.write_f32::<LittleEndian>(w)?;
867 }
868 }
869 WeightQuantization::Float16 => {
870 use half::f16;
871 for &w in weights {
872 data.write_u16::<LittleEndian>(f16::from_f32(w).to_bits())?;
873 }
874 }
875 WeightQuantization::UInt8 => {
876 let min = weights.iter().copied().fold(f32::INFINITY, f32::min);
877 let max = weights.iter().copied().fold(f32::NEG_INFINITY, f32::max);
878 let range = max - min;
879 let scale = if range < f32::EPSILON {
880 1.0
881 } else {
882 range / 255.0
883 };
884 data.write_f32::<LittleEndian>(scale)?;
885 data.write_f32::<LittleEndian>(min)?;
886 for &w in weights {
887 data.write_u8(((w - min) / scale).round() as u8)?;
888 }
889 }
890 WeightQuantization::UInt4 => {
891 let min = weights.iter().copied().fold(f32::INFINITY, f32::min);
892 let max = weights.iter().copied().fold(f32::NEG_INFINITY, f32::max);
893 let range = max - min;
894 let scale = if range < f32::EPSILON {
895 1.0
896 } else {
897 range / 15.0
898 };
899 data.write_f32::<LittleEndian>(scale)?;
900 data.write_f32::<LittleEndian>(min)?;
901 let mut i = 0;
902 while i < weights.len() {
903 let q1 = ((weights[i] - min) / scale).round() as u8 & 0x0F;
904 let q2 = if i + 1 < weights.len() {
905 ((weights[i + 1] - min) / scale).round() as u8 & 0x0F
906 } else {
907 0
908 };
909 data.write_u8((q2 << 4) | q1)?;
910 i += 2;
911 }
912 }
913 }
914 Ok(data)
915}
916
917fn decode_weights_into(data: &[u8], quant: WeightQuantization, count: usize, out: &mut Vec<f32>) {
918 let mut cursor = Cursor::new(data);
919 match quant {
920 WeightQuantization::Float32 => {
921 for _ in 0..count {
922 out.push(cursor.read_f32::<LittleEndian>().unwrap_or(0.0));
923 }
924 }
925 WeightQuantization::Float16 => {
926 use half::f16;
927 for _ in 0..count {
928 let bits = cursor.read_u16::<LittleEndian>().unwrap_or(0);
929 out.push(f16::from_bits(bits).to_f32());
930 }
931 }
932 WeightQuantization::UInt8 => {
933 let scale = cursor.read_f32::<LittleEndian>().unwrap_or(1.0);
934 let min_val = cursor.read_f32::<LittleEndian>().unwrap_or(0.0);
935 let offset = cursor.position() as usize;
936 out.resize(count, 0.0);
937 simd::dequantize_uint8(&data[offset..], out, scale, min_val, count);
938 }
939 WeightQuantization::UInt4 => {
940 let scale = cursor.read_f32::<LittleEndian>().unwrap_or(1.0);
941 let min = cursor.read_f32::<LittleEndian>().unwrap_or(0.0);
942 let mut i = 0;
943 while i < count {
944 let byte = cursor.read_u8().unwrap_or(0);
945 out.push((byte & 0x0F) as f32 * scale + min);
946 i += 1;
947 if i < count {
948 out.push((byte >> 4) as f32 * scale + min);
949 i += 1;
950 }
951 }
952 }
953 }
954}
955
956#[cfg(test)]
957mod tests {
958 use super::*;
959
960 #[test]
961 fn test_block_roundtrip() {
962 let postings = vec![
963 (10u32, 0u16, 1.5f32),
964 (15, 0, 2.0),
965 (20, 1, 0.5),
966 (100, 0, 3.0),
967 ];
968 let block = SparseBlock::from_postings(&postings, WeightQuantization::Float32).unwrap();
969
970 assert_eq!(block.decode_doc_ids(), vec![10, 15, 20, 100]);
971 assert_eq!(block.decode_ordinals(), vec![0, 0, 1, 0]);
972 let weights = block.decode_weights();
973 assert!((weights[0] - 1.5).abs() < 0.01);
974 }
975
976 #[test]
977 fn test_posting_list() {
978 let postings: Vec<(DocId, u16, f32)> =
979 (0..300).map(|i| (i * 2, 0, i as f32 * 0.1)).collect();
980 let list =
981 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
982
983 assert_eq!(list.doc_count(), 300);
984 assert_eq!(list.num_blocks(), 3);
985
986 let mut iter = list.iterator();
987 assert_eq!(iter.doc(), 0);
988 iter.advance();
989 assert_eq!(iter.doc(), 2);
990 }
991
992 #[test]
993 fn test_serialization() {
994 let postings = vec![(1u32, 0u16, 0.5f32), (10, 1, 1.5), (100, 0, 2.5)];
995 let list =
996 BlockSparsePostingList::from_postings(&postings, WeightQuantization::UInt8).unwrap();
997
998 let mut buf = Vec::new();
999 list.serialize(&mut buf).unwrap();
1000 let list2 = BlockSparsePostingList::deserialize(&mut Cursor::new(&buf)).unwrap();
1001
1002 assert_eq!(list.doc_count(), list2.doc_count());
1003 }
1004
1005 #[test]
1006 fn test_seek() {
1007 let postings: Vec<(DocId, u16, f32)> = (0..500).map(|i| (i * 3, 0, i as f32)).collect();
1008 let list =
1009 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1010
1011 let mut iter = list.iterator();
1012 assert_eq!(iter.seek(300), 300);
1013 assert_eq!(iter.seek(301), 303);
1014 assert_eq!(iter.seek(2000), TERMINATED);
1015 }
1016
1017 #[test]
1018 fn test_merge_with_offsets() {
1019 let postings1: Vec<(DocId, u16, f32)> = vec![(0, 0, 1.0), (5, 0, 2.0), (10, 1, 3.0)];
1021 let list1 =
1022 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1023
1024 let postings2: Vec<(DocId, u16, f32)> = vec![(0, 0, 4.0), (3, 1, 5.0), (7, 0, 6.0)];
1026 let list2 =
1027 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1028
1029 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
1031
1032 assert_eq!(merged.doc_count(), 6);
1033
1034 let decoded = merged.decode_all();
1036 assert_eq!(decoded.len(), 6);
1037
1038 assert_eq!(decoded[0].0, 0);
1040 assert_eq!(decoded[1].0, 5);
1041 assert_eq!(decoded[2].0, 10);
1042
1043 assert_eq!(decoded[3].0, 100); assert_eq!(decoded[4].0, 103); assert_eq!(decoded[5].0, 107); assert!((decoded[0].2 - 1.0).abs() < 0.01);
1050 assert!((decoded[3].2 - 4.0).abs() < 0.01);
1051
1052 assert_eq!(decoded[2].1, 1); assert_eq!(decoded[4].1, 1); }
1056
1057 #[test]
1058 fn test_merge_with_offsets_multi_block() {
1059 let postings1: Vec<(DocId, u16, f32)> = (0..200).map(|i| (i * 2, 0, i as f32)).collect();
1061 let list1 =
1062 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1063 assert!(list1.num_blocks() > 1, "Should have multiple blocks");
1064
1065 let postings2: Vec<(DocId, u16, f32)> = (0..150).map(|i| (i * 3, 1, i as f32)).collect();
1066 let list2 =
1067 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1068
1069 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 1000)]);
1071
1072 assert_eq!(merged.doc_count(), 350);
1073 assert_eq!(merged.num_blocks(), list1.num_blocks() + list2.num_blocks());
1074
1075 let mut iter = merged.iterator();
1077
1078 assert_eq!(iter.doc(), 0);
1080
1081 let doc = iter.seek(1000);
1083 assert_eq!(doc, 1000); iter.advance();
1087 assert_eq!(iter.doc(), 1003); }
1089
1090 #[test]
1091 fn test_merge_with_offsets_serialize_roundtrip() {
1092 let postings1: Vec<(DocId, u16, f32)> = vec![(0, 0, 1.0), (5, 0, 2.0), (10, 1, 3.0)];
1094 let list1 =
1095 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1096
1097 let postings2: Vec<(DocId, u16, f32)> = vec![(0, 0, 4.0), (3, 1, 5.0), (7, 0, 6.0)];
1098 let list2 =
1099 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1100
1101 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
1103
1104 let mut bytes = Vec::new();
1106 merged.serialize(&mut bytes).unwrap();
1107
1108 let mut cursor = std::io::Cursor::new(&bytes);
1110 let loaded = BlockSparsePostingList::deserialize(&mut cursor).unwrap();
1111
1112 let decoded = loaded.decode_all();
1114 assert_eq!(decoded.len(), 6);
1115
1116 assert_eq!(decoded[0].0, 0);
1118 assert_eq!(decoded[1].0, 5);
1119 assert_eq!(decoded[2].0, 10);
1120
1121 assert_eq!(decoded[3].0, 100, "First doc of seg2 should be 0+100=100");
1123 assert_eq!(decoded[4].0, 103, "Second doc of seg2 should be 3+100=103");
1124 assert_eq!(decoded[5].0, 107, "Third doc of seg2 should be 7+100=107");
1125
1126 let mut iter = loaded.iterator();
1128 assert_eq!(iter.doc(), 0);
1129 iter.advance();
1130 assert_eq!(iter.doc(), 5);
1131 iter.advance();
1132 assert_eq!(iter.doc(), 10);
1133 iter.advance();
1134 assert_eq!(iter.doc(), 100);
1135 iter.advance();
1136 assert_eq!(iter.doc(), 103);
1137 iter.advance();
1138 assert_eq!(iter.doc(), 107);
1139 }
1140
1141 #[test]
1142 fn test_merge_seek_after_roundtrip() {
1143 let postings1: Vec<(DocId, u16, f32)> = (0..200).map(|i| (i * 2, 0, 1.0)).collect();
1145 let list1 =
1146 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1147
1148 let postings2: Vec<(DocId, u16, f32)> = (0..150).map(|i| (i * 3, 0, 2.0)).collect();
1149 let list2 =
1150 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1151
1152 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 1000)]);
1154
1155 let mut bytes = Vec::new();
1157 merged.serialize(&mut bytes).unwrap();
1158 let loaded =
1159 BlockSparsePostingList::deserialize(&mut std::io::Cursor::new(&bytes)).unwrap();
1160
1161 let mut iter = loaded.iterator();
1163
1164 let doc = iter.seek(100);
1166 assert_eq!(doc, 100, "Seek to 100 in segment 1");
1167
1168 let doc = iter.seek(1000);
1170 assert_eq!(doc, 1000, "Seek to 1000 (first doc of segment 2)");
1171
1172 let doc = iter.seek(1050);
1174 assert!(
1175 doc >= 1050,
1176 "Seek to 1050 should find doc >= 1050, got {}",
1177 doc
1178 );
1179
1180 let doc = iter.seek(500);
1182 assert!(
1183 doc >= 1050,
1184 "Seek backwards should not go back, got {}",
1185 doc
1186 );
1187
1188 let mut iter2 = loaded.iterator();
1190
1191 let mut count = 0;
1193 let mut prev_doc = 0;
1194 while iter2.doc() != super::TERMINATED {
1195 let current = iter2.doc();
1196 if count > 0 {
1197 assert!(
1198 current > prev_doc,
1199 "Docs should be monotonically increasing: {} vs {}",
1200 prev_doc,
1201 current
1202 );
1203 }
1204 prev_doc = current;
1205 iter2.advance();
1206 count += 1;
1207 }
1208 assert_eq!(count, 350, "Should have 350 total docs");
1209 }
1210
1211 #[test]
1212 fn test_doc_count_multi_value() {
1213 let postings: Vec<(DocId, u16, f32)> = vec![
1216 (0, 0, 1.0),
1217 (0, 1, 1.5),
1218 (0, 2, 2.0),
1219 (5, 0, 3.0),
1220 (5, 1, 3.5),
1221 (10, 0, 4.0),
1222 ];
1223 let list =
1224 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1225
1226 assert_eq!(list.doc_count(), 3);
1228
1229 let decoded = list.decode_all();
1231 assert_eq!(decoded.len(), 6);
1232 }
1233
1234 #[test]
1238 fn test_zero_copy_merge_patches_first_doc_id() {
1239 use crate::structures::SparseSkipEntry;
1240
1241 let postings1: Vec<(DocId, u16, f32)> = (0..200).map(|i| (i * 2, 0, i as f32)).collect();
1243 let list1 =
1244 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1245 assert!(list1.num_blocks() > 1);
1246
1247 let postings2: Vec<(DocId, u16, f32)> = (0..150).map(|i| (i * 3, 1, i as f32)).collect();
1248 let list2 =
1249 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1250
1251 let mut bytes1 = Vec::new();
1253 list1.serialize(&mut bytes1).unwrap();
1254 let mut bytes2 = Vec::new();
1255 list2.serialize(&mut bytes2).unwrap();
1256
1257 fn parse_raw(data: &[u8]) -> (u32, f32, Vec<SparseSkipEntry>, &[u8]) {
1259 let doc_count = u32::from_le_bytes(data[0..4].try_into().unwrap());
1260 let global_max = f32::from_le_bytes(data[4..8].try_into().unwrap());
1261 let num_blocks = u32::from_le_bytes(data[8..12].try_into().unwrap()) as usize;
1262 let mut pos = 12;
1263 let mut skip = Vec::new();
1264 for _ in 0..num_blocks {
1265 let first_doc = u32::from_le_bytes(data[pos..pos + 4].try_into().unwrap());
1266 let last_doc = u32::from_le_bytes(data[pos + 4..pos + 8].try_into().unwrap());
1267 let offset = u32::from_le_bytes(data[pos + 8..pos + 12].try_into().unwrap());
1268 let length = u32::from_le_bytes(data[pos + 12..pos + 16].try_into().unwrap());
1269 let max_w = f32::from_le_bytes(data[pos + 16..pos + 20].try_into().unwrap());
1270 skip.push(SparseSkipEntry::new(
1271 first_doc, last_doc, offset, length, max_w,
1272 ));
1273 pos += 20;
1274 }
1275 (doc_count, global_max, skip, &data[pos..])
1276 }
1277
1278 let (dc1, gm1, skip1, raw1) = parse_raw(&bytes1);
1279 let (dc2, gm2, skip2, raw2) = parse_raw(&bytes2);
1280
1281 let doc_offset: u32 = 1000; let total_docs = dc1 + dc2;
1284 let global_max = gm1.max(gm2);
1285 let total_blocks = (skip1.len() + skip2.len()) as u32;
1286
1287 let mut output = Vec::new();
1288 output.extend_from_slice(&total_docs.to_le_bytes());
1290 output.extend_from_slice(&global_max.to_le_bytes());
1291 output.extend_from_slice(&total_blocks.to_le_bytes());
1292
1293 let mut block_data_offset = 0u32;
1295 for entry in &skip1 {
1296 let adjusted = SparseSkipEntry::new(
1297 entry.first_doc,
1298 entry.last_doc,
1299 block_data_offset + entry.offset,
1300 entry.length,
1301 entry.max_weight,
1302 );
1303 adjusted.write(&mut output).unwrap();
1304 }
1305 if let Some(last) = skip1.last() {
1306 block_data_offset += last.offset + last.length;
1307 }
1308 for entry in &skip2 {
1309 let adjusted = SparseSkipEntry::new(
1310 entry.first_doc + doc_offset,
1311 entry.last_doc + doc_offset,
1312 block_data_offset + entry.offset,
1313 entry.length,
1314 entry.max_weight,
1315 );
1316 adjusted.write(&mut output).unwrap();
1317 }
1318
1319 output.extend_from_slice(raw1);
1321
1322 const FIRST_DOC_ID_OFFSET: usize = 8;
1323 let mut buf2 = raw2.to_vec();
1324 for entry in &skip2 {
1325 let off = entry.offset as usize + FIRST_DOC_ID_OFFSET;
1326 if off + 4 <= buf2.len() {
1327 let old = u32::from_le_bytes(buf2[off..off + 4].try_into().unwrap());
1328 let patched = (old + doc_offset).to_le_bytes();
1329 buf2[off..off + 4].copy_from_slice(&patched);
1330 }
1331 }
1332 output.extend_from_slice(&buf2);
1333
1334 let loaded = BlockSparsePostingList::deserialize(&mut Cursor::new(&output)).unwrap();
1336 assert_eq!(loaded.doc_count(), 350);
1337
1338 let mut iter = loaded.iterator();
1339
1340 assert_eq!(iter.doc(), 0);
1342 let doc = iter.seek(100);
1343 assert_eq!(doc, 100);
1344 let doc = iter.seek(398);
1345 assert_eq!(doc, 398);
1346
1347 let doc = iter.seek(1000);
1349 assert_eq!(doc, 1000, "First doc of segment 2 should be 1000");
1350 iter.advance();
1351 assert_eq!(iter.doc(), 1003, "Second doc of segment 2 should be 1003");
1352 let doc = iter.seek(1447);
1353 assert_eq!(doc, 1447, "Last doc of segment 2 should be 1447");
1354
1355 iter.advance();
1357 assert_eq!(iter.doc(), super::TERMINATED);
1358
1359 let reference =
1361 BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, doc_offset)]);
1362 let mut ref_iter = reference.iterator();
1363 let mut zc_iter = loaded.iterator();
1364 while ref_iter.doc() != super::TERMINATED {
1365 assert_eq!(
1366 ref_iter.doc(),
1367 zc_iter.doc(),
1368 "Zero-copy and reference merge should produce identical doc_ids"
1369 );
1370 assert!(
1371 (ref_iter.weight() - zc_iter.weight()).abs() < 0.01,
1372 "Weights should match: {} vs {}",
1373 ref_iter.weight(),
1374 zc_iter.weight()
1375 );
1376 ref_iter.advance();
1377 zc_iter.advance();
1378 }
1379 assert_eq!(zc_iter.doc(), super::TERMINATED);
1380 }
1381
1382 #[test]
1383 fn test_doc_count_single_value() {
1384 let postings: Vec<(DocId, u16, f32)> =
1386 vec![(0, 0, 1.0), (5, 0, 2.0), (10, 0, 3.0), (15, 0, 4.0)];
1387 let list =
1388 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1389
1390 assert_eq!(list.doc_count(), 4);
1392 }
1393
1394 #[test]
1395 fn test_doc_count_multi_value_serialization_roundtrip() {
1396 let postings: Vec<(DocId, u16, f32)> =
1398 vec![(0, 0, 1.0), (0, 1, 1.5), (5, 0, 2.0), (5, 1, 2.5)];
1399 let list =
1400 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1401 assert_eq!(list.doc_count(), 2);
1402
1403 let mut buf = Vec::new();
1404 list.serialize(&mut buf).unwrap();
1405 let loaded = BlockSparsePostingList::deserialize(&mut Cursor::new(&buf)).unwrap();
1406 assert_eq!(loaded.doc_count(), 2);
1407 }
1408
1409 #[test]
1410 fn test_merge_preserves_weights_and_ordinals() {
1411 let postings1: Vec<(DocId, u16, f32)> = vec![(0, 0, 1.5), (5, 1, 2.5), (10, 2, 3.5)];
1413 let list1 =
1414 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1415
1416 let postings2: Vec<(DocId, u16, f32)> = vec![(0, 0, 4.5), (3, 1, 5.5), (7, 3, 6.5)];
1417 let list2 =
1418 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1419
1420 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
1422
1423 let mut bytes = Vec::new();
1425 merged.serialize(&mut bytes).unwrap();
1426 let loaded =
1427 BlockSparsePostingList::deserialize(&mut std::io::Cursor::new(&bytes)).unwrap();
1428
1429 let mut iter = loaded.iterator();
1431
1432 assert_eq!(iter.doc(), 0);
1434 assert!(
1435 (iter.weight() - 1.5).abs() < 0.01,
1436 "Weight should be 1.5, got {}",
1437 iter.weight()
1438 );
1439 assert_eq!(iter.ordinal(), 0);
1440
1441 iter.advance();
1442 assert_eq!(iter.doc(), 5);
1443 assert!(
1444 (iter.weight() - 2.5).abs() < 0.01,
1445 "Weight should be 2.5, got {}",
1446 iter.weight()
1447 );
1448 assert_eq!(iter.ordinal(), 1);
1449
1450 iter.advance();
1451 assert_eq!(iter.doc(), 10);
1452 assert!(
1453 (iter.weight() - 3.5).abs() < 0.01,
1454 "Weight should be 3.5, got {}",
1455 iter.weight()
1456 );
1457 assert_eq!(iter.ordinal(), 2);
1458
1459 iter.advance();
1461 assert_eq!(iter.doc(), 100);
1462 assert!(
1463 (iter.weight() - 4.5).abs() < 0.01,
1464 "Weight should be 4.5, got {}",
1465 iter.weight()
1466 );
1467 assert_eq!(iter.ordinal(), 0);
1468
1469 iter.advance();
1470 assert_eq!(iter.doc(), 103);
1471 assert!(
1472 (iter.weight() - 5.5).abs() < 0.01,
1473 "Weight should be 5.5, got {}",
1474 iter.weight()
1475 );
1476 assert_eq!(iter.ordinal(), 1);
1477
1478 iter.advance();
1479 assert_eq!(iter.doc(), 107);
1480 assert!(
1481 (iter.weight() - 6.5).abs() < 0.01,
1482 "Weight should be 6.5, got {}",
1483 iter.weight()
1484 );
1485 assert_eq!(iter.ordinal(), 3);
1486
1487 iter.advance();
1489 assert_eq!(iter.doc(), super::TERMINATED);
1490 }
1491
1492 #[test]
1493 fn test_merge_global_max_weight() {
1494 let postings1: Vec<(DocId, u16, f32)> = vec![
1496 (0, 0, 3.0),
1497 (1, 0, 7.0), (2, 0, 2.0),
1499 ];
1500 let list1 =
1501 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1502
1503 let postings2: Vec<(DocId, u16, f32)> = vec![
1504 (0, 0, 5.0),
1505 (1, 0, 4.0),
1506 (2, 0, 6.0), ];
1508 let list2 =
1509 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1510
1511 assert!((list1.global_max_weight() - 7.0).abs() < 0.01);
1513 assert!((list2.global_max_weight() - 6.0).abs() < 0.01);
1514
1515 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
1517
1518 assert!(
1520 (merged.global_max_weight() - 7.0).abs() < 0.01,
1521 "Global max should be 7.0, got {}",
1522 merged.global_max_weight()
1523 );
1524
1525 let mut bytes = Vec::new();
1527 merged.serialize(&mut bytes).unwrap();
1528 let loaded =
1529 BlockSparsePostingList::deserialize(&mut std::io::Cursor::new(&bytes)).unwrap();
1530
1531 assert!(
1532 (loaded.global_max_weight() - 7.0).abs() < 0.01,
1533 "After roundtrip, global max should still be 7.0, got {}",
1534 loaded.global_max_weight()
1535 );
1536 }
1537
1538 #[test]
1539 fn test_scoring_simulation_after_merge() {
1540 let postings1: Vec<(DocId, u16, f32)> = vec![
1542 (0, 0, 0.5), (5, 0, 0.8), ];
1545 let list1 =
1546 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1547
1548 let postings2: Vec<(DocId, u16, f32)> = vec![
1549 (0, 0, 0.6), (3, 0, 0.9), ];
1552 let list2 =
1553 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1554
1555 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
1557
1558 let mut bytes = Vec::new();
1560 merged.serialize(&mut bytes).unwrap();
1561 let loaded =
1562 BlockSparsePostingList::deserialize(&mut std::io::Cursor::new(&bytes)).unwrap();
1563
1564 let query_weight = 2.0f32;
1566 let mut iter = loaded.iterator();
1567
1568 assert_eq!(iter.doc(), 0);
1571 let score = query_weight * iter.weight();
1572 assert!(
1573 (score - 1.0).abs() < 0.01,
1574 "Doc 0 score should be 1.0, got {}",
1575 score
1576 );
1577
1578 iter.advance();
1579 assert_eq!(iter.doc(), 5);
1581 let score = query_weight * iter.weight();
1582 assert!(
1583 (score - 1.6).abs() < 0.01,
1584 "Doc 5 score should be 1.6, got {}",
1585 score
1586 );
1587
1588 iter.advance();
1589 assert_eq!(iter.doc(), 100);
1591 let score = query_weight * iter.weight();
1592 assert!(
1593 (score - 1.2).abs() < 0.01,
1594 "Doc 100 score should be 1.2, got {}",
1595 score
1596 );
1597
1598 iter.advance();
1599 assert_eq!(iter.doc(), 103);
1601 let score = query_weight * iter.weight();
1602 assert!(
1603 (score - 1.8).abs() < 0.01,
1604 "Doc 103 score should be 1.8, got {}",
1605 score
1606 );
1607 }
1608}