1use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
9use std::io::{self, Cursor, Read, Write};
10
11use super::config::WeightQuantization;
12use crate::DocId;
13use crate::structures::postings::TERMINATED;
14use crate::structures::simd;
15
16pub const BLOCK_SIZE: usize = 128;
17
18#[derive(Debug, Clone, Copy)]
19pub struct BlockHeader {
20 pub count: u16,
21 pub doc_id_bits: u8,
22 pub ordinal_bits: u8,
23 pub weight_quant: WeightQuantization,
24 pub first_doc_id: DocId,
25 pub max_weight: f32,
26}
27
28impl BlockHeader {
29 pub const SIZE: usize = 16;
30
31 pub fn write<W: Write>(&self, w: &mut W) -> io::Result<()> {
32 w.write_u16::<LittleEndian>(self.count)?;
33 w.write_u8(self.doc_id_bits)?;
34 w.write_u8(self.ordinal_bits)?;
35 w.write_u8(self.weight_quant as u8)?;
36 w.write_u8(0)?;
37 w.write_u16::<LittleEndian>(0)?;
38 w.write_u32::<LittleEndian>(self.first_doc_id)?;
39 w.write_f32::<LittleEndian>(self.max_weight)?;
40 Ok(())
41 }
42
43 pub fn read<R: Read>(r: &mut R) -> io::Result<Self> {
44 let count = r.read_u16::<LittleEndian>()?;
45 let doc_id_bits = r.read_u8()?;
46 let ordinal_bits = r.read_u8()?;
47 let weight_quant_byte = r.read_u8()?;
48 let _ = r.read_u8()?;
49 let _ = r.read_u16::<LittleEndian>()?;
50 let first_doc_id = r.read_u32::<LittleEndian>()?;
51 let max_weight = r.read_f32::<LittleEndian>()?;
52
53 let weight_quant = WeightQuantization::from_u8(weight_quant_byte)
54 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Invalid weight quant"))?;
55
56 Ok(Self {
57 count,
58 doc_id_bits,
59 ordinal_bits,
60 weight_quant,
61 first_doc_id,
62 max_weight,
63 })
64 }
65}
66
67#[derive(Debug, Clone)]
68pub struct SparseBlock {
69 pub header: BlockHeader,
70 pub doc_ids_data: Vec<u8>,
71 pub ordinals_data: Vec<u8>,
72 pub weights_data: Vec<u8>,
73}
74
75impl SparseBlock {
76 pub fn from_postings(
77 postings: &[(DocId, u16, f32)],
78 weight_quant: WeightQuantization,
79 ) -> io::Result<Self> {
80 assert!(!postings.is_empty() && postings.len() <= BLOCK_SIZE);
81
82 let count = postings.len();
83 let first_doc_id = postings[0].0;
84
85 let mut deltas = Vec::with_capacity(count);
87 let mut prev = first_doc_id;
88 for &(doc_id, _, _) in postings {
89 deltas.push(doc_id.saturating_sub(prev));
90 prev = doc_id;
91 }
92 deltas[0] = 0;
93
94 let doc_id_bits = find_optimal_bit_width(&deltas[1..]);
95 let ordinals: Vec<u16> = postings.iter().map(|(_, o, _)| *o).collect();
96 let max_ordinal = ordinals.iter().copied().max().unwrap_or(0);
97 let ordinal_bits = if max_ordinal == 0 {
98 0
99 } else {
100 bits_needed_u16(max_ordinal)
101 };
102
103 let weights: Vec<f32> = postings.iter().map(|(_, _, w)| *w).collect();
104 let max_weight = weights.iter().copied().fold(0.0f32, f32::max);
105
106 let doc_ids_data = pack_bit_array(&deltas[1..], doc_id_bits);
107 let ordinals_data = if ordinal_bits > 0 {
108 pack_bit_array_u16(&ordinals, ordinal_bits)
109 } else {
110 Vec::new()
111 };
112 let weights_data = encode_weights(&weights, weight_quant)?;
113
114 Ok(Self {
115 header: BlockHeader {
116 count: count as u16,
117 doc_id_bits,
118 ordinal_bits,
119 weight_quant,
120 first_doc_id,
121 max_weight,
122 },
123 doc_ids_data,
124 ordinals_data,
125 weights_data,
126 })
127 }
128
129 pub fn decode_doc_ids(&self) -> Vec<DocId> {
130 let count = self.header.count as usize;
131 let mut doc_ids = Vec::with_capacity(count);
132 doc_ids.push(self.header.first_doc_id);
133
134 if count > 1 {
135 let deltas = unpack_bit_array(&self.doc_ids_data, self.header.doc_id_bits, count - 1);
136 let mut prev = self.header.first_doc_id;
137 for delta in deltas {
138 prev += delta;
139 doc_ids.push(prev);
140 }
141 }
142 doc_ids
143 }
144
145 pub fn decode_ordinals(&self) -> Vec<u16> {
146 let count = self.header.count as usize;
147 if self.header.ordinal_bits == 0 {
148 vec![0u16; count]
149 } else {
150 unpack_bit_array_u16(&self.ordinals_data, self.header.ordinal_bits, count)
151 }
152 }
153
154 pub fn decode_weights(&self) -> Vec<f32> {
155 decode_weights(
156 &self.weights_data,
157 self.header.weight_quant,
158 self.header.count as usize,
159 )
160 }
161
162 pub fn write<W: Write>(&self, w: &mut W) -> io::Result<()> {
163 self.header.write(w)?;
164 w.write_u16::<LittleEndian>(self.doc_ids_data.len() as u16)?;
165 w.write_u16::<LittleEndian>(self.ordinals_data.len() as u16)?;
166 w.write_u16::<LittleEndian>(self.weights_data.len() as u16)?;
167 w.write_u16::<LittleEndian>(0)?;
168 w.write_all(&self.doc_ids_data)?;
169 w.write_all(&self.ordinals_data)?;
170 w.write_all(&self.weights_data)?;
171 Ok(())
172 }
173
174 pub fn read<R: Read>(r: &mut R) -> io::Result<Self> {
175 let header = BlockHeader::read(r)?;
176 let doc_ids_len = r.read_u16::<LittleEndian>()? as usize;
177 let ordinals_len = r.read_u16::<LittleEndian>()? as usize;
178 let weights_len = r.read_u16::<LittleEndian>()? as usize;
179 let _ = r.read_u16::<LittleEndian>()?;
180
181 let mut doc_ids_data = vec![0u8; doc_ids_len];
182 r.read_exact(&mut doc_ids_data)?;
183 let mut ordinals_data = vec![0u8; ordinals_len];
184 r.read_exact(&mut ordinals_data)?;
185 let mut weights_data = vec![0u8; weights_len];
186 r.read_exact(&mut weights_data)?;
187
188 Ok(Self {
189 header,
190 doc_ids_data,
191 ordinals_data,
192 weights_data,
193 })
194 }
195}
196
197#[derive(Debug, Clone)]
202pub struct BlockSparsePostingList {
203 pub doc_count: u32,
204 pub blocks: Vec<SparseBlock>,
205}
206
207impl BlockSparsePostingList {
208 pub fn from_postings_with_block_size(
210 postings: &[(DocId, u16, f32)],
211 weight_quant: WeightQuantization,
212 block_size: usize,
213 ) -> io::Result<Self> {
214 if postings.is_empty() {
215 return Ok(Self {
216 doc_count: 0,
217 blocks: Vec::new(),
218 });
219 }
220
221 let block_size = block_size.max(16); let mut blocks = Vec::new();
223 for chunk in postings.chunks(block_size) {
224 blocks.push(SparseBlock::from_postings(chunk, weight_quant)?);
225 }
226
227 Ok(Self {
228 doc_count: postings.len() as u32,
229 blocks,
230 })
231 }
232
233 pub fn from_postings(
235 postings: &[(DocId, u16, f32)],
236 weight_quant: WeightQuantization,
237 ) -> io::Result<Self> {
238 Self::from_postings_with_block_size(postings, weight_quant, BLOCK_SIZE)
239 }
240
241 pub fn doc_count(&self) -> u32 {
242 self.doc_count
243 }
244
245 pub fn num_blocks(&self) -> usize {
246 self.blocks.len()
247 }
248
249 pub fn global_max_weight(&self) -> f32 {
250 self.blocks
251 .iter()
252 .map(|b| b.header.max_weight)
253 .fold(0.0f32, f32::max)
254 }
255
256 pub fn block_max_weight(&self, block_idx: usize) -> Option<f32> {
257 self.blocks.get(block_idx).map(|b| b.header.max_weight)
258 }
259
260 pub fn size_bytes(&self) -> usize {
262 use std::mem::size_of;
263
264 let header_size = size_of::<u32>() * 2; let blocks_size: usize = self
266 .blocks
267 .iter()
268 .map(|b| {
269 size_of::<BlockHeader>()
270 + b.doc_ids_data.len()
271 + b.ordinals_data.len()
272 + b.weights_data.len()
273 })
274 .sum();
275 header_size + blocks_size
276 }
277
278 pub fn iterator(&self) -> BlockSparsePostingIterator<'_> {
279 BlockSparsePostingIterator::new(self)
280 }
281
282 pub fn serialize<W: Write>(&self, w: &mut W) -> io::Result<()> {
283 w.write_u32::<LittleEndian>(self.doc_count)?;
284 w.write_u32::<LittleEndian>(self.blocks.len() as u32)?;
285 for block in &self.blocks {
286 block.write(w)?;
287 }
288 Ok(())
289 }
290
291 pub fn deserialize<R: Read>(r: &mut R) -> io::Result<Self> {
292 let doc_count = r.read_u32::<LittleEndian>()?;
293 let num_blocks = r.read_u32::<LittleEndian>()? as usize;
294 let mut blocks = Vec::with_capacity(num_blocks);
295 for _ in 0..num_blocks {
296 blocks.push(SparseBlock::read(r)?);
297 }
298 Ok(Self { doc_count, blocks })
299 }
300
301 pub fn decode_all(&self) -> Vec<(DocId, u16, f32)> {
302 let mut result = Vec::with_capacity(self.doc_count as usize);
303 for block in &self.blocks {
304 let doc_ids = block.decode_doc_ids();
305 let ordinals = block.decode_ordinals();
306 let weights = block.decode_weights();
307 for i in 0..block.header.count as usize {
308 result.push((doc_ids[i], ordinals[i], weights[i]));
309 }
310 }
311 result
312 }
313
314 fn find_block(&self, target: DocId) -> Option<usize> {
315 let mut lo = 0;
316 let mut hi = self.blocks.len();
317 while lo < hi {
318 let mid = lo + (hi - lo) / 2;
319 let block = &self.blocks[mid];
320 let doc_ids = block.decode_doc_ids();
321 let last_doc = doc_ids.last().copied().unwrap_or(block.header.first_doc_id);
322 if last_doc < target {
323 lo = mid + 1;
324 } else {
325 hi = mid;
326 }
327 }
328 if lo < self.blocks.len() {
329 Some(lo)
330 } else {
331 None
332 }
333 }
334}
335
336pub struct BlockSparsePostingIterator<'a> {
341 posting_list: &'a BlockSparsePostingList,
342 block_idx: usize,
343 in_block_idx: usize,
344 current_doc_ids: Vec<DocId>,
345 current_weights: Vec<f32>,
346 exhausted: bool,
347}
348
349impl<'a> BlockSparsePostingIterator<'a> {
350 fn new(posting_list: &'a BlockSparsePostingList) -> Self {
351 let mut iter = Self {
352 posting_list,
353 block_idx: 0,
354 in_block_idx: 0,
355 current_doc_ids: Vec::new(),
356 current_weights: Vec::new(),
357 exhausted: posting_list.blocks.is_empty(),
358 };
359 if !iter.exhausted {
360 iter.load_block(0);
361 }
362 iter
363 }
364
365 fn load_block(&mut self, block_idx: usize) {
366 if let Some(block) = self.posting_list.blocks.get(block_idx) {
367 self.current_doc_ids = block.decode_doc_ids();
368 self.current_weights = block.decode_weights();
369 self.block_idx = block_idx;
370 self.in_block_idx = 0;
371 }
372 }
373
374 pub fn doc(&self) -> DocId {
375 if self.exhausted {
376 TERMINATED
377 } else {
378 self.current_doc_ids
379 .get(self.in_block_idx)
380 .copied()
381 .unwrap_or(TERMINATED)
382 }
383 }
384
385 pub fn weight(&self) -> f32 {
386 self.current_weights
387 .get(self.in_block_idx)
388 .copied()
389 .unwrap_or(0.0)
390 }
391
392 pub fn ordinal(&self) -> u16 {
393 if let Some(block) = self.posting_list.blocks.get(self.block_idx) {
394 let ordinals = block.decode_ordinals();
395 ordinals.get(self.in_block_idx).copied().unwrap_or(0)
396 } else {
397 0
398 }
399 }
400
401 pub fn advance(&mut self) -> DocId {
402 if self.exhausted {
403 return TERMINATED;
404 }
405 self.in_block_idx += 1;
406 if self.in_block_idx >= self.current_doc_ids.len() {
407 self.block_idx += 1;
408 if self.block_idx >= self.posting_list.blocks.len() {
409 self.exhausted = true;
410 } else {
411 self.load_block(self.block_idx);
412 }
413 }
414 self.doc()
415 }
416
417 pub fn seek(&mut self, target: DocId) -> DocId {
418 if self.exhausted {
419 return TERMINATED;
420 }
421 if self.doc() >= target {
422 return self.doc();
423 }
424
425 if let Some(&last_doc) = self.current_doc_ids.last()
427 && last_doc >= target
428 {
429 while !self.exhausted && self.doc() < target {
430 self.in_block_idx += 1;
431 if self.in_block_idx >= self.current_doc_ids.len() {
432 self.block_idx += 1;
433 if self.block_idx >= self.posting_list.blocks.len() {
434 self.exhausted = true;
435 } else {
436 self.load_block(self.block_idx);
437 }
438 }
439 }
440 return self.doc();
441 }
442
443 if let Some(block_idx) = self.posting_list.find_block(target) {
445 self.load_block(block_idx);
446 while self.in_block_idx < self.current_doc_ids.len()
447 && self.current_doc_ids[self.in_block_idx] < target
448 {
449 self.in_block_idx += 1;
450 }
451 if self.in_block_idx >= self.current_doc_ids.len() {
452 self.block_idx += 1;
453 if self.block_idx >= self.posting_list.blocks.len() {
454 self.exhausted = true;
455 } else {
456 self.load_block(self.block_idx);
457 }
458 }
459 } else {
460 self.exhausted = true;
461 }
462 self.doc()
463 }
464
465 pub fn is_exhausted(&self) -> bool {
466 self.exhausted
467 }
468
469 pub fn current_block_max_weight(&self) -> f32 {
470 self.posting_list
471 .blocks
472 .get(self.block_idx)
473 .map(|b| b.header.max_weight)
474 .unwrap_or(0.0)
475 }
476
477 pub fn current_block_max_contribution(&self, query_weight: f32) -> f32 {
478 query_weight * self.current_block_max_weight()
479 }
480}
481
482fn find_optimal_bit_width(values: &[u32]) -> u8 {
487 if values.is_empty() {
488 return 0;
489 }
490 let max_val = values.iter().copied().max().unwrap_or(0);
491 simd::bits_needed(max_val)
492}
493
494fn bits_needed_u16(val: u16) -> u8 {
495 if val == 0 {
496 0
497 } else {
498 16 - val.leading_zeros() as u8
499 }
500}
501
502fn pack_bit_array(values: &[u32], bits: u8) -> Vec<u8> {
503 if bits == 0 || values.is_empty() {
504 return Vec::new();
505 }
506 let total_bytes = (values.len() * bits as usize).div_ceil(8);
507 let mut result = vec![0u8; total_bytes];
508 let mut bit_pos = 0usize;
509 for &val in values {
510 pack_value(&mut result, bit_pos, val & ((1u32 << bits) - 1), bits);
511 bit_pos += bits as usize;
512 }
513 result
514}
515
516fn pack_bit_array_u16(values: &[u16], bits: u8) -> Vec<u8> {
517 if bits == 0 || values.is_empty() {
518 return Vec::new();
519 }
520 let total_bytes = (values.len() * bits as usize).div_ceil(8);
521 let mut result = vec![0u8; total_bytes];
522 let mut bit_pos = 0usize;
523 for &val in values {
524 pack_value(
525 &mut result,
526 bit_pos,
527 (val as u32) & ((1u32 << bits) - 1),
528 bits,
529 );
530 bit_pos += bits as usize;
531 }
532 result
533}
534
535#[inline]
536fn pack_value(data: &mut [u8], bit_pos: usize, val: u32, bits: u8) {
537 let mut remaining = bits as usize;
538 let mut val = val;
539 let mut byte = bit_pos / 8;
540 let mut offset = bit_pos % 8;
541 while remaining > 0 {
542 let space = 8 - offset;
543 let to_write = remaining.min(space);
544 let mask = (1u32 << to_write) - 1;
545 data[byte] |= ((val & mask) as u8) << offset;
546 val >>= to_write;
547 remaining -= to_write;
548 byte += 1;
549 offset = 0;
550 }
551}
552
553fn unpack_bit_array(data: &[u8], bits: u8, count: usize) -> Vec<u32> {
554 if bits == 0 || count == 0 {
555 return vec![0; count];
556 }
557 let mut result = Vec::with_capacity(count);
558 let mut bit_pos = 0usize;
559 for _ in 0..count {
560 result.push(unpack_value(data, bit_pos, bits));
561 bit_pos += bits as usize;
562 }
563 result
564}
565
566fn unpack_bit_array_u16(data: &[u8], bits: u8, count: usize) -> Vec<u16> {
567 if bits == 0 || count == 0 {
568 return vec![0; count];
569 }
570 let mut result = Vec::with_capacity(count);
571 let mut bit_pos = 0usize;
572 for _ in 0..count {
573 result.push(unpack_value(data, bit_pos, bits) as u16);
574 bit_pos += bits as usize;
575 }
576 result
577}
578
579#[inline]
580fn unpack_value(data: &[u8], bit_pos: usize, bits: u8) -> u32 {
581 let mut val = 0u32;
582 let mut remaining = bits as usize;
583 let mut byte = bit_pos / 8;
584 let mut offset = bit_pos % 8;
585 let mut shift = 0;
586 while remaining > 0 {
587 let space = 8 - offset;
588 let to_read = remaining.min(space);
589 let mask = (1u8 << to_read) - 1;
590 val |= (((data.get(byte).copied().unwrap_or(0) >> offset) & mask) as u32) << shift;
591 remaining -= to_read;
592 shift += to_read;
593 byte += 1;
594 offset = 0;
595 }
596 val
597}
598
599fn encode_weights(weights: &[f32], quant: WeightQuantization) -> io::Result<Vec<u8>> {
604 let mut data = Vec::new();
605 match quant {
606 WeightQuantization::Float32 => {
607 for &w in weights {
608 data.write_f32::<LittleEndian>(w)?;
609 }
610 }
611 WeightQuantization::Float16 => {
612 use half::f16;
613 for &w in weights {
614 data.write_u16::<LittleEndian>(f16::from_f32(w).to_bits())?;
615 }
616 }
617 WeightQuantization::UInt8 => {
618 let min = weights.iter().copied().fold(f32::INFINITY, f32::min);
619 let max = weights.iter().copied().fold(f32::NEG_INFINITY, f32::max);
620 let range = max - min;
621 let scale = if range < f32::EPSILON {
622 1.0
623 } else {
624 range / 255.0
625 };
626 data.write_f32::<LittleEndian>(scale)?;
627 data.write_f32::<LittleEndian>(min)?;
628 for &w in weights {
629 data.write_u8(((w - min) / scale).round() as u8)?;
630 }
631 }
632 WeightQuantization::UInt4 => {
633 let min = weights.iter().copied().fold(f32::INFINITY, f32::min);
634 let max = weights.iter().copied().fold(f32::NEG_INFINITY, f32::max);
635 let range = max - min;
636 let scale = if range < f32::EPSILON {
637 1.0
638 } else {
639 range / 15.0
640 };
641 data.write_f32::<LittleEndian>(scale)?;
642 data.write_f32::<LittleEndian>(min)?;
643 let mut i = 0;
644 while i < weights.len() {
645 let q1 = ((weights[i] - min) / scale).round() as u8 & 0x0F;
646 let q2 = if i + 1 < weights.len() {
647 ((weights[i + 1] - min) / scale).round() as u8 & 0x0F
648 } else {
649 0
650 };
651 data.write_u8((q2 << 4) | q1)?;
652 i += 2;
653 }
654 }
655 }
656 Ok(data)
657}
658
659fn decode_weights(data: &[u8], quant: WeightQuantization, count: usize) -> Vec<f32> {
660 let mut cursor = Cursor::new(data);
661 let mut weights = Vec::with_capacity(count);
662 match quant {
663 WeightQuantization::Float32 => {
664 for _ in 0..count {
665 weights.push(cursor.read_f32::<LittleEndian>().unwrap_or(0.0));
666 }
667 }
668 WeightQuantization::Float16 => {
669 use half::f16;
670 for _ in 0..count {
671 let bits = cursor.read_u16::<LittleEndian>().unwrap_or(0);
672 weights.push(f16::from_bits(bits).to_f32());
673 }
674 }
675 WeightQuantization::UInt8 => {
676 let scale = cursor.read_f32::<LittleEndian>().unwrap_or(1.0);
677 let min = cursor.read_f32::<LittleEndian>().unwrap_or(0.0);
678 for _ in 0..count {
679 let q = cursor.read_u8().unwrap_or(0);
680 weights.push(q as f32 * scale + min);
681 }
682 }
683 WeightQuantization::UInt4 => {
684 let scale = cursor.read_f32::<LittleEndian>().unwrap_or(1.0);
685 let min = cursor.read_f32::<LittleEndian>().unwrap_or(0.0);
686 let mut i = 0;
687 while i < count {
688 let byte = cursor.read_u8().unwrap_or(0);
689 weights.push((byte & 0x0F) as f32 * scale + min);
690 i += 1;
691 if i < count {
692 weights.push((byte >> 4) as f32 * scale + min);
693 i += 1;
694 }
695 }
696 }
697 }
698 weights
699}
700
701#[cfg(test)]
702mod tests {
703 use super::*;
704
705 #[test]
706 fn test_block_roundtrip() {
707 let postings = vec![
708 (10u32, 0u16, 1.5f32),
709 (15, 0, 2.0),
710 (20, 1, 0.5),
711 (100, 0, 3.0),
712 ];
713 let block = SparseBlock::from_postings(&postings, WeightQuantization::Float32).unwrap();
714
715 assert_eq!(block.decode_doc_ids(), vec![10, 15, 20, 100]);
716 assert_eq!(block.decode_ordinals(), vec![0, 0, 1, 0]);
717 let weights = block.decode_weights();
718 assert!((weights[0] - 1.5).abs() < 0.01);
719 }
720
721 #[test]
722 fn test_posting_list() {
723 let postings: Vec<(DocId, u16, f32)> =
724 (0..300).map(|i| (i * 2, 0, i as f32 * 0.1)).collect();
725 let list =
726 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
727
728 assert_eq!(list.doc_count(), 300);
729 assert_eq!(list.num_blocks(), 3);
730
731 let mut iter = list.iterator();
732 assert_eq!(iter.doc(), 0);
733 iter.advance();
734 assert_eq!(iter.doc(), 2);
735 }
736
737 #[test]
738 fn test_serialization() {
739 let postings = vec![(1u32, 0u16, 0.5f32), (10, 1, 1.5), (100, 0, 2.5)];
740 let list =
741 BlockSparsePostingList::from_postings(&postings, WeightQuantization::UInt8).unwrap();
742
743 let mut buf = Vec::new();
744 list.serialize(&mut buf).unwrap();
745 let list2 = BlockSparsePostingList::deserialize(&mut Cursor::new(&buf)).unwrap();
746
747 assert_eq!(list.doc_count(), list2.doc_count());
748 }
749
750 #[test]
751 fn test_seek() {
752 let postings: Vec<(DocId, u16, f32)> = (0..500).map(|i| (i * 3, 0, i as f32)).collect();
753 let list =
754 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
755
756 let mut iter = list.iterator();
757 assert_eq!(iter.seek(300), 300);
758 assert_eq!(iter.seek(301), 303);
759 assert_eq!(iter.seek(2000), TERMINATED);
760 }
761}