1use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
20use serde::{Deserialize, Serialize};
21use std::io::{self, Read, Write};
22
23use super::posting_common::{
24 RoundedBitWidth, pack_deltas_fixed, read_vint, unpack_deltas_fixed, write_vint,
25};
26use crate::DocId;
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
30#[repr(u8)]
31pub enum IndexSize {
32 U16 = 0,
34 #[default]
36 U32 = 1,
37}
38
39impl IndexSize {
40 pub fn bytes(&self) -> usize {
42 match self {
43 IndexSize::U16 => 2,
44 IndexSize::U32 => 4,
45 }
46 }
47
48 pub fn max_value(&self) -> u32 {
50 match self {
51 IndexSize::U16 => u16::MAX as u32,
52 IndexSize::U32 => u32::MAX,
53 }
54 }
55
56 fn from_u8(v: u8) -> Option<Self> {
57 match v {
58 0 => Some(IndexSize::U16),
59 1 => Some(IndexSize::U32),
60 _ => None,
61 }
62 }
63}
64
65#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
67#[repr(u8)]
68pub enum WeightQuantization {
69 #[default]
71 Float32 = 0,
72 Float16 = 1,
74 UInt8 = 2,
76 UInt4 = 3,
78}
79
80impl WeightQuantization {
81 pub fn bytes_per_weight(&self) -> f32 {
83 match self {
84 WeightQuantization::Float32 => 4.0,
85 WeightQuantization::Float16 => 2.0,
86 WeightQuantization::UInt8 => 1.0,
87 WeightQuantization::UInt4 => 0.5,
88 }
89 }
90
91 fn from_u8(v: u8) -> Option<Self> {
92 match v {
93 0 => Some(WeightQuantization::Float32),
94 1 => Some(WeightQuantization::Float16),
95 2 => Some(WeightQuantization::UInt8),
96 3 => Some(WeightQuantization::UInt4),
97 _ => None,
98 }
99 }
100}
101
102#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
104pub struct SparseVectorConfig {
105 pub index_size: IndexSize,
107 pub weight_quantization: WeightQuantization,
109}
110
111impl Default for SparseVectorConfig {
112 fn default() -> Self {
113 Self {
114 index_size: IndexSize::U32,
115 weight_quantization: WeightQuantization::Float32,
116 }
117 }
118}
119
120impl SparseVectorConfig {
121 pub fn splade() -> Self {
123 Self {
124 index_size: IndexSize::U16,
125 weight_quantization: WeightQuantization::UInt8,
126 }
127 }
128
129 pub fn compact() -> Self {
131 Self {
132 index_size: IndexSize::U16,
133 weight_quantization: WeightQuantization::UInt4,
134 }
135 }
136
137 pub fn full_precision() -> Self {
139 Self {
140 index_size: IndexSize::U32,
141 weight_quantization: WeightQuantization::Float32,
142 }
143 }
144
145 pub fn bytes_per_entry(&self) -> f32 {
147 self.index_size.bytes() as f32 + self.weight_quantization.bytes_per_weight()
148 }
149
150 pub fn to_byte(&self) -> u8 {
152 ((self.index_size as u8) << 4) | (self.weight_quantization as u8)
153 }
154
155 pub fn from_byte(b: u8) -> Option<Self> {
157 let index_size = IndexSize::from_u8(b >> 4)?;
158 let weight_quantization = WeightQuantization::from_u8(b & 0x0F)?;
159 Some(Self {
160 index_size,
161 weight_quantization,
162 })
163 }
164}
165
166#[derive(Debug, Clone, Copy, PartialEq)]
168pub struct SparseEntry {
169 pub dim_id: u32,
170 pub weight: f32,
171}
172
173#[derive(Debug, Clone, Default)]
175pub struct SparseVector {
176 entries: Vec<SparseEntry>,
177}
178
179impl SparseVector {
180 pub fn new() -> Self {
181 Self::default()
182 }
183
184 pub fn with_capacity(capacity: usize) -> Self {
185 Self {
186 entries: Vec::with_capacity(capacity),
187 }
188 }
189
190 pub fn from_entries(dim_ids: &[u32], weights: &[f32]) -> Self {
192 assert_eq!(dim_ids.len(), weights.len());
193 let mut entries: Vec<SparseEntry> = dim_ids
194 .iter()
195 .zip(weights.iter())
196 .map(|(&dim_id, &weight)| SparseEntry { dim_id, weight })
197 .collect();
198 entries.sort_by_key(|e| e.dim_id);
200 Self { entries }
201 }
202
203 pub fn push(&mut self, dim_id: u32, weight: f32) {
205 debug_assert!(
206 self.entries.is_empty() || self.entries.last().unwrap().dim_id < dim_id,
207 "Entries must be added in sorted order by dim_id"
208 );
209 self.entries.push(SparseEntry { dim_id, weight });
210 }
211
212 pub fn len(&self) -> usize {
214 self.entries.len()
215 }
216
217 pub fn is_empty(&self) -> bool {
218 self.entries.is_empty()
219 }
220
221 pub fn iter(&self) -> impl Iterator<Item = &SparseEntry> {
222 self.entries.iter()
223 }
224
225 pub fn dot(&self, other: &SparseVector) -> f32 {
227 let mut result = 0.0f32;
228 let mut i = 0;
229 let mut j = 0;
230
231 while i < self.entries.len() && j < other.entries.len() {
232 let a = &self.entries[i];
233 let b = &other.entries[j];
234
235 match a.dim_id.cmp(&b.dim_id) {
236 std::cmp::Ordering::Less => i += 1,
237 std::cmp::Ordering::Greater => j += 1,
238 std::cmp::Ordering::Equal => {
239 result += a.weight * b.weight;
240 i += 1;
241 j += 1;
242 }
243 }
244 }
245
246 result
247 }
248
249 pub fn norm_squared(&self) -> f32 {
251 self.entries.iter().map(|e| e.weight * e.weight).sum()
252 }
253
254 pub fn norm(&self) -> f32 {
256 self.norm_squared().sqrt()
257 }
258}
259
260#[derive(Debug, Clone, Copy)]
262pub struct SparsePosting {
263 pub doc_id: DocId,
264 pub weight: f32,
265}
266
267pub const SPARSE_BLOCK_SIZE: usize = 128;
269
270#[derive(Debug, Clone, Copy, PartialEq)]
274pub struct SparseSkipEntry {
275 pub first_doc: DocId,
277 pub last_doc: DocId,
279 pub offset: u32,
281 pub max_weight: f32,
283}
284
285impl SparseSkipEntry {
286 pub fn new(first_doc: DocId, last_doc: DocId, offset: u32, max_weight: f32) -> Self {
287 Self {
288 first_doc,
289 last_doc,
290 offset,
291 max_weight,
292 }
293 }
294
295 #[inline]
300 pub fn block_max_contribution(&self, query_weight: f32) -> f32 {
301 query_weight * self.max_weight
302 }
303
304 pub fn write<W: Write>(&self, writer: &mut W) -> io::Result<()> {
306 writer.write_u32::<LittleEndian>(self.first_doc)?;
307 writer.write_u32::<LittleEndian>(self.last_doc)?;
308 writer.write_u32::<LittleEndian>(self.offset)?;
309 writer.write_f32::<LittleEndian>(self.max_weight)?;
310 Ok(())
311 }
312
313 pub fn read<R: Read>(reader: &mut R) -> io::Result<Self> {
315 let first_doc = reader.read_u32::<LittleEndian>()?;
316 let last_doc = reader.read_u32::<LittleEndian>()?;
317 let offset = reader.read_u32::<LittleEndian>()?;
318 let max_weight = reader.read_f32::<LittleEndian>()?;
319 Ok(Self {
320 first_doc,
321 last_doc,
322 offset,
323 max_weight,
324 })
325 }
326}
327
328#[derive(Debug, Clone, Default)]
330pub struct SparseSkipList {
331 entries: Vec<SparseSkipEntry>,
332 global_max_weight: f32,
334}
335
336impl SparseSkipList {
337 pub fn new() -> Self {
338 Self::default()
339 }
340
341 pub fn push(&mut self, first_doc: DocId, last_doc: DocId, offset: u32, max_weight: f32) {
343 self.global_max_weight = self.global_max_weight.max(max_weight);
344 self.entries.push(SparseSkipEntry::new(
345 first_doc, last_doc, offset, max_weight,
346 ));
347 }
348
349 pub fn len(&self) -> usize {
351 self.entries.len()
352 }
353
354 pub fn is_empty(&self) -> bool {
355 self.entries.is_empty()
356 }
357
358 pub fn get(&self, index: usize) -> Option<&SparseSkipEntry> {
360 self.entries.get(index)
361 }
362
363 pub fn global_max_weight(&self) -> f32 {
365 self.global_max_weight
366 }
367
368 pub fn find_block(&self, target: DocId) -> Option<usize> {
370 self.entries.iter().position(|e| e.last_doc >= target)
371 }
372
373 pub fn iter(&self) -> impl Iterator<Item = &SparseSkipEntry> {
375 self.entries.iter()
376 }
377
378 pub fn write<W: Write>(&self, writer: &mut W) -> io::Result<()> {
380 writer.write_u32::<LittleEndian>(self.entries.len() as u32)?;
381 writer.write_f32::<LittleEndian>(self.global_max_weight)?;
382 for entry in &self.entries {
383 entry.write(writer)?;
384 }
385 Ok(())
386 }
387
388 pub fn read<R: Read>(reader: &mut R) -> io::Result<Self> {
390 let count = reader.read_u32::<LittleEndian>()? as usize;
391 let global_max_weight = reader.read_f32::<LittleEndian>()?;
392 let mut entries = Vec::with_capacity(count);
393 for _ in 0..count {
394 entries.push(SparseSkipEntry::read(reader)?);
395 }
396 Ok(Self {
397 entries,
398 global_max_weight,
399 })
400 }
401}
402
403#[derive(Debug, Clone)]
409pub struct SparsePostingList {
410 quantization: WeightQuantization,
412 scale: f32,
414 min_val: f32,
416 doc_count: u32,
418 data: Vec<u8>,
420}
421
422impl SparsePostingList {
423 pub fn from_postings(
425 postings: &[(DocId, f32)],
426 quantization: WeightQuantization,
427 ) -> io::Result<Self> {
428 if postings.is_empty() {
429 return Ok(Self {
430 quantization,
431 scale: 1.0,
432 min_val: 0.0,
433 doc_count: 0,
434 data: Vec::new(),
435 });
436 }
437
438 let weights: Vec<f32> = postings.iter().map(|(_, w)| *w).collect();
440 let min_val = weights.iter().cloned().fold(f32::INFINITY, f32::min);
441 let max_val = weights.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
442
443 let (scale, adjusted_min) = match quantization {
444 WeightQuantization::Float32 | WeightQuantization::Float16 => (1.0, 0.0),
445 WeightQuantization::UInt8 => {
446 let range = max_val - min_val;
447 if range < f32::EPSILON {
448 (1.0, min_val)
449 } else {
450 (range / 255.0, min_val)
451 }
452 }
453 WeightQuantization::UInt4 => {
454 let range = max_val - min_val;
455 if range < f32::EPSILON {
456 (1.0, min_val)
457 } else {
458 (range / 15.0, min_val)
459 }
460 }
461 };
462
463 let mut data = Vec::new();
464
465 let mut prev_doc_id = 0u32;
467 for (doc_id, _) in postings {
468 let delta = doc_id - prev_doc_id;
469 write_vint(&mut data, delta as u64)?;
470 prev_doc_id = *doc_id;
471 }
472
473 match quantization {
475 WeightQuantization::Float32 => {
476 for (_, weight) in postings {
477 data.write_f32::<LittleEndian>(*weight)?;
478 }
479 }
480 WeightQuantization::Float16 => {
481 use half::slice::HalfFloatSliceExt;
483 let weights: Vec<f32> = postings.iter().map(|(_, w)| *w).collect();
484 let mut f16_slice: Vec<half::f16> = vec![half::f16::ZERO; weights.len()];
485 f16_slice.convert_from_f32_slice(&weights);
486 for h in f16_slice {
487 data.write_u16::<LittleEndian>(h.to_bits())?;
488 }
489 }
490 WeightQuantization::UInt8 => {
491 for (_, weight) in postings {
492 let quantized = ((*weight - adjusted_min) / scale).round() as u8;
493 data.write_u8(quantized)?;
494 }
495 }
496 WeightQuantization::UInt4 => {
497 let mut i = 0;
499 while i < postings.len() {
500 let q1 = ((postings[i].1 - adjusted_min) / scale).round() as u8 & 0x0F;
501 let q2 = if i + 1 < postings.len() {
502 ((postings[i + 1].1 - adjusted_min) / scale).round() as u8 & 0x0F
503 } else {
504 0
505 };
506 data.write_u8((q2 << 4) | q1)?;
507 i += 2;
508 }
509 }
510 }
511
512 Ok(Self {
513 quantization,
514 scale,
515 min_val: adjusted_min,
516 doc_count: postings.len() as u32,
517 data,
518 })
519 }
520
521 pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
523 writer.write_u8(self.quantization as u8)?;
524 writer.write_f32::<LittleEndian>(self.scale)?;
525 writer.write_f32::<LittleEndian>(self.min_val)?;
526 writer.write_u32::<LittleEndian>(self.doc_count)?;
527 writer.write_u32::<LittleEndian>(self.data.len() as u32)?;
528 writer.write_all(&self.data)?;
529 Ok(())
530 }
531
532 pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
534 let quant_byte = reader.read_u8()?;
535 let quantization = WeightQuantization::from_u8(quant_byte).ok_or_else(|| {
536 io::Error::new(io::ErrorKind::InvalidData, "Invalid quantization type")
537 })?;
538 let scale = reader.read_f32::<LittleEndian>()?;
539 let min_val = reader.read_f32::<LittleEndian>()?;
540 let doc_count = reader.read_u32::<LittleEndian>()?;
541 let data_len = reader.read_u32::<LittleEndian>()? as usize;
542 let mut data = vec![0u8; data_len];
543 reader.read_exact(&mut data)?;
544
545 Ok(Self {
546 quantization,
547 scale,
548 min_val,
549 doc_count,
550 data,
551 })
552 }
553
554 pub fn doc_count(&self) -> u32 {
556 self.doc_count
557 }
558
559 pub fn quantization(&self) -> WeightQuantization {
561 self.quantization
562 }
563
564 pub fn iterator(&self) -> SparsePostingIterator<'_> {
566 SparsePostingIterator::new(self)
567 }
568
569 pub fn decode_all(&self) -> io::Result<Vec<(DocId, f32)>> {
571 let mut result = Vec::with_capacity(self.doc_count as usize);
572 let mut iter = self.iterator();
573
574 while !iter.exhausted {
575 result.push((iter.doc_id, iter.weight));
576 iter.advance();
577 }
578
579 Ok(result)
580 }
581}
582
583pub struct SparsePostingIterator<'a> {
585 posting_list: &'a SparsePostingList,
586 doc_id_offset: usize,
588 weight_offset: usize,
590 index: usize,
592 doc_id: DocId,
594 weight: f32,
596 exhausted: bool,
598}
599
600impl<'a> SparsePostingIterator<'a> {
601 fn new(posting_list: &'a SparsePostingList) -> Self {
602 let mut iter = Self {
603 posting_list,
604 doc_id_offset: 0,
605 weight_offset: 0,
606 index: 0,
607 doc_id: 0,
608 weight: 0.0,
609 exhausted: posting_list.doc_count == 0,
610 };
611
612 if !iter.exhausted {
613 iter.weight_offset = iter.calculate_weight_offset();
615 iter.load_current();
616 }
617
618 iter
619 }
620
621 fn calculate_weight_offset(&self) -> usize {
622 let mut offset = 0;
624 let mut reader = &self.posting_list.data[..];
625
626 for _ in 0..self.posting_list.doc_count {
627 if read_vint(&mut reader).is_ok() {
628 offset = self.posting_list.data.len() - reader.len();
629 }
630 }
631
632 offset
633 }
634
635 fn load_current(&mut self) {
636 if self.index >= self.posting_list.doc_count as usize {
637 self.exhausted = true;
638 return;
639 }
640
641 let mut reader = &self.posting_list.data[self.doc_id_offset..];
643 if let Ok(delta) = read_vint(&mut reader) {
644 self.doc_id = self.doc_id.wrapping_add(delta as u32);
645 self.doc_id_offset = self.posting_list.data.len() - reader.len();
646 }
647
648 let weight_idx = self.index;
650 let pl = self.posting_list;
651
652 self.weight = match pl.quantization {
653 WeightQuantization::Float32 => {
654 let offset = self.weight_offset + weight_idx * 4;
655 if offset + 4 <= pl.data.len() {
656 let bytes = &pl.data[offset..offset + 4];
657 f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]])
658 } else {
659 0.0
660 }
661 }
662 WeightQuantization::Float16 => {
663 let offset = self.weight_offset + weight_idx * 2;
664 if offset + 2 <= pl.data.len() {
665 let bits = u16::from_le_bytes([pl.data[offset], pl.data[offset + 1]]);
666 half::f16::from_bits(bits).to_f32()
667 } else {
668 0.0
669 }
670 }
671 WeightQuantization::UInt8 => {
672 let offset = self.weight_offset + weight_idx;
673 if offset < pl.data.len() {
674 let quantized = pl.data[offset];
675 quantized as f32 * pl.scale + pl.min_val
676 } else {
677 0.0
678 }
679 }
680 WeightQuantization::UInt4 => {
681 let byte_offset = self.weight_offset + weight_idx / 2;
682 if byte_offset < pl.data.len() {
683 let byte = pl.data[byte_offset];
684 let quantized = if weight_idx.is_multiple_of(2) {
685 byte & 0x0F
686 } else {
687 (byte >> 4) & 0x0F
688 };
689 quantized as f32 * pl.scale + pl.min_val
690 } else {
691 0.0
692 }
693 }
694 };
695 }
696
697 pub fn doc(&self) -> DocId {
699 if self.exhausted {
700 super::TERMINATED
701 } else {
702 self.doc_id
703 }
704 }
705
706 pub fn weight(&self) -> f32 {
708 if self.exhausted { 0.0 } else { self.weight }
709 }
710
711 pub fn advance(&mut self) -> DocId {
713 if self.exhausted {
714 return super::TERMINATED;
715 }
716
717 self.index += 1;
718 if self.index >= self.posting_list.doc_count as usize {
719 self.exhausted = true;
720 return super::TERMINATED;
721 }
722
723 self.load_current();
724 self.doc_id
725 }
726
727 pub fn seek(&mut self, target: DocId) -> DocId {
729 while !self.exhausted && self.doc_id < target {
730 self.advance();
731 }
732 self.doc()
733 }
734}
735
736#[derive(Debug, Clone)]
741pub struct BlockSparsePostingList {
742 quantization: WeightQuantization,
744 scale: f32,
746 min_val: f32,
748 skip_list: SparseSkipList,
750 data: Vec<u8>,
752 doc_count: u32,
754}
755
756impl BlockSparsePostingList {
757 pub fn from_postings(
759 postings: &[(DocId, f32)],
760 quantization: WeightQuantization,
761 ) -> io::Result<Self> {
762 if postings.is_empty() {
763 return Ok(Self {
764 quantization,
765 scale: 1.0,
766 min_val: 0.0,
767 skip_list: SparseSkipList::new(),
768 data: Vec::new(),
769 doc_count: 0,
770 });
771 }
772
773 let weights: Vec<f32> = postings.iter().map(|(_, w)| *w).collect();
775 let min_val = weights.iter().cloned().fold(f32::INFINITY, f32::min);
776 let max_val = weights.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
777
778 let (scale, adjusted_min) = match quantization {
779 WeightQuantization::Float32 | WeightQuantization::Float16 => (1.0, 0.0),
780 WeightQuantization::UInt8 => {
781 let range = max_val - min_val;
782 if range < f32::EPSILON {
783 (1.0, min_val)
784 } else {
785 (range / 255.0, min_val)
786 }
787 }
788 WeightQuantization::UInt4 => {
789 let range = max_val - min_val;
790 if range < f32::EPSILON {
791 (1.0, min_val)
792 } else {
793 (range / 15.0, min_val)
794 }
795 }
796 };
797
798 let mut skip_list = SparseSkipList::new();
799 let mut data = Vec::new();
800
801 let mut i = 0;
802 while i < postings.len() {
803 let block_end = (i + SPARSE_BLOCK_SIZE).min(postings.len());
804 let block = &postings[i..block_end];
805
806 let first_doc_id = block.first().unwrap().0;
807 let last_doc_id = block.last().unwrap().0;
808
809 let block_max_weight = block
811 .iter()
812 .map(|(_, w)| *w)
813 .fold(f32::NEG_INFINITY, f32::max);
814
815 let block_doc_ids: Vec<DocId> = block.iter().map(|(d, _)| *d).collect();
817 let (doc_bit_width, packed_doc_ids) = pack_deltas_fixed(&block_doc_ids);
818
819 let block_start = data.len() as u32;
821 skip_list.push(first_doc_id, last_doc_id, block_start, block_max_weight);
822
823 data.write_u16::<LittleEndian>(block.len() as u16)?;
824 data.write_u8(doc_bit_width as u8)?;
825 data.extend_from_slice(&packed_doc_ids);
826
827 match quantization {
829 WeightQuantization::Float32 => {
830 for (_, weight) in block {
831 data.write_f32::<LittleEndian>(*weight)?;
832 }
833 }
834 WeightQuantization::Float16 => {
835 use half::slice::HalfFloatSliceExt;
837 let weights: Vec<f32> = block.iter().map(|(_, w)| *w).collect();
838 let mut f16_slice: Vec<half::f16> = vec![half::f16::ZERO; weights.len()];
839 f16_slice.convert_from_f32_slice(&weights);
840 for h in f16_slice {
841 data.write_u16::<LittleEndian>(h.to_bits())?;
842 }
843 }
844 WeightQuantization::UInt8 => {
845 for (_, weight) in block {
846 let quantized = ((*weight - adjusted_min) / scale).round() as u8;
847 data.write_u8(quantized)?;
848 }
849 }
850 WeightQuantization::UInt4 => {
851 let mut j = 0;
852 while j < block.len() {
853 let q1 = ((block[j].1 - adjusted_min) / scale).round() as u8 & 0x0F;
854 let q2 = if j + 1 < block.len() {
855 ((block[j + 1].1 - adjusted_min) / scale).round() as u8 & 0x0F
856 } else {
857 0
858 };
859 data.write_u8((q2 << 4) | q1)?;
860 j += 2;
861 }
862 }
863 }
864
865 i = block_end;
866 }
867
868 Ok(Self {
869 quantization,
870 scale,
871 min_val: adjusted_min,
872 skip_list,
873 data,
874 doc_count: postings.len() as u32,
875 })
876 }
877
878 pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
880 writer.write_u8(self.quantization as u8)?;
881 writer.write_f32::<LittleEndian>(self.scale)?;
882 writer.write_f32::<LittleEndian>(self.min_val)?;
883 writer.write_u32::<LittleEndian>(self.doc_count)?;
884
885 self.skip_list.write(writer)?;
887
888 writer.write_u32::<LittleEndian>(self.data.len() as u32)?;
890 writer.write_all(&self.data)?;
891
892 Ok(())
893 }
894
895 pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
897 let quant_byte = reader.read_u8()?;
898 let quantization = WeightQuantization::from_u8(quant_byte).ok_or_else(|| {
899 io::Error::new(io::ErrorKind::InvalidData, "Invalid quantization type")
900 })?;
901 let scale = reader.read_f32::<LittleEndian>()?;
902 let min_val = reader.read_f32::<LittleEndian>()?;
903 let doc_count = reader.read_u32::<LittleEndian>()?;
904
905 let skip_list = SparseSkipList::read(reader)?;
907
908 let data_len = reader.read_u32::<LittleEndian>()? as usize;
909 let mut data = vec![0u8; data_len];
910 reader.read_exact(&mut data)?;
911
912 Ok(Self {
913 quantization,
914 scale,
915 min_val,
916 skip_list,
917 data,
918 doc_count,
919 })
920 }
921
922 pub fn doc_count(&self) -> u32 {
924 self.doc_count
925 }
926
927 pub fn num_blocks(&self) -> usize {
929 self.skip_list.len()
930 }
931
932 pub fn quantization(&self) -> WeightQuantization {
934 self.quantization
935 }
936
937 pub fn global_max_weight(&self) -> f32 {
939 self.skip_list.global_max_weight()
940 }
941
942 pub fn block_max_weight(&self, block_idx: usize) -> Option<f32> {
944 self.skip_list.get(block_idx).map(|e| e.max_weight)
945 }
946
947 #[inline]
952 pub fn max_contribution(&self, query_weight: f32) -> f32 {
953 query_weight * self.skip_list.global_max_weight()
954 }
955
956 pub fn iterator(&self) -> BlockSparsePostingIterator<'_> {
958 BlockSparsePostingIterator::new(self)
959 }
960
961 pub fn concatenate(
963 sources: &[(BlockSparsePostingList, u32)],
964 target_quantization: WeightQuantization,
965 ) -> io::Result<Self> {
966 let mut all_postings: Vec<(DocId, f32)> = Vec::new();
968
969 for (source, doc_offset) in sources {
970 let decoded = source.decode_all()?;
971 for (doc_id, weight) in decoded {
972 all_postings.push((doc_id + doc_offset, weight));
973 }
974 }
975
976 Self::from_postings(&all_postings, target_quantization)
978 }
979
980 pub fn decode_all(&self) -> io::Result<Vec<(DocId, f32)>> {
982 let mut result = Vec::with_capacity(self.doc_count as usize);
983 let mut iter = self.iterator();
984
985 while iter.doc() != super::TERMINATED {
986 result.push((iter.doc(), iter.weight()));
987 iter.advance();
988 }
989
990 Ok(result)
991 }
992}
993
994pub struct BlockSparsePostingIterator<'a> {
996 posting_list: &'a BlockSparsePostingList,
997 current_block: usize,
998 block_postings: Vec<(DocId, f32)>,
999 position_in_block: usize,
1000 exhausted: bool,
1001}
1002
1003impl<'a> BlockSparsePostingIterator<'a> {
1004 fn new(posting_list: &'a BlockSparsePostingList) -> Self {
1005 let exhausted = posting_list.skip_list.is_empty();
1006 let mut iter = Self {
1007 posting_list,
1008 current_block: 0,
1009 block_postings: Vec::new(),
1010 position_in_block: 0,
1011 exhausted,
1012 };
1013
1014 if !iter.exhausted {
1015 iter.load_block(0);
1016 }
1017
1018 iter
1019 }
1020
1021 fn load_block(&mut self, block_idx: usize) {
1022 let entry = match self.posting_list.skip_list.get(block_idx) {
1023 Some(e) => e,
1024 None => {
1025 self.exhausted = true;
1026 return;
1027 }
1028 };
1029
1030 self.current_block = block_idx;
1031 self.position_in_block = 0;
1032 self.block_postings.clear();
1033
1034 let offset = entry.offset as usize;
1035 let first_doc_id = entry.first_doc;
1036 let data = &self.posting_list.data[offset..];
1037
1038 if data.len() < 3 {
1040 self.exhausted = true;
1041 return;
1042 }
1043 let count = u16::from_le_bytes([data[0], data[1]]) as usize;
1044 let doc_bit_width = RoundedBitWidth::from_u8(data[2]).unwrap_or(RoundedBitWidth::Zero);
1045
1046 let doc_bytes = doc_bit_width.bytes_per_value() * count.saturating_sub(1);
1048 let doc_data = &data[3..3 + doc_bytes];
1049 let mut doc_ids = vec![0u32; count];
1050 unpack_deltas_fixed(doc_data, doc_bit_width, first_doc_id, count, &mut doc_ids);
1051
1052 let weight_offset = 3 + doc_bytes;
1054 let weight_data = &data[weight_offset..];
1055 let pl = self.posting_list;
1056
1057 let weights: Vec<f32> = match pl.quantization {
1059 WeightQuantization::Float32 => {
1060 let mut weights = Vec::with_capacity(count);
1061 let mut reader = weight_data;
1062 for _ in 0..count {
1063 if reader.len() >= 4 {
1064 weights.push((&mut reader).read_f32::<LittleEndian>().unwrap_or(0.0));
1065 } else {
1066 weights.push(0.0);
1067 }
1068 }
1069 weights
1070 }
1071 WeightQuantization::Float16 => {
1072 use half::slice::HalfFloatSliceExt;
1074 let mut f16_slice: Vec<half::f16> = Vec::with_capacity(count);
1075 for i in 0..count {
1076 let offset = i * 2;
1077 if offset + 2 <= weight_data.len() {
1078 let bits =
1079 u16::from_le_bytes([weight_data[offset], weight_data[offset + 1]]);
1080 f16_slice.push(half::f16::from_bits(bits));
1081 } else {
1082 f16_slice.push(half::f16::ZERO);
1083 }
1084 }
1085 let mut weights = vec![0.0f32; count];
1086 f16_slice.convert_to_f32_slice(&mut weights);
1087 weights
1088 }
1089 WeightQuantization::UInt8 => {
1090 let mut weights = Vec::with_capacity(count);
1091 for i in 0..count {
1092 if i < weight_data.len() {
1093 weights.push(weight_data[i] as f32 * pl.scale + pl.min_val);
1094 } else {
1095 weights.push(0.0);
1096 }
1097 }
1098 weights
1099 }
1100 WeightQuantization::UInt4 => {
1101 let mut weights = Vec::with_capacity(count);
1102 for i in 0..count {
1103 let byte_idx = i / 2;
1104 if byte_idx < weight_data.len() {
1105 let byte = weight_data[byte_idx];
1106 let quantized = if i % 2 == 0 {
1107 byte & 0x0F
1108 } else {
1109 (byte >> 4) & 0x0F
1110 };
1111 weights.push(quantized as f32 * pl.scale + pl.min_val);
1112 } else {
1113 weights.push(0.0);
1114 }
1115 }
1116 weights
1117 }
1118 };
1119
1120 for (doc_id, weight) in doc_ids.into_iter().zip(weights.into_iter()) {
1122 self.block_postings.push((doc_id, weight));
1123 }
1124 }
1125
1126 pub fn doc(&self) -> DocId {
1128 if self.exhausted {
1129 super::TERMINATED
1130 } else if self.position_in_block < self.block_postings.len() {
1131 self.block_postings[self.position_in_block].0
1132 } else {
1133 super::TERMINATED
1134 }
1135 }
1136
1137 pub fn weight(&self) -> f32 {
1139 if self.exhausted || self.position_in_block >= self.block_postings.len() {
1140 0.0
1141 } else {
1142 self.block_postings[self.position_in_block].1
1143 }
1144 }
1145
1146 #[inline]
1151 pub fn current_block_max_weight(&self) -> f32 {
1152 if self.exhausted {
1153 0.0
1154 } else {
1155 self.posting_list
1156 .skip_list
1157 .get(self.current_block)
1158 .map(|e| e.max_weight)
1159 .unwrap_or(0.0)
1160 }
1161 }
1162
1163 #[inline]
1167 pub fn current_block_max_contribution(&self, query_weight: f32) -> f32 {
1168 query_weight * self.current_block_max_weight()
1169 }
1170
1171 pub fn advance(&mut self) -> DocId {
1173 if self.exhausted {
1174 return super::TERMINATED;
1175 }
1176
1177 self.position_in_block += 1;
1178 if self.position_in_block >= self.block_postings.len() {
1179 self.load_block(self.current_block + 1);
1180 }
1181
1182 self.doc()
1183 }
1184
1185 pub fn seek(&mut self, target: DocId) -> DocId {
1187 if self.exhausted {
1188 return super::TERMINATED;
1189 }
1190
1191 if let Some(block_idx) = self.posting_list.skip_list.find_block(target) {
1193 if block_idx != self.current_block {
1194 self.load_block(block_idx);
1195 }
1196
1197 while self.position_in_block < self.block_postings.len() {
1199 if self.block_postings[self.position_in_block].0 >= target {
1200 return self.doc();
1201 }
1202 self.position_in_block += 1;
1203 }
1204
1205 self.load_block(self.current_block + 1);
1207 self.seek(target)
1208 } else {
1209 self.exhausted = true;
1210 super::TERMINATED
1211 }
1212 }
1213}
1214
1215#[cfg(test)]
1216mod tests {
1217 use super::*;
1218
1219 #[test]
1220 fn test_sparse_vector_dot_product() {
1221 let v1 = SparseVector::from_entries(&[0, 2, 5], &[1.0, 2.0, 3.0]);
1222 let v2 = SparseVector::from_entries(&[1, 2, 5], &[1.0, 4.0, 2.0]);
1223
1224 assert!((v1.dot(&v2) - 14.0).abs() < 1e-6);
1226 }
1227
1228 #[test]
1229 fn test_sparse_posting_list_float32() {
1230 let postings = vec![(0, 1.5), (5, 2.3), (10, 0.8), (100, 3.15)];
1231 let pl = SparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1232
1233 assert_eq!(pl.doc_count(), 4);
1234
1235 let mut iter = pl.iterator();
1236 assert_eq!(iter.doc(), 0);
1237 assert!((iter.weight() - 1.5).abs() < 1e-6);
1238
1239 iter.advance();
1240 assert_eq!(iter.doc(), 5);
1241 assert!((iter.weight() - 2.3).abs() < 1e-6);
1242
1243 iter.advance();
1244 assert_eq!(iter.doc(), 10);
1245
1246 iter.advance();
1247 assert_eq!(iter.doc(), 100);
1248 assert!((iter.weight() - 3.15).abs() < 1e-6);
1249
1250 iter.advance();
1251 assert_eq!(iter.doc(), super::super::TERMINATED);
1252 }
1253
1254 #[test]
1255 fn test_sparse_posting_list_uint8() {
1256 let postings = vec![(0, 0.0), (5, 0.5), (10, 1.0)];
1257 let pl = SparsePostingList::from_postings(&postings, WeightQuantization::UInt8).unwrap();
1258
1259 let decoded = pl.decode_all().unwrap();
1260 assert_eq!(decoded.len(), 3);
1261
1262 assert!(decoded[0].1 < decoded[1].1);
1264 assert!(decoded[1].1 < decoded[2].1);
1265 }
1266
1267 #[test]
1268 fn test_block_sparse_posting_list() {
1269 let postings: Vec<(DocId, f32)> = (0..300).map(|i| (i * 2, (i as f32) * 0.1)).collect();
1271
1272 let pl =
1273 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1274
1275 assert_eq!(pl.doc_count(), 300);
1276 assert!(pl.num_blocks() >= 2);
1277
1278 let mut iter = pl.iterator();
1280 for (expected_doc, expected_weight) in &postings {
1281 assert_eq!(iter.doc(), *expected_doc);
1282 assert!((iter.weight() - expected_weight).abs() < 1e-6);
1283 iter.advance();
1284 }
1285 assert_eq!(iter.doc(), super::super::TERMINATED);
1286 }
1287
1288 #[test]
1289 fn test_block_sparse_seek() {
1290 let postings: Vec<(DocId, f32)> = (0..500).map(|i| (i * 3, i as f32)).collect();
1291
1292 let pl =
1293 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1294
1295 let mut iter = pl.iterator();
1296
1297 assert_eq!(iter.seek(300), 300);
1299
1300 assert_eq!(iter.seek(301), 303);
1302
1303 assert_eq!(iter.seek(2000), super::super::TERMINATED);
1305 }
1306
1307 #[test]
1308 fn test_serialization_roundtrip() {
1309 let postings: Vec<(DocId, f32)> = vec![(0, 1.0), (10, 2.0), (100, 3.0)];
1310
1311 for quant in [
1312 WeightQuantization::Float32,
1313 WeightQuantization::Float16,
1314 WeightQuantization::UInt8,
1315 ] {
1316 let pl = BlockSparsePostingList::from_postings(&postings, quant).unwrap();
1317
1318 let mut buffer = Vec::new();
1319 pl.serialize(&mut buffer).unwrap();
1320
1321 let pl2 = BlockSparsePostingList::deserialize(&mut &buffer[..]).unwrap();
1322
1323 assert_eq!(pl.doc_count(), pl2.doc_count());
1324 assert_eq!(pl.quantization(), pl2.quantization());
1325
1326 let mut iter1 = pl.iterator();
1328 let mut iter2 = pl2.iterator();
1329
1330 while iter1.doc() != super::super::TERMINATED {
1331 assert_eq!(iter1.doc(), iter2.doc());
1332 assert!((iter1.weight() - iter2.weight()).abs() < 0.1);
1334 iter1.advance();
1335 iter2.advance();
1336 }
1337 }
1338 }
1339
1340 #[test]
1341 fn test_concatenate() {
1342 let postings1: Vec<(DocId, f32)> = vec![(0, 1.0), (5, 2.0)];
1343 let postings2: Vec<(DocId, f32)> = vec![(0, 3.0), (10, 4.0)];
1344
1345 let pl1 =
1346 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1347 let pl2 =
1348 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1349
1350 let merged = BlockSparsePostingList::concatenate(
1352 &[(pl1, 0), (pl2, 100)],
1353 WeightQuantization::Float32,
1354 )
1355 .unwrap();
1356
1357 assert_eq!(merged.doc_count(), 4);
1358
1359 let decoded = merged.decode_all().unwrap();
1360 assert_eq!(decoded[0], (0, 1.0));
1361 assert_eq!(decoded[1], (5, 2.0));
1362 assert_eq!(decoded[2], (100, 3.0)); assert_eq!(decoded[3], (110, 4.0)); }
1365
1366 #[test]
1367 fn test_sparse_vector_config() {
1368 let default = SparseVectorConfig::default();
1370 assert_eq!(default.index_size, IndexSize::U32);
1371 assert_eq!(default.weight_quantization, WeightQuantization::Float32);
1372 assert_eq!(default.bytes_per_entry(), 8.0); let splade = SparseVectorConfig::splade();
1376 assert_eq!(splade.index_size, IndexSize::U16);
1377 assert_eq!(splade.weight_quantization, WeightQuantization::UInt8);
1378 assert_eq!(splade.bytes_per_entry(), 3.0); let compact = SparseVectorConfig::compact();
1382 assert_eq!(compact.index_size, IndexSize::U16);
1383 assert_eq!(compact.weight_quantization, WeightQuantization::UInt4);
1384 assert_eq!(compact.bytes_per_entry(), 2.5); let byte = splade.to_byte();
1388 let restored = SparseVectorConfig::from_byte(byte).unwrap();
1389 assert_eq!(restored, splade);
1390 }
1391
1392 #[test]
1393 fn test_index_size() {
1394 assert_eq!(IndexSize::U16.bytes(), 2);
1395 assert_eq!(IndexSize::U32.bytes(), 4);
1396 assert_eq!(IndexSize::U16.max_value(), 65535);
1397 assert_eq!(IndexSize::U32.max_value(), u32::MAX);
1398 }
1399
1400 #[test]
1401 fn test_block_max_weight() {
1402 let postings: Vec<(DocId, f32)> =
1407 (0..300).map(|i| (i as DocId, (i as f32) * 0.1)).collect();
1408
1409 let pl =
1410 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1411
1412 assert!((pl.global_max_weight() - 29.9).abs() < 0.01);
1414
1415 assert!(pl.num_blocks() >= 3);
1417
1418 let block0_max = pl.block_max_weight(0).unwrap();
1420 assert!((block0_max - 12.7).abs() < 0.01);
1421
1422 let block1_max = pl.block_max_weight(1).unwrap();
1424 assert!((block1_max - 25.5).abs() < 0.01);
1425
1426 let block2_max = pl.block_max_weight(2).unwrap();
1428 assert!((block2_max - 29.9).abs() < 0.01);
1429
1430 let query_weight = 2.0;
1432 assert!((pl.max_contribution(query_weight) - 59.8).abs() < 0.1);
1433
1434 let mut iter = pl.iterator();
1436 assert!((iter.current_block_max_weight() - 12.7).abs() < 0.01);
1437 assert!((iter.current_block_max_contribution(query_weight) - 25.4).abs() < 0.1);
1438
1439 iter.seek(128);
1441 assert!((iter.current_block_max_weight() - 25.5).abs() < 0.01);
1442 }
1443
1444 #[test]
1445 fn test_sparse_skip_list_serialization() {
1446 let mut skip_list = SparseSkipList::new();
1447 skip_list.push(0, 127, 0, 12.7);
1448 skip_list.push(128, 255, 100, 25.5);
1449 skip_list.push(256, 299, 200, 29.9);
1450
1451 assert_eq!(skip_list.len(), 3);
1452 assert!((skip_list.global_max_weight() - 29.9).abs() < 0.01);
1453
1454 let mut buffer = Vec::new();
1456 skip_list.write(&mut buffer).unwrap();
1457
1458 let restored = SparseSkipList::read(&mut buffer.as_slice()).unwrap();
1460
1461 assert_eq!(restored.len(), 3);
1462 assert!((restored.global_max_weight() - 29.9).abs() < 0.01);
1463
1464 let e0 = restored.get(0).unwrap();
1466 assert_eq!(e0.first_doc, 0);
1467 assert_eq!(e0.last_doc, 127);
1468 assert!((e0.max_weight - 12.7).abs() < 0.01);
1469
1470 let e1 = restored.get(1).unwrap();
1471 assert_eq!(e1.first_doc, 128);
1472 assert!((e1.max_weight - 25.5).abs() < 0.01);
1473 }
1474}