1use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
20use serde::{Deserialize, Serialize};
21use std::io::{self, Read, Write};
22
23use super::posting_common::{
24 RoundedBitWidth, pack_deltas_fixed, read_vint, unpack_deltas_fixed, write_vint,
25};
26use crate::DocId;
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
30#[repr(u8)]
31pub enum IndexSize {
32 U16 = 0,
34 #[default]
36 U32 = 1,
37}
38
39impl IndexSize {
40 pub fn bytes(&self) -> usize {
42 match self {
43 IndexSize::U16 => 2,
44 IndexSize::U32 => 4,
45 }
46 }
47
48 pub fn max_value(&self) -> u32 {
50 match self {
51 IndexSize::U16 => u16::MAX as u32,
52 IndexSize::U32 => u32::MAX,
53 }
54 }
55
56 fn from_u8(v: u8) -> Option<Self> {
57 match v {
58 0 => Some(IndexSize::U16),
59 1 => Some(IndexSize::U32),
60 _ => None,
61 }
62 }
63}
64
65#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
67#[repr(u8)]
68pub enum WeightQuantization {
69 #[default]
71 Float32 = 0,
72 Float16 = 1,
74 UInt8 = 2,
76 UInt4 = 3,
78}
79
80impl WeightQuantization {
81 pub fn bytes_per_weight(&self) -> f32 {
83 match self {
84 WeightQuantization::Float32 => 4.0,
85 WeightQuantization::Float16 => 2.0,
86 WeightQuantization::UInt8 => 1.0,
87 WeightQuantization::UInt4 => 0.5,
88 }
89 }
90
91 fn from_u8(v: u8) -> Option<Self> {
92 match v {
93 0 => Some(WeightQuantization::Float32),
94 1 => Some(WeightQuantization::Float16),
95 2 => Some(WeightQuantization::UInt8),
96 3 => Some(WeightQuantization::UInt4),
97 _ => None,
98 }
99 }
100}
101
102#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
104pub struct SparseVectorConfig {
105 pub index_size: IndexSize,
107 pub weight_quantization: WeightQuantization,
109}
110
111impl Default for SparseVectorConfig {
112 fn default() -> Self {
113 Self {
114 index_size: IndexSize::U32,
115 weight_quantization: WeightQuantization::Float32,
116 }
117 }
118}
119
120impl SparseVectorConfig {
121 pub fn splade() -> Self {
123 Self {
124 index_size: IndexSize::U16,
125 weight_quantization: WeightQuantization::UInt8,
126 }
127 }
128
129 pub fn compact() -> Self {
131 Self {
132 index_size: IndexSize::U16,
133 weight_quantization: WeightQuantization::UInt4,
134 }
135 }
136
137 pub fn full_precision() -> Self {
139 Self {
140 index_size: IndexSize::U32,
141 weight_quantization: WeightQuantization::Float32,
142 }
143 }
144
145 pub fn bytes_per_entry(&self) -> f32 {
147 self.index_size.bytes() as f32 + self.weight_quantization.bytes_per_weight()
148 }
149
150 pub fn to_byte(&self) -> u8 {
152 ((self.index_size as u8) << 4) | (self.weight_quantization as u8)
153 }
154
155 pub fn from_byte(b: u8) -> Option<Self> {
157 let index_size = IndexSize::from_u8(b >> 4)?;
158 let weight_quantization = WeightQuantization::from_u8(b & 0x0F)?;
159 Some(Self {
160 index_size,
161 weight_quantization,
162 })
163 }
164}
165
166#[derive(Debug, Clone, Copy, PartialEq)]
168pub struct SparseEntry {
169 pub dim_id: u32,
170 pub weight: f32,
171}
172
173#[derive(Debug, Clone, Default)]
175pub struct SparseVector {
176 entries: Vec<SparseEntry>,
177}
178
179impl SparseVector {
180 pub fn new() -> Self {
181 Self::default()
182 }
183
184 pub fn with_capacity(capacity: usize) -> Self {
185 Self {
186 entries: Vec::with_capacity(capacity),
187 }
188 }
189
190 pub fn from_entries(dim_ids: &[u32], weights: &[f32]) -> Self {
192 assert_eq!(dim_ids.len(), weights.len());
193 let mut entries: Vec<SparseEntry> = dim_ids
194 .iter()
195 .zip(weights.iter())
196 .map(|(&dim_id, &weight)| SparseEntry { dim_id, weight })
197 .collect();
198 entries.sort_by_key(|e| e.dim_id);
200 Self { entries }
201 }
202
203 pub fn push(&mut self, dim_id: u32, weight: f32) {
205 debug_assert!(
206 self.entries.is_empty() || self.entries.last().unwrap().dim_id < dim_id,
207 "Entries must be added in sorted order by dim_id"
208 );
209 self.entries.push(SparseEntry { dim_id, weight });
210 }
211
212 pub fn len(&self) -> usize {
214 self.entries.len()
215 }
216
217 pub fn is_empty(&self) -> bool {
218 self.entries.is_empty()
219 }
220
221 pub fn iter(&self) -> impl Iterator<Item = &SparseEntry> {
222 self.entries.iter()
223 }
224
225 pub fn dot(&self, other: &SparseVector) -> f32 {
227 let mut result = 0.0f32;
228 let mut i = 0;
229 let mut j = 0;
230
231 while i < self.entries.len() && j < other.entries.len() {
232 let a = &self.entries[i];
233 let b = &other.entries[j];
234
235 match a.dim_id.cmp(&b.dim_id) {
236 std::cmp::Ordering::Less => i += 1,
237 std::cmp::Ordering::Greater => j += 1,
238 std::cmp::Ordering::Equal => {
239 result += a.weight * b.weight;
240 i += 1;
241 j += 1;
242 }
243 }
244 }
245
246 result
247 }
248
249 pub fn norm_squared(&self) -> f32 {
251 self.entries.iter().map(|e| e.weight * e.weight).sum()
252 }
253
254 pub fn norm(&self) -> f32 {
256 self.norm_squared().sqrt()
257 }
258}
259
260#[derive(Debug, Clone, Copy)]
262pub struct SparsePosting {
263 pub doc_id: DocId,
264 pub weight: f32,
265}
266
267pub const SPARSE_BLOCK_SIZE: usize = 128;
269
270#[derive(Debug, Clone, Copy, PartialEq)]
274pub struct SparseSkipEntry {
275 pub first_doc: DocId,
277 pub last_doc: DocId,
279 pub offset: u32,
281 pub max_weight: f32,
283}
284
285impl SparseSkipEntry {
286 pub fn new(first_doc: DocId, last_doc: DocId, offset: u32, max_weight: f32) -> Self {
287 Self {
288 first_doc,
289 last_doc,
290 offset,
291 max_weight,
292 }
293 }
294
295 #[inline]
300 pub fn block_max_contribution(&self, query_weight: f32) -> f32 {
301 query_weight * self.max_weight
302 }
303
304 pub fn write<W: Write>(&self, writer: &mut W) -> io::Result<()> {
306 writer.write_u32::<LittleEndian>(self.first_doc)?;
307 writer.write_u32::<LittleEndian>(self.last_doc)?;
308 writer.write_u32::<LittleEndian>(self.offset)?;
309 writer.write_f32::<LittleEndian>(self.max_weight)?;
310 Ok(())
311 }
312
313 pub fn read<R: Read>(reader: &mut R) -> io::Result<Self> {
315 let first_doc = reader.read_u32::<LittleEndian>()?;
316 let last_doc = reader.read_u32::<LittleEndian>()?;
317 let offset = reader.read_u32::<LittleEndian>()?;
318 let max_weight = reader.read_f32::<LittleEndian>()?;
319 Ok(Self {
320 first_doc,
321 last_doc,
322 offset,
323 max_weight,
324 })
325 }
326}
327
328#[derive(Debug, Clone, Default)]
330pub struct SparseSkipList {
331 entries: Vec<SparseSkipEntry>,
332 global_max_weight: f32,
334}
335
336impl SparseSkipList {
337 pub fn new() -> Self {
338 Self::default()
339 }
340
341 pub fn push(&mut self, first_doc: DocId, last_doc: DocId, offset: u32, max_weight: f32) {
343 self.global_max_weight = self.global_max_weight.max(max_weight);
344 self.entries.push(SparseSkipEntry::new(
345 first_doc, last_doc, offset, max_weight,
346 ));
347 }
348
349 pub fn len(&self) -> usize {
351 self.entries.len()
352 }
353
354 pub fn is_empty(&self) -> bool {
355 self.entries.is_empty()
356 }
357
358 pub fn get(&self, index: usize) -> Option<&SparseSkipEntry> {
360 self.entries.get(index)
361 }
362
363 pub fn global_max_weight(&self) -> f32 {
365 self.global_max_weight
366 }
367
368 pub fn find_block(&self, target: DocId) -> Option<usize> {
370 self.entries.iter().position(|e| e.last_doc >= target)
371 }
372
373 pub fn iter(&self) -> impl Iterator<Item = &SparseSkipEntry> {
375 self.entries.iter()
376 }
377
378 pub fn write<W: Write>(&self, writer: &mut W) -> io::Result<()> {
380 writer.write_u32::<LittleEndian>(self.entries.len() as u32)?;
381 writer.write_f32::<LittleEndian>(self.global_max_weight)?;
382 for entry in &self.entries {
383 entry.write(writer)?;
384 }
385 Ok(())
386 }
387
388 pub fn read<R: Read>(reader: &mut R) -> io::Result<Self> {
390 let count = reader.read_u32::<LittleEndian>()? as usize;
391 let global_max_weight = reader.read_f32::<LittleEndian>()?;
392 let mut entries = Vec::with_capacity(count);
393 for _ in 0..count {
394 entries.push(SparseSkipEntry::read(reader)?);
395 }
396 Ok(Self {
397 entries,
398 global_max_weight,
399 })
400 }
401}
402
403#[derive(Debug, Clone)]
409pub struct SparsePostingList {
410 quantization: WeightQuantization,
412 scale: f32,
414 min_val: f32,
416 doc_count: u32,
418 data: Vec<u8>,
420}
421
422impl SparsePostingList {
423 pub fn from_postings(
425 postings: &[(DocId, f32)],
426 quantization: WeightQuantization,
427 ) -> io::Result<Self> {
428 if postings.is_empty() {
429 return Ok(Self {
430 quantization,
431 scale: 1.0,
432 min_val: 0.0,
433 doc_count: 0,
434 data: Vec::new(),
435 });
436 }
437
438 let weights: Vec<f32> = postings.iter().map(|(_, w)| *w).collect();
440 let min_val = weights.iter().cloned().fold(f32::INFINITY, f32::min);
441 let max_val = weights.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
442
443 let (scale, adjusted_min) = match quantization {
444 WeightQuantization::Float32 | WeightQuantization::Float16 => (1.0, 0.0),
445 WeightQuantization::UInt8 => {
446 let range = max_val - min_val;
447 if range < f32::EPSILON {
448 (1.0, min_val)
449 } else {
450 (range / 255.0, min_val)
451 }
452 }
453 WeightQuantization::UInt4 => {
454 let range = max_val - min_val;
455 if range < f32::EPSILON {
456 (1.0, min_val)
457 } else {
458 (range / 15.0, min_val)
459 }
460 }
461 };
462
463 let mut data = Vec::new();
464
465 let mut prev_doc_id = 0u32;
467 for (doc_id, _) in postings {
468 let delta = doc_id - prev_doc_id;
469 write_vint(&mut data, delta as u64)?;
470 prev_doc_id = *doc_id;
471 }
472
473 match quantization {
475 WeightQuantization::Float32 => {
476 for (_, weight) in postings {
477 data.write_f32::<LittleEndian>(*weight)?;
478 }
479 }
480 WeightQuantization::Float16 => {
481 use half::slice::HalfFloatSliceExt;
483 let weights: Vec<f32> = postings.iter().map(|(_, w)| *w).collect();
484 let mut f16_slice: Vec<half::f16> = vec![half::f16::ZERO; weights.len()];
485 f16_slice.convert_from_f32_slice(&weights);
486 for h in f16_slice {
487 data.write_u16::<LittleEndian>(h.to_bits())?;
488 }
489 }
490 WeightQuantization::UInt8 => {
491 for (_, weight) in postings {
492 let quantized = ((*weight - adjusted_min) / scale).round() as u8;
493 data.write_u8(quantized)?;
494 }
495 }
496 WeightQuantization::UInt4 => {
497 let mut i = 0;
499 while i < postings.len() {
500 let q1 = ((postings[i].1 - adjusted_min) / scale).round() as u8 & 0x0F;
501 let q2 = if i + 1 < postings.len() {
502 ((postings[i + 1].1 - adjusted_min) / scale).round() as u8 & 0x0F
503 } else {
504 0
505 };
506 data.write_u8((q2 << 4) | q1)?;
507 i += 2;
508 }
509 }
510 }
511
512 Ok(Self {
513 quantization,
514 scale,
515 min_val: adjusted_min,
516 doc_count: postings.len() as u32,
517 data,
518 })
519 }
520
521 pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
523 writer.write_u8(self.quantization as u8)?;
524 writer.write_f32::<LittleEndian>(self.scale)?;
525 writer.write_f32::<LittleEndian>(self.min_val)?;
526 writer.write_u32::<LittleEndian>(self.doc_count)?;
527 writer.write_u32::<LittleEndian>(self.data.len() as u32)?;
528 writer.write_all(&self.data)?;
529 Ok(())
530 }
531
532 pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
534 let quant_byte = reader.read_u8()?;
535 let quantization = WeightQuantization::from_u8(quant_byte).ok_or_else(|| {
536 io::Error::new(io::ErrorKind::InvalidData, "Invalid quantization type")
537 })?;
538 let scale = reader.read_f32::<LittleEndian>()?;
539 let min_val = reader.read_f32::<LittleEndian>()?;
540 let doc_count = reader.read_u32::<LittleEndian>()?;
541 let data_len = reader.read_u32::<LittleEndian>()? as usize;
542 let mut data = vec![0u8; data_len];
543 reader.read_exact(&mut data)?;
544
545 Ok(Self {
546 quantization,
547 scale,
548 min_val,
549 doc_count,
550 data,
551 })
552 }
553
554 pub fn doc_count(&self) -> u32 {
556 self.doc_count
557 }
558
559 pub fn quantization(&self) -> WeightQuantization {
561 self.quantization
562 }
563
564 pub fn iterator(&self) -> SparsePostingIterator<'_> {
566 SparsePostingIterator::new(self)
567 }
568
569 pub fn decode_all(&self) -> io::Result<Vec<(DocId, f32)>> {
571 let mut result = Vec::with_capacity(self.doc_count as usize);
572 let mut iter = self.iterator();
573
574 while !iter.exhausted {
575 result.push((iter.doc_id, iter.weight));
576 iter.advance();
577 }
578
579 Ok(result)
580 }
581}
582
583pub struct SparsePostingIterator<'a> {
585 posting_list: &'a SparsePostingList,
586 doc_id_offset: usize,
588 weight_offset: usize,
590 index: usize,
592 doc_id: DocId,
594 weight: f32,
596 exhausted: bool,
598}
599
600impl<'a> SparsePostingIterator<'a> {
601 fn new(posting_list: &'a SparsePostingList) -> Self {
602 let mut iter = Self {
603 posting_list,
604 doc_id_offset: 0,
605 weight_offset: 0,
606 index: 0,
607 doc_id: 0,
608 weight: 0.0,
609 exhausted: posting_list.doc_count == 0,
610 };
611
612 if !iter.exhausted {
613 iter.weight_offset = iter.calculate_weight_offset();
615 iter.load_current();
616 }
617
618 iter
619 }
620
621 fn calculate_weight_offset(&self) -> usize {
622 let mut offset = 0;
624 let mut reader = &self.posting_list.data[..];
625
626 for _ in 0..self.posting_list.doc_count {
627 if read_vint(&mut reader).is_ok() {
628 offset = self.posting_list.data.len() - reader.len();
629 }
630 }
631
632 offset
633 }
634
635 fn load_current(&mut self) {
636 if self.index >= self.posting_list.doc_count as usize {
637 self.exhausted = true;
638 return;
639 }
640
641 let mut reader = &self.posting_list.data[self.doc_id_offset..];
643 if let Ok(delta) = read_vint(&mut reader) {
644 self.doc_id = self.doc_id.wrapping_add(delta as u32);
645 self.doc_id_offset = self.posting_list.data.len() - reader.len();
646 }
647
648 let weight_idx = self.index;
650 let pl = self.posting_list;
651
652 self.weight = match pl.quantization {
653 WeightQuantization::Float32 => {
654 let offset = self.weight_offset + weight_idx * 4;
655 if offset + 4 <= pl.data.len() {
656 let bytes = &pl.data[offset..offset + 4];
657 f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]])
658 } else {
659 0.0
660 }
661 }
662 WeightQuantization::Float16 => {
663 let offset = self.weight_offset + weight_idx * 2;
664 if offset + 2 <= pl.data.len() {
665 let bits = u16::from_le_bytes([pl.data[offset], pl.data[offset + 1]]);
666 half::f16::from_bits(bits).to_f32()
667 } else {
668 0.0
669 }
670 }
671 WeightQuantization::UInt8 => {
672 let offset = self.weight_offset + weight_idx;
673 if offset < pl.data.len() {
674 let quantized = pl.data[offset];
675 quantized as f32 * pl.scale + pl.min_val
676 } else {
677 0.0
678 }
679 }
680 WeightQuantization::UInt4 => {
681 let byte_offset = self.weight_offset + weight_idx / 2;
682 if byte_offset < pl.data.len() {
683 let byte = pl.data[byte_offset];
684 let quantized = if weight_idx.is_multiple_of(2) {
685 byte & 0x0F
686 } else {
687 (byte >> 4) & 0x0F
688 };
689 quantized as f32 * pl.scale + pl.min_val
690 } else {
691 0.0
692 }
693 }
694 };
695 }
696
697 pub fn doc(&self) -> DocId {
699 if self.exhausted {
700 super::TERMINATED
701 } else {
702 self.doc_id
703 }
704 }
705
706 pub fn weight(&self) -> f32 {
708 if self.exhausted { 0.0 } else { self.weight }
709 }
710
711 pub fn advance(&mut self) -> DocId {
713 if self.exhausted {
714 return super::TERMINATED;
715 }
716
717 self.index += 1;
718 if self.index >= self.posting_list.doc_count as usize {
719 self.exhausted = true;
720 return super::TERMINATED;
721 }
722
723 self.load_current();
724 self.doc_id
725 }
726
727 pub fn seek(&mut self, target: DocId) -> DocId {
729 while !self.exhausted && self.doc_id < target {
730 self.advance();
731 }
732 self.doc()
733 }
734}
735
736#[derive(Debug, Clone)]
741pub struct BlockSparsePostingList {
742 quantization: WeightQuantization,
744 scale: f32,
746 min_val: f32,
748 skip_list: SparseSkipList,
750 data: Vec<u8>,
752 doc_count: u32,
754}
755
756impl BlockSparsePostingList {
757 pub fn from_postings(
759 postings: &[(DocId, f32)],
760 quantization: WeightQuantization,
761 ) -> io::Result<Self> {
762 if postings.is_empty() {
763 return Ok(Self {
764 quantization,
765 scale: 1.0,
766 min_val: 0.0,
767 skip_list: SparseSkipList::new(),
768 data: Vec::new(),
769 doc_count: 0,
770 });
771 }
772
773 let weights: Vec<f32> = postings.iter().map(|(_, w)| *w).collect();
775 let min_val = weights.iter().cloned().fold(f32::INFINITY, f32::min);
776 let max_val = weights.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
777
778 let (scale, adjusted_min) = match quantization {
779 WeightQuantization::Float32 | WeightQuantization::Float16 => (1.0, 0.0),
780 WeightQuantization::UInt8 => {
781 let range = max_val - min_val;
782 if range < f32::EPSILON {
783 (1.0, min_val)
784 } else {
785 (range / 255.0, min_val)
786 }
787 }
788 WeightQuantization::UInt4 => {
789 let range = max_val - min_val;
790 if range < f32::EPSILON {
791 (1.0, min_val)
792 } else {
793 (range / 15.0, min_val)
794 }
795 }
796 };
797
798 let mut skip_list = SparseSkipList::new();
799 let mut data = Vec::new();
800
801 let mut i = 0;
802 while i < postings.len() {
803 let block_end = (i + SPARSE_BLOCK_SIZE).min(postings.len());
804 let block = &postings[i..block_end];
805
806 let first_doc_id = block.first().unwrap().0;
807 let last_doc_id = block.last().unwrap().0;
808
809 let block_max_weight = block
811 .iter()
812 .map(|(_, w)| *w)
813 .fold(f32::NEG_INFINITY, f32::max);
814
815 let block_doc_ids: Vec<DocId> = block.iter().map(|(d, _)| *d).collect();
817 let (doc_bit_width, packed_doc_ids) = pack_deltas_fixed(&block_doc_ids);
818
819 let block_start = data.len() as u32;
821 skip_list.push(first_doc_id, last_doc_id, block_start, block_max_weight);
822
823 data.write_u16::<LittleEndian>(block.len() as u16)?;
824 data.write_u8(doc_bit_width as u8)?;
825 data.extend_from_slice(&packed_doc_ids);
826
827 match quantization {
829 WeightQuantization::Float32 => {
830 for (_, weight) in block {
831 data.write_f32::<LittleEndian>(*weight)?;
832 }
833 }
834 WeightQuantization::Float16 => {
835 use half::slice::HalfFloatSliceExt;
837 let weights: Vec<f32> = block.iter().map(|(_, w)| *w).collect();
838 let mut f16_slice: Vec<half::f16> = vec![half::f16::ZERO; weights.len()];
839 f16_slice.convert_from_f32_slice(&weights);
840 for h in f16_slice {
841 data.write_u16::<LittleEndian>(h.to_bits())?;
842 }
843 }
844 WeightQuantization::UInt8 => {
845 for (_, weight) in block {
846 let quantized = ((*weight - adjusted_min) / scale).round() as u8;
847 data.write_u8(quantized)?;
848 }
849 }
850 WeightQuantization::UInt4 => {
851 let mut j = 0;
852 while j < block.len() {
853 let q1 = ((block[j].1 - adjusted_min) / scale).round() as u8 & 0x0F;
854 let q2 = if j + 1 < block.len() {
855 ((block[j + 1].1 - adjusted_min) / scale).round() as u8 & 0x0F
856 } else {
857 0
858 };
859 data.write_u8((q2 << 4) | q1)?;
860 j += 2;
861 }
862 }
863 }
864
865 i = block_end;
866 }
867
868 Ok(Self {
869 quantization,
870 scale,
871 min_val: adjusted_min,
872 skip_list,
873 data,
874 doc_count: postings.len() as u32,
875 })
876 }
877
878 pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
880 writer.write_u8(self.quantization as u8)?;
881 writer.write_f32::<LittleEndian>(self.scale)?;
882 writer.write_f32::<LittleEndian>(self.min_val)?;
883 writer.write_u32::<LittleEndian>(self.doc_count)?;
884
885 self.skip_list.write(writer)?;
887
888 writer.write_u32::<LittleEndian>(self.data.len() as u32)?;
890 writer.write_all(&self.data)?;
891
892 Ok(())
893 }
894
895 pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
897 let quant_byte = reader.read_u8()?;
898 let quantization = WeightQuantization::from_u8(quant_byte).ok_or_else(|| {
899 io::Error::new(io::ErrorKind::InvalidData, "Invalid quantization type")
900 })?;
901 let scale = reader.read_f32::<LittleEndian>()?;
902 let min_val = reader.read_f32::<LittleEndian>()?;
903 let doc_count = reader.read_u32::<LittleEndian>()?;
904
905 let skip_list = SparseSkipList::read(reader)?;
907
908 let data_len = reader.read_u32::<LittleEndian>()? as usize;
909 let mut data = vec![0u8; data_len];
910 reader.read_exact(&mut data)?;
911
912 Ok(Self {
913 quantization,
914 scale,
915 min_val,
916 skip_list,
917 data,
918 doc_count,
919 })
920 }
921
922 pub fn doc_count(&self) -> u32 {
924 self.doc_count
925 }
926
927 pub fn num_blocks(&self) -> usize {
929 self.skip_list.len()
930 }
931
932 pub fn quantization(&self) -> WeightQuantization {
934 self.quantization
935 }
936
937 pub fn global_max_weight(&self) -> f32 {
939 self.skip_list.global_max_weight()
940 }
941
942 pub fn block_max_weight(&self, block_idx: usize) -> Option<f32> {
944 self.skip_list.get(block_idx).map(|e| e.max_weight)
945 }
946
947 #[inline]
952 pub fn max_contribution(&self, query_weight: f32) -> f32 {
953 query_weight * self.skip_list.global_max_weight()
954 }
955
956 pub fn iterator(&self) -> BlockSparsePostingIterator<'_> {
958 BlockSparsePostingIterator::new(self)
959 }
960
961 pub fn size_bytes(&self) -> usize {
963 13 + 8 + self.skip_list.len() * 16 + self.data.len()
967 }
968
969 pub fn concatenate(
971 sources: &[(BlockSparsePostingList, u32)],
972 target_quantization: WeightQuantization,
973 ) -> io::Result<Self> {
974 let mut all_postings: Vec<(DocId, f32)> = Vec::new();
976
977 for (source, doc_offset) in sources {
978 let decoded = source.decode_all()?;
979 for (doc_id, weight) in decoded {
980 all_postings.push((doc_id + doc_offset, weight));
981 }
982 }
983
984 Self::from_postings(&all_postings, target_quantization)
986 }
987
988 pub fn decode_all(&self) -> io::Result<Vec<(DocId, f32)>> {
990 let mut result = Vec::with_capacity(self.doc_count as usize);
991 let mut iter = self.iterator();
992
993 while iter.doc() != super::TERMINATED {
994 result.push((iter.doc(), iter.weight()));
995 iter.advance();
996 }
997
998 Ok(result)
999 }
1000}
1001
1002pub struct BlockSparsePostingIterator<'a> {
1004 posting_list: &'a BlockSparsePostingList,
1005 current_block: usize,
1006 block_postings: Vec<(DocId, f32)>,
1007 position_in_block: usize,
1008 exhausted: bool,
1009}
1010
1011impl<'a> BlockSparsePostingIterator<'a> {
1012 fn new(posting_list: &'a BlockSparsePostingList) -> Self {
1013 let exhausted = posting_list.skip_list.is_empty();
1014 let mut iter = Self {
1015 posting_list,
1016 current_block: 0,
1017 block_postings: Vec::new(),
1018 position_in_block: 0,
1019 exhausted,
1020 };
1021
1022 if !iter.exhausted {
1023 iter.load_block(0);
1024 }
1025
1026 iter
1027 }
1028
1029 fn load_block(&mut self, block_idx: usize) {
1030 let entry = match self.posting_list.skip_list.get(block_idx) {
1031 Some(e) => e,
1032 None => {
1033 self.exhausted = true;
1034 return;
1035 }
1036 };
1037
1038 self.current_block = block_idx;
1039 self.position_in_block = 0;
1040 self.block_postings.clear();
1041
1042 let offset = entry.offset as usize;
1043 let first_doc_id = entry.first_doc;
1044 let data = &self.posting_list.data[offset..];
1045
1046 if data.len() < 3 {
1048 self.exhausted = true;
1049 return;
1050 }
1051 let count = u16::from_le_bytes([data[0], data[1]]) as usize;
1052 let doc_bit_width = RoundedBitWidth::from_u8(data[2]).unwrap_or(RoundedBitWidth::Zero);
1053
1054 let doc_bytes = doc_bit_width.bytes_per_value() * count.saturating_sub(1);
1056 let doc_data = &data[3..3 + doc_bytes];
1057 let mut doc_ids = vec![0u32; count];
1058 unpack_deltas_fixed(doc_data, doc_bit_width, first_doc_id, count, &mut doc_ids);
1059
1060 let weight_offset = 3 + doc_bytes;
1062 let weight_data = &data[weight_offset..];
1063 let pl = self.posting_list;
1064
1065 let weights: Vec<f32> = match pl.quantization {
1067 WeightQuantization::Float32 => {
1068 let mut weights = Vec::with_capacity(count);
1069 let mut reader = weight_data;
1070 for _ in 0..count {
1071 if reader.len() >= 4 {
1072 weights.push((&mut reader).read_f32::<LittleEndian>().unwrap_or(0.0));
1073 } else {
1074 weights.push(0.0);
1075 }
1076 }
1077 weights
1078 }
1079 WeightQuantization::Float16 => {
1080 use half::slice::HalfFloatSliceExt;
1082 let mut f16_slice: Vec<half::f16> = Vec::with_capacity(count);
1083 for i in 0..count {
1084 let offset = i * 2;
1085 if offset + 2 <= weight_data.len() {
1086 let bits =
1087 u16::from_le_bytes([weight_data[offset], weight_data[offset + 1]]);
1088 f16_slice.push(half::f16::from_bits(bits));
1089 } else {
1090 f16_slice.push(half::f16::ZERO);
1091 }
1092 }
1093 let mut weights = vec![0.0f32; count];
1094 f16_slice.convert_to_f32_slice(&mut weights);
1095 weights
1096 }
1097 WeightQuantization::UInt8 => {
1098 let mut weights = Vec::with_capacity(count);
1099 for i in 0..count {
1100 if i < weight_data.len() {
1101 weights.push(weight_data[i] as f32 * pl.scale + pl.min_val);
1102 } else {
1103 weights.push(0.0);
1104 }
1105 }
1106 weights
1107 }
1108 WeightQuantization::UInt4 => {
1109 let mut weights = Vec::with_capacity(count);
1110 for i in 0..count {
1111 let byte_idx = i / 2;
1112 if byte_idx < weight_data.len() {
1113 let byte = weight_data[byte_idx];
1114 let quantized = if i % 2 == 0 {
1115 byte & 0x0F
1116 } else {
1117 (byte >> 4) & 0x0F
1118 };
1119 weights.push(quantized as f32 * pl.scale + pl.min_val);
1120 } else {
1121 weights.push(0.0);
1122 }
1123 }
1124 weights
1125 }
1126 };
1127
1128 for (doc_id, weight) in doc_ids.into_iter().zip(weights.into_iter()) {
1130 self.block_postings.push((doc_id, weight));
1131 }
1132 }
1133
1134 #[inline]
1136 pub fn is_exhausted(&self) -> bool {
1137 self.exhausted
1138 }
1139
1140 pub fn doc(&self) -> DocId {
1142 if self.exhausted {
1143 super::TERMINATED
1144 } else if self.position_in_block < self.block_postings.len() {
1145 self.block_postings[self.position_in_block].0
1146 } else {
1147 super::TERMINATED
1148 }
1149 }
1150
1151 pub fn weight(&self) -> f32 {
1153 if self.exhausted || self.position_in_block >= self.block_postings.len() {
1154 0.0
1155 } else {
1156 self.block_postings[self.position_in_block].1
1157 }
1158 }
1159
1160 #[inline]
1165 pub fn current_block_max_weight(&self) -> f32 {
1166 if self.exhausted {
1167 0.0
1168 } else {
1169 self.posting_list
1170 .skip_list
1171 .get(self.current_block)
1172 .map(|e| e.max_weight)
1173 .unwrap_or(0.0)
1174 }
1175 }
1176
1177 #[inline]
1181 pub fn current_block_max_contribution(&self, query_weight: f32) -> f32 {
1182 query_weight * self.current_block_max_weight()
1183 }
1184
1185 pub fn advance(&mut self) -> DocId {
1187 if self.exhausted {
1188 return super::TERMINATED;
1189 }
1190
1191 self.position_in_block += 1;
1192 if self.position_in_block >= self.block_postings.len() {
1193 self.load_block(self.current_block + 1);
1194 }
1195
1196 self.doc()
1197 }
1198
1199 pub fn seek(&mut self, target: DocId) -> DocId {
1201 if self.exhausted {
1202 return super::TERMINATED;
1203 }
1204
1205 if let Some(block_idx) = self.posting_list.skip_list.find_block(target) {
1207 if block_idx != self.current_block {
1208 self.load_block(block_idx);
1209 }
1210
1211 while self.position_in_block < self.block_postings.len() {
1213 if self.block_postings[self.position_in_block].0 >= target {
1214 return self.doc();
1215 }
1216 self.position_in_block += 1;
1217 }
1218
1219 self.load_block(self.current_block + 1);
1221 self.seek(target)
1222 } else {
1223 self.exhausted = true;
1224 super::TERMINATED
1225 }
1226 }
1227}
1228
1229#[cfg(test)]
1230mod tests {
1231 use super::*;
1232
1233 #[test]
1234 fn test_sparse_vector_dot_product() {
1235 let v1 = SparseVector::from_entries(&[0, 2, 5], &[1.0, 2.0, 3.0]);
1236 let v2 = SparseVector::from_entries(&[1, 2, 5], &[1.0, 4.0, 2.0]);
1237
1238 assert!((v1.dot(&v2) - 14.0).abs() < 1e-6);
1240 }
1241
1242 #[test]
1243 fn test_sparse_posting_list_float32() {
1244 let postings = vec![(0, 1.5), (5, 2.3), (10, 0.8), (100, 3.15)];
1245 let pl = SparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1246
1247 assert_eq!(pl.doc_count(), 4);
1248
1249 let mut iter = pl.iterator();
1250 assert_eq!(iter.doc(), 0);
1251 assert!((iter.weight() - 1.5).abs() < 1e-6);
1252
1253 iter.advance();
1254 assert_eq!(iter.doc(), 5);
1255 assert!((iter.weight() - 2.3).abs() < 1e-6);
1256
1257 iter.advance();
1258 assert_eq!(iter.doc(), 10);
1259
1260 iter.advance();
1261 assert_eq!(iter.doc(), 100);
1262 assert!((iter.weight() - 3.15).abs() < 1e-6);
1263
1264 iter.advance();
1265 assert_eq!(iter.doc(), super::super::TERMINATED);
1266 }
1267
1268 #[test]
1269 fn test_sparse_posting_list_uint8() {
1270 let postings = vec![(0, 0.0), (5, 0.5), (10, 1.0)];
1271 let pl = SparsePostingList::from_postings(&postings, WeightQuantization::UInt8).unwrap();
1272
1273 let decoded = pl.decode_all().unwrap();
1274 assert_eq!(decoded.len(), 3);
1275
1276 assert!(decoded[0].1 < decoded[1].1);
1278 assert!(decoded[1].1 < decoded[2].1);
1279 }
1280
1281 #[test]
1282 fn test_block_sparse_posting_list() {
1283 let postings: Vec<(DocId, f32)> = (0..300).map(|i| (i * 2, (i as f32) * 0.1)).collect();
1285
1286 let pl =
1287 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1288
1289 assert_eq!(pl.doc_count(), 300);
1290 assert!(pl.num_blocks() >= 2);
1291
1292 let mut iter = pl.iterator();
1294 for (expected_doc, expected_weight) in &postings {
1295 assert_eq!(iter.doc(), *expected_doc);
1296 assert!((iter.weight() - expected_weight).abs() < 1e-6);
1297 iter.advance();
1298 }
1299 assert_eq!(iter.doc(), super::super::TERMINATED);
1300 }
1301
1302 #[test]
1303 fn test_block_sparse_seek() {
1304 let postings: Vec<(DocId, f32)> = (0..500).map(|i| (i * 3, i as f32)).collect();
1305
1306 let pl =
1307 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1308
1309 let mut iter = pl.iterator();
1310
1311 assert_eq!(iter.seek(300), 300);
1313
1314 assert_eq!(iter.seek(301), 303);
1316
1317 assert_eq!(iter.seek(2000), super::super::TERMINATED);
1319 }
1320
1321 #[test]
1322 fn test_serialization_roundtrip() {
1323 let postings: Vec<(DocId, f32)> = vec![(0, 1.0), (10, 2.0), (100, 3.0)];
1324
1325 for quant in [
1326 WeightQuantization::Float32,
1327 WeightQuantization::Float16,
1328 WeightQuantization::UInt8,
1329 ] {
1330 let pl = BlockSparsePostingList::from_postings(&postings, quant).unwrap();
1331
1332 let mut buffer = Vec::new();
1333 pl.serialize(&mut buffer).unwrap();
1334
1335 let pl2 = BlockSparsePostingList::deserialize(&mut &buffer[..]).unwrap();
1336
1337 assert_eq!(pl.doc_count(), pl2.doc_count());
1338 assert_eq!(pl.quantization(), pl2.quantization());
1339
1340 let mut iter1 = pl.iterator();
1342 let mut iter2 = pl2.iterator();
1343
1344 while iter1.doc() != super::super::TERMINATED {
1345 assert_eq!(iter1.doc(), iter2.doc());
1346 assert!((iter1.weight() - iter2.weight()).abs() < 0.1);
1348 iter1.advance();
1349 iter2.advance();
1350 }
1351 }
1352 }
1353
1354 #[test]
1355 fn test_concatenate() {
1356 let postings1: Vec<(DocId, f32)> = vec![(0, 1.0), (5, 2.0)];
1357 let postings2: Vec<(DocId, f32)> = vec![(0, 3.0), (10, 4.0)];
1358
1359 let pl1 =
1360 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1361 let pl2 =
1362 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1363
1364 let merged = BlockSparsePostingList::concatenate(
1366 &[(pl1, 0), (pl2, 100)],
1367 WeightQuantization::Float32,
1368 )
1369 .unwrap();
1370
1371 assert_eq!(merged.doc_count(), 4);
1372
1373 let decoded = merged.decode_all().unwrap();
1374 assert_eq!(decoded[0], (0, 1.0));
1375 assert_eq!(decoded[1], (5, 2.0));
1376 assert_eq!(decoded[2], (100, 3.0)); assert_eq!(decoded[3], (110, 4.0)); }
1379
1380 #[test]
1381 fn test_sparse_vector_config() {
1382 let default = SparseVectorConfig::default();
1384 assert_eq!(default.index_size, IndexSize::U32);
1385 assert_eq!(default.weight_quantization, WeightQuantization::Float32);
1386 assert_eq!(default.bytes_per_entry(), 8.0); let splade = SparseVectorConfig::splade();
1390 assert_eq!(splade.index_size, IndexSize::U16);
1391 assert_eq!(splade.weight_quantization, WeightQuantization::UInt8);
1392 assert_eq!(splade.bytes_per_entry(), 3.0); let compact = SparseVectorConfig::compact();
1396 assert_eq!(compact.index_size, IndexSize::U16);
1397 assert_eq!(compact.weight_quantization, WeightQuantization::UInt4);
1398 assert_eq!(compact.bytes_per_entry(), 2.5); let byte = splade.to_byte();
1402 let restored = SparseVectorConfig::from_byte(byte).unwrap();
1403 assert_eq!(restored, splade);
1404 }
1405
1406 #[test]
1407 fn test_index_size() {
1408 assert_eq!(IndexSize::U16.bytes(), 2);
1409 assert_eq!(IndexSize::U32.bytes(), 4);
1410 assert_eq!(IndexSize::U16.max_value(), 65535);
1411 assert_eq!(IndexSize::U32.max_value(), u32::MAX);
1412 }
1413
1414 #[test]
1415 fn test_block_max_weight() {
1416 let postings: Vec<(DocId, f32)> =
1421 (0..300).map(|i| (i as DocId, (i as f32) * 0.1)).collect();
1422
1423 let pl =
1424 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1425
1426 assert!((pl.global_max_weight() - 29.9).abs() < 0.01);
1428
1429 assert!(pl.num_blocks() >= 3);
1431
1432 let block0_max = pl.block_max_weight(0).unwrap();
1434 assert!((block0_max - 12.7).abs() < 0.01);
1435
1436 let block1_max = pl.block_max_weight(1).unwrap();
1438 assert!((block1_max - 25.5).abs() < 0.01);
1439
1440 let block2_max = pl.block_max_weight(2).unwrap();
1442 assert!((block2_max - 29.9).abs() < 0.01);
1443
1444 let query_weight = 2.0;
1446 assert!((pl.max_contribution(query_weight) - 59.8).abs() < 0.1);
1447
1448 let mut iter = pl.iterator();
1450 assert!((iter.current_block_max_weight() - 12.7).abs() < 0.01);
1451 assert!((iter.current_block_max_contribution(query_weight) - 25.4).abs() < 0.1);
1452
1453 iter.seek(128);
1455 assert!((iter.current_block_max_weight() - 25.5).abs() < 0.01);
1456 }
1457
1458 #[test]
1459 fn test_sparse_skip_list_serialization() {
1460 let mut skip_list = SparseSkipList::new();
1461 skip_list.push(0, 127, 0, 12.7);
1462 skip_list.push(128, 255, 100, 25.5);
1463 skip_list.push(256, 299, 200, 29.9);
1464
1465 assert_eq!(skip_list.len(), 3);
1466 assert!((skip_list.global_max_weight() - 29.9).abs() < 0.01);
1467
1468 let mut buffer = Vec::new();
1470 skip_list.write(&mut buffer).unwrap();
1471
1472 let restored = SparseSkipList::read(&mut buffer.as_slice()).unwrap();
1474
1475 assert_eq!(restored.len(), 3);
1476 assert!((restored.global_max_weight() - 29.9).abs() < 0.01);
1477
1478 let e0 = restored.get(0).unwrap();
1480 assert_eq!(e0.first_doc, 0);
1481 assert_eq!(e0.last_doc, 127);
1482 assert!((e0.max_weight - 12.7).abs() < 0.01);
1483
1484 let e1 = restored.get(1).unwrap();
1485 assert_eq!(e1.first_doc, 128);
1486 assert!((e1.max_weight - 25.5).abs() < 0.01);
1487 }
1488}