1use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
9use std::io::{self, Cursor, Read, Write};
10
11use super::config::WeightQuantization;
12use crate::DocId;
13use crate::directories::OwnedBytes;
14use crate::structures::postings::TERMINATED;
15use crate::structures::simd;
16
17pub const BLOCK_SIZE: usize = 128;
18
19#[derive(Debug, Clone, Copy)]
20pub struct BlockHeader {
21 pub count: u16,
22 pub doc_id_bits: u8,
23 pub ordinal_bits: u8,
24 pub weight_quant: WeightQuantization,
25 pub first_doc_id: DocId,
26 pub max_weight: f32,
27}
28
29impl BlockHeader {
30 pub const SIZE: usize = 16;
31
32 pub fn write<W: Write>(&self, w: &mut W) -> io::Result<()> {
33 w.write_u16::<LittleEndian>(self.count)?;
34 w.write_u8(self.doc_id_bits)?;
35 w.write_u8(self.ordinal_bits)?;
36 w.write_u8(self.weight_quant as u8)?;
37 w.write_u8(0)?;
38 w.write_u16::<LittleEndian>(0)?;
39 w.write_u32::<LittleEndian>(self.first_doc_id)?;
40 w.write_f32::<LittleEndian>(self.max_weight)?;
41 Ok(())
42 }
43
44 pub fn read<R: Read>(r: &mut R) -> io::Result<Self> {
45 let count = r.read_u16::<LittleEndian>()?;
46 let doc_id_bits = r.read_u8()?;
47 let ordinal_bits = r.read_u8()?;
48 let weight_quant_byte = r.read_u8()?;
49 let _ = r.read_u8()?;
50 let _ = r.read_u16::<LittleEndian>()?;
51 let first_doc_id = r.read_u32::<LittleEndian>()?;
52 let max_weight = r.read_f32::<LittleEndian>()?;
53
54 let weight_quant = WeightQuantization::from_u8(weight_quant_byte)
55 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Invalid weight quant"))?;
56
57 Ok(Self {
58 count,
59 doc_id_bits,
60 ordinal_bits,
61 weight_quant,
62 first_doc_id,
63 max_weight,
64 })
65 }
66}
67
68#[derive(Debug, Clone)]
69pub struct SparseBlock {
70 pub header: BlockHeader,
71 pub doc_ids_data: OwnedBytes,
73 pub ordinals_data: OwnedBytes,
75 pub weights_data: OwnedBytes,
77}
78
79impl SparseBlock {
80 pub fn from_postings(
81 postings: &[(DocId, u16, f32)],
82 weight_quant: WeightQuantization,
83 ) -> io::Result<Self> {
84 assert!(!postings.is_empty() && postings.len() <= BLOCK_SIZE);
85
86 let count = postings.len();
87 let first_doc_id = postings[0].0;
88
89 let mut deltas = Vec::with_capacity(count);
91 let mut prev = first_doc_id;
92 for &(doc_id, _, _) in postings {
93 deltas.push(doc_id.saturating_sub(prev));
94 prev = doc_id;
95 }
96 deltas[0] = 0;
97
98 let doc_id_bits = simd::round_bit_width(find_optimal_bit_width(&deltas[1..]));
99 let ordinals: Vec<u16> = postings.iter().map(|(_, o, _)| *o).collect();
100 let max_ordinal = ordinals.iter().copied().max().unwrap_or(0);
101 let ordinal_bits = if max_ordinal == 0 {
102 0
103 } else {
104 simd::round_bit_width(bits_needed_u16(max_ordinal))
105 };
106
107 let weights: Vec<f32> = postings.iter().map(|(_, _, w)| *w).collect();
108 let max_weight = weights
109 .iter()
110 .copied()
111 .fold(0.0f32, |acc, w| acc.max(w.abs()));
112
113 let doc_ids_data = OwnedBytes::new({
114 let rounded = simd::RoundedBitWidth::from_u8(doc_id_bits);
115 let num_deltas = count - 1;
116 let byte_count = num_deltas * rounded.bytes_per_value();
117 let mut data = vec![0u8; byte_count];
118 simd::pack_rounded(&deltas[1..], rounded, &mut data);
119 data
120 });
121 let ordinals_data = OwnedBytes::new(if ordinal_bits > 0 {
122 let rounded = simd::RoundedBitWidth::from_u8(ordinal_bits);
123 let byte_count = count * rounded.bytes_per_value();
124 let mut data = vec![0u8; byte_count];
125 let ord_u32: Vec<u32> = ordinals.iter().map(|&o| o as u32).collect();
126 simd::pack_rounded(&ord_u32, rounded, &mut data);
127 data
128 } else {
129 Vec::new()
130 });
131 let weights_data = OwnedBytes::new(encode_weights(&weights, weight_quant)?);
132
133 Ok(Self {
134 header: BlockHeader {
135 count: count as u16,
136 doc_id_bits,
137 ordinal_bits,
138 weight_quant,
139 first_doc_id,
140 max_weight,
141 },
142 doc_ids_data,
143 ordinals_data,
144 weights_data,
145 })
146 }
147
148 pub fn decode_doc_ids(&self) -> Vec<DocId> {
149 let mut out = Vec::with_capacity(self.header.count as usize);
150 self.decode_doc_ids_into(&mut out);
151 out
152 }
153
154 pub fn decode_doc_ids_into(&self, out: &mut Vec<DocId>) {
158 let count = self.header.count as usize;
159 out.clear();
160 out.resize(count, 0);
161 out[0] = self.header.first_doc_id;
162
163 if count > 1 {
164 let bits = self.header.doc_id_bits;
165 if bits == 0 {
166 out[1..].fill(self.header.first_doc_id);
168 } else {
169 simd::unpack_rounded(
171 &self.doc_ids_data,
172 simd::RoundedBitWidth::from_u8(bits),
173 &mut out[1..],
174 count - 1,
175 );
176 for i in 1..count {
178 out[i] += out[i - 1];
179 }
180 }
181 }
182 }
183
184 pub fn decode_ordinals(&self) -> Vec<u16> {
185 let mut out = Vec::with_capacity(self.header.count as usize);
186 self.decode_ordinals_into(&mut out);
187 out
188 }
189
190 pub fn decode_ordinals_into(&self, out: &mut Vec<u16>) {
194 let count = self.header.count as usize;
195 out.clear();
196 if self.header.ordinal_bits == 0 {
197 out.resize(count, 0u16);
198 } else {
199 let mut temp = [0u32; BLOCK_SIZE];
201 simd::unpack_rounded(
202 &self.ordinals_data,
203 simd::RoundedBitWidth::from_u8(self.header.ordinal_bits),
204 &mut temp[..count],
205 count,
206 );
207 out.reserve(count);
208 for &v in &temp[..count] {
209 out.push(v as u16);
210 }
211 }
212 }
213
214 pub fn decode_weights(&self) -> Vec<f32> {
215 let mut out = Vec::with_capacity(self.header.count as usize);
216 self.decode_weights_into(&mut out);
217 out
218 }
219
220 pub fn decode_weights_into(&self, out: &mut Vec<f32>) {
222 out.clear();
223 decode_weights_into(
224 &self.weights_data,
225 self.header.weight_quant,
226 self.header.count as usize,
227 out,
228 );
229 }
230
231 pub fn decode_scored_weights_into(&self, query_weight: f32, out: &mut Vec<f32>) {
239 out.clear();
240 let count = self.header.count as usize;
241 match self.header.weight_quant {
242 WeightQuantization::UInt8 if self.weights_data.len() >= 8 => {
243 let scale = f32::from_le_bytes([
245 self.weights_data[0],
246 self.weights_data[1],
247 self.weights_data[2],
248 self.weights_data[3],
249 ]);
250 let min_val = f32::from_le_bytes([
251 self.weights_data[4],
252 self.weights_data[5],
253 self.weights_data[6],
254 self.weights_data[7],
255 ]);
256 let eff_scale = query_weight * scale;
258 let eff_bias = query_weight * min_val;
259 out.resize(count, 0.0);
260 simd::dequantize_uint8(&self.weights_data[8..], out, eff_scale, eff_bias, count);
261 }
262 _ => {
263 decode_weights_into(&self.weights_data, self.header.weight_quant, count, out);
265 for w in out.iter_mut() {
266 *w *= query_weight;
267 }
268 }
269 }
270 }
271
272 pub fn write<W: Write>(&self, w: &mut W) -> io::Result<()> {
273 self.header.write(w)?;
274 w.write_u16::<LittleEndian>(self.doc_ids_data.len() as u16)?;
275 w.write_u16::<LittleEndian>(self.ordinals_data.len() as u16)?;
276 w.write_u16::<LittleEndian>(self.weights_data.len() as u16)?;
277 w.write_u16::<LittleEndian>(0)?;
278 w.write_all(&self.doc_ids_data)?;
279 w.write_all(&self.ordinals_data)?;
280 w.write_all(&self.weights_data)?;
281 Ok(())
282 }
283
284 pub fn read<R: Read>(r: &mut R) -> io::Result<Self> {
285 let header = BlockHeader::read(r)?;
286 let doc_ids_len = r.read_u16::<LittleEndian>()? as usize;
287 let ordinals_len = r.read_u16::<LittleEndian>()? as usize;
288 let weights_len = r.read_u16::<LittleEndian>()? as usize;
289 let _ = r.read_u16::<LittleEndian>()?;
290
291 let mut doc_ids_vec = vec![0u8; doc_ids_len];
292 r.read_exact(&mut doc_ids_vec)?;
293 let mut ordinals_vec = vec![0u8; ordinals_len];
294 r.read_exact(&mut ordinals_vec)?;
295 let mut weights_vec = vec![0u8; weights_len];
296 r.read_exact(&mut weights_vec)?;
297
298 Ok(Self {
299 header,
300 doc_ids_data: OwnedBytes::new(doc_ids_vec),
301 ordinals_data: OwnedBytes::new(ordinals_vec),
302 weights_data: OwnedBytes::new(weights_vec),
303 })
304 }
305
306 pub fn from_owned_bytes(data: crate::directories::OwnedBytes) -> crate::Result<Self> {
312 let b = data.as_slice();
313 if b.len() < BlockHeader::SIZE + 8 {
314 return Err(crate::Error::Corruption(
315 "sparse block too small".to_string(),
316 ));
317 }
318 let mut cursor = Cursor::new(&b[..BlockHeader::SIZE]);
319 let header =
320 BlockHeader::read(&mut cursor).map_err(|e| crate::Error::Corruption(e.to_string()))?;
321
322 if header.count == 0 {
323 let hex: String = b
324 .iter()
325 .take(32)
326 .map(|x| format!("{x:02x}"))
327 .collect::<Vec<_>>()
328 .join(" ");
329 return Err(crate::Error::Corruption(format!(
330 "sparse block has count=0 (data_len={}, first_32_bytes=[{}])",
331 b.len(),
332 hex
333 )));
334 }
335
336 let p = BlockHeader::SIZE;
337 let doc_ids_len = u16::from_le_bytes([b[p], b[p + 1]]) as usize;
338 let ordinals_len = u16::from_le_bytes([b[p + 2], b[p + 3]]) as usize;
339 let weights_len = u16::from_le_bytes([b[p + 4], b[p + 5]]) as usize;
340 let data_start = p + 8;
343 let ord_start = data_start + doc_ids_len;
344 let wt_start = ord_start + ordinals_len;
345 let expected_end = wt_start + weights_len;
346
347 if expected_end > b.len() {
348 let hex: String = b
349 .iter()
350 .take(32)
351 .map(|x| format!("{x:02x}"))
352 .collect::<Vec<_>>()
353 .join(" ");
354 return Err(crate::Error::Corruption(format!(
355 "sparse block sub-block overflow: count={} doc_ids={}B ords={}B wts={}B need={}B have={}B (first_32=[{}])",
356 header.count,
357 doc_ids_len,
358 ordinals_len,
359 weights_len,
360 expected_end,
361 b.len(),
362 hex
363 )));
364 }
365
366 Ok(Self {
367 header,
368 doc_ids_data: data.slice(data_start..ord_start),
369 ordinals_data: data.slice(ord_start..wt_start),
370 weights_data: data.slice(wt_start..wt_start + weights_len),
371 })
372 }
373
374 pub fn with_doc_offset(&self, doc_offset: u32) -> Self {
380 Self {
381 header: BlockHeader {
382 first_doc_id: self.header.first_doc_id + doc_offset,
383 ..self.header
384 },
385 doc_ids_data: self.doc_ids_data.clone(),
386 ordinals_data: self.ordinals_data.clone(),
387 weights_data: self.weights_data.clone(),
388 }
389 }
390}
391
392#[derive(Debug, Clone)]
397pub struct BlockSparsePostingList {
398 pub doc_count: u32,
399 pub blocks: Vec<SparseBlock>,
400}
401
402impl BlockSparsePostingList {
403 pub fn from_postings_with_block_size(
405 postings: &[(DocId, u16, f32)],
406 weight_quant: WeightQuantization,
407 block_size: usize,
408 ) -> io::Result<Self> {
409 if postings.is_empty() {
410 return Ok(Self {
411 doc_count: 0,
412 blocks: Vec::new(),
413 });
414 }
415
416 let block_size = block_size.max(16); let mut blocks = Vec::new();
418 for chunk in postings.chunks(block_size) {
419 blocks.push(SparseBlock::from_postings(chunk, weight_quant)?);
420 }
421
422 let mut unique_docs = 1u32;
427 for i in 1..postings.len() {
428 if postings[i].0 != postings[i - 1].0 {
429 unique_docs += 1;
430 }
431 }
432
433 Ok(Self {
434 doc_count: unique_docs,
435 blocks,
436 })
437 }
438
439 pub fn from_postings(
441 postings: &[(DocId, u16, f32)],
442 weight_quant: WeightQuantization,
443 ) -> io::Result<Self> {
444 Self::from_postings_with_block_size(postings, weight_quant, BLOCK_SIZE)
445 }
446
447 pub fn doc_count(&self) -> u32 {
448 self.doc_count
449 }
450
451 pub fn num_blocks(&self) -> usize {
452 self.blocks.len()
453 }
454
455 pub fn global_max_weight(&self) -> f32 {
456 self.blocks
457 .iter()
458 .map(|b| b.header.max_weight)
459 .fold(0.0f32, f32::max)
460 }
461
462 pub fn block_max_weight(&self, block_idx: usize) -> Option<f32> {
463 self.blocks.get(block_idx).map(|b| b.header.max_weight)
464 }
465
466 pub fn size_bytes(&self) -> usize {
468 use std::mem::size_of;
469
470 let header_size = size_of::<u32>() * 2; let blocks_size: usize = self
472 .blocks
473 .iter()
474 .map(|b| {
475 size_of::<BlockHeader>()
476 + b.doc_ids_data.len()
477 + b.ordinals_data.len()
478 + b.weights_data.len()
479 })
480 .sum();
481 header_size + blocks_size
482 }
483
484 pub fn iterator(&self) -> BlockSparsePostingIterator<'_> {
485 BlockSparsePostingIterator::new(self)
486 }
487
488 pub fn serialize(&self) -> io::Result<(Vec<u8>, Vec<super::SparseSkipEntry>)> {
494 let mut block_data = Vec::new();
496 let mut skip_entries = Vec::with_capacity(self.blocks.len());
497 let mut offset = 0u64;
498
499 for block in &self.blocks {
500 let mut buf = Vec::new();
501 block.write(&mut buf)?;
502 let length = buf.len() as u32;
503
504 let first_doc = block.header.first_doc_id;
505 let doc_ids = block.decode_doc_ids();
506 let last_doc = doc_ids.last().copied().unwrap_or(first_doc);
507
508 skip_entries.push(super::SparseSkipEntry::new(
509 first_doc,
510 last_doc,
511 offset,
512 length,
513 block.header.max_weight,
514 ));
515
516 block_data.extend_from_slice(&buf);
517 offset += length as u64;
518 }
519
520 Ok((block_data, skip_entries))
521 }
522
523 #[cfg(test)]
528 pub fn from_parts(
529 doc_count: u32,
530 block_data: &[u8],
531 skip_entries: &[super::SparseSkipEntry],
532 ) -> io::Result<Self> {
533 let mut blocks = Vec::with_capacity(skip_entries.len());
534 for entry in skip_entries {
535 let start = entry.offset as usize;
536 let end = start + entry.length as usize;
537 blocks.push(SparseBlock::read(&mut std::io::Cursor::new(
538 &block_data[start..end],
539 ))?);
540 }
541 Ok(Self { doc_count, blocks })
542 }
543
544 pub fn decode_all(&self) -> Vec<(DocId, u16, f32)> {
545 let total_postings: usize = self.blocks.iter().map(|b| b.header.count as usize).sum();
546 let mut result = Vec::with_capacity(total_postings);
547 for block in &self.blocks {
548 let doc_ids = block.decode_doc_ids();
549 let ordinals = block.decode_ordinals();
550 let weights = block.decode_weights();
551 for i in 0..block.header.count as usize {
552 result.push((doc_ids[i], ordinals[i], weights[i]));
553 }
554 }
555 result
556 }
557
558 pub fn merge_with_offsets(lists: &[(&BlockSparsePostingList, u32)]) -> Self {
569 if lists.is_empty() {
570 return Self {
571 doc_count: 0,
572 blocks: Vec::new(),
573 };
574 }
575
576 let total_blocks: usize = lists.iter().map(|(pl, _)| pl.blocks.len()).sum();
578 let total_docs: u32 = lists.iter().map(|(pl, _)| pl.doc_count).sum();
579
580 let mut merged_blocks = Vec::with_capacity(total_blocks);
581
582 for (posting_list, doc_offset) in lists {
584 for block in &posting_list.blocks {
585 merged_blocks.push(block.with_doc_offset(*doc_offset));
586 }
587 }
588
589 Self {
590 doc_count: total_docs,
591 blocks: merged_blocks,
592 }
593 }
594
595 fn find_block(&self, target: DocId) -> Option<usize> {
596 if self.blocks.is_empty() {
597 return None;
598 }
599 let idx = self
602 .blocks
603 .partition_point(|b| b.header.first_doc_id <= target);
604 if idx == 0 {
605 Some(0)
607 } else {
608 Some(idx - 1)
609 }
610 }
611}
612
613pub struct BlockSparsePostingIterator<'a> {
618 posting_list: &'a BlockSparsePostingList,
619 block_idx: usize,
620 in_block_idx: usize,
621 current_doc_ids: Vec<DocId>,
622 current_ordinals: Vec<u16>,
623 current_weights: Vec<f32>,
624 ordinals_decoded: bool,
626 exhausted: bool,
627}
628
629impl<'a> BlockSparsePostingIterator<'a> {
630 fn new(posting_list: &'a BlockSparsePostingList) -> Self {
631 let mut iter = Self {
632 posting_list,
633 block_idx: 0,
634 in_block_idx: 0,
635 current_doc_ids: Vec::with_capacity(128),
636 current_ordinals: Vec::with_capacity(128),
637 current_weights: Vec::with_capacity(128),
638 ordinals_decoded: false,
639 exhausted: posting_list.blocks.is_empty(),
640 };
641 if !iter.exhausted {
642 iter.load_block(0);
643 }
644 iter
645 }
646
647 fn load_block(&mut self, block_idx: usize) {
648 if let Some(block) = self.posting_list.blocks.get(block_idx) {
649 block.decode_doc_ids_into(&mut self.current_doc_ids);
650 block.decode_weights_into(&mut self.current_weights);
651 self.ordinals_decoded = false;
653 self.block_idx = block_idx;
654 self.in_block_idx = 0;
655 }
656 }
657
658 #[inline]
660 fn ensure_ordinals_decoded(&mut self) {
661 if !self.ordinals_decoded {
662 if let Some(block) = self.posting_list.blocks.get(self.block_idx) {
663 block.decode_ordinals_into(&mut self.current_ordinals);
664 }
665 self.ordinals_decoded = true;
666 }
667 }
668
669 #[inline]
670 pub fn doc(&self) -> DocId {
671 if self.exhausted {
672 TERMINATED
673 } else {
674 self.current_doc_ids[self.in_block_idx]
676 }
677 }
678
679 #[inline]
680 pub fn weight(&self) -> f32 {
681 if self.exhausted {
682 return 0.0;
683 }
684 self.current_weights[self.in_block_idx]
686 }
687
688 #[inline]
689 pub fn ordinal(&mut self) -> u16 {
690 if self.exhausted {
691 return 0;
692 }
693 self.ensure_ordinals_decoded();
694 self.current_ordinals[self.in_block_idx]
695 }
696
697 pub fn advance(&mut self) -> DocId {
698 if self.exhausted {
699 return TERMINATED;
700 }
701 self.in_block_idx += 1;
702 if self.in_block_idx >= self.current_doc_ids.len() {
703 self.block_idx += 1;
704 if self.block_idx >= self.posting_list.blocks.len() {
705 self.exhausted = true;
706 } else {
707 self.load_block(self.block_idx);
708 }
709 }
710 self.doc()
711 }
712
713 pub fn seek(&mut self, target: DocId) -> DocId {
714 if self.exhausted {
715 return TERMINATED;
716 }
717 if self.doc() >= target {
718 return self.doc();
719 }
720
721 if let Some(&last_doc) = self.current_doc_ids.last()
723 && last_doc >= target
724 {
725 let remaining = &self.current_doc_ids[self.in_block_idx..];
726 let pos = crate::structures::simd::find_first_ge_u32(remaining, target);
727 self.in_block_idx += pos;
728 if self.in_block_idx >= self.current_doc_ids.len() {
729 self.block_idx += 1;
730 if self.block_idx >= self.posting_list.blocks.len() {
731 self.exhausted = true;
732 } else {
733 self.load_block(self.block_idx);
734 }
735 }
736 return self.doc();
737 }
738
739 if let Some(block_idx) = self.posting_list.find_block(target) {
741 self.load_block(block_idx);
742 let pos = crate::structures::simd::find_first_ge_u32(&self.current_doc_ids, target);
743 self.in_block_idx = pos;
744 if self.in_block_idx >= self.current_doc_ids.len() {
745 self.block_idx += 1;
746 if self.block_idx >= self.posting_list.blocks.len() {
747 self.exhausted = true;
748 } else {
749 self.load_block(self.block_idx);
750 }
751 }
752 } else {
753 self.exhausted = true;
754 }
755 self.doc()
756 }
757
758 pub fn skip_to_next_block(&mut self) -> DocId {
761 if self.exhausted {
762 return TERMINATED;
763 }
764 let next = self.block_idx + 1;
765 if next >= self.posting_list.blocks.len() {
766 self.exhausted = true;
767 return TERMINATED;
768 }
769 self.load_block(next);
770 self.doc()
771 }
772
773 pub fn is_exhausted(&self) -> bool {
774 self.exhausted
775 }
776
777 pub fn current_block_max_weight(&self) -> f32 {
778 self.posting_list
779 .blocks
780 .get(self.block_idx)
781 .map(|b| b.header.max_weight)
782 .unwrap_or(0.0)
783 }
784
785 pub fn current_block_max_contribution(&self, query_weight: f32) -> f32 {
786 query_weight * self.current_block_max_weight()
787 }
788}
789
790fn find_optimal_bit_width(values: &[u32]) -> u8 {
795 if values.is_empty() {
796 return 0;
797 }
798 let max_val = values.iter().copied().max().unwrap_or(0);
799 simd::bits_needed(max_val)
800}
801
802fn bits_needed_u16(val: u16) -> u8 {
803 if val == 0 {
804 0
805 } else {
806 16 - val.leading_zeros() as u8
807 }
808}
809
810fn encode_weights(weights: &[f32], quant: WeightQuantization) -> io::Result<Vec<u8>> {
815 let mut data = Vec::new();
816 match quant {
817 WeightQuantization::Float32 => {
818 for &w in weights {
819 data.write_f32::<LittleEndian>(w)?;
820 }
821 }
822 WeightQuantization::Float16 => {
823 use half::f16;
824 for &w in weights {
825 data.write_u16::<LittleEndian>(f16::from_f32(w).to_bits())?;
826 }
827 }
828 WeightQuantization::UInt8 => {
829 let min = weights.iter().copied().fold(f32::INFINITY, f32::min);
830 let max = weights.iter().copied().fold(f32::NEG_INFINITY, f32::max);
831 let range = max - min;
832 let scale = if range < f32::EPSILON {
833 1.0
834 } else {
835 range / 255.0
836 };
837 data.write_f32::<LittleEndian>(scale)?;
838 data.write_f32::<LittleEndian>(min)?;
839 for &w in weights {
840 data.write_u8(((w - min) / scale).round() as u8)?;
841 }
842 }
843 WeightQuantization::UInt4 => {
844 let min = weights.iter().copied().fold(f32::INFINITY, f32::min);
845 let max = weights.iter().copied().fold(f32::NEG_INFINITY, f32::max);
846 let range = max - min;
847 let scale = if range < f32::EPSILON {
848 1.0
849 } else {
850 range / 15.0
851 };
852 data.write_f32::<LittleEndian>(scale)?;
853 data.write_f32::<LittleEndian>(min)?;
854 let mut i = 0;
855 while i < weights.len() {
856 let q1 = ((weights[i] - min) / scale).round() as u8 & 0x0F;
857 let q2 = if i + 1 < weights.len() {
858 ((weights[i + 1] - min) / scale).round() as u8 & 0x0F
859 } else {
860 0
861 };
862 data.write_u8((q2 << 4) | q1)?;
863 i += 2;
864 }
865 }
866 }
867 Ok(data)
868}
869
870fn decode_weights_into(data: &[u8], quant: WeightQuantization, count: usize, out: &mut Vec<f32>) {
871 let mut cursor = Cursor::new(data);
872 match quant {
873 WeightQuantization::Float32 => {
874 for _ in 0..count {
875 out.push(cursor.read_f32::<LittleEndian>().unwrap_or(0.0));
876 }
877 }
878 WeightQuantization::Float16 => {
879 use half::f16;
880 for _ in 0..count {
881 let bits = cursor.read_u16::<LittleEndian>().unwrap_or(0);
882 out.push(f16::from_bits(bits).to_f32());
883 }
884 }
885 WeightQuantization::UInt8 => {
886 let scale = cursor.read_f32::<LittleEndian>().unwrap_or(1.0);
887 let min_val = cursor.read_f32::<LittleEndian>().unwrap_or(0.0);
888 let offset = cursor.position() as usize;
889 out.resize(count, 0.0);
890 simd::dequantize_uint8(&data[offset..], out, scale, min_val, count);
891 }
892 WeightQuantization::UInt4 => {
893 let scale = cursor.read_f32::<LittleEndian>().unwrap_or(1.0);
894 let min = cursor.read_f32::<LittleEndian>().unwrap_or(0.0);
895 let mut i = 0;
896 while i < count {
897 let byte = cursor.read_u8().unwrap_or(0);
898 out.push((byte & 0x0F) as f32 * scale + min);
899 i += 1;
900 if i < count {
901 out.push((byte >> 4) as f32 * scale + min);
902 i += 1;
903 }
904 }
905 }
906 }
907}
908
909#[cfg(test)]
910mod tests {
911 use super::*;
912
913 #[test]
914 fn test_block_roundtrip() {
915 let postings = vec![
916 (10u32, 0u16, 1.5f32),
917 (15, 0, 2.0),
918 (20, 1, 0.5),
919 (100, 0, 3.0),
920 ];
921 let block = SparseBlock::from_postings(&postings, WeightQuantization::Float32).unwrap();
922
923 assert_eq!(block.decode_doc_ids(), vec![10, 15, 20, 100]);
924 assert_eq!(block.decode_ordinals(), vec![0, 0, 1, 0]);
925 let weights = block.decode_weights();
926 assert!((weights[0] - 1.5).abs() < 0.01);
927 }
928
929 #[test]
930 fn test_posting_list() {
931 let postings: Vec<(DocId, u16, f32)> =
932 (0..300).map(|i| (i * 2, 0, i as f32 * 0.1)).collect();
933 let list =
934 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
935
936 assert_eq!(list.doc_count(), 300);
937 assert_eq!(list.num_blocks(), 3);
938
939 let mut iter = list.iterator();
940 assert_eq!(iter.doc(), 0);
941 iter.advance();
942 assert_eq!(iter.doc(), 2);
943 }
944
945 #[test]
946 fn test_serialization() {
947 let postings = vec![(1u32, 0u16, 0.5f32), (10, 1, 1.5), (100, 0, 2.5)];
948 let list =
949 BlockSparsePostingList::from_postings(&postings, WeightQuantization::UInt8).unwrap();
950
951 let (block_data, skip_entries) = list.serialize().unwrap();
952 let list2 =
953 BlockSparsePostingList::from_parts(list.doc_count(), &block_data, &skip_entries)
954 .unwrap();
955
956 assert_eq!(list.doc_count(), list2.doc_count());
957 }
958
959 #[test]
960 fn test_seek() {
961 let postings: Vec<(DocId, u16, f32)> = (0..500).map(|i| (i * 3, 0, i as f32)).collect();
962 let list =
963 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
964
965 let mut iter = list.iterator();
966 assert_eq!(iter.seek(300), 300);
967 assert_eq!(iter.seek(301), 303);
968 assert_eq!(iter.seek(2000), TERMINATED);
969 }
970
971 #[test]
972 fn test_merge_with_offsets() {
973 let postings1: Vec<(DocId, u16, f32)> = vec![(0, 0, 1.0), (5, 0, 2.0), (10, 1, 3.0)];
975 let list1 =
976 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
977
978 let postings2: Vec<(DocId, u16, f32)> = vec![(0, 0, 4.0), (3, 1, 5.0), (7, 0, 6.0)];
980 let list2 =
981 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
982
983 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
985
986 assert_eq!(merged.doc_count(), 6);
987
988 let decoded = merged.decode_all();
990 assert_eq!(decoded.len(), 6);
991
992 assert_eq!(decoded[0].0, 0);
994 assert_eq!(decoded[1].0, 5);
995 assert_eq!(decoded[2].0, 10);
996
997 assert_eq!(decoded[3].0, 100); assert_eq!(decoded[4].0, 103); assert_eq!(decoded[5].0, 107); assert!((decoded[0].2 - 1.0).abs() < 0.01);
1004 assert!((decoded[3].2 - 4.0).abs() < 0.01);
1005
1006 assert_eq!(decoded[2].1, 1); assert_eq!(decoded[4].1, 1); }
1010
1011 #[test]
1012 fn test_merge_with_offsets_multi_block() {
1013 let postings1: Vec<(DocId, u16, f32)> = (0..200).map(|i| (i * 2, 0, i as f32)).collect();
1015 let list1 =
1016 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1017 assert!(list1.num_blocks() > 1, "Should have multiple blocks");
1018
1019 let postings2: Vec<(DocId, u16, f32)> = (0..150).map(|i| (i * 3, 1, i as f32)).collect();
1020 let list2 =
1021 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1022
1023 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 1000)]);
1025
1026 assert_eq!(merged.doc_count(), 350);
1027 assert_eq!(merged.num_blocks(), list1.num_blocks() + list2.num_blocks());
1028
1029 let mut iter = merged.iterator();
1031
1032 assert_eq!(iter.doc(), 0);
1034
1035 let doc = iter.seek(1000);
1037 assert_eq!(doc, 1000); iter.advance();
1041 assert_eq!(iter.doc(), 1003); }
1043
1044 #[test]
1045 fn test_merge_with_offsets_serialize_roundtrip() {
1046 let postings1: Vec<(DocId, u16, f32)> = vec![(0, 0, 1.0), (5, 0, 2.0), (10, 1, 3.0)];
1048 let list1 =
1049 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1050
1051 let postings2: Vec<(DocId, u16, f32)> = vec![(0, 0, 4.0), (3, 1, 5.0), (7, 0, 6.0)];
1052 let list2 =
1053 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1054
1055 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
1057
1058 let (block_data, skip_entries) = merged.serialize().unwrap();
1060 let loaded =
1061 BlockSparsePostingList::from_parts(merged.doc_count(), &block_data, &skip_entries)
1062 .unwrap();
1063
1064 let decoded = loaded.decode_all();
1066 assert_eq!(decoded.len(), 6);
1067
1068 assert_eq!(decoded[0].0, 0);
1070 assert_eq!(decoded[1].0, 5);
1071 assert_eq!(decoded[2].0, 10);
1072
1073 assert_eq!(decoded[3].0, 100, "First doc of seg2 should be 0+100=100");
1075 assert_eq!(decoded[4].0, 103, "Second doc of seg2 should be 3+100=103");
1076 assert_eq!(decoded[5].0, 107, "Third doc of seg2 should be 7+100=107");
1077
1078 let mut iter = loaded.iterator();
1080 assert_eq!(iter.doc(), 0);
1081 iter.advance();
1082 assert_eq!(iter.doc(), 5);
1083 iter.advance();
1084 assert_eq!(iter.doc(), 10);
1085 iter.advance();
1086 assert_eq!(iter.doc(), 100);
1087 iter.advance();
1088 assert_eq!(iter.doc(), 103);
1089 iter.advance();
1090 assert_eq!(iter.doc(), 107);
1091 }
1092
1093 #[test]
1094 fn test_merge_seek_after_roundtrip() {
1095 let postings1: Vec<(DocId, u16, f32)> = (0..200).map(|i| (i * 2, 0, 1.0)).collect();
1097 let list1 =
1098 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1099
1100 let postings2: Vec<(DocId, u16, f32)> = (0..150).map(|i| (i * 3, 0, 2.0)).collect();
1101 let list2 =
1102 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1103
1104 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 1000)]);
1106
1107 let (block_data, skip_entries) = merged.serialize().unwrap();
1109 let loaded =
1110 BlockSparsePostingList::from_parts(merged.doc_count(), &block_data, &skip_entries)
1111 .unwrap();
1112
1113 let mut iter = loaded.iterator();
1115
1116 let doc = iter.seek(100);
1118 assert_eq!(doc, 100, "Seek to 100 in segment 1");
1119
1120 let doc = iter.seek(1000);
1122 assert_eq!(doc, 1000, "Seek to 1000 (first doc of segment 2)");
1123
1124 let doc = iter.seek(1050);
1126 assert!(
1127 doc >= 1050,
1128 "Seek to 1050 should find doc >= 1050, got {}",
1129 doc
1130 );
1131
1132 let doc = iter.seek(500);
1134 assert!(
1135 doc >= 1050,
1136 "Seek backwards should not go back, got {}",
1137 doc
1138 );
1139
1140 let mut iter2 = loaded.iterator();
1142
1143 let mut count = 0;
1145 let mut prev_doc = 0;
1146 while iter2.doc() != super::TERMINATED {
1147 let current = iter2.doc();
1148 if count > 0 {
1149 assert!(
1150 current > prev_doc,
1151 "Docs should be monotonically increasing: {} vs {}",
1152 prev_doc,
1153 current
1154 );
1155 }
1156 prev_doc = current;
1157 iter2.advance();
1158 count += 1;
1159 }
1160 assert_eq!(count, 350, "Should have 350 total docs");
1161 }
1162
1163 #[test]
1164 fn test_doc_count_multi_value() {
1165 let postings: Vec<(DocId, u16, f32)> = vec![
1168 (0, 0, 1.0),
1169 (0, 1, 1.5),
1170 (0, 2, 2.0),
1171 (5, 0, 3.0),
1172 (5, 1, 3.5),
1173 (10, 0, 4.0),
1174 ];
1175 let list =
1176 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1177
1178 assert_eq!(list.doc_count(), 3);
1180
1181 let decoded = list.decode_all();
1183 assert_eq!(decoded.len(), 6);
1184 }
1185
1186 #[test]
1190 fn test_zero_copy_merge_patches_first_doc_id() {
1191 use crate::structures::SparseSkipEntry;
1192
1193 let postings1: Vec<(DocId, u16, f32)> = (0..200).map(|i| (i * 2, 0, i as f32)).collect();
1195 let list1 =
1196 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1197 assert!(list1.num_blocks() > 1);
1198
1199 let postings2: Vec<(DocId, u16, f32)> = (0..150).map(|i| (i * 3, 1, i as f32)).collect();
1200 let list2 =
1201 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1202
1203 let (raw1, skip1) = list1.serialize().unwrap();
1205 let (raw2, skip2) = list2.serialize().unwrap();
1206
1207 let doc_offset: u32 = 1000; let total_docs = list1.doc_count() + list2.doc_count();
1210
1211 let mut merged_skip = Vec::new();
1213 let mut cumulative_offset = 0u64;
1214 for entry in &skip1 {
1215 merged_skip.push(SparseSkipEntry::new(
1216 entry.first_doc,
1217 entry.last_doc,
1218 cumulative_offset + entry.offset,
1219 entry.length,
1220 entry.max_weight,
1221 ));
1222 }
1223 if let Some(last) = skip1.last() {
1224 cumulative_offset += last.offset + last.length as u64;
1225 }
1226 for entry in &skip2 {
1227 merged_skip.push(SparseSkipEntry::new(
1228 entry.first_doc + doc_offset,
1229 entry.last_doc + doc_offset,
1230 cumulative_offset + entry.offset,
1231 entry.length,
1232 entry.max_weight,
1233 ));
1234 }
1235
1236 let mut merged_block_data = Vec::new();
1238 merged_block_data.extend_from_slice(&raw1);
1239
1240 const FIRST_DOC_ID_OFFSET: usize = 8;
1241 let mut buf2 = raw2.to_vec();
1242 for entry in &skip2 {
1243 let off = entry.offset as usize + FIRST_DOC_ID_OFFSET;
1244 if off + 4 <= buf2.len() {
1245 let old = u32::from_le_bytes(buf2[off..off + 4].try_into().unwrap());
1246 let patched = (old + doc_offset).to_le_bytes();
1247 buf2[off..off + 4].copy_from_slice(&patched);
1248 }
1249 }
1250 merged_block_data.extend_from_slice(&buf2);
1251
1252 let loaded =
1254 BlockSparsePostingList::from_parts(total_docs, &merged_block_data, &merged_skip)
1255 .unwrap();
1256 assert_eq!(loaded.doc_count(), 350);
1257
1258 let mut iter = loaded.iterator();
1259
1260 assert_eq!(iter.doc(), 0);
1262 let doc = iter.seek(100);
1263 assert_eq!(doc, 100);
1264 let doc = iter.seek(398);
1265 assert_eq!(doc, 398);
1266
1267 let doc = iter.seek(1000);
1269 assert_eq!(doc, 1000, "First doc of segment 2 should be 1000");
1270 iter.advance();
1271 assert_eq!(iter.doc(), 1003, "Second doc of segment 2 should be 1003");
1272 let doc = iter.seek(1447);
1273 assert_eq!(doc, 1447, "Last doc of segment 2 should be 1447");
1274
1275 iter.advance();
1277 assert_eq!(iter.doc(), super::TERMINATED);
1278
1279 let reference =
1281 BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, doc_offset)]);
1282 let mut ref_iter = reference.iterator();
1283 let mut zc_iter = loaded.iterator();
1284 while ref_iter.doc() != super::TERMINATED {
1285 assert_eq!(
1286 ref_iter.doc(),
1287 zc_iter.doc(),
1288 "Zero-copy and reference merge should produce identical doc_ids"
1289 );
1290 assert!(
1291 (ref_iter.weight() - zc_iter.weight()).abs() < 0.01,
1292 "Weights should match: {} vs {}",
1293 ref_iter.weight(),
1294 zc_iter.weight()
1295 );
1296 ref_iter.advance();
1297 zc_iter.advance();
1298 }
1299 assert_eq!(zc_iter.doc(), super::TERMINATED);
1300 }
1301
1302 #[test]
1303 fn test_doc_count_single_value() {
1304 let postings: Vec<(DocId, u16, f32)> =
1306 vec![(0, 0, 1.0), (5, 0, 2.0), (10, 0, 3.0), (15, 0, 4.0)];
1307 let list =
1308 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1309
1310 assert_eq!(list.doc_count(), 4);
1312 }
1313
1314 #[test]
1315 fn test_doc_count_multi_value_serialization_roundtrip() {
1316 let postings: Vec<(DocId, u16, f32)> =
1318 vec![(0, 0, 1.0), (0, 1, 1.5), (5, 0, 2.0), (5, 1, 2.5)];
1319 let list =
1320 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1321 assert_eq!(list.doc_count(), 2);
1322
1323 let (block_data, skip_entries) = list.serialize().unwrap();
1324 let loaded =
1325 BlockSparsePostingList::from_parts(list.doc_count(), &block_data, &skip_entries)
1326 .unwrap();
1327 assert_eq!(loaded.doc_count(), 2);
1328 }
1329
1330 #[test]
1331 fn test_merge_preserves_weights_and_ordinals() {
1332 let postings1: Vec<(DocId, u16, f32)> = vec![(0, 0, 1.5), (5, 1, 2.5), (10, 2, 3.5)];
1334 let list1 =
1335 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1336
1337 let postings2: Vec<(DocId, u16, f32)> = vec![(0, 0, 4.5), (3, 1, 5.5), (7, 3, 6.5)];
1338 let list2 =
1339 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1340
1341 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
1343
1344 let (block_data, skip_entries) = merged.serialize().unwrap();
1346 let loaded =
1347 BlockSparsePostingList::from_parts(merged.doc_count(), &block_data, &skip_entries)
1348 .unwrap();
1349
1350 let mut iter = loaded.iterator();
1352
1353 assert_eq!(iter.doc(), 0);
1355 assert!(
1356 (iter.weight() - 1.5).abs() < 0.01,
1357 "Weight should be 1.5, got {}",
1358 iter.weight()
1359 );
1360 assert_eq!(iter.ordinal(), 0);
1361
1362 iter.advance();
1363 assert_eq!(iter.doc(), 5);
1364 assert!(
1365 (iter.weight() - 2.5).abs() < 0.01,
1366 "Weight should be 2.5, got {}",
1367 iter.weight()
1368 );
1369 assert_eq!(iter.ordinal(), 1);
1370
1371 iter.advance();
1372 assert_eq!(iter.doc(), 10);
1373 assert!(
1374 (iter.weight() - 3.5).abs() < 0.01,
1375 "Weight should be 3.5, got {}",
1376 iter.weight()
1377 );
1378 assert_eq!(iter.ordinal(), 2);
1379
1380 iter.advance();
1382 assert_eq!(iter.doc(), 100);
1383 assert!(
1384 (iter.weight() - 4.5).abs() < 0.01,
1385 "Weight should be 4.5, got {}",
1386 iter.weight()
1387 );
1388 assert_eq!(iter.ordinal(), 0);
1389
1390 iter.advance();
1391 assert_eq!(iter.doc(), 103);
1392 assert!(
1393 (iter.weight() - 5.5).abs() < 0.01,
1394 "Weight should be 5.5, got {}",
1395 iter.weight()
1396 );
1397 assert_eq!(iter.ordinal(), 1);
1398
1399 iter.advance();
1400 assert_eq!(iter.doc(), 107);
1401 assert!(
1402 (iter.weight() - 6.5).abs() < 0.01,
1403 "Weight should be 6.5, got {}",
1404 iter.weight()
1405 );
1406 assert_eq!(iter.ordinal(), 3);
1407
1408 iter.advance();
1410 assert_eq!(iter.doc(), super::TERMINATED);
1411 }
1412
1413 #[test]
1414 fn test_merge_global_max_weight() {
1415 let postings1: Vec<(DocId, u16, f32)> = vec![
1417 (0, 0, 3.0),
1418 (1, 0, 7.0), (2, 0, 2.0),
1420 ];
1421 let list1 =
1422 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1423
1424 let postings2: Vec<(DocId, u16, f32)> = vec![
1425 (0, 0, 5.0),
1426 (1, 0, 4.0),
1427 (2, 0, 6.0), ];
1429 let list2 =
1430 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1431
1432 assert!((list1.global_max_weight() - 7.0).abs() < 0.01);
1434 assert!((list2.global_max_weight() - 6.0).abs() < 0.01);
1435
1436 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
1438
1439 assert!(
1441 (merged.global_max_weight() - 7.0).abs() < 0.01,
1442 "Global max should be 7.0, got {}",
1443 merged.global_max_weight()
1444 );
1445
1446 let (block_data, skip_entries) = merged.serialize().unwrap();
1448 let loaded =
1449 BlockSparsePostingList::from_parts(merged.doc_count(), &block_data, &skip_entries)
1450 .unwrap();
1451
1452 assert!(
1453 (loaded.global_max_weight() - 7.0).abs() < 0.01,
1454 "After roundtrip, global max should still be 7.0, got {}",
1455 loaded.global_max_weight()
1456 );
1457 }
1458
1459 #[test]
1460 fn test_scoring_simulation_after_merge() {
1461 let postings1: Vec<(DocId, u16, f32)> = vec![
1463 (0, 0, 0.5), (5, 0, 0.8), ];
1466 let list1 =
1467 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1468
1469 let postings2: Vec<(DocId, u16, f32)> = vec![
1470 (0, 0, 0.6), (3, 0, 0.9), ];
1473 let list2 =
1474 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1475
1476 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
1478
1479 let (block_data, skip_entries) = merged.serialize().unwrap();
1481 let loaded =
1482 BlockSparsePostingList::from_parts(merged.doc_count(), &block_data, &skip_entries)
1483 .unwrap();
1484
1485 let query_weight = 2.0f32;
1487 let mut iter = loaded.iterator();
1488
1489 assert_eq!(iter.doc(), 0);
1492 let score = query_weight * iter.weight();
1493 assert!(
1494 (score - 1.0).abs() < 0.01,
1495 "Doc 0 score should be 1.0, got {}",
1496 score
1497 );
1498
1499 iter.advance();
1500 assert_eq!(iter.doc(), 5);
1502 let score = query_weight * iter.weight();
1503 assert!(
1504 (score - 1.6).abs() < 0.01,
1505 "Doc 5 score should be 1.6, got {}",
1506 score
1507 );
1508
1509 iter.advance();
1510 assert_eq!(iter.doc(), 100);
1512 let score = query_weight * iter.weight();
1513 assert!(
1514 (score - 1.2).abs() < 0.01,
1515 "Doc 100 score should be 1.2, got {}",
1516 score
1517 );
1518
1519 iter.advance();
1520 assert_eq!(iter.doc(), 103);
1522 let score = query_weight * iter.weight();
1523 assert!(
1524 (score - 1.8).abs() < 0.01,
1525 "Doc 103 score should be 1.8, got {}",
1526 score
1527 );
1528 }
1529}