1use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
9use std::io::{self, Cursor, Read, Write};
10
11use super::config::WeightQuantization;
12use crate::DocId;
13use crate::structures::postings::TERMINATED;
14use crate::structures::simd;
15
16pub const BLOCK_SIZE: usize = 128;
17
18#[derive(Debug, Clone, Copy)]
19pub struct BlockHeader {
20 pub count: u16,
21 pub doc_id_bits: u8,
22 pub ordinal_bits: u8,
23 pub weight_quant: WeightQuantization,
24 pub first_doc_id: DocId,
25 pub max_weight: f32,
26}
27
28impl BlockHeader {
29 pub const SIZE: usize = 16;
30
31 pub fn write<W: Write>(&self, w: &mut W) -> io::Result<()> {
32 w.write_u16::<LittleEndian>(self.count)?;
33 w.write_u8(self.doc_id_bits)?;
34 w.write_u8(self.ordinal_bits)?;
35 w.write_u8(self.weight_quant as u8)?;
36 w.write_u8(0)?;
37 w.write_u16::<LittleEndian>(0)?;
38 w.write_u32::<LittleEndian>(self.first_doc_id)?;
39 w.write_f32::<LittleEndian>(self.max_weight)?;
40 Ok(())
41 }
42
43 pub fn read<R: Read>(r: &mut R) -> io::Result<Self> {
44 let count = r.read_u16::<LittleEndian>()?;
45 let doc_id_bits = r.read_u8()?;
46 let ordinal_bits = r.read_u8()?;
47 let weight_quant_byte = r.read_u8()?;
48 let _ = r.read_u8()?;
49 let _ = r.read_u16::<LittleEndian>()?;
50 let first_doc_id = r.read_u32::<LittleEndian>()?;
51 let max_weight = r.read_f32::<LittleEndian>()?;
52
53 let weight_quant = WeightQuantization::from_u8(weight_quant_byte)
54 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Invalid weight quant"))?;
55
56 Ok(Self {
57 count,
58 doc_id_bits,
59 ordinal_bits,
60 weight_quant,
61 first_doc_id,
62 max_weight,
63 })
64 }
65}
66
67#[derive(Debug, Clone)]
68pub struct SparseBlock {
69 pub header: BlockHeader,
70 pub doc_ids_data: Vec<u8>,
71 pub ordinals_data: Vec<u8>,
72 pub weights_data: Vec<u8>,
73}
74
75impl SparseBlock {
76 pub fn from_postings(
77 postings: &[(DocId, u16, f32)],
78 weight_quant: WeightQuantization,
79 ) -> io::Result<Self> {
80 assert!(!postings.is_empty() && postings.len() <= BLOCK_SIZE);
81
82 let count = postings.len();
83 let first_doc_id = postings[0].0;
84
85 let mut deltas = Vec::with_capacity(count);
87 let mut prev = first_doc_id;
88 for &(doc_id, _, _) in postings {
89 deltas.push(doc_id.saturating_sub(prev));
90 prev = doc_id;
91 }
92 deltas[0] = 0;
93
94 let doc_id_bits = find_optimal_bit_width(&deltas[1..]);
95 let ordinals: Vec<u16> = postings.iter().map(|(_, o, _)| *o).collect();
96 let max_ordinal = ordinals.iter().copied().max().unwrap_or(0);
97 let ordinal_bits = if max_ordinal == 0 {
98 0
99 } else {
100 bits_needed_u16(max_ordinal)
101 };
102
103 let weights: Vec<f32> = postings.iter().map(|(_, _, w)| *w).collect();
104 let max_weight = weights.iter().copied().fold(0.0f32, f32::max);
105
106 let doc_ids_data = pack_bit_array(&deltas[1..], doc_id_bits);
107 let ordinals_data = if ordinal_bits > 0 {
108 pack_bit_array_u16(&ordinals, ordinal_bits)
109 } else {
110 Vec::new()
111 };
112 let weights_data = encode_weights(&weights, weight_quant)?;
113
114 Ok(Self {
115 header: BlockHeader {
116 count: count as u16,
117 doc_id_bits,
118 ordinal_bits,
119 weight_quant,
120 first_doc_id,
121 max_weight,
122 },
123 doc_ids_data,
124 ordinals_data,
125 weights_data,
126 })
127 }
128
129 pub fn decode_doc_ids(&self) -> Vec<DocId> {
130 let count = self.header.count as usize;
131 let mut doc_ids = Vec::with_capacity(count);
132 doc_ids.push(self.header.first_doc_id);
133
134 if count > 1 {
135 let deltas = unpack_bit_array(&self.doc_ids_data, self.header.doc_id_bits, count - 1);
136 let mut prev = self.header.first_doc_id;
137 for delta in deltas {
138 prev += delta;
139 doc_ids.push(prev);
140 }
141 }
142 doc_ids
143 }
144
145 pub fn decode_ordinals(&self) -> Vec<u16> {
146 let count = self.header.count as usize;
147 if self.header.ordinal_bits == 0 {
148 vec![0u16; count]
149 } else {
150 unpack_bit_array_u16(&self.ordinals_data, self.header.ordinal_bits, count)
151 }
152 }
153
154 pub fn decode_weights(&self) -> Vec<f32> {
155 decode_weights(
156 &self.weights_data,
157 self.header.weight_quant,
158 self.header.count as usize,
159 )
160 }
161
162 pub fn write<W: Write>(&self, w: &mut W) -> io::Result<()> {
163 self.header.write(w)?;
164 w.write_u16::<LittleEndian>(self.doc_ids_data.len() as u16)?;
165 w.write_u16::<LittleEndian>(self.ordinals_data.len() as u16)?;
166 w.write_u16::<LittleEndian>(self.weights_data.len() as u16)?;
167 w.write_u16::<LittleEndian>(0)?;
168 w.write_all(&self.doc_ids_data)?;
169 w.write_all(&self.ordinals_data)?;
170 w.write_all(&self.weights_data)?;
171 Ok(())
172 }
173
174 pub fn read<R: Read>(r: &mut R) -> io::Result<Self> {
175 let header = BlockHeader::read(r)?;
176 let doc_ids_len = r.read_u16::<LittleEndian>()? as usize;
177 let ordinals_len = r.read_u16::<LittleEndian>()? as usize;
178 let weights_len = r.read_u16::<LittleEndian>()? as usize;
179 let _ = r.read_u16::<LittleEndian>()?;
180
181 let mut doc_ids_data = vec![0u8; doc_ids_len];
182 r.read_exact(&mut doc_ids_data)?;
183 let mut ordinals_data = vec![0u8; ordinals_len];
184 r.read_exact(&mut ordinals_data)?;
185 let mut weights_data = vec![0u8; weights_len];
186 r.read_exact(&mut weights_data)?;
187
188 Ok(Self {
189 header,
190 doc_ids_data,
191 ordinals_data,
192 weights_data,
193 })
194 }
195
196 pub fn with_doc_offset(&self, doc_offset: u32) -> Self {
202 Self {
203 header: BlockHeader {
204 first_doc_id: self.header.first_doc_id + doc_offset,
205 ..self.header
206 },
207 doc_ids_data: self.doc_ids_data.clone(),
208 ordinals_data: self.ordinals_data.clone(),
209 weights_data: self.weights_data.clone(),
210 }
211 }
212}
213
214#[derive(Debug, Clone)]
219pub struct BlockSparsePostingList {
220 pub doc_count: u32,
221 pub blocks: Vec<SparseBlock>,
222}
223
224impl BlockSparsePostingList {
225 pub fn from_postings_with_block_size(
227 postings: &[(DocId, u16, f32)],
228 weight_quant: WeightQuantization,
229 block_size: usize,
230 ) -> io::Result<Self> {
231 if postings.is_empty() {
232 return Ok(Self {
233 doc_count: 0,
234 blocks: Vec::new(),
235 });
236 }
237
238 let block_size = block_size.max(16); let mut blocks = Vec::new();
240 for chunk in postings.chunks(block_size) {
241 blocks.push(SparseBlock::from_postings(chunk, weight_quant)?);
242 }
243
244 Ok(Self {
245 doc_count: postings.len() as u32,
246 blocks,
247 })
248 }
249
250 pub fn from_postings(
252 postings: &[(DocId, u16, f32)],
253 weight_quant: WeightQuantization,
254 ) -> io::Result<Self> {
255 Self::from_postings_with_block_size(postings, weight_quant, BLOCK_SIZE)
256 }
257
258 pub fn doc_count(&self) -> u32 {
259 self.doc_count
260 }
261
262 pub fn num_blocks(&self) -> usize {
263 self.blocks.len()
264 }
265
266 pub fn global_max_weight(&self) -> f32 {
267 self.blocks
268 .iter()
269 .map(|b| b.header.max_weight)
270 .fold(0.0f32, f32::max)
271 }
272
273 pub fn block_max_weight(&self, block_idx: usize) -> Option<f32> {
274 self.blocks.get(block_idx).map(|b| b.header.max_weight)
275 }
276
277 pub fn size_bytes(&self) -> usize {
279 use std::mem::size_of;
280
281 let header_size = size_of::<u32>() * 2; let blocks_size: usize = self
283 .blocks
284 .iter()
285 .map(|b| {
286 size_of::<BlockHeader>()
287 + b.doc_ids_data.len()
288 + b.ordinals_data.len()
289 + b.weights_data.len()
290 })
291 .sum();
292 header_size + blocks_size
293 }
294
295 pub fn iterator(&self) -> BlockSparsePostingIterator<'_> {
296 BlockSparsePostingIterator::new(self)
297 }
298
299 pub fn serialize<W: Write>(&self, w: &mut W) -> io::Result<()> {
308 use super::SparseSkipEntry;
309
310 w.write_u32::<LittleEndian>(self.doc_count)?;
311 w.write_f32::<LittleEndian>(self.global_max_weight())?;
312 w.write_u32::<LittleEndian>(self.blocks.len() as u32)?;
313
314 let mut block_bytes: Vec<Vec<u8>> = Vec::with_capacity(self.blocks.len());
316 for block in &self.blocks {
317 let mut buf = Vec::new();
318 block.write(&mut buf)?;
319 block_bytes.push(buf);
320 }
321
322 let mut offset = 0u32;
324 for (block, bytes) in self.blocks.iter().zip(block_bytes.iter()) {
325 let doc_ids = block.decode_doc_ids();
326 let first_doc = doc_ids.first().copied().unwrap_or(0);
327 let last_doc = doc_ids.last().copied().unwrap_or(0);
328 let length = bytes.len() as u32;
329
330 let entry =
331 SparseSkipEntry::new(first_doc, last_doc, offset, length, block.header.max_weight);
332 entry.write(w)?;
333 offset += length;
334 }
335
336 for bytes in block_bytes {
338 w.write_all(&bytes)?;
339 }
340
341 Ok(())
342 }
343
344 pub fn deserialize<R: Read>(r: &mut R) -> io::Result<Self> {
347 use super::SparseSkipEntry;
348
349 let doc_count = r.read_u32::<LittleEndian>()?;
350 let _global_max_weight = r.read_f32::<LittleEndian>()?;
351 let num_blocks = r.read_u32::<LittleEndian>()? as usize;
352
353 for _ in 0..num_blocks {
355 let _ = SparseSkipEntry::read(r)?;
356 }
357
358 let mut blocks = Vec::with_capacity(num_blocks);
360 for _ in 0..num_blocks {
361 blocks.push(SparseBlock::read(r)?);
362 }
363 Ok(Self { doc_count, blocks })
364 }
365
366 pub fn deserialize_header<R: Read>(
369 r: &mut R,
370 ) -> io::Result<(u32, f32, Vec<super::SparseSkipEntry>, usize)> {
371 use super::SparseSkipEntry;
372
373 let doc_count = r.read_u32::<LittleEndian>()?;
374 let global_max_weight = r.read_f32::<LittleEndian>()?;
375 let num_blocks = r.read_u32::<LittleEndian>()? as usize;
376
377 let mut entries = Vec::with_capacity(num_blocks);
378 for _ in 0..num_blocks {
379 entries.push(SparseSkipEntry::read(r)?);
380 }
381
382 let header_size = 4 + 4 + 4 + num_blocks * SparseSkipEntry::SIZE;
384
385 Ok((doc_count, global_max_weight, entries, header_size))
386 }
387
388 pub fn decode_all(&self) -> Vec<(DocId, u16, f32)> {
389 let mut result = Vec::with_capacity(self.doc_count as usize);
390 for block in &self.blocks {
391 let doc_ids = block.decode_doc_ids();
392 let ordinals = block.decode_ordinals();
393 let weights = block.decode_weights();
394 for i in 0..block.header.count as usize {
395 result.push((doc_ids[i], ordinals[i], weights[i]));
396 }
397 }
398 result
399 }
400
401 pub fn merge_with_offsets(lists: &[(&BlockSparsePostingList, u32)]) -> Self {
412 if lists.is_empty() {
413 return Self {
414 doc_count: 0,
415 blocks: Vec::new(),
416 };
417 }
418
419 let total_blocks: usize = lists.iter().map(|(pl, _)| pl.blocks.len()).sum();
421 let total_docs: u32 = lists.iter().map(|(pl, _)| pl.doc_count).sum();
422
423 let mut merged_blocks = Vec::with_capacity(total_blocks);
424
425 for (posting_list, doc_offset) in lists {
427 for block in &posting_list.blocks {
428 merged_blocks.push(block.with_doc_offset(*doc_offset));
429 }
430 }
431
432 Self {
433 doc_count: total_docs,
434 blocks: merged_blocks,
435 }
436 }
437
438 fn find_block(&self, target: DocId) -> Option<usize> {
439 let mut lo = 0;
440 let mut hi = self.blocks.len();
441 while lo < hi {
442 let mid = lo + (hi - lo) / 2;
443 let block = &self.blocks[mid];
444 let doc_ids = block.decode_doc_ids();
445 let last_doc = doc_ids.last().copied().unwrap_or(block.header.first_doc_id);
446 if last_doc < target {
447 lo = mid + 1;
448 } else {
449 hi = mid;
450 }
451 }
452 if lo < self.blocks.len() {
453 Some(lo)
454 } else {
455 None
456 }
457 }
458}
459
460pub struct BlockSparsePostingIterator<'a> {
465 posting_list: &'a BlockSparsePostingList,
466 block_idx: usize,
467 in_block_idx: usize,
468 current_doc_ids: Vec<DocId>,
469 current_weights: Vec<f32>,
470 exhausted: bool,
471}
472
473impl<'a> BlockSparsePostingIterator<'a> {
474 fn new(posting_list: &'a BlockSparsePostingList) -> Self {
475 let mut iter = Self {
476 posting_list,
477 block_idx: 0,
478 in_block_idx: 0,
479 current_doc_ids: Vec::new(),
480 current_weights: Vec::new(),
481 exhausted: posting_list.blocks.is_empty(),
482 };
483 if !iter.exhausted {
484 iter.load_block(0);
485 }
486 iter
487 }
488
489 fn load_block(&mut self, block_idx: usize) {
490 if let Some(block) = self.posting_list.blocks.get(block_idx) {
491 self.current_doc_ids = block.decode_doc_ids();
492 self.current_weights = block.decode_weights();
493 self.block_idx = block_idx;
494 self.in_block_idx = 0;
495 }
496 }
497
498 pub fn doc(&self) -> DocId {
499 if self.exhausted {
500 TERMINATED
501 } else {
502 self.current_doc_ids
503 .get(self.in_block_idx)
504 .copied()
505 .unwrap_or(TERMINATED)
506 }
507 }
508
509 pub fn weight(&self) -> f32 {
510 self.current_weights
511 .get(self.in_block_idx)
512 .copied()
513 .unwrap_or(0.0)
514 }
515
516 pub fn ordinal(&self) -> u16 {
517 if let Some(block) = self.posting_list.blocks.get(self.block_idx) {
518 let ordinals = block.decode_ordinals();
519 ordinals.get(self.in_block_idx).copied().unwrap_or(0)
520 } else {
521 0
522 }
523 }
524
525 pub fn advance(&mut self) -> DocId {
526 if self.exhausted {
527 return TERMINATED;
528 }
529 self.in_block_idx += 1;
530 if self.in_block_idx >= self.current_doc_ids.len() {
531 self.block_idx += 1;
532 if self.block_idx >= self.posting_list.blocks.len() {
533 self.exhausted = true;
534 } else {
535 self.load_block(self.block_idx);
536 }
537 }
538 self.doc()
539 }
540
541 pub fn seek(&mut self, target: DocId) -> DocId {
542 if self.exhausted {
543 return TERMINATED;
544 }
545 if self.doc() >= target {
546 return self.doc();
547 }
548
549 if let Some(&last_doc) = self.current_doc_ids.last()
551 && last_doc >= target
552 {
553 while !self.exhausted && self.doc() < target {
554 self.in_block_idx += 1;
555 if self.in_block_idx >= self.current_doc_ids.len() {
556 self.block_idx += 1;
557 if self.block_idx >= self.posting_list.blocks.len() {
558 self.exhausted = true;
559 } else {
560 self.load_block(self.block_idx);
561 }
562 }
563 }
564 return self.doc();
565 }
566
567 if let Some(block_idx) = self.posting_list.find_block(target) {
569 self.load_block(block_idx);
570 while self.in_block_idx < self.current_doc_ids.len()
571 && self.current_doc_ids[self.in_block_idx] < target
572 {
573 self.in_block_idx += 1;
574 }
575 if self.in_block_idx >= self.current_doc_ids.len() {
576 self.block_idx += 1;
577 if self.block_idx >= self.posting_list.blocks.len() {
578 self.exhausted = true;
579 } else {
580 self.load_block(self.block_idx);
581 }
582 }
583 } else {
584 self.exhausted = true;
585 }
586 self.doc()
587 }
588
589 pub fn is_exhausted(&self) -> bool {
590 self.exhausted
591 }
592
593 pub fn current_block_max_weight(&self) -> f32 {
594 self.posting_list
595 .blocks
596 .get(self.block_idx)
597 .map(|b| b.header.max_weight)
598 .unwrap_or(0.0)
599 }
600
601 pub fn current_block_max_contribution(&self, query_weight: f32) -> f32 {
602 query_weight * self.current_block_max_weight()
603 }
604}
605
606fn find_optimal_bit_width(values: &[u32]) -> u8 {
611 if values.is_empty() {
612 return 0;
613 }
614 let max_val = values.iter().copied().max().unwrap_or(0);
615 simd::bits_needed(max_val)
616}
617
618fn bits_needed_u16(val: u16) -> u8 {
619 if val == 0 {
620 0
621 } else {
622 16 - val.leading_zeros() as u8
623 }
624}
625
626fn pack_bit_array(values: &[u32], bits: u8) -> Vec<u8> {
627 if bits == 0 || values.is_empty() {
628 return Vec::new();
629 }
630 let total_bytes = (values.len() * bits as usize).div_ceil(8);
631 let mut result = vec![0u8; total_bytes];
632 let mut bit_pos = 0usize;
633 for &val in values {
634 pack_value(&mut result, bit_pos, val & ((1u32 << bits) - 1), bits);
635 bit_pos += bits as usize;
636 }
637 result
638}
639
640fn pack_bit_array_u16(values: &[u16], bits: u8) -> Vec<u8> {
641 if bits == 0 || values.is_empty() {
642 return Vec::new();
643 }
644 let total_bytes = (values.len() * bits as usize).div_ceil(8);
645 let mut result = vec![0u8; total_bytes];
646 let mut bit_pos = 0usize;
647 for &val in values {
648 pack_value(
649 &mut result,
650 bit_pos,
651 (val as u32) & ((1u32 << bits) - 1),
652 bits,
653 );
654 bit_pos += bits as usize;
655 }
656 result
657}
658
659#[inline]
660fn pack_value(data: &mut [u8], bit_pos: usize, val: u32, bits: u8) {
661 let mut remaining = bits as usize;
662 let mut val = val;
663 let mut byte = bit_pos / 8;
664 let mut offset = bit_pos % 8;
665 while remaining > 0 {
666 let space = 8 - offset;
667 let to_write = remaining.min(space);
668 let mask = (1u32 << to_write) - 1;
669 data[byte] |= ((val & mask) as u8) << offset;
670 val >>= to_write;
671 remaining -= to_write;
672 byte += 1;
673 offset = 0;
674 }
675}
676
677fn unpack_bit_array(data: &[u8], bits: u8, count: usize) -> Vec<u32> {
678 if bits == 0 || count == 0 {
679 return vec![0; count];
680 }
681 let mut result = Vec::with_capacity(count);
682 let mut bit_pos = 0usize;
683 for _ in 0..count {
684 result.push(unpack_value(data, bit_pos, bits));
685 bit_pos += bits as usize;
686 }
687 result
688}
689
690fn unpack_bit_array_u16(data: &[u8], bits: u8, count: usize) -> Vec<u16> {
691 if bits == 0 || count == 0 {
692 return vec![0; count];
693 }
694 let mut result = Vec::with_capacity(count);
695 let mut bit_pos = 0usize;
696 for _ in 0..count {
697 result.push(unpack_value(data, bit_pos, bits) as u16);
698 bit_pos += bits as usize;
699 }
700 result
701}
702
703#[inline]
704fn unpack_value(data: &[u8], bit_pos: usize, bits: u8) -> u32 {
705 let mut val = 0u32;
706 let mut remaining = bits as usize;
707 let mut byte = bit_pos / 8;
708 let mut offset = bit_pos % 8;
709 let mut shift = 0;
710 while remaining > 0 {
711 let space = 8 - offset;
712 let to_read = remaining.min(space);
713 let mask = (1u8 << to_read) - 1;
714 val |= (((data.get(byte).copied().unwrap_or(0) >> offset) & mask) as u32) << shift;
715 remaining -= to_read;
716 shift += to_read;
717 byte += 1;
718 offset = 0;
719 }
720 val
721}
722
723fn encode_weights(weights: &[f32], quant: WeightQuantization) -> io::Result<Vec<u8>> {
728 let mut data = Vec::new();
729 match quant {
730 WeightQuantization::Float32 => {
731 for &w in weights {
732 data.write_f32::<LittleEndian>(w)?;
733 }
734 }
735 WeightQuantization::Float16 => {
736 use half::f16;
737 for &w in weights {
738 data.write_u16::<LittleEndian>(f16::from_f32(w).to_bits())?;
739 }
740 }
741 WeightQuantization::UInt8 => {
742 let min = weights.iter().copied().fold(f32::INFINITY, f32::min);
743 let max = weights.iter().copied().fold(f32::NEG_INFINITY, f32::max);
744 let range = max - min;
745 let scale = if range < f32::EPSILON {
746 1.0
747 } else {
748 range / 255.0
749 };
750 data.write_f32::<LittleEndian>(scale)?;
751 data.write_f32::<LittleEndian>(min)?;
752 for &w in weights {
753 data.write_u8(((w - min) / scale).round() as u8)?;
754 }
755 }
756 WeightQuantization::UInt4 => {
757 let min = weights.iter().copied().fold(f32::INFINITY, f32::min);
758 let max = weights.iter().copied().fold(f32::NEG_INFINITY, f32::max);
759 let range = max - min;
760 let scale = if range < f32::EPSILON {
761 1.0
762 } else {
763 range / 15.0
764 };
765 data.write_f32::<LittleEndian>(scale)?;
766 data.write_f32::<LittleEndian>(min)?;
767 let mut i = 0;
768 while i < weights.len() {
769 let q1 = ((weights[i] - min) / scale).round() as u8 & 0x0F;
770 let q2 = if i + 1 < weights.len() {
771 ((weights[i + 1] - min) / scale).round() as u8 & 0x0F
772 } else {
773 0
774 };
775 data.write_u8((q2 << 4) | q1)?;
776 i += 2;
777 }
778 }
779 }
780 Ok(data)
781}
782
783fn decode_weights(data: &[u8], quant: WeightQuantization, count: usize) -> Vec<f32> {
784 let mut cursor = Cursor::new(data);
785 let mut weights = Vec::with_capacity(count);
786 match quant {
787 WeightQuantization::Float32 => {
788 for _ in 0..count {
789 weights.push(cursor.read_f32::<LittleEndian>().unwrap_or(0.0));
790 }
791 }
792 WeightQuantization::Float16 => {
793 use half::f16;
794 for _ in 0..count {
795 let bits = cursor.read_u16::<LittleEndian>().unwrap_or(0);
796 weights.push(f16::from_bits(bits).to_f32());
797 }
798 }
799 WeightQuantization::UInt8 => {
800 let scale = cursor.read_f32::<LittleEndian>().unwrap_or(1.0);
801 let min = cursor.read_f32::<LittleEndian>().unwrap_or(0.0);
802 for _ in 0..count {
803 let q = cursor.read_u8().unwrap_or(0);
804 weights.push(q as f32 * scale + min);
805 }
806 }
807 WeightQuantization::UInt4 => {
808 let scale = cursor.read_f32::<LittleEndian>().unwrap_or(1.0);
809 let min = cursor.read_f32::<LittleEndian>().unwrap_or(0.0);
810 let mut i = 0;
811 while i < count {
812 let byte = cursor.read_u8().unwrap_or(0);
813 weights.push((byte & 0x0F) as f32 * scale + min);
814 i += 1;
815 if i < count {
816 weights.push((byte >> 4) as f32 * scale + min);
817 i += 1;
818 }
819 }
820 }
821 }
822 weights
823}
824
825#[cfg(test)]
826mod tests {
827 use super::*;
828
829 #[test]
830 fn test_block_roundtrip() {
831 let postings = vec![
832 (10u32, 0u16, 1.5f32),
833 (15, 0, 2.0),
834 (20, 1, 0.5),
835 (100, 0, 3.0),
836 ];
837 let block = SparseBlock::from_postings(&postings, WeightQuantization::Float32).unwrap();
838
839 assert_eq!(block.decode_doc_ids(), vec![10, 15, 20, 100]);
840 assert_eq!(block.decode_ordinals(), vec![0, 0, 1, 0]);
841 let weights = block.decode_weights();
842 assert!((weights[0] - 1.5).abs() < 0.01);
843 }
844
845 #[test]
846 fn test_posting_list() {
847 let postings: Vec<(DocId, u16, f32)> =
848 (0..300).map(|i| (i * 2, 0, i as f32 * 0.1)).collect();
849 let list =
850 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
851
852 assert_eq!(list.doc_count(), 300);
853 assert_eq!(list.num_blocks(), 3);
854
855 let mut iter = list.iterator();
856 assert_eq!(iter.doc(), 0);
857 iter.advance();
858 assert_eq!(iter.doc(), 2);
859 }
860
861 #[test]
862 fn test_serialization() {
863 let postings = vec![(1u32, 0u16, 0.5f32), (10, 1, 1.5), (100, 0, 2.5)];
864 let list =
865 BlockSparsePostingList::from_postings(&postings, WeightQuantization::UInt8).unwrap();
866
867 let mut buf = Vec::new();
868 list.serialize(&mut buf).unwrap();
869 let list2 = BlockSparsePostingList::deserialize(&mut Cursor::new(&buf)).unwrap();
870
871 assert_eq!(list.doc_count(), list2.doc_count());
872 }
873
874 #[test]
875 fn test_seek() {
876 let postings: Vec<(DocId, u16, f32)> = (0..500).map(|i| (i * 3, 0, i as f32)).collect();
877 let list =
878 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
879
880 let mut iter = list.iterator();
881 assert_eq!(iter.seek(300), 300);
882 assert_eq!(iter.seek(301), 303);
883 assert_eq!(iter.seek(2000), TERMINATED);
884 }
885
886 #[test]
887 fn test_merge_with_offsets() {
888 let postings1: Vec<(DocId, u16, f32)> = vec![(0, 0, 1.0), (5, 0, 2.0), (10, 1, 3.0)];
890 let list1 =
891 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
892
893 let postings2: Vec<(DocId, u16, f32)> = vec![(0, 0, 4.0), (3, 1, 5.0), (7, 0, 6.0)];
895 let list2 =
896 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
897
898 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
900
901 assert_eq!(merged.doc_count(), 6);
902
903 let decoded = merged.decode_all();
905 assert_eq!(decoded.len(), 6);
906
907 assert_eq!(decoded[0].0, 0);
909 assert_eq!(decoded[1].0, 5);
910 assert_eq!(decoded[2].0, 10);
911
912 assert_eq!(decoded[3].0, 100); assert_eq!(decoded[4].0, 103); assert_eq!(decoded[5].0, 107); assert!((decoded[0].2 - 1.0).abs() < 0.01);
919 assert!((decoded[3].2 - 4.0).abs() < 0.01);
920
921 assert_eq!(decoded[2].1, 1); assert_eq!(decoded[4].1, 1); }
925
926 #[test]
927 fn test_merge_with_offsets_multi_block() {
928 let postings1: Vec<(DocId, u16, f32)> = (0..200).map(|i| (i * 2, 0, i as f32)).collect();
930 let list1 =
931 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
932 assert!(list1.num_blocks() > 1, "Should have multiple blocks");
933
934 let postings2: Vec<(DocId, u16, f32)> = (0..150).map(|i| (i * 3, 1, i as f32)).collect();
935 let list2 =
936 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
937
938 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 1000)]);
940
941 assert_eq!(merged.doc_count(), 350);
942 assert_eq!(merged.num_blocks(), list1.num_blocks() + list2.num_blocks());
943
944 let mut iter = merged.iterator();
946
947 assert_eq!(iter.doc(), 0);
949
950 let doc = iter.seek(1000);
952 assert_eq!(doc, 1000); iter.advance();
956 assert_eq!(iter.doc(), 1003); }
958
959 #[test]
960 fn test_merge_with_offsets_serialize_roundtrip() {
961 let postings1: Vec<(DocId, u16, f32)> = vec![(0, 0, 1.0), (5, 0, 2.0), (10, 1, 3.0)];
963 let list1 =
964 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
965
966 let postings2: Vec<(DocId, u16, f32)> = vec![(0, 0, 4.0), (3, 1, 5.0), (7, 0, 6.0)];
967 let list2 =
968 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
969
970 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
972
973 let mut bytes = Vec::new();
975 merged.serialize(&mut bytes).unwrap();
976
977 let mut cursor = std::io::Cursor::new(&bytes);
979 let loaded = BlockSparsePostingList::deserialize(&mut cursor).unwrap();
980
981 let decoded = loaded.decode_all();
983 assert_eq!(decoded.len(), 6);
984
985 assert_eq!(decoded[0].0, 0);
987 assert_eq!(decoded[1].0, 5);
988 assert_eq!(decoded[2].0, 10);
989
990 assert_eq!(decoded[3].0, 100, "First doc of seg2 should be 0+100=100");
992 assert_eq!(decoded[4].0, 103, "Second doc of seg2 should be 3+100=103");
993 assert_eq!(decoded[5].0, 107, "Third doc of seg2 should be 7+100=107");
994
995 let mut iter = loaded.iterator();
997 assert_eq!(iter.doc(), 0);
998 iter.advance();
999 assert_eq!(iter.doc(), 5);
1000 iter.advance();
1001 assert_eq!(iter.doc(), 10);
1002 iter.advance();
1003 assert_eq!(iter.doc(), 100);
1004 iter.advance();
1005 assert_eq!(iter.doc(), 103);
1006 iter.advance();
1007 assert_eq!(iter.doc(), 107);
1008 }
1009
1010 #[test]
1011 fn test_merge_seek_after_roundtrip() {
1012 let postings1: Vec<(DocId, u16, f32)> = (0..200).map(|i| (i * 2, 0, 1.0)).collect();
1014 let list1 =
1015 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1016
1017 let postings2: Vec<(DocId, u16, f32)> = (0..150).map(|i| (i * 3, 0, 2.0)).collect();
1018 let list2 =
1019 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1020
1021 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 1000)]);
1023
1024 let mut bytes = Vec::new();
1026 merged.serialize(&mut bytes).unwrap();
1027 let loaded =
1028 BlockSparsePostingList::deserialize(&mut std::io::Cursor::new(&bytes)).unwrap();
1029
1030 let mut iter = loaded.iterator();
1032
1033 let doc = iter.seek(100);
1035 assert_eq!(doc, 100, "Seek to 100 in segment 1");
1036
1037 let doc = iter.seek(1000);
1039 assert_eq!(doc, 1000, "Seek to 1000 (first doc of segment 2)");
1040
1041 let doc = iter.seek(1050);
1043 assert!(
1044 doc >= 1050,
1045 "Seek to 1050 should find doc >= 1050, got {}",
1046 doc
1047 );
1048
1049 let doc = iter.seek(500);
1051 assert!(
1052 doc >= 1050,
1053 "Seek backwards should not go back, got {}",
1054 doc
1055 );
1056
1057 let mut iter2 = loaded.iterator();
1059
1060 let mut count = 0;
1062 let mut prev_doc = 0;
1063 while iter2.doc() != super::TERMINATED {
1064 let current = iter2.doc();
1065 if count > 0 {
1066 assert!(
1067 current > prev_doc,
1068 "Docs should be monotonically increasing: {} vs {}",
1069 prev_doc,
1070 current
1071 );
1072 }
1073 prev_doc = current;
1074 iter2.advance();
1075 count += 1;
1076 }
1077 assert_eq!(count, 350, "Should have 350 total docs");
1078 }
1079
1080 #[test]
1081 fn test_merge_preserves_weights_and_ordinals() {
1082 let postings1: Vec<(DocId, u16, f32)> = vec![(0, 0, 1.5), (5, 1, 2.5), (10, 2, 3.5)];
1084 let list1 =
1085 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1086
1087 let postings2: Vec<(DocId, u16, f32)> = vec![(0, 0, 4.5), (3, 1, 5.5), (7, 3, 6.5)];
1088 let list2 =
1089 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1090
1091 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
1093
1094 let mut bytes = Vec::new();
1096 merged.serialize(&mut bytes).unwrap();
1097 let loaded =
1098 BlockSparsePostingList::deserialize(&mut std::io::Cursor::new(&bytes)).unwrap();
1099
1100 let mut iter = loaded.iterator();
1102
1103 assert_eq!(iter.doc(), 0);
1105 assert!(
1106 (iter.weight() - 1.5).abs() < 0.01,
1107 "Weight should be 1.5, got {}",
1108 iter.weight()
1109 );
1110 assert_eq!(iter.ordinal(), 0);
1111
1112 iter.advance();
1113 assert_eq!(iter.doc(), 5);
1114 assert!(
1115 (iter.weight() - 2.5).abs() < 0.01,
1116 "Weight should be 2.5, got {}",
1117 iter.weight()
1118 );
1119 assert_eq!(iter.ordinal(), 1);
1120
1121 iter.advance();
1122 assert_eq!(iter.doc(), 10);
1123 assert!(
1124 (iter.weight() - 3.5).abs() < 0.01,
1125 "Weight should be 3.5, got {}",
1126 iter.weight()
1127 );
1128 assert_eq!(iter.ordinal(), 2);
1129
1130 iter.advance();
1132 assert_eq!(iter.doc(), 100);
1133 assert!(
1134 (iter.weight() - 4.5).abs() < 0.01,
1135 "Weight should be 4.5, got {}",
1136 iter.weight()
1137 );
1138 assert_eq!(iter.ordinal(), 0);
1139
1140 iter.advance();
1141 assert_eq!(iter.doc(), 103);
1142 assert!(
1143 (iter.weight() - 5.5).abs() < 0.01,
1144 "Weight should be 5.5, got {}",
1145 iter.weight()
1146 );
1147 assert_eq!(iter.ordinal(), 1);
1148
1149 iter.advance();
1150 assert_eq!(iter.doc(), 107);
1151 assert!(
1152 (iter.weight() - 6.5).abs() < 0.01,
1153 "Weight should be 6.5, got {}",
1154 iter.weight()
1155 );
1156 assert_eq!(iter.ordinal(), 3);
1157
1158 iter.advance();
1160 assert_eq!(iter.doc(), super::TERMINATED);
1161 }
1162
1163 #[test]
1164 fn test_merge_global_max_weight() {
1165 let postings1: Vec<(DocId, u16, f32)> = vec![
1167 (0, 0, 3.0),
1168 (1, 0, 7.0), (2, 0, 2.0),
1170 ];
1171 let list1 =
1172 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1173
1174 let postings2: Vec<(DocId, u16, f32)> = vec![
1175 (0, 0, 5.0),
1176 (1, 0, 4.0),
1177 (2, 0, 6.0), ];
1179 let list2 =
1180 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1181
1182 assert!((list1.global_max_weight() - 7.0).abs() < 0.01);
1184 assert!((list2.global_max_weight() - 6.0).abs() < 0.01);
1185
1186 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
1188
1189 assert!(
1191 (merged.global_max_weight() - 7.0).abs() < 0.01,
1192 "Global max should be 7.0, got {}",
1193 merged.global_max_weight()
1194 );
1195
1196 let mut bytes = Vec::new();
1198 merged.serialize(&mut bytes).unwrap();
1199 let loaded =
1200 BlockSparsePostingList::deserialize(&mut std::io::Cursor::new(&bytes)).unwrap();
1201
1202 assert!(
1203 (loaded.global_max_weight() - 7.0).abs() < 0.01,
1204 "After roundtrip, global max should still be 7.0, got {}",
1205 loaded.global_max_weight()
1206 );
1207 }
1208
1209 #[test]
1210 fn test_scoring_simulation_after_merge() {
1211 let postings1: Vec<(DocId, u16, f32)> = vec![
1213 (0, 0, 0.5), (5, 0, 0.8), ];
1216 let list1 =
1217 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1218
1219 let postings2: Vec<(DocId, u16, f32)> = vec![
1220 (0, 0, 0.6), (3, 0, 0.9), ];
1223 let list2 =
1224 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1225
1226 let merged = BlockSparsePostingList::merge_with_offsets(&[(&list1, 0), (&list2, 100)]);
1228
1229 let mut bytes = Vec::new();
1231 merged.serialize(&mut bytes).unwrap();
1232 let loaded =
1233 BlockSparsePostingList::deserialize(&mut std::io::Cursor::new(&bytes)).unwrap();
1234
1235 let query_weight = 2.0f32;
1237 let mut iter = loaded.iterator();
1238
1239 assert_eq!(iter.doc(), 0);
1242 let score = query_weight * iter.weight();
1243 assert!(
1244 (score - 1.0).abs() < 0.01,
1245 "Doc 0 score should be 1.0, got {}",
1246 score
1247 );
1248
1249 iter.advance();
1250 assert_eq!(iter.doc(), 5);
1252 let score = query_weight * iter.weight();
1253 assert!(
1254 (score - 1.6).abs() < 0.01,
1255 "Doc 5 score should be 1.6, got {}",
1256 score
1257 );
1258
1259 iter.advance();
1260 assert_eq!(iter.doc(), 100);
1262 let score = query_weight * iter.weight();
1263 assert!(
1264 (score - 1.2).abs() < 0.01,
1265 "Doc 100 score should be 1.2, got {}",
1266 score
1267 );
1268
1269 iter.advance();
1270 assert_eq!(iter.doc(), 103);
1272 let score = query_weight * iter.weight();
1273 assert!(
1274 (score - 1.8).abs() < 0.01,
1275 "Doc 103 score should be 1.8, got {}",
1276 score
1277 );
1278 }
1279}