1use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
20use serde::{Deserialize, Serialize};
21use std::io::{self, Read, Write};
22
23use super::posting_common::{
24 RoundedBitWidth, pack_deltas_fixed, read_vint, unpack_deltas_fixed, write_vint,
25};
26use crate::DocId;
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
30#[repr(u8)]
31pub enum IndexSize {
32 U16 = 0,
34 #[default]
36 U32 = 1,
37}
38
39impl IndexSize {
40 pub fn bytes(&self) -> usize {
42 match self {
43 IndexSize::U16 => 2,
44 IndexSize::U32 => 4,
45 }
46 }
47
48 pub fn max_value(&self) -> u32 {
50 match self {
51 IndexSize::U16 => u16::MAX as u32,
52 IndexSize::U32 => u32::MAX,
53 }
54 }
55
56 fn from_u8(v: u8) -> Option<Self> {
57 match v {
58 0 => Some(IndexSize::U16),
59 1 => Some(IndexSize::U32),
60 _ => None,
61 }
62 }
63}
64
65#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
67#[repr(u8)]
68pub enum WeightQuantization {
69 #[default]
71 Float32 = 0,
72 Float16 = 1,
74 UInt8 = 2,
76 UInt4 = 3,
78}
79
80impl WeightQuantization {
81 pub fn bytes_per_weight(&self) -> f32 {
83 match self {
84 WeightQuantization::Float32 => 4.0,
85 WeightQuantization::Float16 => 2.0,
86 WeightQuantization::UInt8 => 1.0,
87 WeightQuantization::UInt4 => 0.5,
88 }
89 }
90
91 fn from_u8(v: u8) -> Option<Self> {
92 match v {
93 0 => Some(WeightQuantization::Float32),
94 1 => Some(WeightQuantization::Float16),
95 2 => Some(WeightQuantization::UInt8),
96 3 => Some(WeightQuantization::UInt4),
97 _ => None,
98 }
99 }
100}
101
102#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
104pub struct SparseVectorConfig {
105 pub index_size: IndexSize,
107 pub weight_quantization: WeightQuantization,
109 #[serde(default)]
112 pub weight_threshold: f32,
113}
114
115impl Default for SparseVectorConfig {
116 fn default() -> Self {
117 Self {
118 index_size: IndexSize::U32,
119 weight_quantization: WeightQuantization::Float32,
120 weight_threshold: 0.0,
121 }
122 }
123}
124
125impl SparseVectorConfig {
126 pub fn splade() -> Self {
128 Self {
129 index_size: IndexSize::U16,
130 weight_quantization: WeightQuantization::UInt8,
131 weight_threshold: 0.0,
132 }
133 }
134
135 pub fn compact() -> Self {
137 Self {
138 index_size: IndexSize::U16,
139 weight_quantization: WeightQuantization::UInt4,
140 weight_threshold: 0.0,
141 }
142 }
143
144 pub fn full_precision() -> Self {
146 Self {
147 index_size: IndexSize::U32,
148 weight_quantization: WeightQuantization::Float32,
149 weight_threshold: 0.0,
150 }
151 }
152
153 pub fn with_weight_threshold(mut self, threshold: f32) -> Self {
155 self.weight_threshold = threshold;
156 self
157 }
158
159 pub fn bytes_per_entry(&self) -> f32 {
161 self.index_size.bytes() as f32 + self.weight_quantization.bytes_per_weight()
162 }
163
164 pub fn to_byte(&self) -> u8 {
166 ((self.index_size as u8) << 4) | (self.weight_quantization as u8)
167 }
168
169 pub fn from_byte(b: u8) -> Option<Self> {
172 let index_size = IndexSize::from_u8(b >> 4)?;
173 let weight_quantization = WeightQuantization::from_u8(b & 0x0F)?;
174 Some(Self {
175 index_size,
176 weight_quantization,
177 weight_threshold: 0.0,
178 })
179 }
180}
181
182#[derive(Debug, Clone, Copy, PartialEq)]
184pub struct SparseEntry {
185 pub dim_id: u32,
186 pub weight: f32,
187}
188
189#[derive(Debug, Clone, Default)]
191pub struct SparseVector {
192 entries: Vec<SparseEntry>,
193}
194
195impl SparseVector {
196 pub fn new() -> Self {
197 Self::default()
198 }
199
200 pub fn with_capacity(capacity: usize) -> Self {
201 Self {
202 entries: Vec::with_capacity(capacity),
203 }
204 }
205
206 pub fn from_entries(dim_ids: &[u32], weights: &[f32]) -> Self {
208 assert_eq!(dim_ids.len(), weights.len());
209 let mut entries: Vec<SparseEntry> = dim_ids
210 .iter()
211 .zip(weights.iter())
212 .map(|(&dim_id, &weight)| SparseEntry { dim_id, weight })
213 .collect();
214 entries.sort_by_key(|e| e.dim_id);
216 Self { entries }
217 }
218
219 pub fn push(&mut self, dim_id: u32, weight: f32) {
221 debug_assert!(
222 self.entries.is_empty() || self.entries.last().unwrap().dim_id < dim_id,
223 "Entries must be added in sorted order by dim_id"
224 );
225 self.entries.push(SparseEntry { dim_id, weight });
226 }
227
228 pub fn len(&self) -> usize {
230 self.entries.len()
231 }
232
233 pub fn is_empty(&self) -> bool {
234 self.entries.is_empty()
235 }
236
237 pub fn iter(&self) -> impl Iterator<Item = &SparseEntry> {
238 self.entries.iter()
239 }
240
241 pub fn dot(&self, other: &SparseVector) -> f32 {
243 let mut result = 0.0f32;
244 let mut i = 0;
245 let mut j = 0;
246
247 while i < self.entries.len() && j < other.entries.len() {
248 let a = &self.entries[i];
249 let b = &other.entries[j];
250
251 match a.dim_id.cmp(&b.dim_id) {
252 std::cmp::Ordering::Less => i += 1,
253 std::cmp::Ordering::Greater => j += 1,
254 std::cmp::Ordering::Equal => {
255 result += a.weight * b.weight;
256 i += 1;
257 j += 1;
258 }
259 }
260 }
261
262 result
263 }
264
265 pub fn norm_squared(&self) -> f32 {
267 self.entries.iter().map(|e| e.weight * e.weight).sum()
268 }
269
270 pub fn norm(&self) -> f32 {
272 self.norm_squared().sqrt()
273 }
274}
275
276#[derive(Debug, Clone, Copy)]
278pub struct SparsePosting {
279 pub doc_id: DocId,
280 pub weight: f32,
281}
282
283pub const SPARSE_BLOCK_SIZE: usize = 128;
285
286#[derive(Debug, Clone, Copy, PartialEq)]
290pub struct SparseSkipEntry {
291 pub first_doc: DocId,
293 pub last_doc: DocId,
295 pub offset: u32,
297 pub max_weight: f32,
299}
300
301impl SparseSkipEntry {
302 pub fn new(first_doc: DocId, last_doc: DocId, offset: u32, max_weight: f32) -> Self {
303 Self {
304 first_doc,
305 last_doc,
306 offset,
307 max_weight,
308 }
309 }
310
311 #[inline]
316 pub fn block_max_contribution(&self, query_weight: f32) -> f32 {
317 query_weight * self.max_weight
318 }
319
320 pub fn write<W: Write>(&self, writer: &mut W) -> io::Result<()> {
322 writer.write_u32::<LittleEndian>(self.first_doc)?;
323 writer.write_u32::<LittleEndian>(self.last_doc)?;
324 writer.write_u32::<LittleEndian>(self.offset)?;
325 writer.write_f32::<LittleEndian>(self.max_weight)?;
326 Ok(())
327 }
328
329 pub fn read<R: Read>(reader: &mut R) -> io::Result<Self> {
331 let first_doc = reader.read_u32::<LittleEndian>()?;
332 let last_doc = reader.read_u32::<LittleEndian>()?;
333 let offset = reader.read_u32::<LittleEndian>()?;
334 let max_weight = reader.read_f32::<LittleEndian>()?;
335 Ok(Self {
336 first_doc,
337 last_doc,
338 offset,
339 max_weight,
340 })
341 }
342}
343
344#[derive(Debug, Clone, Default)]
346pub struct SparseSkipList {
347 entries: Vec<SparseSkipEntry>,
348 global_max_weight: f32,
350}
351
352impl SparseSkipList {
353 pub fn new() -> Self {
354 Self::default()
355 }
356
357 pub fn push(&mut self, first_doc: DocId, last_doc: DocId, offset: u32, max_weight: f32) {
359 self.global_max_weight = self.global_max_weight.max(max_weight);
360 self.entries.push(SparseSkipEntry::new(
361 first_doc, last_doc, offset, max_weight,
362 ));
363 }
364
365 pub fn len(&self) -> usize {
367 self.entries.len()
368 }
369
370 pub fn is_empty(&self) -> bool {
371 self.entries.is_empty()
372 }
373
374 pub fn get(&self, index: usize) -> Option<&SparseSkipEntry> {
376 self.entries.get(index)
377 }
378
379 pub fn global_max_weight(&self) -> f32 {
381 self.global_max_weight
382 }
383
384 pub fn find_block(&self, target: DocId) -> Option<usize> {
386 self.entries.iter().position(|e| e.last_doc >= target)
387 }
388
389 pub fn iter(&self) -> impl Iterator<Item = &SparseSkipEntry> {
391 self.entries.iter()
392 }
393
394 pub fn write<W: Write>(&self, writer: &mut W) -> io::Result<()> {
396 writer.write_u32::<LittleEndian>(self.entries.len() as u32)?;
397 writer.write_f32::<LittleEndian>(self.global_max_weight)?;
398 for entry in &self.entries {
399 entry.write(writer)?;
400 }
401 Ok(())
402 }
403
404 pub fn read<R: Read>(reader: &mut R) -> io::Result<Self> {
406 let count = reader.read_u32::<LittleEndian>()? as usize;
407 let global_max_weight = reader.read_f32::<LittleEndian>()?;
408 let mut entries = Vec::with_capacity(count);
409 for _ in 0..count {
410 entries.push(SparseSkipEntry::read(reader)?);
411 }
412 Ok(Self {
413 entries,
414 global_max_weight,
415 })
416 }
417}
418
419#[derive(Debug, Clone)]
425pub struct SparsePostingList {
426 quantization: WeightQuantization,
428 scale: f32,
430 min_val: f32,
432 doc_count: u32,
434 data: Vec<u8>,
436}
437
438impl SparsePostingList {
439 pub fn from_postings(
441 postings: &[(DocId, f32)],
442 quantization: WeightQuantization,
443 ) -> io::Result<Self> {
444 if postings.is_empty() {
445 return Ok(Self {
446 quantization,
447 scale: 1.0,
448 min_val: 0.0,
449 doc_count: 0,
450 data: Vec::new(),
451 });
452 }
453
454 let weights: Vec<f32> = postings.iter().map(|(_, w)| *w).collect();
456 let min_val = weights.iter().cloned().fold(f32::INFINITY, f32::min);
457 let max_val = weights.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
458
459 let (scale, adjusted_min) = match quantization {
460 WeightQuantization::Float32 | WeightQuantization::Float16 => (1.0, 0.0),
461 WeightQuantization::UInt8 => {
462 let range = max_val - min_val;
463 if range < f32::EPSILON {
464 (1.0, min_val)
465 } else {
466 (range / 255.0, min_val)
467 }
468 }
469 WeightQuantization::UInt4 => {
470 let range = max_val - min_val;
471 if range < f32::EPSILON {
472 (1.0, min_val)
473 } else {
474 (range / 15.0, min_val)
475 }
476 }
477 };
478
479 let mut data = Vec::new();
480
481 let mut prev_doc_id = 0u32;
483 for (doc_id, _) in postings {
484 let delta = doc_id - prev_doc_id;
485 write_vint(&mut data, delta as u64)?;
486 prev_doc_id = *doc_id;
487 }
488
489 match quantization {
491 WeightQuantization::Float32 => {
492 for (_, weight) in postings {
493 data.write_f32::<LittleEndian>(*weight)?;
494 }
495 }
496 WeightQuantization::Float16 => {
497 use half::slice::HalfFloatSliceExt;
499 let weights: Vec<f32> = postings.iter().map(|(_, w)| *w).collect();
500 let mut f16_slice: Vec<half::f16> = vec![half::f16::ZERO; weights.len()];
501 f16_slice.convert_from_f32_slice(&weights);
502 for h in f16_slice {
503 data.write_u16::<LittleEndian>(h.to_bits())?;
504 }
505 }
506 WeightQuantization::UInt8 => {
507 for (_, weight) in postings {
508 let quantized = ((*weight - adjusted_min) / scale).round() as u8;
509 data.write_u8(quantized)?;
510 }
511 }
512 WeightQuantization::UInt4 => {
513 let mut i = 0;
515 while i < postings.len() {
516 let q1 = ((postings[i].1 - adjusted_min) / scale).round() as u8 & 0x0F;
517 let q2 = if i + 1 < postings.len() {
518 ((postings[i + 1].1 - adjusted_min) / scale).round() as u8 & 0x0F
519 } else {
520 0
521 };
522 data.write_u8((q2 << 4) | q1)?;
523 i += 2;
524 }
525 }
526 }
527
528 Ok(Self {
529 quantization,
530 scale,
531 min_val: adjusted_min,
532 doc_count: postings.len() as u32,
533 data,
534 })
535 }
536
537 pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
539 writer.write_u8(self.quantization as u8)?;
540 writer.write_f32::<LittleEndian>(self.scale)?;
541 writer.write_f32::<LittleEndian>(self.min_val)?;
542 writer.write_u32::<LittleEndian>(self.doc_count)?;
543 writer.write_u32::<LittleEndian>(self.data.len() as u32)?;
544 writer.write_all(&self.data)?;
545 Ok(())
546 }
547
548 pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
550 let quant_byte = reader.read_u8()?;
551 let quantization = WeightQuantization::from_u8(quant_byte).ok_or_else(|| {
552 io::Error::new(io::ErrorKind::InvalidData, "Invalid quantization type")
553 })?;
554 let scale = reader.read_f32::<LittleEndian>()?;
555 let min_val = reader.read_f32::<LittleEndian>()?;
556 let doc_count = reader.read_u32::<LittleEndian>()?;
557 let data_len = reader.read_u32::<LittleEndian>()? as usize;
558 let mut data = vec![0u8; data_len];
559 reader.read_exact(&mut data)?;
560
561 Ok(Self {
562 quantization,
563 scale,
564 min_val,
565 doc_count,
566 data,
567 })
568 }
569
570 pub fn doc_count(&self) -> u32 {
572 self.doc_count
573 }
574
575 pub fn quantization(&self) -> WeightQuantization {
577 self.quantization
578 }
579
580 pub fn iterator(&self) -> SparsePostingIterator<'_> {
582 SparsePostingIterator::new(self)
583 }
584
585 pub fn decode_all(&self) -> io::Result<Vec<(DocId, f32)>> {
587 let mut result = Vec::with_capacity(self.doc_count as usize);
588 let mut iter = self.iterator();
589
590 while !iter.exhausted {
591 result.push((iter.doc_id, iter.weight));
592 iter.advance();
593 }
594
595 Ok(result)
596 }
597}
598
599pub struct SparsePostingIterator<'a> {
601 posting_list: &'a SparsePostingList,
602 doc_id_offset: usize,
604 weight_offset: usize,
606 index: usize,
608 doc_id: DocId,
610 weight: f32,
612 exhausted: bool,
614}
615
616impl<'a> SparsePostingIterator<'a> {
617 fn new(posting_list: &'a SparsePostingList) -> Self {
618 let mut iter = Self {
619 posting_list,
620 doc_id_offset: 0,
621 weight_offset: 0,
622 index: 0,
623 doc_id: 0,
624 weight: 0.0,
625 exhausted: posting_list.doc_count == 0,
626 };
627
628 if !iter.exhausted {
629 iter.weight_offset = iter.calculate_weight_offset();
631 iter.load_current();
632 }
633
634 iter
635 }
636
637 fn calculate_weight_offset(&self) -> usize {
638 let mut offset = 0;
640 let mut reader = &self.posting_list.data[..];
641
642 for _ in 0..self.posting_list.doc_count {
643 if read_vint(&mut reader).is_ok() {
644 offset = self.posting_list.data.len() - reader.len();
645 }
646 }
647
648 offset
649 }
650
651 fn load_current(&mut self) {
652 if self.index >= self.posting_list.doc_count as usize {
653 self.exhausted = true;
654 return;
655 }
656
657 let mut reader = &self.posting_list.data[self.doc_id_offset..];
659 if let Ok(delta) = read_vint(&mut reader) {
660 self.doc_id = self.doc_id.wrapping_add(delta as u32);
661 self.doc_id_offset = self.posting_list.data.len() - reader.len();
662 }
663
664 let weight_idx = self.index;
666 let pl = self.posting_list;
667
668 self.weight = match pl.quantization {
669 WeightQuantization::Float32 => {
670 let offset = self.weight_offset + weight_idx * 4;
671 if offset + 4 <= pl.data.len() {
672 let bytes = &pl.data[offset..offset + 4];
673 f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]])
674 } else {
675 0.0
676 }
677 }
678 WeightQuantization::Float16 => {
679 let offset = self.weight_offset + weight_idx * 2;
680 if offset + 2 <= pl.data.len() {
681 let bits = u16::from_le_bytes([pl.data[offset], pl.data[offset + 1]]);
682 half::f16::from_bits(bits).to_f32()
683 } else {
684 0.0
685 }
686 }
687 WeightQuantization::UInt8 => {
688 let offset = self.weight_offset + weight_idx;
689 if offset < pl.data.len() {
690 let quantized = pl.data[offset];
691 quantized as f32 * pl.scale + pl.min_val
692 } else {
693 0.0
694 }
695 }
696 WeightQuantization::UInt4 => {
697 let byte_offset = self.weight_offset + weight_idx / 2;
698 if byte_offset < pl.data.len() {
699 let byte = pl.data[byte_offset];
700 let quantized = if weight_idx.is_multiple_of(2) {
701 byte & 0x0F
702 } else {
703 (byte >> 4) & 0x0F
704 };
705 quantized as f32 * pl.scale + pl.min_val
706 } else {
707 0.0
708 }
709 }
710 };
711 }
712
713 pub fn doc(&self) -> DocId {
715 if self.exhausted {
716 super::TERMINATED
717 } else {
718 self.doc_id
719 }
720 }
721
722 pub fn weight(&self) -> f32 {
724 if self.exhausted { 0.0 } else { self.weight }
725 }
726
727 pub fn advance(&mut self) -> DocId {
729 if self.exhausted {
730 return super::TERMINATED;
731 }
732
733 self.index += 1;
734 if self.index >= self.posting_list.doc_count as usize {
735 self.exhausted = true;
736 return super::TERMINATED;
737 }
738
739 self.load_current();
740 self.doc_id
741 }
742
743 pub fn seek(&mut self, target: DocId) -> DocId {
745 while !self.exhausted && self.doc_id < target {
746 self.advance();
747 }
748 self.doc()
749 }
750}
751
752#[derive(Debug, Clone)]
757pub struct BlockSparsePostingList {
758 quantization: WeightQuantization,
760 scale: f32,
762 min_val: f32,
764 skip_list: SparseSkipList,
766 data: Vec<u8>,
768 doc_count: u32,
770}
771
772impl BlockSparsePostingList {
773 pub fn from_postings(
775 postings: &[(DocId, f32)],
776 quantization: WeightQuantization,
777 ) -> io::Result<Self> {
778 if postings.is_empty() {
779 return Ok(Self {
780 quantization,
781 scale: 1.0,
782 min_val: 0.0,
783 skip_list: SparseSkipList::new(),
784 data: Vec::new(),
785 doc_count: 0,
786 });
787 }
788
789 let weights: Vec<f32> = postings.iter().map(|(_, w)| *w).collect();
791 let min_val = weights.iter().cloned().fold(f32::INFINITY, f32::min);
792 let max_val = weights.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
793
794 let (scale, adjusted_min) = match quantization {
795 WeightQuantization::Float32 | WeightQuantization::Float16 => (1.0, 0.0),
796 WeightQuantization::UInt8 => {
797 let range = max_val - min_val;
798 if range < f32::EPSILON {
799 (1.0, min_val)
800 } else {
801 (range / 255.0, min_val)
802 }
803 }
804 WeightQuantization::UInt4 => {
805 let range = max_val - min_val;
806 if range < f32::EPSILON {
807 (1.0, min_val)
808 } else {
809 (range / 15.0, min_val)
810 }
811 }
812 };
813
814 let mut skip_list = SparseSkipList::new();
815 let mut data = Vec::new();
816
817 let mut i = 0;
818 while i < postings.len() {
819 let block_end = (i + SPARSE_BLOCK_SIZE).min(postings.len());
820 let block = &postings[i..block_end];
821
822 let first_doc_id = block.first().unwrap().0;
823 let last_doc_id = block.last().unwrap().0;
824
825 let block_max_weight = block
827 .iter()
828 .map(|(_, w)| *w)
829 .fold(f32::NEG_INFINITY, f32::max);
830
831 let block_doc_ids: Vec<DocId> = block.iter().map(|(d, _)| *d).collect();
833 let (doc_bit_width, packed_doc_ids) = pack_deltas_fixed(&block_doc_ids);
834
835 let block_start = data.len() as u32;
837 skip_list.push(first_doc_id, last_doc_id, block_start, block_max_weight);
838
839 data.write_u16::<LittleEndian>(block.len() as u16)?;
840 data.write_u8(doc_bit_width as u8)?;
841 data.extend_from_slice(&packed_doc_ids);
842
843 match quantization {
845 WeightQuantization::Float32 => {
846 for (_, weight) in block {
847 data.write_f32::<LittleEndian>(*weight)?;
848 }
849 }
850 WeightQuantization::Float16 => {
851 use half::slice::HalfFloatSliceExt;
853 let weights: Vec<f32> = block.iter().map(|(_, w)| *w).collect();
854 let mut f16_slice: Vec<half::f16> = vec![half::f16::ZERO; weights.len()];
855 f16_slice.convert_from_f32_slice(&weights);
856 for h in f16_slice {
857 data.write_u16::<LittleEndian>(h.to_bits())?;
858 }
859 }
860 WeightQuantization::UInt8 => {
861 for (_, weight) in block {
862 let quantized = ((*weight - adjusted_min) / scale).round() as u8;
863 data.write_u8(quantized)?;
864 }
865 }
866 WeightQuantization::UInt4 => {
867 let mut j = 0;
868 while j < block.len() {
869 let q1 = ((block[j].1 - adjusted_min) / scale).round() as u8 & 0x0F;
870 let q2 = if j + 1 < block.len() {
871 ((block[j + 1].1 - adjusted_min) / scale).round() as u8 & 0x0F
872 } else {
873 0
874 };
875 data.write_u8((q2 << 4) | q1)?;
876 j += 2;
877 }
878 }
879 }
880
881 i = block_end;
882 }
883
884 Ok(Self {
885 quantization,
886 scale,
887 min_val: adjusted_min,
888 skip_list,
889 data,
890 doc_count: postings.len() as u32,
891 })
892 }
893
894 pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
896 writer.write_u8(self.quantization as u8)?;
897 writer.write_f32::<LittleEndian>(self.scale)?;
898 writer.write_f32::<LittleEndian>(self.min_val)?;
899 writer.write_u32::<LittleEndian>(self.doc_count)?;
900
901 self.skip_list.write(writer)?;
903
904 writer.write_u32::<LittleEndian>(self.data.len() as u32)?;
906 writer.write_all(&self.data)?;
907
908 Ok(())
909 }
910
911 pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
913 let quant_byte = reader.read_u8()?;
914 let quantization = WeightQuantization::from_u8(quant_byte).ok_or_else(|| {
915 io::Error::new(io::ErrorKind::InvalidData, "Invalid quantization type")
916 })?;
917 let scale = reader.read_f32::<LittleEndian>()?;
918 let min_val = reader.read_f32::<LittleEndian>()?;
919 let doc_count = reader.read_u32::<LittleEndian>()?;
920
921 let skip_list = SparseSkipList::read(reader)?;
923
924 let data_len = reader.read_u32::<LittleEndian>()? as usize;
925 let mut data = vec![0u8; data_len];
926 reader.read_exact(&mut data)?;
927
928 Ok(Self {
929 quantization,
930 scale,
931 min_val,
932 skip_list,
933 data,
934 doc_count,
935 })
936 }
937
938 pub fn doc_count(&self) -> u32 {
940 self.doc_count
941 }
942
943 pub fn num_blocks(&self) -> usize {
945 self.skip_list.len()
946 }
947
948 pub fn quantization(&self) -> WeightQuantization {
950 self.quantization
951 }
952
953 pub fn global_max_weight(&self) -> f32 {
955 self.skip_list.global_max_weight()
956 }
957
958 pub fn block_max_weight(&self, block_idx: usize) -> Option<f32> {
960 self.skip_list.get(block_idx).map(|e| e.max_weight)
961 }
962
963 #[inline]
968 pub fn max_contribution(&self, query_weight: f32) -> f32 {
969 query_weight * self.skip_list.global_max_weight()
970 }
971
972 pub fn iterator(&self) -> BlockSparsePostingIterator<'_> {
974 BlockSparsePostingIterator::new(self)
975 }
976
977 pub fn size_bytes(&self) -> usize {
979 13 + 8 + self.skip_list.len() * 16 + self.data.len()
983 }
984
985 pub fn concatenate(
987 sources: &[(BlockSparsePostingList, u32)],
988 target_quantization: WeightQuantization,
989 ) -> io::Result<Self> {
990 let mut all_postings: Vec<(DocId, f32)> = Vec::new();
992
993 for (source, doc_offset) in sources {
994 let decoded = source.decode_all()?;
995 for (doc_id, weight) in decoded {
996 all_postings.push((doc_id + doc_offset, weight));
997 }
998 }
999
1000 Self::from_postings(&all_postings, target_quantization)
1002 }
1003
1004 pub fn decode_all(&self) -> io::Result<Vec<(DocId, f32)>> {
1006 let mut result = Vec::with_capacity(self.doc_count as usize);
1007 let mut iter = self.iterator();
1008
1009 while iter.doc() != super::TERMINATED {
1010 result.push((iter.doc(), iter.weight()));
1011 iter.advance();
1012 }
1013
1014 Ok(result)
1015 }
1016}
1017
1018pub struct BlockSparsePostingIterator<'a> {
1020 posting_list: &'a BlockSparsePostingList,
1021 current_block: usize,
1022 block_postings: Vec<(DocId, f32)>,
1023 position_in_block: usize,
1024 exhausted: bool,
1025}
1026
1027impl<'a> BlockSparsePostingIterator<'a> {
1028 fn new(posting_list: &'a BlockSparsePostingList) -> Self {
1029 let exhausted = posting_list.skip_list.is_empty();
1030 let mut iter = Self {
1031 posting_list,
1032 current_block: 0,
1033 block_postings: Vec::new(),
1034 position_in_block: 0,
1035 exhausted,
1036 };
1037
1038 if !iter.exhausted {
1039 iter.load_block(0);
1040 }
1041
1042 iter
1043 }
1044
1045 fn load_block(&mut self, block_idx: usize) {
1046 let entry = match self.posting_list.skip_list.get(block_idx) {
1047 Some(e) => e,
1048 None => {
1049 self.exhausted = true;
1050 return;
1051 }
1052 };
1053
1054 self.current_block = block_idx;
1055 self.position_in_block = 0;
1056 self.block_postings.clear();
1057
1058 let offset = entry.offset as usize;
1059 let first_doc_id = entry.first_doc;
1060 let data = &self.posting_list.data[offset..];
1061
1062 if data.len() < 3 {
1064 self.exhausted = true;
1065 return;
1066 }
1067 let count = u16::from_le_bytes([data[0], data[1]]) as usize;
1068 let doc_bit_width = RoundedBitWidth::from_u8(data[2]).unwrap_or(RoundedBitWidth::Zero);
1069
1070 let doc_bytes = doc_bit_width.bytes_per_value() * count.saturating_sub(1);
1072 let doc_data = &data[3..3 + doc_bytes];
1073 let mut doc_ids = vec![0u32; count];
1074 unpack_deltas_fixed(doc_data, doc_bit_width, first_doc_id, count, &mut doc_ids);
1075
1076 let weight_offset = 3 + doc_bytes;
1078 let weight_data = &data[weight_offset..];
1079 let pl = self.posting_list;
1080
1081 let weights: Vec<f32> = match pl.quantization {
1083 WeightQuantization::Float32 => {
1084 let mut weights = Vec::with_capacity(count);
1085 let mut reader = weight_data;
1086 for _ in 0..count {
1087 if reader.len() >= 4 {
1088 weights.push((&mut reader).read_f32::<LittleEndian>().unwrap_or(0.0));
1089 } else {
1090 weights.push(0.0);
1091 }
1092 }
1093 weights
1094 }
1095 WeightQuantization::Float16 => {
1096 use half::slice::HalfFloatSliceExt;
1098 let mut f16_slice: Vec<half::f16> = Vec::with_capacity(count);
1099 for i in 0..count {
1100 let offset = i * 2;
1101 if offset + 2 <= weight_data.len() {
1102 let bits =
1103 u16::from_le_bytes([weight_data[offset], weight_data[offset + 1]]);
1104 f16_slice.push(half::f16::from_bits(bits));
1105 } else {
1106 f16_slice.push(half::f16::ZERO);
1107 }
1108 }
1109 let mut weights = vec![0.0f32; count];
1110 f16_slice.convert_to_f32_slice(&mut weights);
1111 weights
1112 }
1113 WeightQuantization::UInt8 => {
1114 let mut weights = Vec::with_capacity(count);
1115 for i in 0..count {
1116 if i < weight_data.len() {
1117 weights.push(weight_data[i] as f32 * pl.scale + pl.min_val);
1118 } else {
1119 weights.push(0.0);
1120 }
1121 }
1122 weights
1123 }
1124 WeightQuantization::UInt4 => {
1125 let mut weights = Vec::with_capacity(count);
1126 for i in 0..count {
1127 let byte_idx = i / 2;
1128 if byte_idx < weight_data.len() {
1129 let byte = weight_data[byte_idx];
1130 let quantized = if i % 2 == 0 {
1131 byte & 0x0F
1132 } else {
1133 (byte >> 4) & 0x0F
1134 };
1135 weights.push(quantized as f32 * pl.scale + pl.min_val);
1136 } else {
1137 weights.push(0.0);
1138 }
1139 }
1140 weights
1141 }
1142 };
1143
1144 for (doc_id, weight) in doc_ids.into_iter().zip(weights.into_iter()) {
1146 self.block_postings.push((doc_id, weight));
1147 }
1148 }
1149
1150 #[inline]
1152 pub fn is_exhausted(&self) -> bool {
1153 self.exhausted
1154 }
1155
1156 pub fn doc(&self) -> DocId {
1158 if self.exhausted {
1159 super::TERMINATED
1160 } else if self.position_in_block < self.block_postings.len() {
1161 self.block_postings[self.position_in_block].0
1162 } else {
1163 super::TERMINATED
1164 }
1165 }
1166
1167 pub fn weight(&self) -> f32 {
1169 if self.exhausted || self.position_in_block >= self.block_postings.len() {
1170 0.0
1171 } else {
1172 self.block_postings[self.position_in_block].1
1173 }
1174 }
1175
1176 #[inline]
1181 pub fn current_block_max_weight(&self) -> f32 {
1182 if self.exhausted {
1183 0.0
1184 } else {
1185 self.posting_list
1186 .skip_list
1187 .get(self.current_block)
1188 .map(|e| e.max_weight)
1189 .unwrap_or(0.0)
1190 }
1191 }
1192
1193 #[inline]
1197 pub fn current_block_max_contribution(&self, query_weight: f32) -> f32 {
1198 query_weight * self.current_block_max_weight()
1199 }
1200
1201 pub fn advance(&mut self) -> DocId {
1203 if self.exhausted {
1204 return super::TERMINATED;
1205 }
1206
1207 self.position_in_block += 1;
1208 if self.position_in_block >= self.block_postings.len() {
1209 self.load_block(self.current_block + 1);
1210 }
1211
1212 self.doc()
1213 }
1214
1215 pub fn seek(&mut self, target: DocId) -> DocId {
1217 if self.exhausted {
1218 return super::TERMINATED;
1219 }
1220
1221 if let Some(block_idx) = self.posting_list.skip_list.find_block(target) {
1223 if block_idx != self.current_block {
1224 self.load_block(block_idx);
1225 }
1226
1227 while self.position_in_block < self.block_postings.len() {
1229 if self.block_postings[self.position_in_block].0 >= target {
1230 return self.doc();
1231 }
1232 self.position_in_block += 1;
1233 }
1234
1235 self.load_block(self.current_block + 1);
1237 self.seek(target)
1238 } else {
1239 self.exhausted = true;
1240 super::TERMINATED
1241 }
1242 }
1243}
1244
1245#[cfg(test)]
1246mod tests {
1247 use super::*;
1248
1249 #[test]
1250 fn test_sparse_vector_dot_product() {
1251 let v1 = SparseVector::from_entries(&[0, 2, 5], &[1.0, 2.0, 3.0]);
1252 let v2 = SparseVector::from_entries(&[1, 2, 5], &[1.0, 4.0, 2.0]);
1253
1254 assert!((v1.dot(&v2) - 14.0).abs() < 1e-6);
1256 }
1257
1258 #[test]
1259 fn test_sparse_posting_list_float32() {
1260 let postings = vec![(0, 1.5), (5, 2.3), (10, 0.8), (100, 3.15)];
1261 let pl = SparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1262
1263 assert_eq!(pl.doc_count(), 4);
1264
1265 let mut iter = pl.iterator();
1266 assert_eq!(iter.doc(), 0);
1267 assert!((iter.weight() - 1.5).abs() < 1e-6);
1268
1269 iter.advance();
1270 assert_eq!(iter.doc(), 5);
1271 assert!((iter.weight() - 2.3).abs() < 1e-6);
1272
1273 iter.advance();
1274 assert_eq!(iter.doc(), 10);
1275
1276 iter.advance();
1277 assert_eq!(iter.doc(), 100);
1278 assert!((iter.weight() - 3.15).abs() < 1e-6);
1279
1280 iter.advance();
1281 assert_eq!(iter.doc(), super::super::TERMINATED);
1282 }
1283
1284 #[test]
1285 fn test_sparse_posting_list_uint8() {
1286 let postings = vec![(0, 0.0), (5, 0.5), (10, 1.0)];
1287 let pl = SparsePostingList::from_postings(&postings, WeightQuantization::UInt8).unwrap();
1288
1289 let decoded = pl.decode_all().unwrap();
1290 assert_eq!(decoded.len(), 3);
1291
1292 assert!(decoded[0].1 < decoded[1].1);
1294 assert!(decoded[1].1 < decoded[2].1);
1295 }
1296
1297 #[test]
1298 fn test_block_sparse_posting_list() {
1299 let postings: Vec<(DocId, f32)> = (0..300).map(|i| (i * 2, (i as f32) * 0.1)).collect();
1301
1302 let pl =
1303 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1304
1305 assert_eq!(pl.doc_count(), 300);
1306 assert!(pl.num_blocks() >= 2);
1307
1308 let mut iter = pl.iterator();
1310 for (expected_doc, expected_weight) in &postings {
1311 assert_eq!(iter.doc(), *expected_doc);
1312 assert!((iter.weight() - expected_weight).abs() < 1e-6);
1313 iter.advance();
1314 }
1315 assert_eq!(iter.doc(), super::super::TERMINATED);
1316 }
1317
1318 #[test]
1319 fn test_block_sparse_seek() {
1320 let postings: Vec<(DocId, f32)> = (0..500).map(|i| (i * 3, i as f32)).collect();
1321
1322 let pl =
1323 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1324
1325 let mut iter = pl.iterator();
1326
1327 assert_eq!(iter.seek(300), 300);
1329
1330 assert_eq!(iter.seek(301), 303);
1332
1333 assert_eq!(iter.seek(2000), super::super::TERMINATED);
1335 }
1336
1337 #[test]
1338 fn test_serialization_roundtrip() {
1339 let postings: Vec<(DocId, f32)> = vec![(0, 1.0), (10, 2.0), (100, 3.0)];
1340
1341 for quant in [
1342 WeightQuantization::Float32,
1343 WeightQuantization::Float16,
1344 WeightQuantization::UInt8,
1345 ] {
1346 let pl = BlockSparsePostingList::from_postings(&postings, quant).unwrap();
1347
1348 let mut buffer = Vec::new();
1349 pl.serialize(&mut buffer).unwrap();
1350
1351 let pl2 = BlockSparsePostingList::deserialize(&mut &buffer[..]).unwrap();
1352
1353 assert_eq!(pl.doc_count(), pl2.doc_count());
1354 assert_eq!(pl.quantization(), pl2.quantization());
1355
1356 let mut iter1 = pl.iterator();
1358 let mut iter2 = pl2.iterator();
1359
1360 while iter1.doc() != super::super::TERMINATED {
1361 assert_eq!(iter1.doc(), iter2.doc());
1362 assert!((iter1.weight() - iter2.weight()).abs() < 0.1);
1364 iter1.advance();
1365 iter2.advance();
1366 }
1367 }
1368 }
1369
1370 #[test]
1371 fn test_concatenate() {
1372 let postings1: Vec<(DocId, f32)> = vec![(0, 1.0), (5, 2.0)];
1373 let postings2: Vec<(DocId, f32)> = vec![(0, 3.0), (10, 4.0)];
1374
1375 let pl1 =
1376 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1377 let pl2 =
1378 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1379
1380 let merged = BlockSparsePostingList::concatenate(
1382 &[(pl1, 0), (pl2, 100)],
1383 WeightQuantization::Float32,
1384 )
1385 .unwrap();
1386
1387 assert_eq!(merged.doc_count(), 4);
1388
1389 let decoded = merged.decode_all().unwrap();
1390 assert_eq!(decoded[0], (0, 1.0));
1391 assert_eq!(decoded[1], (5, 2.0));
1392 assert_eq!(decoded[2], (100, 3.0)); assert_eq!(decoded[3], (110, 4.0)); }
1395
1396 #[test]
1397 fn test_sparse_vector_config() {
1398 let default = SparseVectorConfig::default();
1400 assert_eq!(default.index_size, IndexSize::U32);
1401 assert_eq!(default.weight_quantization, WeightQuantization::Float32);
1402 assert_eq!(default.bytes_per_entry(), 8.0); let splade = SparseVectorConfig::splade();
1406 assert_eq!(splade.index_size, IndexSize::U16);
1407 assert_eq!(splade.weight_quantization, WeightQuantization::UInt8);
1408 assert_eq!(splade.bytes_per_entry(), 3.0); let compact = SparseVectorConfig::compact();
1412 assert_eq!(compact.index_size, IndexSize::U16);
1413 assert_eq!(compact.weight_quantization, WeightQuantization::UInt4);
1414 assert_eq!(compact.bytes_per_entry(), 2.5); let byte = splade.to_byte();
1418 let restored = SparseVectorConfig::from_byte(byte).unwrap();
1419 assert_eq!(restored, splade);
1420 }
1421
1422 #[test]
1423 fn test_index_size() {
1424 assert_eq!(IndexSize::U16.bytes(), 2);
1425 assert_eq!(IndexSize::U32.bytes(), 4);
1426 assert_eq!(IndexSize::U16.max_value(), 65535);
1427 assert_eq!(IndexSize::U32.max_value(), u32::MAX);
1428 }
1429
1430 #[test]
1431 fn test_block_max_weight() {
1432 let postings: Vec<(DocId, f32)> =
1437 (0..300).map(|i| (i as DocId, (i as f32) * 0.1)).collect();
1438
1439 let pl =
1440 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1441
1442 assert!((pl.global_max_weight() - 29.9).abs() < 0.01);
1444
1445 assert!(pl.num_blocks() >= 3);
1447
1448 let block0_max = pl.block_max_weight(0).unwrap();
1450 assert!((block0_max - 12.7).abs() < 0.01);
1451
1452 let block1_max = pl.block_max_weight(1).unwrap();
1454 assert!((block1_max - 25.5).abs() < 0.01);
1455
1456 let block2_max = pl.block_max_weight(2).unwrap();
1458 assert!((block2_max - 29.9).abs() < 0.01);
1459
1460 let query_weight = 2.0;
1462 assert!((pl.max_contribution(query_weight) - 59.8).abs() < 0.1);
1463
1464 let mut iter = pl.iterator();
1466 assert!((iter.current_block_max_weight() - 12.7).abs() < 0.01);
1467 assert!((iter.current_block_max_contribution(query_weight) - 25.4).abs() < 0.1);
1468
1469 iter.seek(128);
1471 assert!((iter.current_block_max_weight() - 25.5).abs() < 0.01);
1472 }
1473
1474 #[test]
1475 fn test_sparse_skip_list_serialization() {
1476 let mut skip_list = SparseSkipList::new();
1477 skip_list.push(0, 127, 0, 12.7);
1478 skip_list.push(128, 255, 100, 25.5);
1479 skip_list.push(256, 299, 200, 29.9);
1480
1481 assert_eq!(skip_list.len(), 3);
1482 assert!((skip_list.global_max_weight() - 29.9).abs() < 0.01);
1483
1484 let mut buffer = Vec::new();
1486 skip_list.write(&mut buffer).unwrap();
1487
1488 let restored = SparseSkipList::read(&mut buffer.as_slice()).unwrap();
1490
1491 assert_eq!(restored.len(), 3);
1492 assert!((restored.global_max_weight() - 29.9).abs() < 0.01);
1493
1494 let e0 = restored.get(0).unwrap();
1496 assert_eq!(e0.first_doc, 0);
1497 assert_eq!(e0.last_doc, 127);
1498 assert!((e0.max_weight - 12.7).abs() < 0.01);
1499
1500 let e1 = restored.get(1).unwrap();
1501 assert_eq!(e1.first_doc, 128);
1502 assert!((e1.max_weight - 25.5).abs() < 0.01);
1503 }
1504}