1use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
20use serde::{Deserialize, Serialize};
21use std::io::{self, Read, Write};
22
23use super::posting_common::{
24 RoundedBitWidth, pack_deltas_fixed, read_vint, unpack_deltas_fixed, write_vint,
25};
26use crate::DocId;
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
30#[repr(u8)]
31pub enum IndexSize {
32 U16 = 0,
34 #[default]
36 U32 = 1,
37}
38
39impl IndexSize {
40 pub fn bytes(&self) -> usize {
42 match self {
43 IndexSize::U16 => 2,
44 IndexSize::U32 => 4,
45 }
46 }
47
48 pub fn max_value(&self) -> u32 {
50 match self {
51 IndexSize::U16 => u16::MAX as u32,
52 IndexSize::U32 => u32::MAX,
53 }
54 }
55
56 fn from_u8(v: u8) -> Option<Self> {
57 match v {
58 0 => Some(IndexSize::U16),
59 1 => Some(IndexSize::U32),
60 _ => None,
61 }
62 }
63}
64
65#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
67#[repr(u8)]
68pub enum WeightQuantization {
69 #[default]
71 Float32 = 0,
72 Float16 = 1,
74 UInt8 = 2,
76 UInt4 = 3,
78}
79
80impl WeightQuantization {
81 pub fn bytes_per_weight(&self) -> f32 {
83 match self {
84 WeightQuantization::Float32 => 4.0,
85 WeightQuantization::Float16 => 2.0,
86 WeightQuantization::UInt8 => 1.0,
87 WeightQuantization::UInt4 => 0.5,
88 }
89 }
90
91 fn from_u8(v: u8) -> Option<Self> {
92 match v {
93 0 => Some(WeightQuantization::Float32),
94 1 => Some(WeightQuantization::Float16),
95 2 => Some(WeightQuantization::UInt8),
96 3 => Some(WeightQuantization::UInt4),
97 _ => None,
98 }
99 }
100}
101
102#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
104pub enum QueryWeighting {
105 #[default]
107 One,
108 Idf,
110}
111
112#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
114pub struct SparseQueryConfig {
115 #[serde(default, skip_serializing_if = "Option::is_none")]
118 pub tokenizer: Option<String>,
119 #[serde(default)]
121 pub weighting: QueryWeighting,
122 #[serde(default = "default_heap_factor")]
128 pub heap_factor: f32,
129 #[serde(default, skip_serializing_if = "Option::is_none")]
134 pub max_query_dims: Option<usize>,
135}
136
137fn default_heap_factor() -> f32 {
138 1.0
139}
140
141impl Default for SparseQueryConfig {
142 fn default() -> Self {
143 Self {
144 tokenizer: None,
145 weighting: QueryWeighting::One,
146 heap_factor: 1.0,
147 max_query_dims: None,
148 }
149 }
150}
151
152#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
154pub struct SparseVectorConfig {
155 pub index_size: IndexSize,
157 pub weight_quantization: WeightQuantization,
159 #[serde(default)]
162 pub weight_threshold: f32,
163 #[serde(default, skip_serializing_if = "Option::is_none")]
171 pub posting_list_pruning: Option<f32>,
172 #[serde(default, skip_serializing_if = "Option::is_none")]
174 pub query_config: Option<SparseQueryConfig>,
175}
176
177impl Default for SparseVectorConfig {
178 fn default() -> Self {
179 Self {
180 index_size: IndexSize::U32,
181 weight_quantization: WeightQuantization::Float32,
182 weight_threshold: 0.0,
183 posting_list_pruning: None,
184 query_config: None,
185 }
186 }
187}
188
189impl SparseVectorConfig {
190 pub fn splade() -> Self {
192 Self {
193 index_size: IndexSize::U16,
194 weight_quantization: WeightQuantization::UInt8,
195 weight_threshold: 0.0,
196 posting_list_pruning: None,
197 query_config: None,
198 }
199 }
200
201 pub fn compact() -> Self {
203 Self {
204 index_size: IndexSize::U16,
205 weight_quantization: WeightQuantization::UInt4,
206 weight_threshold: 0.0,
207 posting_list_pruning: None,
208 query_config: None,
209 }
210 }
211
212 pub fn full_precision() -> Self {
214 Self {
215 index_size: IndexSize::U32,
216 weight_quantization: WeightQuantization::Float32,
217 weight_threshold: 0.0,
218 posting_list_pruning: None,
219 query_config: None,
220 }
221 }
222
223 pub fn with_weight_threshold(mut self, threshold: f32) -> Self {
225 self.weight_threshold = threshold;
226 self
227 }
228
229 pub fn with_pruning(mut self, fraction: f32) -> Self {
232 self.posting_list_pruning = Some(fraction.clamp(0.0, 1.0));
233 self
234 }
235
236 pub fn bytes_per_entry(&self) -> f32 {
238 self.index_size.bytes() as f32 + self.weight_quantization.bytes_per_weight()
239 }
240
241 pub fn to_byte(&self) -> u8 {
243 ((self.index_size as u8) << 4) | (self.weight_quantization as u8)
244 }
245
246 pub fn from_byte(b: u8) -> Option<Self> {
249 let index_size = IndexSize::from_u8(b >> 4)?;
250 let weight_quantization = WeightQuantization::from_u8(b & 0x0F)?;
251 Some(Self {
252 index_size,
253 weight_quantization,
254 weight_threshold: 0.0,
255 posting_list_pruning: None,
256 query_config: None,
257 })
258 }
259
260 pub fn with_query_config(mut self, config: SparseQueryConfig) -> Self {
262 self.query_config = Some(config);
263 self
264 }
265}
266
267#[derive(Debug, Clone, Copy, PartialEq)]
269pub struct SparseEntry {
270 pub dim_id: u32,
271 pub weight: f32,
272}
273
274#[derive(Debug, Clone, Default)]
276pub struct SparseVector {
277 entries: Vec<SparseEntry>,
278}
279
280impl SparseVector {
281 pub fn new() -> Self {
282 Self::default()
283 }
284
285 pub fn with_capacity(capacity: usize) -> Self {
286 Self {
287 entries: Vec::with_capacity(capacity),
288 }
289 }
290
291 pub fn from_entries(dim_ids: &[u32], weights: &[f32]) -> Self {
293 assert_eq!(dim_ids.len(), weights.len());
294 let mut entries: Vec<SparseEntry> = dim_ids
295 .iter()
296 .zip(weights.iter())
297 .map(|(&dim_id, &weight)| SparseEntry { dim_id, weight })
298 .collect();
299 entries.sort_by_key(|e| e.dim_id);
301 Self { entries }
302 }
303
304 pub fn push(&mut self, dim_id: u32, weight: f32) {
306 debug_assert!(
307 self.entries.is_empty() || self.entries.last().unwrap().dim_id < dim_id,
308 "Entries must be added in sorted order by dim_id"
309 );
310 self.entries.push(SparseEntry { dim_id, weight });
311 }
312
313 pub fn len(&self) -> usize {
315 self.entries.len()
316 }
317
318 pub fn is_empty(&self) -> bool {
319 self.entries.is_empty()
320 }
321
322 pub fn iter(&self) -> impl Iterator<Item = &SparseEntry> {
323 self.entries.iter()
324 }
325
326 pub fn dot(&self, other: &SparseVector) -> f32 {
328 let mut result = 0.0f32;
329 let mut i = 0;
330 let mut j = 0;
331
332 while i < self.entries.len() && j < other.entries.len() {
333 let a = &self.entries[i];
334 let b = &other.entries[j];
335
336 match a.dim_id.cmp(&b.dim_id) {
337 std::cmp::Ordering::Less => i += 1,
338 std::cmp::Ordering::Greater => j += 1,
339 std::cmp::Ordering::Equal => {
340 result += a.weight * b.weight;
341 i += 1;
342 j += 1;
343 }
344 }
345 }
346
347 result
348 }
349
350 pub fn norm_squared(&self) -> f32 {
352 self.entries.iter().map(|e| e.weight * e.weight).sum()
353 }
354
355 pub fn norm(&self) -> f32 {
357 self.norm_squared().sqrt()
358 }
359
360 pub fn top_k(&self, k: usize) -> Self {
365 if self.entries.len() <= k {
366 return self.clone();
367 }
368
369 let mut sorted: Vec<SparseEntry> = self.entries.clone();
371 sorted.sort_by(|a, b| {
372 b.weight
373 .abs()
374 .partial_cmp(&a.weight.abs())
375 .unwrap_or(std::cmp::Ordering::Equal)
376 });
377 sorted.truncate(k);
378 sorted.sort_by_key(|e| e.dim_id);
379
380 Self { entries: sorted }
381 }
382
383 pub fn filter_by_weight(&self, min_weight: f32) -> Self {
385 let entries: Vec<SparseEntry> = self
386 .entries
387 .iter()
388 .filter(|e| e.weight.abs() >= min_weight)
389 .cloned()
390 .collect();
391 Self { entries }
392 }
393}
394
395#[derive(Debug, Clone, Copy)]
397pub struct SparsePosting {
398 pub doc_id: DocId,
399 pub weight: f32,
400}
401
402pub const SPARSE_BLOCK_SIZE: usize = 128;
404
405#[derive(Debug, Clone, Copy, PartialEq)]
409pub struct SparseSkipEntry {
410 pub first_doc: DocId,
412 pub last_doc: DocId,
414 pub offset: u32,
416 pub max_weight: f32,
418}
419
420impl SparseSkipEntry {
421 pub fn new(first_doc: DocId, last_doc: DocId, offset: u32, max_weight: f32) -> Self {
422 Self {
423 first_doc,
424 last_doc,
425 offset,
426 max_weight,
427 }
428 }
429
430 #[inline]
435 pub fn block_max_contribution(&self, query_weight: f32) -> f32 {
436 query_weight * self.max_weight
437 }
438
439 pub fn write<W: Write>(&self, writer: &mut W) -> io::Result<()> {
441 writer.write_u32::<LittleEndian>(self.first_doc)?;
442 writer.write_u32::<LittleEndian>(self.last_doc)?;
443 writer.write_u32::<LittleEndian>(self.offset)?;
444 writer.write_f32::<LittleEndian>(self.max_weight)?;
445 Ok(())
446 }
447
448 pub fn read<R: Read>(reader: &mut R) -> io::Result<Self> {
450 let first_doc = reader.read_u32::<LittleEndian>()?;
451 let last_doc = reader.read_u32::<LittleEndian>()?;
452 let offset = reader.read_u32::<LittleEndian>()?;
453 let max_weight = reader.read_f32::<LittleEndian>()?;
454 Ok(Self {
455 first_doc,
456 last_doc,
457 offset,
458 max_weight,
459 })
460 }
461}
462
463#[derive(Debug, Clone, Default)]
465pub struct SparseSkipList {
466 entries: Vec<SparseSkipEntry>,
467 global_max_weight: f32,
469}
470
471impl SparseSkipList {
472 pub fn new() -> Self {
473 Self::default()
474 }
475
476 pub fn push(&mut self, first_doc: DocId, last_doc: DocId, offset: u32, max_weight: f32) {
478 self.global_max_weight = self.global_max_weight.max(max_weight);
479 self.entries.push(SparseSkipEntry::new(
480 first_doc, last_doc, offset, max_weight,
481 ));
482 }
483
484 pub fn len(&self) -> usize {
486 self.entries.len()
487 }
488
489 pub fn is_empty(&self) -> bool {
490 self.entries.is_empty()
491 }
492
493 pub fn get(&self, index: usize) -> Option<&SparseSkipEntry> {
495 self.entries.get(index)
496 }
497
498 pub fn global_max_weight(&self) -> f32 {
500 self.global_max_weight
501 }
502
503 pub fn find_block(&self, target: DocId) -> Option<usize> {
505 self.entries.iter().position(|e| e.last_doc >= target)
506 }
507
508 pub fn iter(&self) -> impl Iterator<Item = &SparseSkipEntry> {
510 self.entries.iter()
511 }
512
513 pub fn write<W: Write>(&self, writer: &mut W) -> io::Result<()> {
515 writer.write_u32::<LittleEndian>(self.entries.len() as u32)?;
516 writer.write_f32::<LittleEndian>(self.global_max_weight)?;
517 for entry in &self.entries {
518 entry.write(writer)?;
519 }
520 Ok(())
521 }
522
523 pub fn read<R: Read>(reader: &mut R) -> io::Result<Self> {
525 let count = reader.read_u32::<LittleEndian>()? as usize;
526 let global_max_weight = reader.read_f32::<LittleEndian>()?;
527 let mut entries = Vec::with_capacity(count);
528 for _ in 0..count {
529 entries.push(SparseSkipEntry::read(reader)?);
530 }
531 Ok(Self {
532 entries,
533 global_max_weight,
534 })
535 }
536}
537
538#[derive(Debug, Clone)]
544pub struct SparsePostingList {
545 quantization: WeightQuantization,
547 scale: f32,
549 min_val: f32,
551 doc_count: u32,
553 data: Vec<u8>,
555}
556
557impl SparsePostingList {
558 pub fn from_postings(
560 postings: &[(DocId, f32)],
561 quantization: WeightQuantization,
562 ) -> io::Result<Self> {
563 if postings.is_empty() {
564 return Ok(Self {
565 quantization,
566 scale: 1.0,
567 min_val: 0.0,
568 doc_count: 0,
569 data: Vec::new(),
570 });
571 }
572
573 let weights: Vec<f32> = postings.iter().map(|(_, w)| *w).collect();
575 let min_val = weights.iter().cloned().fold(f32::INFINITY, f32::min);
576 let max_val = weights.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
577
578 let (scale, adjusted_min) = match quantization {
579 WeightQuantization::Float32 | WeightQuantization::Float16 => (1.0, 0.0),
580 WeightQuantization::UInt8 => {
581 let range = max_val - min_val;
582 if range < f32::EPSILON {
583 (1.0, min_val)
584 } else {
585 (range / 255.0, min_val)
586 }
587 }
588 WeightQuantization::UInt4 => {
589 let range = max_val - min_val;
590 if range < f32::EPSILON {
591 (1.0, min_val)
592 } else {
593 (range / 15.0, min_val)
594 }
595 }
596 };
597
598 let mut data = Vec::new();
599
600 let mut prev_doc_id = 0u32;
602 for (doc_id, _) in postings {
603 let delta = doc_id - prev_doc_id;
604 write_vint(&mut data, delta as u64)?;
605 prev_doc_id = *doc_id;
606 }
607
608 match quantization {
610 WeightQuantization::Float32 => {
611 for (_, weight) in postings {
612 data.write_f32::<LittleEndian>(*weight)?;
613 }
614 }
615 WeightQuantization::Float16 => {
616 use half::slice::HalfFloatSliceExt;
618 let weights: Vec<f32> = postings.iter().map(|(_, w)| *w).collect();
619 let mut f16_slice: Vec<half::f16> = vec![half::f16::ZERO; weights.len()];
620 f16_slice.convert_from_f32_slice(&weights);
621 for h in f16_slice {
622 data.write_u16::<LittleEndian>(h.to_bits())?;
623 }
624 }
625 WeightQuantization::UInt8 => {
626 for (_, weight) in postings {
627 let quantized = ((*weight - adjusted_min) / scale).round() as u8;
628 data.write_u8(quantized)?;
629 }
630 }
631 WeightQuantization::UInt4 => {
632 let mut i = 0;
634 while i < postings.len() {
635 let q1 = ((postings[i].1 - adjusted_min) / scale).round() as u8 & 0x0F;
636 let q2 = if i + 1 < postings.len() {
637 ((postings[i + 1].1 - adjusted_min) / scale).round() as u8 & 0x0F
638 } else {
639 0
640 };
641 data.write_u8((q2 << 4) | q1)?;
642 i += 2;
643 }
644 }
645 }
646
647 Ok(Self {
648 quantization,
649 scale,
650 min_val: adjusted_min,
651 doc_count: postings.len() as u32,
652 data,
653 })
654 }
655
656 pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
658 writer.write_u8(self.quantization as u8)?;
659 writer.write_f32::<LittleEndian>(self.scale)?;
660 writer.write_f32::<LittleEndian>(self.min_val)?;
661 writer.write_u32::<LittleEndian>(self.doc_count)?;
662 writer.write_u32::<LittleEndian>(self.data.len() as u32)?;
663 writer.write_all(&self.data)?;
664 Ok(())
665 }
666
667 pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
669 let quant_byte = reader.read_u8()?;
670 let quantization = WeightQuantization::from_u8(quant_byte).ok_or_else(|| {
671 io::Error::new(io::ErrorKind::InvalidData, "Invalid quantization type")
672 })?;
673 let scale = reader.read_f32::<LittleEndian>()?;
674 let min_val = reader.read_f32::<LittleEndian>()?;
675 let doc_count = reader.read_u32::<LittleEndian>()?;
676 let data_len = reader.read_u32::<LittleEndian>()? as usize;
677 let mut data = vec![0u8; data_len];
678 reader.read_exact(&mut data)?;
679
680 Ok(Self {
681 quantization,
682 scale,
683 min_val,
684 doc_count,
685 data,
686 })
687 }
688
689 pub fn doc_count(&self) -> u32 {
691 self.doc_count
692 }
693
694 pub fn quantization(&self) -> WeightQuantization {
696 self.quantization
697 }
698
699 pub fn iterator(&self) -> SparsePostingIterator<'_> {
701 SparsePostingIterator::new(self)
702 }
703
704 pub fn decode_all(&self) -> io::Result<Vec<(DocId, f32)>> {
706 let mut result = Vec::with_capacity(self.doc_count as usize);
707 let mut iter = self.iterator();
708
709 while !iter.exhausted {
710 result.push((iter.doc_id, iter.weight));
711 iter.advance();
712 }
713
714 Ok(result)
715 }
716}
717
718pub struct SparsePostingIterator<'a> {
720 posting_list: &'a SparsePostingList,
721 doc_id_offset: usize,
723 weight_offset: usize,
725 index: usize,
727 doc_id: DocId,
729 weight: f32,
731 exhausted: bool,
733}
734
735impl<'a> SparsePostingIterator<'a> {
736 fn new(posting_list: &'a SparsePostingList) -> Self {
737 let mut iter = Self {
738 posting_list,
739 doc_id_offset: 0,
740 weight_offset: 0,
741 index: 0,
742 doc_id: 0,
743 weight: 0.0,
744 exhausted: posting_list.doc_count == 0,
745 };
746
747 if !iter.exhausted {
748 iter.weight_offset = iter.calculate_weight_offset();
750 iter.load_current();
751 }
752
753 iter
754 }
755
756 fn calculate_weight_offset(&self) -> usize {
757 let mut offset = 0;
759 let mut reader = &self.posting_list.data[..];
760
761 for _ in 0..self.posting_list.doc_count {
762 if read_vint(&mut reader).is_ok() {
763 offset = self.posting_list.data.len() - reader.len();
764 }
765 }
766
767 offset
768 }
769
770 fn load_current(&mut self) {
771 if self.index >= self.posting_list.doc_count as usize {
772 self.exhausted = true;
773 return;
774 }
775
776 let mut reader = &self.posting_list.data[self.doc_id_offset..];
778 if let Ok(delta) = read_vint(&mut reader) {
779 self.doc_id = self.doc_id.wrapping_add(delta as u32);
780 self.doc_id_offset = self.posting_list.data.len() - reader.len();
781 }
782
783 let weight_idx = self.index;
785 let pl = self.posting_list;
786
787 self.weight = match pl.quantization {
788 WeightQuantization::Float32 => {
789 let offset = self.weight_offset + weight_idx * 4;
790 if offset + 4 <= pl.data.len() {
791 let bytes = &pl.data[offset..offset + 4];
792 f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]])
793 } else {
794 0.0
795 }
796 }
797 WeightQuantization::Float16 => {
798 let offset = self.weight_offset + weight_idx * 2;
799 if offset + 2 <= pl.data.len() {
800 let bits = u16::from_le_bytes([pl.data[offset], pl.data[offset + 1]]);
801 half::f16::from_bits(bits).to_f32()
802 } else {
803 0.0
804 }
805 }
806 WeightQuantization::UInt8 => {
807 let offset = self.weight_offset + weight_idx;
808 if offset < pl.data.len() {
809 let quantized = pl.data[offset];
810 quantized as f32 * pl.scale + pl.min_val
811 } else {
812 0.0
813 }
814 }
815 WeightQuantization::UInt4 => {
816 let byte_offset = self.weight_offset + weight_idx / 2;
817 if byte_offset < pl.data.len() {
818 let byte = pl.data[byte_offset];
819 let quantized = if weight_idx.is_multiple_of(2) {
820 byte & 0x0F
821 } else {
822 (byte >> 4) & 0x0F
823 };
824 quantized as f32 * pl.scale + pl.min_val
825 } else {
826 0.0
827 }
828 }
829 };
830 }
831
832 pub fn doc(&self) -> DocId {
834 if self.exhausted {
835 super::TERMINATED
836 } else {
837 self.doc_id
838 }
839 }
840
841 pub fn weight(&self) -> f32 {
843 if self.exhausted { 0.0 } else { self.weight }
844 }
845
846 pub fn advance(&mut self) -> DocId {
848 if self.exhausted {
849 return super::TERMINATED;
850 }
851
852 self.index += 1;
853 if self.index >= self.posting_list.doc_count as usize {
854 self.exhausted = true;
855 return super::TERMINATED;
856 }
857
858 self.load_current();
859 self.doc_id
860 }
861
862 pub fn seek(&mut self, target: DocId) -> DocId {
864 while !self.exhausted && self.doc_id < target {
865 self.advance();
866 }
867 self.doc()
868 }
869}
870
871#[derive(Debug, Clone)]
876pub struct BlockSparsePostingList {
877 quantization: WeightQuantization,
879 scale: f32,
881 min_val: f32,
883 skip_list: SparseSkipList,
885 data: Vec<u8>,
887 doc_count: u32,
889}
890
891impl BlockSparsePostingList {
892 pub fn from_postings(
894 postings: &[(DocId, f32)],
895 quantization: WeightQuantization,
896 ) -> io::Result<Self> {
897 Self::from_postings_with_pruning(postings, quantization, None)
898 }
899
900 pub fn from_postings_with_pruning(
906 postings: &[(DocId, f32)],
907 quantization: WeightQuantization,
908 pruning_fraction: Option<f32>,
909 ) -> io::Result<Self> {
910 if postings.is_empty() {
911 return Ok(Self {
912 quantization,
913 scale: 1.0,
914 min_val: 0.0,
915 skip_list: SparseSkipList::new(),
916 data: Vec::new(),
917 doc_count: 0,
918 });
919 }
920
921 let postings: std::borrow::Cow<'_, [(DocId, f32)]> = if let Some(fraction) =
924 pruning_fraction
925 {
926 let max = ((postings.len() as f32 * fraction).ceil() as usize).max(1);
927 if postings.len() > max {
928 let mut sorted: Vec<(DocId, f32)> = postings.to_vec();
929 sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
931 sorted.truncate(max);
933 sorted.sort_by_key(|(doc_id, _)| *doc_id);
935 std::borrow::Cow::Owned(sorted)
936 } else {
937 std::borrow::Cow::Borrowed(postings)
938 }
939 } else {
940 std::borrow::Cow::Borrowed(postings)
941 };
942
943 let weights: Vec<f32> = postings.iter().map(|(_, w)| *w).collect();
945 let min_val = weights.iter().cloned().fold(f32::INFINITY, f32::min);
946 let max_val = weights.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
947
948 let (scale, adjusted_min) = match quantization {
949 WeightQuantization::Float32 | WeightQuantization::Float16 => (1.0, 0.0),
950 WeightQuantization::UInt8 => {
951 let range = max_val - min_val;
952 if range < f32::EPSILON {
953 (1.0, min_val)
954 } else {
955 (range / 255.0, min_val)
956 }
957 }
958 WeightQuantization::UInt4 => {
959 let range = max_val - min_val;
960 if range < f32::EPSILON {
961 (1.0, min_val)
962 } else {
963 (range / 15.0, min_val)
964 }
965 }
966 };
967
968 let mut skip_list = SparseSkipList::new();
969 let mut data = Vec::new();
970
971 let mut i = 0;
972 while i < postings.len() {
973 let block_end = (i + SPARSE_BLOCK_SIZE).min(postings.len());
974 let block = &postings[i..block_end];
975
976 let first_doc_id = block.first().unwrap().0;
977 let last_doc_id = block.last().unwrap().0;
978
979 let block_max_weight = block
981 .iter()
982 .map(|(_, w)| *w)
983 .fold(f32::NEG_INFINITY, f32::max);
984
985 let block_doc_ids: Vec<DocId> = block.iter().map(|(d, _)| *d).collect();
987 let (doc_bit_width, packed_doc_ids) = pack_deltas_fixed(&block_doc_ids);
988
989 let block_start = data.len() as u32;
991 skip_list.push(first_doc_id, last_doc_id, block_start, block_max_weight);
992
993 data.write_u16::<LittleEndian>(block.len() as u16)?;
994 data.write_u8(doc_bit_width as u8)?;
995 data.extend_from_slice(&packed_doc_ids);
996
997 match quantization {
999 WeightQuantization::Float32 => {
1000 for (_, weight) in block {
1001 data.write_f32::<LittleEndian>(*weight)?;
1002 }
1003 }
1004 WeightQuantization::Float16 => {
1005 use half::slice::HalfFloatSliceExt;
1007 let weights: Vec<f32> = block.iter().map(|(_, w)| *w).collect();
1008 let mut f16_slice: Vec<half::f16> = vec![half::f16::ZERO; weights.len()];
1009 f16_slice.convert_from_f32_slice(&weights);
1010 for h in f16_slice {
1011 data.write_u16::<LittleEndian>(h.to_bits())?;
1012 }
1013 }
1014 WeightQuantization::UInt8 => {
1015 for (_, weight) in block {
1016 let quantized = ((*weight - adjusted_min) / scale).round() as u8;
1017 data.write_u8(quantized)?;
1018 }
1019 }
1020 WeightQuantization::UInt4 => {
1021 let mut j = 0;
1022 while j < block.len() {
1023 let q1 = ((block[j].1 - adjusted_min) / scale).round() as u8 & 0x0F;
1024 let q2 = if j + 1 < block.len() {
1025 ((block[j + 1].1 - adjusted_min) / scale).round() as u8 & 0x0F
1026 } else {
1027 0
1028 };
1029 data.write_u8((q2 << 4) | q1)?;
1030 j += 2;
1031 }
1032 }
1033 }
1034
1035 i = block_end;
1036 }
1037
1038 Ok(Self {
1039 quantization,
1040 scale,
1041 min_val: adjusted_min,
1042 skip_list,
1043 data,
1044 doc_count: postings.len() as u32,
1045 })
1046 }
1047
1048 pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
1050 writer.write_u8(self.quantization as u8)?;
1051 writer.write_f32::<LittleEndian>(self.scale)?;
1052 writer.write_f32::<LittleEndian>(self.min_val)?;
1053 writer.write_u32::<LittleEndian>(self.doc_count)?;
1054
1055 self.skip_list.write(writer)?;
1057
1058 writer.write_u32::<LittleEndian>(self.data.len() as u32)?;
1060 writer.write_all(&self.data)?;
1061
1062 Ok(())
1063 }
1064
1065 pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
1067 let quant_byte = reader.read_u8()?;
1068 let quantization = WeightQuantization::from_u8(quant_byte).ok_or_else(|| {
1069 io::Error::new(io::ErrorKind::InvalidData, "Invalid quantization type")
1070 })?;
1071 let scale = reader.read_f32::<LittleEndian>()?;
1072 let min_val = reader.read_f32::<LittleEndian>()?;
1073 let doc_count = reader.read_u32::<LittleEndian>()?;
1074
1075 let skip_list = SparseSkipList::read(reader)?;
1077
1078 let data_len = reader.read_u32::<LittleEndian>()? as usize;
1079 let mut data = vec![0u8; data_len];
1080 reader.read_exact(&mut data)?;
1081
1082 Ok(Self {
1083 quantization,
1084 scale,
1085 min_val,
1086 skip_list,
1087 data,
1088 doc_count,
1089 })
1090 }
1091
1092 pub fn doc_count(&self) -> u32 {
1094 self.doc_count
1095 }
1096
1097 pub fn num_blocks(&self) -> usize {
1099 self.skip_list.len()
1100 }
1101
1102 pub fn quantization(&self) -> WeightQuantization {
1104 self.quantization
1105 }
1106
1107 pub fn global_max_weight(&self) -> f32 {
1109 self.skip_list.global_max_weight()
1110 }
1111
1112 pub fn block_max_weight(&self, block_idx: usize) -> Option<f32> {
1114 self.skip_list.get(block_idx).map(|e| e.max_weight)
1115 }
1116
1117 #[inline]
1122 pub fn max_contribution(&self, query_weight: f32) -> f32 {
1123 query_weight * self.skip_list.global_max_weight()
1124 }
1125
1126 pub fn iterator(&self) -> BlockSparsePostingIterator<'_> {
1128 BlockSparsePostingIterator::new(self)
1129 }
1130
1131 pub fn size_bytes(&self) -> usize {
1133 13 + 8 + self.skip_list.len() * 16 + self.data.len()
1137 }
1138
1139 pub fn concatenate(
1141 sources: &[(BlockSparsePostingList, u32)],
1142 target_quantization: WeightQuantization,
1143 ) -> io::Result<Self> {
1144 let mut all_postings: Vec<(DocId, f32)> = Vec::new();
1146
1147 for (source, doc_offset) in sources {
1148 let decoded = source.decode_all()?;
1149 for (doc_id, weight) in decoded {
1150 all_postings.push((doc_id + doc_offset, weight));
1151 }
1152 }
1153
1154 Self::from_postings(&all_postings, target_quantization)
1156 }
1157
1158 pub fn decode_all(&self) -> io::Result<Vec<(DocId, f32)>> {
1160 let mut result = Vec::with_capacity(self.doc_count as usize);
1161 let mut iter = self.iterator();
1162
1163 while iter.doc() != super::TERMINATED {
1164 result.push((iter.doc(), iter.weight()));
1165 iter.advance();
1166 }
1167
1168 Ok(result)
1169 }
1170}
1171
1172pub struct BlockSparsePostingIterator<'a> {
1174 posting_list: &'a BlockSparsePostingList,
1175 current_block: usize,
1176 block_postings: Vec<(DocId, f32)>,
1177 position_in_block: usize,
1178 exhausted: bool,
1179}
1180
1181impl<'a> BlockSparsePostingIterator<'a> {
1182 fn new(posting_list: &'a BlockSparsePostingList) -> Self {
1183 let exhausted = posting_list.skip_list.is_empty();
1184 let mut iter = Self {
1185 posting_list,
1186 current_block: 0,
1187 block_postings: Vec::new(),
1188 position_in_block: 0,
1189 exhausted,
1190 };
1191
1192 if !iter.exhausted {
1193 iter.load_block(0);
1194 }
1195
1196 iter
1197 }
1198
1199 fn load_block(&mut self, block_idx: usize) {
1200 let entry = match self.posting_list.skip_list.get(block_idx) {
1201 Some(e) => e,
1202 None => {
1203 self.exhausted = true;
1204 return;
1205 }
1206 };
1207
1208 self.current_block = block_idx;
1209 self.position_in_block = 0;
1210 self.block_postings.clear();
1211
1212 let offset = entry.offset as usize;
1213 let first_doc_id = entry.first_doc;
1214 let data = &self.posting_list.data[offset..];
1215
1216 if data.len() < 3 {
1218 self.exhausted = true;
1219 return;
1220 }
1221 let count = u16::from_le_bytes([data[0], data[1]]) as usize;
1222 let doc_bit_width = RoundedBitWidth::from_u8(data[2]).unwrap_or(RoundedBitWidth::Zero);
1223
1224 let doc_bytes = doc_bit_width.bytes_per_value() * count.saturating_sub(1);
1226 let doc_data = &data[3..3 + doc_bytes];
1227 let mut doc_ids = vec![0u32; count];
1228 unpack_deltas_fixed(doc_data, doc_bit_width, first_doc_id, count, &mut doc_ids);
1229
1230 let weight_offset = 3 + doc_bytes;
1232 let weight_data = &data[weight_offset..];
1233 let pl = self.posting_list;
1234
1235 let weights: Vec<f32> = match pl.quantization {
1237 WeightQuantization::Float32 => {
1238 let mut weights = Vec::with_capacity(count);
1239 let mut reader = weight_data;
1240 for _ in 0..count {
1241 if reader.len() >= 4 {
1242 weights.push((&mut reader).read_f32::<LittleEndian>().unwrap_or(0.0));
1243 } else {
1244 weights.push(0.0);
1245 }
1246 }
1247 weights
1248 }
1249 WeightQuantization::Float16 => {
1250 use half::slice::HalfFloatSliceExt;
1252 let mut f16_slice: Vec<half::f16> = Vec::with_capacity(count);
1253 for i in 0..count {
1254 let offset = i * 2;
1255 if offset + 2 <= weight_data.len() {
1256 let bits =
1257 u16::from_le_bytes([weight_data[offset], weight_data[offset + 1]]);
1258 f16_slice.push(half::f16::from_bits(bits));
1259 } else {
1260 f16_slice.push(half::f16::ZERO);
1261 }
1262 }
1263 let mut weights = vec![0.0f32; count];
1264 f16_slice.convert_to_f32_slice(&mut weights);
1265 weights
1266 }
1267 WeightQuantization::UInt8 => {
1268 let mut weights = Vec::with_capacity(count);
1269 for i in 0..count {
1270 if i < weight_data.len() {
1271 weights.push(weight_data[i] as f32 * pl.scale + pl.min_val);
1272 } else {
1273 weights.push(0.0);
1274 }
1275 }
1276 weights
1277 }
1278 WeightQuantization::UInt4 => {
1279 let mut weights = Vec::with_capacity(count);
1280 for i in 0..count {
1281 let byte_idx = i / 2;
1282 if byte_idx < weight_data.len() {
1283 let byte = weight_data[byte_idx];
1284 let quantized = if i % 2 == 0 {
1285 byte & 0x0F
1286 } else {
1287 (byte >> 4) & 0x0F
1288 };
1289 weights.push(quantized as f32 * pl.scale + pl.min_val);
1290 } else {
1291 weights.push(0.0);
1292 }
1293 }
1294 weights
1295 }
1296 };
1297
1298 for (doc_id, weight) in doc_ids.into_iter().zip(weights.into_iter()) {
1300 self.block_postings.push((doc_id, weight));
1301 }
1302 }
1303
1304 #[inline]
1306 pub fn is_exhausted(&self) -> bool {
1307 self.exhausted
1308 }
1309
1310 pub fn doc(&self) -> DocId {
1312 if self.exhausted {
1313 super::TERMINATED
1314 } else if self.position_in_block < self.block_postings.len() {
1315 self.block_postings[self.position_in_block].0
1316 } else {
1317 super::TERMINATED
1318 }
1319 }
1320
1321 pub fn weight(&self) -> f32 {
1323 if self.exhausted || self.position_in_block >= self.block_postings.len() {
1324 0.0
1325 } else {
1326 self.block_postings[self.position_in_block].1
1327 }
1328 }
1329
1330 #[inline]
1335 pub fn current_block_max_weight(&self) -> f32 {
1336 if self.exhausted {
1337 0.0
1338 } else {
1339 self.posting_list
1340 .skip_list
1341 .get(self.current_block)
1342 .map(|e| e.max_weight)
1343 .unwrap_or(0.0)
1344 }
1345 }
1346
1347 #[inline]
1351 pub fn current_block_max_contribution(&self, query_weight: f32) -> f32 {
1352 query_weight * self.current_block_max_weight()
1353 }
1354
1355 pub fn advance(&mut self) -> DocId {
1357 if self.exhausted {
1358 return super::TERMINATED;
1359 }
1360
1361 self.position_in_block += 1;
1362 if self.position_in_block >= self.block_postings.len() {
1363 self.load_block(self.current_block + 1);
1364 }
1365
1366 self.doc()
1367 }
1368
1369 pub fn seek(&mut self, target: DocId) -> DocId {
1371 if self.exhausted {
1372 return super::TERMINATED;
1373 }
1374
1375 if let Some(block_idx) = self.posting_list.skip_list.find_block(target) {
1377 if block_idx != self.current_block {
1378 self.load_block(block_idx);
1379 }
1380
1381 while self.position_in_block < self.block_postings.len() {
1383 if self.block_postings[self.position_in_block].0 >= target {
1384 return self.doc();
1385 }
1386 self.position_in_block += 1;
1387 }
1388
1389 self.load_block(self.current_block + 1);
1391 self.seek(target)
1392 } else {
1393 self.exhausted = true;
1394 super::TERMINATED
1395 }
1396 }
1397}
1398
1399#[cfg(test)]
1400mod tests {
1401 use super::*;
1402
1403 #[test]
1404 fn test_sparse_vector_dot_product() {
1405 let v1 = SparseVector::from_entries(&[0, 2, 5], &[1.0, 2.0, 3.0]);
1406 let v2 = SparseVector::from_entries(&[1, 2, 5], &[1.0, 4.0, 2.0]);
1407
1408 assert!((v1.dot(&v2) - 14.0).abs() < 1e-6);
1410 }
1411
1412 #[test]
1413 fn test_sparse_posting_list_float32() {
1414 let postings = vec![(0, 1.5), (5, 2.3), (10, 0.8), (100, 3.15)];
1415 let pl = SparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1416
1417 assert_eq!(pl.doc_count(), 4);
1418
1419 let mut iter = pl.iterator();
1420 assert_eq!(iter.doc(), 0);
1421 assert!((iter.weight() - 1.5).abs() < 1e-6);
1422
1423 iter.advance();
1424 assert_eq!(iter.doc(), 5);
1425 assert!((iter.weight() - 2.3).abs() < 1e-6);
1426
1427 iter.advance();
1428 assert_eq!(iter.doc(), 10);
1429
1430 iter.advance();
1431 assert_eq!(iter.doc(), 100);
1432 assert!((iter.weight() - 3.15).abs() < 1e-6);
1433
1434 iter.advance();
1435 assert_eq!(iter.doc(), super::super::TERMINATED);
1436 }
1437
1438 #[test]
1439 fn test_sparse_posting_list_uint8() {
1440 let postings = vec![(0, 0.0), (5, 0.5), (10, 1.0)];
1441 let pl = SparsePostingList::from_postings(&postings, WeightQuantization::UInt8).unwrap();
1442
1443 let decoded = pl.decode_all().unwrap();
1444 assert_eq!(decoded.len(), 3);
1445
1446 assert!(decoded[0].1 < decoded[1].1);
1448 assert!(decoded[1].1 < decoded[2].1);
1449 }
1450
1451 #[test]
1452 fn test_block_sparse_posting_list() {
1453 let postings: Vec<(DocId, f32)> = (0..300).map(|i| (i * 2, (i as f32) * 0.1)).collect();
1455
1456 let pl =
1457 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1458
1459 assert_eq!(pl.doc_count(), 300);
1460 assert!(pl.num_blocks() >= 2);
1461
1462 let mut iter = pl.iterator();
1464 for (expected_doc, expected_weight) in &postings {
1465 assert_eq!(iter.doc(), *expected_doc);
1466 assert!((iter.weight() - expected_weight).abs() < 1e-6);
1467 iter.advance();
1468 }
1469 assert_eq!(iter.doc(), super::super::TERMINATED);
1470 }
1471
1472 #[test]
1473 fn test_block_sparse_seek() {
1474 let postings: Vec<(DocId, f32)> = (0..500).map(|i| (i * 3, i as f32)).collect();
1475
1476 let pl =
1477 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1478
1479 let mut iter = pl.iterator();
1480
1481 assert_eq!(iter.seek(300), 300);
1483
1484 assert_eq!(iter.seek(301), 303);
1486
1487 assert_eq!(iter.seek(2000), super::super::TERMINATED);
1489 }
1490
1491 #[test]
1492 fn test_serialization_roundtrip() {
1493 let postings: Vec<(DocId, f32)> = vec![(0, 1.0), (10, 2.0), (100, 3.0)];
1494
1495 for quant in [
1496 WeightQuantization::Float32,
1497 WeightQuantization::Float16,
1498 WeightQuantization::UInt8,
1499 ] {
1500 let pl = BlockSparsePostingList::from_postings(&postings, quant).unwrap();
1501
1502 let mut buffer = Vec::new();
1503 pl.serialize(&mut buffer).unwrap();
1504
1505 let pl2 = BlockSparsePostingList::deserialize(&mut &buffer[..]).unwrap();
1506
1507 assert_eq!(pl.doc_count(), pl2.doc_count());
1508 assert_eq!(pl.quantization(), pl2.quantization());
1509
1510 let mut iter1 = pl.iterator();
1512 let mut iter2 = pl2.iterator();
1513
1514 while iter1.doc() != super::super::TERMINATED {
1515 assert_eq!(iter1.doc(), iter2.doc());
1516 assert!((iter1.weight() - iter2.weight()).abs() < 0.1);
1518 iter1.advance();
1519 iter2.advance();
1520 }
1521 }
1522 }
1523
1524 #[test]
1525 fn test_concatenate() {
1526 let postings1: Vec<(DocId, f32)> = vec![(0, 1.0), (5, 2.0)];
1527 let postings2: Vec<(DocId, f32)> = vec![(0, 3.0), (10, 4.0)];
1528
1529 let pl1 =
1530 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1531 let pl2 =
1532 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1533
1534 let merged = BlockSparsePostingList::concatenate(
1536 &[(pl1, 0), (pl2, 100)],
1537 WeightQuantization::Float32,
1538 )
1539 .unwrap();
1540
1541 assert_eq!(merged.doc_count(), 4);
1542
1543 let decoded = merged.decode_all().unwrap();
1544 assert_eq!(decoded[0], (0, 1.0));
1545 assert_eq!(decoded[1], (5, 2.0));
1546 assert_eq!(decoded[2], (100, 3.0)); assert_eq!(decoded[3], (110, 4.0)); }
1549
1550 #[test]
1551 fn test_sparse_vector_config() {
1552 let default = SparseVectorConfig::default();
1554 assert_eq!(default.index_size, IndexSize::U32);
1555 assert_eq!(default.weight_quantization, WeightQuantization::Float32);
1556 assert_eq!(default.bytes_per_entry(), 8.0); let splade = SparseVectorConfig::splade();
1560 assert_eq!(splade.index_size, IndexSize::U16);
1561 assert_eq!(splade.weight_quantization, WeightQuantization::UInt8);
1562 assert_eq!(splade.bytes_per_entry(), 3.0); let compact = SparseVectorConfig::compact();
1566 assert_eq!(compact.index_size, IndexSize::U16);
1567 assert_eq!(compact.weight_quantization, WeightQuantization::UInt4);
1568 assert_eq!(compact.bytes_per_entry(), 2.5); let byte = splade.to_byte();
1572 let restored = SparseVectorConfig::from_byte(byte).unwrap();
1573 assert_eq!(restored, splade);
1574 }
1575
1576 #[test]
1577 fn test_index_size() {
1578 assert_eq!(IndexSize::U16.bytes(), 2);
1579 assert_eq!(IndexSize::U32.bytes(), 4);
1580 assert_eq!(IndexSize::U16.max_value(), 65535);
1581 assert_eq!(IndexSize::U32.max_value(), u32::MAX);
1582 }
1583
1584 #[test]
1585 fn test_block_max_weight() {
1586 let postings: Vec<(DocId, f32)> =
1591 (0..300).map(|i| (i as DocId, (i as f32) * 0.1)).collect();
1592
1593 let pl =
1594 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1595
1596 assert!((pl.global_max_weight() - 29.9).abs() < 0.01);
1598
1599 assert!(pl.num_blocks() >= 3);
1601
1602 let block0_max = pl.block_max_weight(0).unwrap();
1604 assert!((block0_max - 12.7).abs() < 0.01);
1605
1606 let block1_max = pl.block_max_weight(1).unwrap();
1608 assert!((block1_max - 25.5).abs() < 0.01);
1609
1610 let block2_max = pl.block_max_weight(2).unwrap();
1612 assert!((block2_max - 29.9).abs() < 0.01);
1613
1614 let query_weight = 2.0;
1616 assert!((pl.max_contribution(query_weight) - 59.8).abs() < 0.1);
1617
1618 let mut iter = pl.iterator();
1620 assert!((iter.current_block_max_weight() - 12.7).abs() < 0.01);
1621 assert!((iter.current_block_max_contribution(query_weight) - 25.4).abs() < 0.1);
1622
1623 iter.seek(128);
1625 assert!((iter.current_block_max_weight() - 25.5).abs() < 0.01);
1626 }
1627
1628 #[test]
1629 fn test_sparse_skip_list_serialization() {
1630 let mut skip_list = SparseSkipList::new();
1631 skip_list.push(0, 127, 0, 12.7);
1632 skip_list.push(128, 255, 100, 25.5);
1633 skip_list.push(256, 299, 200, 29.9);
1634
1635 assert_eq!(skip_list.len(), 3);
1636 assert!((skip_list.global_max_weight() - 29.9).abs() < 0.01);
1637
1638 let mut buffer = Vec::new();
1640 skip_list.write(&mut buffer).unwrap();
1641
1642 let restored = SparseSkipList::read(&mut buffer.as_slice()).unwrap();
1644
1645 assert_eq!(restored.len(), 3);
1646 assert!((restored.global_max_weight() - 29.9).abs() < 0.01);
1647
1648 let e0 = restored.get(0).unwrap();
1650 assert_eq!(e0.first_doc, 0);
1651 assert_eq!(e0.last_doc, 127);
1652 assert!((e0.max_weight - 12.7).abs() < 0.01);
1653
1654 let e1 = restored.get(1).unwrap();
1655 assert_eq!(e1.first_doc, 128);
1656 assert!((e1.max_weight - 25.5).abs() < 0.01);
1657 }
1658}