1use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
9use std::io::{self, Cursor, Read, Write};
10
11use super::config::WeightQuantization;
12use crate::DocId;
13use crate::directories::OwnedBytes;
14use crate::structures::postings::TERMINATED;
15use crate::structures::simd;
16
17pub const BLOCK_SIZE: usize = 128;
18pub const MAX_BLOCK_SIZE: usize = 256;
19
20#[derive(Debug, Clone, Copy)]
21pub struct BlockHeader {
22 pub count: u16,
23 pub doc_id_bits: u8,
24 pub ordinal_bits: u8,
25 pub weight_quant: WeightQuantization,
26 pub first_doc_id: DocId,
27 pub max_weight: f32,
28}
29
30impl BlockHeader {
31 pub const SIZE: usize = 16;
32
33 pub fn write<W: Write>(&self, w: &mut W) -> io::Result<()> {
34 w.write_u16::<LittleEndian>(self.count)?;
35 w.write_u8(self.doc_id_bits)?;
36 w.write_u8(self.ordinal_bits)?;
37 w.write_u8(self.weight_quant as u8)?;
38 w.write_u8(0)?;
39 w.write_u16::<LittleEndian>(0)?;
40 w.write_u32::<LittleEndian>(self.first_doc_id)?;
41 w.write_f32::<LittleEndian>(self.max_weight)?;
42 Ok(())
43 }
44
45 pub fn read<R: Read>(r: &mut R) -> io::Result<Self> {
46 let count = r.read_u16::<LittleEndian>()?;
47 let doc_id_bits = r.read_u8()?;
48 let ordinal_bits = r.read_u8()?;
49 let weight_quant_byte = r.read_u8()?;
50 let _ = r.read_u8()?;
51 let _ = r.read_u16::<LittleEndian>()?;
52 let first_doc_id = r.read_u32::<LittleEndian>()?;
53 let max_weight = r.read_f32::<LittleEndian>()?;
54
55 let weight_quant = WeightQuantization::from_u8(weight_quant_byte)
56 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Invalid weight quant"))?;
57
58 Ok(Self {
59 count,
60 doc_id_bits,
61 ordinal_bits,
62 weight_quant,
63 first_doc_id,
64 max_weight,
65 })
66 }
67}
68
69#[derive(Debug, Clone)]
70pub struct SparseBlock {
71 pub header: BlockHeader,
72 pub doc_ids_data: OwnedBytes,
74 pub ordinals_data: OwnedBytes,
76 pub weights_data: OwnedBytes,
78 last_doc_id: DocId,
80}
81
82impl SparseBlock {
83 pub fn from_postings(
84 postings: &[(DocId, u16, f32)],
85 weight_quant: WeightQuantization,
86 ) -> io::Result<Self> {
87 assert!(!postings.is_empty() && postings.len() <= MAX_BLOCK_SIZE);
88
89 let count = postings.len();
90 let first_doc_id = postings[0].0;
91
92 let mut deltas = Vec::with_capacity(count);
94 let mut prev = first_doc_id;
95 for &(doc_id, _, _) in postings {
96 deltas.push(doc_id.saturating_sub(prev));
97 prev = doc_id;
98 }
99 deltas[0] = 0;
100
101 let doc_id_bits = simd::round_bit_width(find_optimal_bit_width(&deltas[1..]));
102 let ordinals: Vec<u16> = postings.iter().map(|(_, o, _)| *o).collect();
103 let max_ordinal = ordinals.iter().copied().max().unwrap_or(0);
104 let ordinal_bits = if max_ordinal == 0 {
105 0
106 } else {
107 simd::round_bit_width(bits_needed_u16(max_ordinal))
108 };
109
110 let weights: Vec<f32> = postings.iter().map(|(_, _, w)| *w).collect();
111 let max_weight = weights
112 .iter()
113 .copied()
114 .fold(0.0f32, |acc, w| acc.max(w.abs()));
115
116 let doc_ids_data = OwnedBytes::new({
117 let rounded = simd::RoundedBitWidth::from_u8(doc_id_bits);
118 let num_deltas = count - 1;
119 let byte_count = num_deltas * rounded.bytes_per_value();
120 let mut data = vec![0u8; byte_count];
121 simd::pack_rounded(&deltas[1..], rounded, &mut data);
122 data
123 });
124 let ordinals_data = OwnedBytes::new(if ordinal_bits > 0 {
125 let rounded = simd::RoundedBitWidth::from_u8(ordinal_bits);
126 let byte_count = count * rounded.bytes_per_value();
127 let mut data = vec![0u8; byte_count];
128 let ord_u32: Vec<u32> = ordinals.iter().map(|&o| o as u32).collect();
129 simd::pack_rounded(&ord_u32, rounded, &mut data);
130 data
131 } else {
132 Vec::new()
133 });
134 let weights_data = OwnedBytes::new(encode_weights(&weights, weight_quant)?);
135
136 let last_doc_id = postings.last().unwrap().0;
137
138 Ok(Self {
139 header: BlockHeader {
140 count: count as u16,
141 doc_id_bits,
142 ordinal_bits,
143 weight_quant,
144 first_doc_id,
145 max_weight,
146 },
147 doc_ids_data,
148 ordinals_data,
149 weights_data,
150 last_doc_id,
151 })
152 }
153
154 #[inline]
156 pub fn last_doc_id(&self) -> DocId {
157 self.last_doc_id
158 }
159
160 pub fn decode_doc_ids(&self) -> Vec<DocId> {
161 let mut out = Vec::with_capacity(self.header.count as usize);
162 self.decode_doc_ids_into(&mut out);
163 out
164 }
165
166 pub fn decode_doc_ids_into(&self, out: &mut Vec<DocId>) {
170 let count = self.header.count as usize;
171 out.clear();
172 out.resize(count, 0);
173 out[0] = self.header.first_doc_id;
174
175 if count > 1 {
176 let bits = self.header.doc_id_bits;
177 if bits == 0 {
178 out[1..].fill(self.header.first_doc_id);
180 } else {
181 simd::unpack_rounded(
183 &self.doc_ids_data,
184 simd::RoundedBitWidth::from_u8(bits),
185 &mut out[1..],
186 count - 1,
187 );
188 for i in 1..count {
190 out[i] += out[i - 1];
191 }
192 }
193 }
194 }
195
196 pub fn decode_ordinals(&self) -> Vec<u16> {
197 let mut out = Vec::with_capacity(self.header.count as usize);
198 self.decode_ordinals_into(&mut out);
199 out
200 }
201
202 pub fn decode_ordinals_into(&self, out: &mut Vec<u16>) {
206 let count = self.header.count as usize;
207 out.clear();
208 if self.header.ordinal_bits == 0 {
209 out.resize(count, 0u16);
210 } else {
211 let mut temp = [0u32; BLOCK_SIZE];
213 simd::unpack_rounded(
214 &self.ordinals_data,
215 simd::RoundedBitWidth::from_u8(self.header.ordinal_bits),
216 &mut temp[..count],
217 count,
218 );
219 out.reserve(count);
220 for &v in &temp[..count] {
221 out.push(v as u16);
222 }
223 }
224 }
225
226 pub fn decode_weights(&self) -> Vec<f32> {
227 let mut out = Vec::with_capacity(self.header.count as usize);
228 self.decode_weights_into(&mut out);
229 out
230 }
231
232 pub fn decode_weights_into(&self, out: &mut Vec<f32>) {
234 out.clear();
235 decode_weights_into(
236 &self.weights_data,
237 self.header.weight_quant,
238 self.header.count as usize,
239 out,
240 );
241 }
242
243 pub fn decode_scored_weights_into(&self, query_weight: f32, out: &mut Vec<f32>) {
251 out.clear();
252 let count = self.header.count as usize;
253 match self.header.weight_quant {
254 WeightQuantization::UInt8 if self.weights_data.len() >= 8 => {
255 let scale = f32::from_le_bytes([
257 self.weights_data[0],
258 self.weights_data[1],
259 self.weights_data[2],
260 self.weights_data[3],
261 ]);
262 let min_val = f32::from_le_bytes([
263 self.weights_data[4],
264 self.weights_data[5],
265 self.weights_data[6],
266 self.weights_data[7],
267 ]);
268 let eff_scale = query_weight * scale;
270 let eff_bias = query_weight * min_val;
271 out.resize(count, 0.0);
272 simd::dequantize_uint8(&self.weights_data[8..], out, eff_scale, eff_bias, count);
273 }
274 _ => {
275 decode_weights_into(&self.weights_data, self.header.weight_quant, count, out);
277 for w in out.iter_mut() {
278 *w *= query_weight;
279 }
280 }
281 }
282 }
283
284 #[inline]
296 pub fn accumulate_scored_weights(
297 &self,
298 query_weight: f32,
299 doc_ids: &[u32],
300 flat_scores: &mut [f32],
301 base_doc: u32,
302 dirty: &mut Vec<u32>,
303 ) -> usize {
304 let count = self.header.count as usize;
305 match self.header.weight_quant {
306 WeightQuantization::UInt8 if self.weights_data.len() >= 8 => {
307 let scale = f32::from_le_bytes([
309 self.weights_data[0],
310 self.weights_data[1],
311 self.weights_data[2],
312 self.weights_data[3],
313 ]);
314 let min_val = f32::from_le_bytes([
315 self.weights_data[4],
316 self.weights_data[5],
317 self.weights_data[6],
318 self.weights_data[7],
319 ]);
320 let eff_scale = query_weight * scale;
321 let eff_bias = query_weight * min_val;
322 let quant_data = &self.weights_data[8..];
323
324 for i in 0..count.min(quant_data.len()).min(doc_ids.len()) {
325 let w = quant_data[i] as f32 * eff_scale + eff_bias;
326 let off = (doc_ids[i] - base_doc) as usize;
327 if off >= flat_scores.len() {
328 continue;
329 }
330 if flat_scores[off] == 0.0 {
331 dirty.push(doc_ids[i]);
332 }
333 flat_scores[off] += w;
334 }
335 count
336 }
337 _ => {
338 let mut weights_buf = Vec::with_capacity(count);
340 decode_weights_into(
341 &self.weights_data,
342 self.header.weight_quant,
343 count,
344 &mut weights_buf,
345 );
346 for i in 0..count.min(weights_buf.len()).min(doc_ids.len()) {
347 let w = weights_buf[i] * query_weight;
348 let off = (doc_ids[i] - base_doc) as usize;
349 if off >= flat_scores.len() {
350 continue;
351 }
352 if flat_scores[off] == 0.0 {
353 dirty.push(doc_ids[i]);
354 }
355 flat_scores[off] += w;
356 }
357 count
358 }
359 }
360 }
361
362 pub fn write<W: Write>(&self, w: &mut W) -> io::Result<()> {
363 self.header.write(w)?;
364 if self.doc_ids_data.len() > u16::MAX as usize
365 || self.ordinals_data.len() > u16::MAX as usize
366 || self.weights_data.len() > u16::MAX as usize
367 {
368 return Err(io::Error::new(
369 io::ErrorKind::InvalidData,
370 format!(
371 "sparse sub-block too large for u16 length: doc_ids={}B ords={}B wts={}B",
372 self.doc_ids_data.len(),
373 self.ordinals_data.len(),
374 self.weights_data.len()
375 ),
376 ));
377 }
378 w.write_u16::<LittleEndian>(self.doc_ids_data.len() as u16)?;
379 w.write_u16::<LittleEndian>(self.ordinals_data.len() as u16)?;
380 w.write_u16::<LittleEndian>(self.weights_data.len() as u16)?;
381 w.write_u16::<LittleEndian>(0)?;
382 w.write_all(&self.doc_ids_data)?;
383 w.write_all(&self.ordinals_data)?;
384 w.write_all(&self.weights_data)?;
385 Ok(())
386 }
387
388 pub fn read<R: Read>(r: &mut R) -> io::Result<Self> {
389 let header = BlockHeader::read(r)?;
390 let doc_ids_len = r.read_u16::<LittleEndian>()? as usize;
391 let ordinals_len = r.read_u16::<LittleEndian>()? as usize;
392 let weights_len = r.read_u16::<LittleEndian>()? as usize;
393 let _ = r.read_u16::<LittleEndian>()?;
394
395 let mut doc_ids_vec = vec![0u8; doc_ids_len];
396 r.read_exact(&mut doc_ids_vec)?;
397 let mut ordinals_vec = vec![0u8; ordinals_len];
398 r.read_exact(&mut ordinals_vec)?;
399 let mut weights_vec = vec![0u8; weights_len];
400 r.read_exact(&mut weights_vec)?;
401
402 let last_doc_id = compute_last_doc(&header, &doc_ids_vec);
404
405 Ok(Self {
406 header,
407 doc_ids_data: OwnedBytes::new(doc_ids_vec),
408 ordinals_data: OwnedBytes::new(ordinals_vec),
409 weights_data: OwnedBytes::new(weights_vec),
410 last_doc_id,
411 })
412 }
413
414 pub fn from_owned_bytes(data: crate::directories::OwnedBytes) -> crate::Result<Self> {
420 let b = data.as_slice();
421 if b.len() < BlockHeader::SIZE + 8 {
422 return Err(crate::Error::Corruption(
423 "sparse block too small".to_string(),
424 ));
425 }
426 let mut cursor = Cursor::new(&b[..BlockHeader::SIZE]);
427 let header =
428 BlockHeader::read(&mut cursor).map_err(|e| crate::Error::Corruption(e.to_string()))?;
429
430 if header.count == 0 {
431 let hex: String = b
432 .iter()
433 .take(32)
434 .map(|x| format!("{x:02x}"))
435 .collect::<Vec<_>>()
436 .join(" ");
437 return Err(crate::Error::Corruption(format!(
438 "sparse block has count=0 (data_len={}, first_32_bytes=[{}])",
439 b.len(),
440 hex
441 )));
442 }
443
444 let p = BlockHeader::SIZE;
445 let doc_ids_len = u16::from_le_bytes([b[p], b[p + 1]]) as usize;
446 let ordinals_len = u16::from_le_bytes([b[p + 2], b[p + 3]]) as usize;
447 let weights_len = u16::from_le_bytes([b[p + 4], b[p + 5]]) as usize;
448 let data_start = p + 8;
451 let ord_start = data_start + doc_ids_len;
452 let wt_start = ord_start + ordinals_len;
453 let expected_end = wt_start + weights_len;
454
455 if expected_end > b.len() {
456 let hex: String = b
457 .iter()
458 .take(32)
459 .map(|x| format!("{x:02x}"))
460 .collect::<Vec<_>>()
461 .join(" ");
462 return Err(crate::Error::Corruption(format!(
463 "sparse block sub-block overflow: count={} doc_ids={}B ords={}B wts={}B need={}B have={}B (first_32=[{}])",
464 header.count,
465 doc_ids_len,
466 ordinals_len,
467 weights_len,
468 expected_end,
469 b.len(),
470 hex
471 )));
472 }
473
474 let doc_ids_slice = data.slice(data_start..ord_start);
475 let last_doc_id = compute_last_doc(&header, &doc_ids_slice);
477
478 Ok(Self {
479 header,
480 doc_ids_data: doc_ids_slice,
481 ordinals_data: data.slice(ord_start..wt_start),
482 weights_data: data.slice(wt_start..wt_start + weights_len),
483 last_doc_id,
484 })
485 }
486
487 pub fn with_doc_offset(&self, doc_offset: u32) -> Self {
493 Self {
494 header: BlockHeader {
495 first_doc_id: self.header.first_doc_id + doc_offset,
496 ..self.header
497 },
498 doc_ids_data: self.doc_ids_data.clone(),
499 ordinals_data: self.ordinals_data.clone(),
500 weights_data: self.weights_data.clone(),
501 last_doc_id: self.last_doc_id + doc_offset,
502 }
503 }
504}
505
506#[derive(Debug, Clone)]
511pub struct BlockSparsePostingList {
512 pub doc_count: u32,
513 pub blocks: Vec<SparseBlock>,
514}
515
516impl BlockSparsePostingList {
517 pub fn from_postings_with_block_size(
519 postings: &[(DocId, u16, f32)],
520 weight_quant: WeightQuantization,
521 block_size: usize,
522 ) -> io::Result<Self> {
523 if postings.is_empty() {
524 return Ok(Self {
525 doc_count: 0,
526 blocks: Vec::new(),
527 });
528 }
529
530 let block_size = block_size.max(16); let mut blocks = Vec::new();
532 for chunk in postings.chunks(block_size) {
533 blocks.push(SparseBlock::from_postings(chunk, weight_quant)?);
534 }
535
536 let mut unique_docs = 1u32;
541 for i in 1..postings.len() {
542 if postings[i].0 != postings[i - 1].0 {
543 unique_docs += 1;
544 }
545 }
546
547 Ok(Self {
548 doc_count: unique_docs,
549 blocks,
550 })
551 }
552
553 pub fn from_postings(
555 postings: &[(DocId, u16, f32)],
556 weight_quant: WeightQuantization,
557 ) -> io::Result<Self> {
558 Self::from_postings_with_block_size(postings, weight_quant, BLOCK_SIZE)
559 }
560
561 pub fn from_postings_with_partition(
567 postings: &[(DocId, u16, f32)],
568 weight_quant: WeightQuantization,
569 partition: &[usize],
570 ) -> io::Result<Self> {
571 if postings.is_empty() {
572 return Ok(Self {
573 doc_count: 0,
574 blocks: Vec::new(),
575 });
576 }
577
578 let mut blocks = Vec::with_capacity(partition.len());
579 let mut offset = 0;
580 for &block_size in partition {
581 let end = (offset + block_size).min(postings.len());
582 blocks.push(SparseBlock::from_postings(
583 &postings[offset..end],
584 weight_quant,
585 )?);
586 offset = end;
587 }
588
589 let mut unique_docs = 1u32;
590 for i in 1..postings.len() {
591 if postings[i].0 != postings[i - 1].0 {
592 unique_docs += 1;
593 }
594 }
595
596 Ok(Self {
597 doc_count: unique_docs,
598 blocks,
599 })
600 }
601
602 pub fn doc_count(&self) -> u32 {
603 self.doc_count
604 }
605
606 pub fn num_blocks(&self) -> usize {
607 self.blocks.len()
608 }
609
610 pub fn global_max_weight(&self) -> f32 {
611 self.blocks
612 .iter()
613 .map(|b| b.header.max_weight)
614 .fold(0.0f32, f32::max)
615 }
616
617 pub fn block_max_weight(&self, block_idx: usize) -> Option<f32> {
618 self.blocks.get(block_idx).map(|b| b.header.max_weight)
619 }
620
621 pub fn size_bytes(&self) -> usize {
623 use std::mem::size_of;
624
625 let header_size = size_of::<u32>() * 2; let blocks_size: usize = self
627 .blocks
628 .iter()
629 .map(|b| {
630 size_of::<BlockHeader>()
631 + b.doc_ids_data.len()
632 + b.ordinals_data.len()
633 + b.weights_data.len()
634 })
635 .sum();
636 header_size + blocks_size
637 }
638
639 pub fn iterator(&self) -> BlockSparsePostingIterator<'_> {
640 BlockSparsePostingIterator::new(self)
641 }
642
643 pub fn serialize(&self) -> io::Result<(Vec<u8>, Vec<super::SparseSkipEntry>)> {
649 let mut block_data = Vec::new();
651 let mut skip_entries = Vec::with_capacity(self.blocks.len());
652 let mut offset = 0u64;
653
654 for block in &self.blocks {
655 let mut buf = Vec::new();
656 block.write(&mut buf)?;
657 let length = buf.len() as u32;
658
659 let first_doc = block.header.first_doc_id;
660 let last_doc = block.last_doc_id;
661
662 skip_entries.push(super::SparseSkipEntry::new(
663 first_doc,
664 last_doc,
665 offset,
666 length,
667 block.header.max_weight,
668 ));
669
670 block_data.extend_from_slice(&buf);
671 offset += length as u64;
672 }
673
674 Ok((block_data, skip_entries))
675 }
676
677 #[cfg(test)]
682 pub fn from_parts(
683 doc_count: u32,
684 block_data: &[u8],
685 skip_entries: &[super::SparseSkipEntry],
686 ) -> io::Result<Self> {
687 let mut blocks = Vec::with_capacity(skip_entries.len());
688 for entry in skip_entries {
689 let start = entry.offset as usize;
690 let end = start + entry.length as usize;
691 blocks.push(SparseBlock::read(&mut std::io::Cursor::new(
692 &block_data[start..end],
693 ))?);
694 }
695 Ok(Self { doc_count, blocks })
696 }
697
698 pub fn decode_all(&self) -> Vec<(DocId, u16, f32)> {
699 let total_postings: usize = self.blocks.iter().map(|b| b.header.count as usize).sum();
700 let mut result = Vec::with_capacity(total_postings);
701 for block in &self.blocks {
702 let doc_ids = block.decode_doc_ids();
703 let ordinals = block.decode_ordinals();
704 let weights = block.decode_weights();
705 for i in 0..block.header.count as usize {
706 result.push((doc_ids[i], ordinals[i], weights[i]));
707 }
708 }
709 result
710 }
711
712 pub fn merge_with_offsets(lists: &[(&BlockSparsePostingList, u32)]) -> Self {
723 if lists.is_empty() {
724 return Self {
725 doc_count: 0,
726 blocks: Vec::new(),
727 };
728 }
729
730 let total_blocks: usize = lists.iter().map(|(pl, _)| pl.blocks.len()).sum();
732 let total_docs: u32 = lists.iter().map(|(pl, _)| pl.doc_count).sum();
733
734 let mut merged_blocks = Vec::with_capacity(total_blocks);
735
736 for (posting_list, doc_offset) in lists {
738 for block in &posting_list.blocks {
739 merged_blocks.push(block.with_doc_offset(*doc_offset));
740 }
741 }
742
743 Self {
744 doc_count: total_docs,
745 blocks: merged_blocks,
746 }
747 }
748
749 fn find_block(&self, target: DocId) -> Option<usize> {
750 if self.blocks.is_empty() {
751 return None;
752 }
753 let idx = self
756 .blocks
757 .partition_point(|b| b.header.first_doc_id <= target);
758 if idx == 0 {
759 Some(0)
761 } else {
762 Some(idx - 1)
763 }
764 }
765}
766
767pub struct BlockSparsePostingIterator<'a> {
772 posting_list: &'a BlockSparsePostingList,
773 block_idx: usize,
774 in_block_idx: usize,
775 current_doc_ids: Vec<DocId>,
776 current_ordinals: Vec<u16>,
777 current_weights: Vec<f32>,
778 ordinals_decoded: bool,
780 exhausted: bool,
781}
782
783impl<'a> BlockSparsePostingIterator<'a> {
784 fn new(posting_list: &'a BlockSparsePostingList) -> Self {
785 let mut iter = Self {
786 posting_list,
787 block_idx: 0,
788 in_block_idx: 0,
789 current_doc_ids: Vec::with_capacity(128),
790 current_ordinals: Vec::with_capacity(128),
791 current_weights: Vec::with_capacity(128),
792 ordinals_decoded: false,
793 exhausted: posting_list.blocks.is_empty(),
794 };
795 if !iter.exhausted {
796 iter.load_block(0);
797 }
798 iter
799 }
800
801 fn load_block(&mut self, block_idx: usize) {
802 if let Some(block) = self.posting_list.blocks.get(block_idx) {
803 block.decode_doc_ids_into(&mut self.current_doc_ids);
804 block.decode_weights_into(&mut self.current_weights);
805 self.ordinals_decoded = false;
807 self.block_idx = block_idx;
808 self.in_block_idx = 0;
809 }
810 }
811
812 #[inline]
814 fn ensure_ordinals_decoded(&mut self) {
815 if !self.ordinals_decoded {
816 if let Some(block) = self.posting_list.blocks.get(self.block_idx) {
817 block.decode_ordinals_into(&mut self.current_ordinals);
818 }
819 self.ordinals_decoded = true;
820 }
821 }
822
823 #[inline]
824 pub fn doc(&self) -> DocId {
825 if self.exhausted {
826 TERMINATED
827 } else {
828 self.current_doc_ids[self.in_block_idx]
830 }
831 }
832
833 #[inline]
834 pub fn weight(&self) -> f32 {
835 if self.exhausted {
836 return 0.0;
837 }
838 self.current_weights[self.in_block_idx]
840 }
841
842 #[inline]
843 pub fn ordinal(&mut self) -> u16 {
844 if self.exhausted {
845 return 0;
846 }
847 self.ensure_ordinals_decoded();
848 self.current_ordinals[self.in_block_idx]
849 }
850
851 pub fn advance(&mut self) -> DocId {
852 if self.exhausted {
853 return TERMINATED;
854 }
855 self.in_block_idx += 1;
856 if self.in_block_idx >= self.current_doc_ids.len() {
857 self.block_idx += 1;
858 if self.block_idx >= self.posting_list.blocks.len() {
859 self.exhausted = true;
860 } else {
861 self.load_block(self.block_idx);
862 }
863 }
864 self.doc()
865 }
866
867 pub fn seek(&mut self, target: DocId) -> DocId {
868 if self.exhausted {
869 return TERMINATED;
870 }
871 if self.doc() >= target {
872 return self.doc();
873 }
874
875 if let Some(&last_doc) = self.current_doc_ids.last()
877 && last_doc >= target
878 {
879 let remaining = &self.current_doc_ids[self.in_block_idx..];
880 let pos = crate::structures::simd::find_first_ge_u32(remaining, target);
881 self.in_block_idx += pos;
882 if self.in_block_idx >= self.current_doc_ids.len() {
883 self.block_idx += 1;
884 if self.block_idx >= self.posting_list.blocks.len() {
885 self.exhausted = true;
886 } else {
887 self.load_block(self.block_idx);
888 }
889 }
890 return self.doc();
891 }
892
893 if let Some(block_idx) = self.posting_list.find_block(target) {
895 self.load_block(block_idx);
896 let pos = crate::structures::simd::find_first_ge_u32(&self.current_doc_ids, target);
897 self.in_block_idx = pos;
898 if self.in_block_idx >= self.current_doc_ids.len() {
899 self.block_idx += 1;
900 if self.block_idx >= self.posting_list.blocks.len() {
901 self.exhausted = true;
902 } else {
903 self.load_block(self.block_idx);
904 }
905 }
906 } else {
907 self.exhausted = true;
908 }
909 self.doc()
910 }
911
912 pub fn skip_to_next_block(&mut self) -> DocId {
915 if self.exhausted {
916 return TERMINATED;
917 }
918 let next = self.block_idx + 1;
919 if next >= self.posting_list.blocks.len() {
920 self.exhausted = true;
921 return TERMINATED;
922 }
923 self.load_block(next);
924 self.doc()
925 }
926
927 pub fn is_exhausted(&self) -> bool {
928 self.exhausted
929 }
930
931 pub fn current_block_max_weight(&self) -> f32 {
932 self.posting_list
933 .blocks
934 .get(self.block_idx)
935 .map(|b| b.header.max_weight)
936 .unwrap_or(0.0)
937 }
938
939 pub fn current_block_max_contribution(&self, query_weight: f32) -> f32 {
940 query_weight * self.current_block_max_weight()
941 }
942}
943
944fn compute_last_doc(header: &BlockHeader, doc_ids_data: &[u8]) -> DocId {
951 let count = header.count as usize;
952 if count <= 1 {
953 return header.first_doc_id;
954 }
955 let bits = header.doc_id_bits;
956 if bits == 0 {
957 return header.first_doc_id; }
959 let rounded = simd::RoundedBitWidth::from_u8(bits);
960 let num_deltas = count - 1;
961 let mut deltas = [0u32; MAX_BLOCK_SIZE];
962 simd::unpack_rounded(doc_ids_data, rounded, &mut deltas[..num_deltas], num_deltas);
963 let sum: u32 = deltas[..num_deltas].iter().sum();
964 header.first_doc_id + sum
965}
966
967fn find_optimal_bit_width(values: &[u32]) -> u8 {
968 if values.is_empty() {
969 return 0;
970 }
971 let max_val = values.iter().copied().max().unwrap_or(0);
972 simd::bits_needed(max_val)
973}
974
975fn bits_needed_u16(val: u16) -> u8 {
976 if val == 0 {
977 0
978 } else {
979 16 - val.leading_zeros() as u8
980 }
981}
982
983fn encode_weights(weights: &[f32], quant: WeightQuantization) -> io::Result<Vec<u8>> {
988 let mut data = Vec::new();
989 match quant {
990 WeightQuantization::Float32 => {
991 for &w in weights {
992 data.write_f32::<LittleEndian>(w)?;
993 }
994 }
995 WeightQuantization::Float16 => {
996 use half::f16;
997 for &w in weights {
998 data.write_u16::<LittleEndian>(f16::from_f32(w).to_bits())?;
999 }
1000 }
1001 WeightQuantization::UInt8 => {
1002 let min = weights.iter().copied().fold(f32::INFINITY, f32::min);
1003 let max = weights.iter().copied().fold(f32::NEG_INFINITY, f32::max);
1004 let range = max - min;
1005 let scale = if range < f32::EPSILON {
1006 1.0
1007 } else {
1008 range / 255.0
1009 };
1010 data.write_f32::<LittleEndian>(scale)?;
1011 data.write_f32::<LittleEndian>(min)?;
1012 for &w in weights {
1013 data.write_u8(((w - min) / scale).round() as u8)?;
1014 }
1015 }
1016 WeightQuantization::UInt4 => {
1017 let min = weights.iter().copied().fold(f32::INFINITY, f32::min);
1018 let max = weights.iter().copied().fold(f32::NEG_INFINITY, f32::max);
1019 let range = max - min;
1020 let scale = if range < f32::EPSILON {
1021 1.0
1022 } else {
1023 range / 15.0
1024 };
1025 data.write_f32::<LittleEndian>(scale)?;
1026 data.write_f32::<LittleEndian>(min)?;
1027 let mut i = 0;
1028 while i < weights.len() {
1029 let q1 = ((weights[i] - min) / scale).round() as u8 & 0x0F;
1030 let q2 = if i + 1 < weights.len() {
1031 ((weights[i + 1] - min) / scale).round() as u8 & 0x0F
1032 } else {
1033 0
1034 };
1035 data.write_u8((q2 << 4) | q1)?;
1036 i += 2;
1037 }
1038 }
1039 }
1040 Ok(data)
1041}
1042
1043fn decode_weights_into(data: &[u8], quant: WeightQuantization, count: usize, out: &mut Vec<f32>) {
1044 match quant {
1045 WeightQuantization::Float32 => {
1046 out.reserve(count);
1047 for chunk in data[..count * 4].chunks_exact(4) {
1048 out.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
1049 }
1050 }
1051 WeightQuantization::Float16 => {
1052 use half::f16;
1055 use half::slice::HalfFloatSliceExt;
1056 let byte_count = count * 2;
1057 let src = &data[..byte_count];
1058 let mut f16_buf: Vec<f16> = Vec::with_capacity(count);
1059 for chunk in src.chunks_exact(2) {
1060 f16_buf.push(f16::from_bits(u16::from_le_bytes([chunk[0], chunk[1]])));
1061 }
1062 let start = out.len();
1063 out.resize(start + count, 0.0);
1064 f16_buf.convert_to_f32_slice(&mut out[start..start + count]);
1065 }
1066 WeightQuantization::UInt8 => {
1067 let mut cursor = Cursor::new(data);
1068 let scale = cursor.read_f32::<LittleEndian>().unwrap_or(1.0);
1069 let min_val = cursor.read_f32::<LittleEndian>().unwrap_or(0.0);
1070 let offset = cursor.position() as usize;
1071 out.resize(count, 0.0);
1072 simd::dequantize_uint8(&data[offset..], out, scale, min_val, count);
1073 }
1074 WeightQuantization::UInt4 => {
1075 let mut cursor = Cursor::new(data);
1076 let scale = cursor.read_f32::<LittleEndian>().unwrap_or(1.0);
1077 let min = cursor.read_f32::<LittleEndian>().unwrap_or(0.0);
1078 let mut i = 0;
1079 while i < count {
1080 let byte = cursor.read_u8().unwrap_or(0);
1081 out.push((byte & 0x0F) as f32 * scale + min);
1082 i += 1;
1083 if i < count {
1084 out.push((byte >> 4) as f32 * scale + min);
1085 i += 1;
1086 }
1087 }
1088 }
1089 }
1090}
1091
1092#[cfg(test)]
1093mod tests {
1094 use super::*;
1095
1096 #[test]
1097 fn test_block_roundtrip() {
1098 let postings = vec![
1099 (10u32, 0u16, 1.5f32),
1100 (15, 0, 2.0),
1101 (20, 1, 0.5),
1102 (100, 0, 3.0),
1103 ];
1104 let block = SparseBlock::from_postings(&postings, WeightQuantization::Float32).unwrap();
1105
1106 assert_eq!(block.decode_doc_ids(), vec![10, 15, 20, 100]);
1107 assert_eq!(block.decode_ordinals(), vec![0, 0, 1, 0]);
1108 let weights = block.decode_weights();
1109 assert!((weights[0] - 1.5).abs() < 0.01);
1110 }
1111
1112 #[test]
1113 fn test_posting_list() {
1114 let postings: Vec<(DocId, u16, f32)> =
1115 (0..300).map(|i| (i * 2, 0, i as f32 * 0.1)).collect();
1116 let list =
1117 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1118
1119 assert_eq!(list.doc_count(), 300);
1120 assert_eq!(list.num_blocks(), 3);
1121
1122 let mut iter = list.iterator();
1123 assert_eq!(iter.doc(), 0);
1124 iter.advance();
1125 assert_eq!(iter.doc(), 2);
1126 }
1127
1128 #[test]
1129 fn test_serialization() {
1130 let postings = vec![(1u32, 0u16, 0.5f32), (10, 1, 1.5), (100, 0, 2.5)];
1131 let list =
1132 BlockSparsePostingList::from_postings(&postings, WeightQuantization::UInt8).unwrap();
1133
1134 let (block_data, skip_entries) = list.serialize().unwrap();
1135 let list2 =
1136 BlockSparsePostingList::from_parts(list.doc_count(), &block_data, &skip_entries)
1137 .unwrap();
1138
1139 assert_eq!(list.doc_count(), list2.doc_count());
1140 }
1141
1142 #[test]
1143 fn test_seek() {
1144 let postings: Vec<(DocId, u16, f32)> = (0..500).map(|i| (i * 3, 0, i as f32)).collect();
1145 let list =
1146 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1147
1148 let mut iter = list.iterator();
1149 assert_eq!(iter.seek(300), 300);
1150 assert_eq!(iter.seek(301), 303);
1151 assert_eq!(iter.seek(2000), TERMINATED);
1152 }
1153
1154 #[test]
1155 fn test_merge_with_offsets() {
1156 let postings1: Vec<(DocId, u16, f32)> = vec![(0, 0, 1.0), (5, 0, 2.0), (10, 1, 3.0)];
1158 let list1 =
1159 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1160
1161 let postings2: Vec<(DocId, u16, f32)> = vec![(0, 0, 4.0), (3, 1, 5.0), (7, 0, 6.0)];
1163 let list2 =
1164 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1165
1166 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
1168
1169 assert_eq!(merged.doc_count(), 6);
1170
1171 let decoded = merged.decode_all();
1173 assert_eq!(decoded.len(), 6);
1174
1175 assert_eq!(decoded[0].0, 0);
1177 assert_eq!(decoded[1].0, 5);
1178 assert_eq!(decoded[2].0, 10);
1179
1180 assert_eq!(decoded[3].0, 100); assert_eq!(decoded[4].0, 103); assert_eq!(decoded[5].0, 107); assert!((decoded[0].2 - 1.0).abs() < 0.01);
1187 assert!((decoded[3].2 - 4.0).abs() < 0.01);
1188
1189 assert_eq!(decoded[2].1, 1); assert_eq!(decoded[4].1, 1); }
1193
1194 #[test]
1195 fn test_merge_with_offsets_multi_block() {
1196 let postings1: Vec<(DocId, u16, f32)> = (0..200).map(|i| (i * 2, 0, i as f32)).collect();
1198 let list1 =
1199 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1200 assert!(list1.num_blocks() > 1, "Should have multiple blocks");
1201
1202 let postings2: Vec<(DocId, u16, f32)> = (0..150).map(|i| (i * 3, 1, i as f32)).collect();
1203 let list2 =
1204 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1205
1206 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 1000)]);
1208
1209 assert_eq!(merged.doc_count(), 350);
1210 assert_eq!(merged.num_blocks(), list1.num_blocks() + list2.num_blocks());
1211
1212 let mut iter = merged.iterator();
1214
1215 assert_eq!(iter.doc(), 0);
1217
1218 let doc = iter.seek(1000);
1220 assert_eq!(doc, 1000); iter.advance();
1224 assert_eq!(iter.doc(), 1003); }
1226
1227 #[test]
1228 fn test_merge_with_offsets_serialize_roundtrip() {
1229 let postings1: Vec<(DocId, u16, f32)> = vec![(0, 0, 1.0), (5, 0, 2.0), (10, 1, 3.0)];
1231 let list1 =
1232 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1233
1234 let postings2: Vec<(DocId, u16, f32)> = vec![(0, 0, 4.0), (3, 1, 5.0), (7, 0, 6.0)];
1235 let list2 =
1236 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1237
1238 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
1240
1241 let (block_data, skip_entries) = merged.serialize().unwrap();
1243 let loaded =
1244 BlockSparsePostingList::from_parts(merged.doc_count(), &block_data, &skip_entries)
1245 .unwrap();
1246
1247 let decoded = loaded.decode_all();
1249 assert_eq!(decoded.len(), 6);
1250
1251 assert_eq!(decoded[0].0, 0);
1253 assert_eq!(decoded[1].0, 5);
1254 assert_eq!(decoded[2].0, 10);
1255
1256 assert_eq!(decoded[3].0, 100, "First doc of seg2 should be 0+100=100");
1258 assert_eq!(decoded[4].0, 103, "Second doc of seg2 should be 3+100=103");
1259 assert_eq!(decoded[5].0, 107, "Third doc of seg2 should be 7+100=107");
1260
1261 let mut iter = loaded.iterator();
1263 assert_eq!(iter.doc(), 0);
1264 iter.advance();
1265 assert_eq!(iter.doc(), 5);
1266 iter.advance();
1267 assert_eq!(iter.doc(), 10);
1268 iter.advance();
1269 assert_eq!(iter.doc(), 100);
1270 iter.advance();
1271 assert_eq!(iter.doc(), 103);
1272 iter.advance();
1273 assert_eq!(iter.doc(), 107);
1274 }
1275
1276 #[test]
1277 fn test_merge_seek_after_roundtrip() {
1278 let postings1: Vec<(DocId, u16, f32)> = (0..200).map(|i| (i * 2, 0, 1.0)).collect();
1280 let list1 =
1281 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1282
1283 let postings2: Vec<(DocId, u16, f32)> = (0..150).map(|i| (i * 3, 0, 2.0)).collect();
1284 let list2 =
1285 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1286
1287 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 1000)]);
1289
1290 let (block_data, skip_entries) = merged.serialize().unwrap();
1292 let loaded =
1293 BlockSparsePostingList::from_parts(merged.doc_count(), &block_data, &skip_entries)
1294 .unwrap();
1295
1296 let mut iter = loaded.iterator();
1298
1299 let doc = iter.seek(100);
1301 assert_eq!(doc, 100, "Seek to 100 in segment 1");
1302
1303 let doc = iter.seek(1000);
1305 assert_eq!(doc, 1000, "Seek to 1000 (first doc of segment 2)");
1306
1307 let doc = iter.seek(1050);
1309 assert!(
1310 doc >= 1050,
1311 "Seek to 1050 should find doc >= 1050, got {}",
1312 doc
1313 );
1314
1315 let doc = iter.seek(500);
1317 assert!(
1318 doc >= 1050,
1319 "Seek backwards should not go back, got {}",
1320 doc
1321 );
1322
1323 let mut iter2 = loaded.iterator();
1325
1326 let mut count = 0;
1328 let mut prev_doc = 0;
1329 while iter2.doc() != super::TERMINATED {
1330 let current = iter2.doc();
1331 if count > 0 {
1332 assert!(
1333 current > prev_doc,
1334 "Docs should be monotonically increasing: {} vs {}",
1335 prev_doc,
1336 current
1337 );
1338 }
1339 prev_doc = current;
1340 iter2.advance();
1341 count += 1;
1342 }
1343 assert_eq!(count, 350, "Should have 350 total docs");
1344 }
1345
1346 #[test]
1347 fn test_doc_count_multi_value() {
1348 let postings: Vec<(DocId, u16, f32)> = vec![
1351 (0, 0, 1.0),
1352 (0, 1, 1.5),
1353 (0, 2, 2.0),
1354 (5, 0, 3.0),
1355 (5, 1, 3.5),
1356 (10, 0, 4.0),
1357 ];
1358 let list =
1359 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1360
1361 assert_eq!(list.doc_count(), 3);
1363
1364 let decoded = list.decode_all();
1366 assert_eq!(decoded.len(), 6);
1367 }
1368
1369 #[test]
1373 fn test_zero_copy_merge_patches_first_doc_id() {
1374 use crate::structures::SparseSkipEntry;
1375
1376 let postings1: Vec<(DocId, u16, f32)> = (0..200).map(|i| (i * 2, 0, i as f32)).collect();
1378 let list1 =
1379 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1380 assert!(list1.num_blocks() > 1);
1381
1382 let postings2: Vec<(DocId, u16, f32)> = (0..150).map(|i| (i * 3, 1, i as f32)).collect();
1383 let list2 =
1384 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1385
1386 let (raw1, skip1) = list1.serialize().unwrap();
1388 let (raw2, skip2) = list2.serialize().unwrap();
1389
1390 let doc_offset: u32 = 1000; let total_docs = list1.doc_count() + list2.doc_count();
1393
1394 let mut merged_skip = Vec::new();
1396 let mut cumulative_offset = 0u64;
1397 for entry in &skip1 {
1398 merged_skip.push(SparseSkipEntry::new(
1399 entry.first_doc,
1400 entry.last_doc,
1401 cumulative_offset + entry.offset,
1402 entry.length,
1403 entry.max_weight,
1404 ));
1405 }
1406 if let Some(last) = skip1.last() {
1407 cumulative_offset += last.offset + last.length as u64;
1408 }
1409 for entry in &skip2 {
1410 merged_skip.push(SparseSkipEntry::new(
1411 entry.first_doc + doc_offset,
1412 entry.last_doc + doc_offset,
1413 cumulative_offset + entry.offset,
1414 entry.length,
1415 entry.max_weight,
1416 ));
1417 }
1418
1419 let mut merged_block_data = Vec::new();
1421 merged_block_data.extend_from_slice(&raw1);
1422
1423 const FIRST_DOC_ID_OFFSET: usize = 8;
1424 let mut buf2 = raw2.to_vec();
1425 for entry in &skip2 {
1426 let off = entry.offset as usize + FIRST_DOC_ID_OFFSET;
1427 if off + 4 <= buf2.len() {
1428 let old = u32::from_le_bytes(buf2[off..off + 4].try_into().unwrap());
1429 let patched = (old + doc_offset).to_le_bytes();
1430 buf2[off..off + 4].copy_from_slice(&patched);
1431 }
1432 }
1433 merged_block_data.extend_from_slice(&buf2);
1434
1435 let loaded =
1437 BlockSparsePostingList::from_parts(total_docs, &merged_block_data, &merged_skip)
1438 .unwrap();
1439 assert_eq!(loaded.doc_count(), 350);
1440
1441 let mut iter = loaded.iterator();
1442
1443 assert_eq!(iter.doc(), 0);
1445 let doc = iter.seek(100);
1446 assert_eq!(doc, 100);
1447 let doc = iter.seek(398);
1448 assert_eq!(doc, 398);
1449
1450 let doc = iter.seek(1000);
1452 assert_eq!(doc, 1000, "First doc of segment 2 should be 1000");
1453 iter.advance();
1454 assert_eq!(iter.doc(), 1003, "Second doc of segment 2 should be 1003");
1455 let doc = iter.seek(1447);
1456 assert_eq!(doc, 1447, "Last doc of segment 2 should be 1447");
1457
1458 iter.advance();
1460 assert_eq!(iter.doc(), super::TERMINATED);
1461
1462 let reference =
1464 BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, doc_offset)]);
1465 let mut ref_iter = reference.iterator();
1466 let mut zc_iter = loaded.iterator();
1467 while ref_iter.doc() != super::TERMINATED {
1468 assert_eq!(
1469 ref_iter.doc(),
1470 zc_iter.doc(),
1471 "Zero-copy and reference merge should produce identical doc_ids"
1472 );
1473 assert!(
1474 (ref_iter.weight() - zc_iter.weight()).abs() < 0.01,
1475 "Weights should match: {} vs {}",
1476 ref_iter.weight(),
1477 zc_iter.weight()
1478 );
1479 ref_iter.advance();
1480 zc_iter.advance();
1481 }
1482 assert_eq!(zc_iter.doc(), super::TERMINATED);
1483 }
1484
1485 #[test]
1486 fn test_doc_count_single_value() {
1487 let postings: Vec<(DocId, u16, f32)> =
1489 vec![(0, 0, 1.0), (5, 0, 2.0), (10, 0, 3.0), (15, 0, 4.0)];
1490 let list =
1491 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1492
1493 assert_eq!(list.doc_count(), 4);
1495 }
1496
1497 #[test]
1498 fn test_doc_count_multi_value_serialization_roundtrip() {
1499 let postings: Vec<(DocId, u16, f32)> =
1501 vec![(0, 0, 1.0), (0, 1, 1.5), (5, 0, 2.0), (5, 1, 2.5)];
1502 let list =
1503 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1504 assert_eq!(list.doc_count(), 2);
1505
1506 let (block_data, skip_entries) = list.serialize().unwrap();
1507 let loaded =
1508 BlockSparsePostingList::from_parts(list.doc_count(), &block_data, &skip_entries)
1509 .unwrap();
1510 assert_eq!(loaded.doc_count(), 2);
1511 }
1512
1513 #[test]
1514 fn test_merge_preserves_weights_and_ordinals() {
1515 let postings1: Vec<(DocId, u16, f32)> = vec![(0, 0, 1.5), (5, 1, 2.5), (10, 2, 3.5)];
1517 let list1 =
1518 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1519
1520 let postings2: Vec<(DocId, u16, f32)> = vec![(0, 0, 4.5), (3, 1, 5.5), (7, 3, 6.5)];
1521 let list2 =
1522 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1523
1524 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
1526
1527 let (block_data, skip_entries) = merged.serialize().unwrap();
1529 let loaded =
1530 BlockSparsePostingList::from_parts(merged.doc_count(), &block_data, &skip_entries)
1531 .unwrap();
1532
1533 let mut iter = loaded.iterator();
1535
1536 assert_eq!(iter.doc(), 0);
1538 assert!(
1539 (iter.weight() - 1.5).abs() < 0.01,
1540 "Weight should be 1.5, got {}",
1541 iter.weight()
1542 );
1543 assert_eq!(iter.ordinal(), 0);
1544
1545 iter.advance();
1546 assert_eq!(iter.doc(), 5);
1547 assert!(
1548 (iter.weight() - 2.5).abs() < 0.01,
1549 "Weight should be 2.5, got {}",
1550 iter.weight()
1551 );
1552 assert_eq!(iter.ordinal(), 1);
1553
1554 iter.advance();
1555 assert_eq!(iter.doc(), 10);
1556 assert!(
1557 (iter.weight() - 3.5).abs() < 0.01,
1558 "Weight should be 3.5, got {}",
1559 iter.weight()
1560 );
1561 assert_eq!(iter.ordinal(), 2);
1562
1563 iter.advance();
1565 assert_eq!(iter.doc(), 100);
1566 assert!(
1567 (iter.weight() - 4.5).abs() < 0.01,
1568 "Weight should be 4.5, got {}",
1569 iter.weight()
1570 );
1571 assert_eq!(iter.ordinal(), 0);
1572
1573 iter.advance();
1574 assert_eq!(iter.doc(), 103);
1575 assert!(
1576 (iter.weight() - 5.5).abs() < 0.01,
1577 "Weight should be 5.5, got {}",
1578 iter.weight()
1579 );
1580 assert_eq!(iter.ordinal(), 1);
1581
1582 iter.advance();
1583 assert_eq!(iter.doc(), 107);
1584 assert!(
1585 (iter.weight() - 6.5).abs() < 0.01,
1586 "Weight should be 6.5, got {}",
1587 iter.weight()
1588 );
1589 assert_eq!(iter.ordinal(), 3);
1590
1591 iter.advance();
1593 assert_eq!(iter.doc(), super::TERMINATED);
1594 }
1595
1596 #[test]
1597 fn test_merge_global_max_weight() {
1598 let postings1: Vec<(DocId, u16, f32)> = vec![
1600 (0, 0, 3.0),
1601 (1, 0, 7.0), (2, 0, 2.0),
1603 ];
1604 let list1 =
1605 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1606
1607 let postings2: Vec<(DocId, u16, f32)> = vec![
1608 (0, 0, 5.0),
1609 (1, 0, 4.0),
1610 (2, 0, 6.0), ];
1612 let list2 =
1613 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1614
1615 assert!((list1.global_max_weight() - 7.0).abs() < 0.01);
1617 assert!((list2.global_max_weight() - 6.0).abs() < 0.01);
1618
1619 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
1621
1622 assert!(
1624 (merged.global_max_weight() - 7.0).abs() < 0.01,
1625 "Global max should be 7.0, got {}",
1626 merged.global_max_weight()
1627 );
1628
1629 let (block_data, skip_entries) = merged.serialize().unwrap();
1631 let loaded =
1632 BlockSparsePostingList::from_parts(merged.doc_count(), &block_data, &skip_entries)
1633 .unwrap();
1634
1635 assert!(
1636 (loaded.global_max_weight() - 7.0).abs() < 0.01,
1637 "After roundtrip, global max should still be 7.0, got {}",
1638 loaded.global_max_weight()
1639 );
1640 }
1641
1642 #[test]
1643 fn test_scoring_simulation_after_merge() {
1644 let postings1: Vec<(DocId, u16, f32)> = vec![
1646 (0, 0, 0.5), (5, 0, 0.8), ];
1649 let list1 =
1650 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1651
1652 let postings2: Vec<(DocId, u16, f32)> = vec![
1653 (0, 0, 0.6), (3, 0, 0.9), ];
1656 let list2 =
1657 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1658
1659 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
1661
1662 let (block_data, skip_entries) = merged.serialize().unwrap();
1664 let loaded =
1665 BlockSparsePostingList::from_parts(merged.doc_count(), &block_data, &skip_entries)
1666 .unwrap();
1667
1668 let query_weight = 2.0f32;
1670 let mut iter = loaded.iterator();
1671
1672 assert_eq!(iter.doc(), 0);
1675 let score = query_weight * iter.weight();
1676 assert!(
1677 (score - 1.0).abs() < 0.01,
1678 "Doc 0 score should be 1.0, got {}",
1679 score
1680 );
1681
1682 iter.advance();
1683 assert_eq!(iter.doc(), 5);
1685 let score = query_weight * iter.weight();
1686 assert!(
1687 (score - 1.6).abs() < 0.01,
1688 "Doc 5 score should be 1.6, got {}",
1689 score
1690 );
1691
1692 iter.advance();
1693 assert_eq!(iter.doc(), 100);
1695 let score = query_weight * iter.weight();
1696 assert!(
1697 (score - 1.2).abs() < 0.01,
1698 "Doc 100 score should be 1.2, got {}",
1699 score
1700 );
1701
1702 iter.advance();
1703 assert_eq!(iter.doc(), 103);
1705 let score = query_weight * iter.weight();
1706 assert!(
1707 (score - 1.8).abs() < 0.01,
1708 "Doc 103 score should be 1.8, got {}",
1709 score
1710 );
1711 }
1712}