1mod block;
27mod config;
28mod partitioner;
29
30pub use block::{BlockSparsePostingIterator, BlockSparsePostingList, SparseBlock};
31pub use config::{
32 IndexSize, QueryWeighting, SparseEntry, SparseQueryConfig, SparseVector, SparseVectorConfig,
33 WeightQuantization,
34};
35pub use partitioner::optimal_partition;
36
37use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
38use std::io::{self, Read, Write};
39
40use super::posting_common::{read_vint, write_vint};
41use crate::DocId;
42
43#[derive(Debug, Clone, Copy)]
45pub struct SparsePosting {
46 pub doc_id: DocId,
47 pub weight: f32,
48}
49
50pub const SPARSE_BLOCK_SIZE: usize = 128;
52
53#[derive(Debug, Clone, Copy, PartialEq)]
58pub struct SparseSkipEntry {
59 pub first_doc: DocId,
61 pub last_doc: DocId,
63 pub offset: u64,
65 pub length: u32,
67 pub max_weight: f32,
69}
70
71impl SparseSkipEntry {
72 pub const SIZE: usize = 24; pub fn new(
76 first_doc: DocId,
77 last_doc: DocId,
78 offset: u64,
79 length: u32,
80 max_weight: f32,
81 ) -> Self {
82 Self {
83 first_doc,
84 last_doc,
85 offset,
86 length,
87 max_weight,
88 }
89 }
90
91 #[inline]
96 pub fn block_max_contribution(&self, query_weight: f32) -> f32 {
97 query_weight * self.max_weight
98 }
99
100 pub fn write<W: Write + ?Sized>(&self, writer: &mut W) -> io::Result<()> {
102 writer.write_u32::<LittleEndian>(self.first_doc)?;
103 writer.write_u32::<LittleEndian>(self.last_doc)?;
104 writer.write_u64::<LittleEndian>(self.offset)?;
105 writer.write_u32::<LittleEndian>(self.length)?;
106 writer.write_f32::<LittleEndian>(self.max_weight)?;
107 Ok(())
108 }
109
110 pub fn read<R: Read>(reader: &mut R) -> io::Result<Self> {
112 let first_doc = reader.read_u32::<LittleEndian>()?;
113 let last_doc = reader.read_u32::<LittleEndian>()?;
114 let offset = reader.read_u64::<LittleEndian>()?;
115 let length = reader.read_u32::<LittleEndian>()?;
116 let max_weight = reader.read_f32::<LittleEndian>()?;
117 Ok(Self {
118 first_doc,
119 last_doc,
120 offset,
121 length,
122 max_weight,
123 })
124 }
125}
126
127#[derive(Debug, Clone, Default)]
129pub struct SparseSkipList {
130 entries: Vec<SparseSkipEntry>,
131 global_max_weight: f32,
133}
134
135impl SparseSkipList {
136 pub fn new() -> Self {
137 Self::default()
138 }
139
140 pub fn push(
142 &mut self,
143 first_doc: DocId,
144 last_doc: DocId,
145 offset: u64,
146 length: u32,
147 max_weight: f32,
148 ) {
149 self.global_max_weight = self.global_max_weight.max(max_weight);
150 self.entries.push(SparseSkipEntry::new(
151 first_doc, last_doc, offset, length, max_weight,
152 ));
153 }
154
155 pub fn len(&self) -> usize {
157 self.entries.len()
158 }
159
160 pub fn is_empty(&self) -> bool {
161 self.entries.is_empty()
162 }
163
164 pub fn get(&self, index: usize) -> Option<&SparseSkipEntry> {
166 self.entries.get(index)
167 }
168
169 pub fn global_max_weight(&self) -> f32 {
171 self.global_max_weight
172 }
173
174 pub fn find_block(&self, target: DocId) -> Option<usize> {
176 if self.entries.is_empty() {
177 return None;
178 }
179 let idx = self.entries.partition_point(|e| e.last_doc < target);
181 if idx < self.entries.len() {
182 Some(idx)
183 } else {
184 None
185 }
186 }
187
188 pub fn iter(&self) -> impl Iterator<Item = &SparseSkipEntry> {
190 self.entries.iter()
191 }
192
193 pub fn write<W: Write>(&self, writer: &mut W) -> io::Result<()> {
195 writer.write_u32::<LittleEndian>(self.entries.len() as u32)?;
196 writer.write_f32::<LittleEndian>(self.global_max_weight)?;
197 for entry in &self.entries {
198 entry.write(writer)?;
199 }
200 Ok(())
201 }
202
203 pub fn read<R: Read>(reader: &mut R) -> io::Result<Self> {
205 let count = reader.read_u32::<LittleEndian>()? as usize;
206 let global_max_weight = reader.read_f32::<LittleEndian>()?;
207 let mut entries = Vec::with_capacity(count);
208 for _ in 0..count {
209 entries.push(SparseSkipEntry::read(reader)?);
210 }
211 Ok(Self {
212 entries,
213 global_max_weight,
214 })
215 }
216}
217
218#[derive(Debug, Clone)]
224pub struct SparsePostingList {
225 quantization: WeightQuantization,
227 scale: f32,
229 min_val: f32,
231 doc_count: u32,
233 data: Vec<u8>,
235}
236
237impl SparsePostingList {
238 pub fn from_postings(
240 postings: &[(DocId, f32)],
241 quantization: WeightQuantization,
242 ) -> io::Result<Self> {
243 if postings.is_empty() {
244 return Ok(Self {
245 quantization,
246 scale: 1.0,
247 min_val: 0.0,
248 doc_count: 0,
249 data: Vec::new(),
250 });
251 }
252
253 let weights: Vec<f32> = postings.iter().map(|(_, w)| *w).collect();
255 let min_val = weights.iter().cloned().fold(f32::INFINITY, f32::min);
256 let max_val = weights.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
257
258 let (scale, adjusted_min) = match quantization {
259 WeightQuantization::Float32 | WeightQuantization::Float16 => (1.0, 0.0),
260 WeightQuantization::UInt8 => {
261 let range = max_val - min_val;
262 if range < f32::EPSILON {
263 (1.0, min_val)
264 } else {
265 (range / 255.0, min_val)
266 }
267 }
268 WeightQuantization::UInt4 => {
269 let range = max_val - min_val;
270 if range < f32::EPSILON {
271 (1.0, min_val)
272 } else {
273 (range / 15.0, min_val)
274 }
275 }
276 };
277
278 let mut data = Vec::new();
279
280 let mut prev_doc_id = 0u32;
282 for (doc_id, _) in postings {
283 let delta = doc_id - prev_doc_id;
284 write_vint(&mut data, delta as u64)?;
285 prev_doc_id = *doc_id;
286 }
287
288 match quantization {
290 WeightQuantization::Float32 => {
291 for (_, weight) in postings {
292 data.write_f32::<LittleEndian>(*weight)?;
293 }
294 }
295 WeightQuantization::Float16 => {
296 use half::slice::HalfFloatSliceExt;
298 let weights: Vec<f32> = postings.iter().map(|(_, w)| *w).collect();
299 let mut f16_slice: Vec<half::f16> = vec![half::f16::ZERO; weights.len()];
300 f16_slice.convert_from_f32_slice(&weights);
301 for h in f16_slice {
302 data.write_u16::<LittleEndian>(h.to_bits())?;
303 }
304 }
305 WeightQuantization::UInt8 => {
306 for (_, weight) in postings {
307 let quantized = ((*weight - adjusted_min) / scale).round() as u8;
308 data.write_u8(quantized)?;
309 }
310 }
311 WeightQuantization::UInt4 => {
312 let mut i = 0;
314 while i < postings.len() {
315 let q1 = ((postings[i].1 - adjusted_min) / scale).round() as u8 & 0x0F;
316 let q2 = if i + 1 < postings.len() {
317 ((postings[i + 1].1 - adjusted_min) / scale).round() as u8 & 0x0F
318 } else {
319 0
320 };
321 data.write_u8((q2 << 4) | q1)?;
322 i += 2;
323 }
324 }
325 }
326
327 Ok(Self {
328 quantization,
329 scale,
330 min_val: adjusted_min,
331 doc_count: postings.len() as u32,
332 data,
333 })
334 }
335
336 pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
338 writer.write_u8(self.quantization as u8)?;
339 writer.write_f32::<LittleEndian>(self.scale)?;
340 writer.write_f32::<LittleEndian>(self.min_val)?;
341 writer.write_u32::<LittleEndian>(self.doc_count)?;
342 writer.write_u32::<LittleEndian>(self.data.len() as u32)?;
343 writer.write_all(&self.data)?;
344 Ok(())
345 }
346
347 pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
349 let quant_byte = reader.read_u8()?;
350 let quantization = WeightQuantization::from_u8(quant_byte).ok_or_else(|| {
351 io::Error::new(io::ErrorKind::InvalidData, "Invalid quantization type")
352 })?;
353 let scale = reader.read_f32::<LittleEndian>()?;
354 let min_val = reader.read_f32::<LittleEndian>()?;
355 let doc_count = reader.read_u32::<LittleEndian>()?;
356 let data_len = reader.read_u32::<LittleEndian>()? as usize;
357 let mut data = vec![0u8; data_len];
358 reader.read_exact(&mut data)?;
359
360 Ok(Self {
361 quantization,
362 scale,
363 min_val,
364 doc_count,
365 data,
366 })
367 }
368
369 pub fn doc_count(&self) -> u32 {
371 self.doc_count
372 }
373
374 pub fn quantization(&self) -> WeightQuantization {
376 self.quantization
377 }
378
379 pub fn iterator(&self) -> SparsePostingIterator<'_> {
381 SparsePostingIterator::new(self)
382 }
383
384 pub fn decode_all(&self) -> io::Result<Vec<(DocId, f32)>> {
386 let mut result = Vec::with_capacity(self.doc_count as usize);
387 let mut iter = self.iterator();
388
389 while !iter.exhausted {
390 result.push((iter.doc_id, iter.weight));
391 iter.advance();
392 }
393
394 Ok(result)
395 }
396}
397
398pub struct SparsePostingIterator<'a> {
400 posting_list: &'a SparsePostingList,
401 doc_id_offset: usize,
403 weight_offset: usize,
405 index: usize,
407 doc_id: DocId,
409 weight: f32,
411 exhausted: bool,
413}
414
415impl<'a> SparsePostingIterator<'a> {
416 fn new(posting_list: &'a SparsePostingList) -> Self {
417 let mut iter = Self {
418 posting_list,
419 doc_id_offset: 0,
420 weight_offset: 0,
421 index: 0,
422 doc_id: 0,
423 weight: 0.0,
424 exhausted: posting_list.doc_count == 0,
425 };
426
427 if !iter.exhausted {
428 iter.weight_offset = iter.calculate_weight_offset();
430 iter.load_current();
431 }
432
433 iter
434 }
435
436 fn calculate_weight_offset(&self) -> usize {
437 let mut offset = 0;
439 let mut reader = &self.posting_list.data[..];
440
441 for _ in 0..self.posting_list.doc_count {
442 if read_vint(&mut reader).is_ok() {
443 offset = self.posting_list.data.len() - reader.len();
444 }
445 }
446
447 offset
448 }
449
450 fn load_current(&mut self) {
451 if self.index >= self.posting_list.doc_count as usize {
452 self.exhausted = true;
453 return;
454 }
455
456 let mut reader = &self.posting_list.data[self.doc_id_offset..];
458 if let Ok(delta) = read_vint(&mut reader) {
459 self.doc_id = self.doc_id.wrapping_add(delta as u32);
460 self.doc_id_offset = self.posting_list.data.len() - reader.len();
461 }
462
463 let weight_idx = self.index;
465 let pl = self.posting_list;
466
467 self.weight = match pl.quantization {
468 WeightQuantization::Float32 => {
469 let offset = self.weight_offset + weight_idx * 4;
470 if offset + 4 <= pl.data.len() {
471 let bytes = &pl.data[offset..offset + 4];
472 f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]])
473 } else {
474 0.0
475 }
476 }
477 WeightQuantization::Float16 => {
478 let offset = self.weight_offset + weight_idx * 2;
479 if offset + 2 <= pl.data.len() {
480 let bits = u16::from_le_bytes([pl.data[offset], pl.data[offset + 1]]);
481 half::f16::from_bits(bits).to_f32()
482 } else {
483 0.0
484 }
485 }
486 WeightQuantization::UInt8 => {
487 let offset = self.weight_offset + weight_idx;
488 if offset < pl.data.len() {
489 let quantized = pl.data[offset];
490 quantized as f32 * pl.scale + pl.min_val
491 } else {
492 0.0
493 }
494 }
495 WeightQuantization::UInt4 => {
496 let byte_offset = self.weight_offset + weight_idx / 2;
497 if byte_offset < pl.data.len() {
498 let byte = pl.data[byte_offset];
499 let quantized = if weight_idx.is_multiple_of(2) {
500 byte & 0x0F
501 } else {
502 (byte >> 4) & 0x0F
503 };
504 quantized as f32 * pl.scale + pl.min_val
505 } else {
506 0.0
507 }
508 }
509 };
510 }
511
512 pub fn doc(&self) -> DocId {
514 if self.exhausted {
515 super::TERMINATED
516 } else {
517 self.doc_id
518 }
519 }
520
521 pub fn weight(&self) -> f32 {
523 if self.exhausted { 0.0 } else { self.weight }
524 }
525
526 pub fn advance(&mut self) -> DocId {
528 if self.exhausted {
529 return super::TERMINATED;
530 }
531
532 self.index += 1;
533 if self.index >= self.posting_list.doc_count as usize {
534 self.exhausted = true;
535 return super::TERMINATED;
536 }
537
538 self.load_current();
539 self.doc_id
540 }
541
542 pub fn seek(&mut self, target: DocId) -> DocId {
544 while !self.exhausted && self.doc_id < target {
545 self.advance();
546 }
547 self.doc()
548 }
549}
550
551#[cfg(test)]
552mod tests {
553 use super::*;
554
555 #[test]
556 fn test_sparse_vector_dot_product() {
557 let v1 = SparseVector::from_entries(&[0, 2, 5], &[1.0, 2.0, 3.0]);
558 let v2 = SparseVector::from_entries(&[1, 2, 5], &[1.0, 4.0, 2.0]);
559
560 assert!((v1.dot(&v2) - 14.0).abs() < 1e-6);
562 }
563
564 #[test]
565 fn test_sparse_posting_list_float32() {
566 let postings = vec![(0, 1.5), (5, 2.3), (10, 0.8), (100, 3.15)];
567 let pl = SparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
568
569 assert_eq!(pl.doc_count(), 4);
570
571 let mut iter = pl.iterator();
572 assert_eq!(iter.doc(), 0);
573 assert!((iter.weight() - 1.5).abs() < 1e-6);
574
575 iter.advance();
576 assert_eq!(iter.doc(), 5);
577 assert!((iter.weight() - 2.3).abs() < 1e-6);
578
579 iter.advance();
580 assert_eq!(iter.doc(), 10);
581
582 iter.advance();
583 assert_eq!(iter.doc(), 100);
584 assert!((iter.weight() - 3.15).abs() < 1e-6);
585
586 iter.advance();
587 assert_eq!(iter.doc(), super::super::TERMINATED);
588 }
589
590 #[test]
591 fn test_sparse_posting_list_uint8() {
592 let postings = vec![(0, 0.0), (5, 0.5), (10, 1.0)];
593 let pl = SparsePostingList::from_postings(&postings, WeightQuantization::UInt8).unwrap();
594
595 let decoded = pl.decode_all().unwrap();
596 assert_eq!(decoded.len(), 3);
597
598 assert!(decoded[0].1 < decoded[1].1);
600 assert!(decoded[1].1 < decoded[2].1);
601 }
602
603 #[test]
604 fn test_block_sparse_posting_list() {
605 let postings: Vec<(DocId, u16, f32)> =
607 (0..300).map(|i| (i * 2, 0, (i as f32) * 0.1)).collect();
608
609 let pl =
610 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
611
612 assert_eq!(pl.doc_count(), 300);
613 assert!(pl.num_blocks() >= 2);
614
615 let mut iter = pl.iterator();
617 for (expected_doc, _, expected_weight) in &postings {
618 assert_eq!(iter.doc(), *expected_doc);
619 assert!((iter.weight() - expected_weight).abs() < 1e-6);
620 iter.advance();
621 }
622 assert_eq!(iter.doc(), super::super::TERMINATED);
623 }
624
625 #[test]
626 fn test_block_sparse_seek() {
627 let postings: Vec<(DocId, u16, f32)> = (0..500).map(|i| (i * 3, 0, i as f32)).collect();
628
629 let pl =
630 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
631
632 let mut iter = pl.iterator();
633
634 assert_eq!(iter.seek(300), 300);
636
637 assert_eq!(iter.seek(301), 303);
639
640 assert_eq!(iter.seek(2000), super::super::TERMINATED);
642 }
643
644 #[test]
645 fn test_serialization_roundtrip() {
646 let postings: Vec<(DocId, u16, f32)> = vec![(0, 0, 1.0), (10, 0, 2.0), (100, 0, 3.0)];
647
648 for quant in [
649 WeightQuantization::Float32,
650 WeightQuantization::Float16,
651 WeightQuantization::UInt8,
652 ] {
653 let pl = BlockSparsePostingList::from_postings(&postings, quant).unwrap();
654
655 let (block_data, skip_entries) = pl.serialize().unwrap();
656 let pl2 =
657 BlockSparsePostingList::from_parts(pl.doc_count(), &block_data, &skip_entries)
658 .unwrap();
659
660 assert_eq!(pl.doc_count(), pl2.doc_count());
661
662 let mut iter1 = pl.iterator();
664 let mut iter2 = pl2.iterator();
665
666 while iter1.doc() != super::super::TERMINATED {
667 assert_eq!(iter1.doc(), iter2.doc());
668 assert!((iter1.weight() - iter2.weight()).abs() < 0.1);
669 iter1.advance();
670 iter2.advance();
671 }
672 }
673 }
674
675 #[test]
676 fn test_concatenate() {
677 let postings1: Vec<(DocId, u16, f32)> = vec![(0, 0, 1.0), (5, 1, 2.0)];
678 let postings2: Vec<(DocId, u16, f32)> = vec![(0, 0, 3.0), (10, 1, 4.0)];
679
680 let pl1 =
681 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
682 let pl2 =
683 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
684
685 let mut all: Vec<(DocId, u16, f32)> = pl1.decode_all();
687 for (doc_id, ord, w) in pl2.decode_all() {
688 all.push((doc_id + 100, ord, w));
689 }
690 let merged =
691 BlockSparsePostingList::from_postings(&all, WeightQuantization::Float32).unwrap();
692
693 assert_eq!(merged.doc_count(), 4);
694
695 let decoded = merged.decode_all();
696 assert_eq!(decoded[0], (0, 0, 1.0));
697 assert_eq!(decoded[1], (5, 1, 2.0));
698 assert_eq!(decoded[2], (100, 0, 3.0));
699 assert_eq!(decoded[3], (110, 1, 4.0));
700 }
701
702 #[test]
703 fn test_sparse_vector_config() {
704 let default = SparseVectorConfig::default();
706 assert_eq!(default.index_size, IndexSize::U32);
707 assert_eq!(default.weight_quantization, WeightQuantization::Float32);
708 assert_eq!(default.bytes_per_entry(), 8.0); let splade = SparseVectorConfig::splade();
712 assert_eq!(splade.index_size, IndexSize::U16);
713 assert_eq!(splade.weight_quantization, WeightQuantization::UInt8);
714 assert_eq!(splade.bytes_per_entry(), 3.0); assert_eq!(splade.weight_threshold, 0.01);
716 assert_eq!(splade.posting_list_pruning, Some(0.1));
717 assert!(splade.query_config.is_some());
718 let query_cfg = splade.query_config.as_ref().unwrap();
719 assert_eq!(query_cfg.heap_factor, 0.8);
720 assert_eq!(query_cfg.max_query_dims, Some(20));
721
722 let compact = SparseVectorConfig::compact();
724 assert_eq!(compact.index_size, IndexSize::U16);
725 assert_eq!(compact.weight_quantization, WeightQuantization::UInt4);
726 assert_eq!(compact.bytes_per_entry(), 2.5); let conservative = SparseVectorConfig::conservative();
730 assert_eq!(conservative.index_size, IndexSize::U32);
731 assert_eq!(
732 conservative.weight_quantization,
733 WeightQuantization::Float16
734 );
735 assert_eq!(conservative.weight_threshold, 0.005);
736 assert_eq!(conservative.posting_list_pruning, None);
737
738 let byte = splade.to_byte();
740 let restored = SparseVectorConfig::from_byte(byte).unwrap();
741 assert_eq!(restored.index_size, splade.index_size);
742 assert_eq!(restored.weight_quantization, splade.weight_quantization);
743 }
746
747 #[test]
748 fn test_index_size() {
749 assert_eq!(IndexSize::U16.bytes(), 2);
750 assert_eq!(IndexSize::U32.bytes(), 4);
751 assert_eq!(IndexSize::U16.max_value(), 65535);
752 assert_eq!(IndexSize::U32.max_value(), u32::MAX);
753 }
754
755 #[test]
756 fn test_block_max_weight() {
757 let postings: Vec<(DocId, u16, f32)> = (0..300)
758 .map(|i| (i as DocId, 0, (i as f32) * 0.1))
759 .collect();
760
761 let pl =
762 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
763
764 assert!((pl.global_max_weight() - 29.9).abs() < 0.01);
765 assert!(pl.num_blocks() >= 3);
766
767 let block0_max = pl.block_max_weight(0).unwrap();
768 assert!((block0_max - 12.7).abs() < 0.01);
769
770 let block1_max = pl.block_max_weight(1).unwrap();
771 assert!((block1_max - 25.5).abs() < 0.01);
772
773 let block2_max = pl.block_max_weight(2).unwrap();
774 assert!((block2_max - 29.9).abs() < 0.01);
775
776 let query_weight = 2.0;
778 let mut iter = pl.iterator();
779 assert!((iter.current_block_max_weight() - 12.7).abs() < 0.01);
780 assert!((iter.current_block_max_contribution(query_weight) - 25.4).abs() < 0.1);
781
782 iter.seek(128);
783 assert!((iter.current_block_max_weight() - 25.5).abs() < 0.01);
784 }
785
786 #[test]
787 fn test_sparse_skip_list_serialization() {
788 let mut skip_list = SparseSkipList::new();
789 skip_list.push(0, 127, 0, 50, 12.7);
790 skip_list.push(128, 255, 100, 60, 25.5);
791 skip_list.push(256, 299, 200, 40, 29.9);
792
793 assert_eq!(skip_list.len(), 3);
794 assert!((skip_list.global_max_weight() - 29.9).abs() < 0.01);
795
796 let mut buffer = Vec::new();
798 skip_list.write(&mut buffer).unwrap();
799
800 let restored = SparseSkipList::read(&mut buffer.as_slice()).unwrap();
802
803 assert_eq!(restored.len(), 3);
804 assert!((restored.global_max_weight() - 29.9).abs() < 0.01);
805
806 let e0 = restored.get(0).unwrap();
808 assert_eq!(e0.first_doc, 0);
809 assert_eq!(e0.last_doc, 127);
810 assert!((e0.max_weight - 12.7).abs() < 0.01);
811
812 let e1 = restored.get(1).unwrap();
813 assert_eq!(e1.first_doc, 128);
814 assert!((e1.max_weight - 25.5).abs() < 0.01);
815 }
816}