1mod block;
27mod config;
28
29pub use block::{BlockSparsePostingIterator, BlockSparsePostingList, SparseBlock};
30pub use config::{
31 IndexSize, QueryWeighting, SparseEntry, SparseQueryConfig, SparseVector, SparseVectorConfig,
32 WeightQuantization,
33};
34
35use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
36use std::io::{self, Read, Write};
37
38use super::posting_common::{read_vint, write_vint};
39use crate::DocId;
40
41#[derive(Debug, Clone, Copy)]
43pub struct SparsePosting {
44 pub doc_id: DocId,
45 pub weight: f32,
46}
47
48pub const SPARSE_BLOCK_SIZE: usize = 128;
50
51#[derive(Debug, Clone, Copy, PartialEq)]
56pub struct SparseSkipEntry {
57 pub first_doc: DocId,
59 pub last_doc: DocId,
61 pub offset: u32,
63 pub length: u32,
65 pub max_weight: f32,
67}
68
69impl SparseSkipEntry {
70 pub const SIZE: usize = 20; pub fn new(
74 first_doc: DocId,
75 last_doc: DocId,
76 offset: u32,
77 length: u32,
78 max_weight: f32,
79 ) -> Self {
80 Self {
81 first_doc,
82 last_doc,
83 offset,
84 length,
85 max_weight,
86 }
87 }
88
89 #[inline]
94 pub fn block_max_contribution(&self, query_weight: f32) -> f32 {
95 query_weight * self.max_weight
96 }
97
98 pub fn write<W: Write + ?Sized>(&self, writer: &mut W) -> io::Result<()> {
100 writer.write_u32::<LittleEndian>(self.first_doc)?;
101 writer.write_u32::<LittleEndian>(self.last_doc)?;
102 writer.write_u32::<LittleEndian>(self.offset)?;
103 writer.write_u32::<LittleEndian>(self.length)?;
104 writer.write_f32::<LittleEndian>(self.max_weight)?;
105 Ok(())
106 }
107
108 pub fn read<R: Read>(reader: &mut R) -> io::Result<Self> {
110 let first_doc = reader.read_u32::<LittleEndian>()?;
111 let last_doc = reader.read_u32::<LittleEndian>()?;
112 let offset = reader.read_u32::<LittleEndian>()?;
113 let length = reader.read_u32::<LittleEndian>()?;
114 let max_weight = reader.read_f32::<LittleEndian>()?;
115 Ok(Self {
116 first_doc,
117 last_doc,
118 offset,
119 length,
120 max_weight,
121 })
122 }
123}
124
125#[derive(Debug, Clone, Default)]
127pub struct SparseSkipList {
128 entries: Vec<SparseSkipEntry>,
129 global_max_weight: f32,
131}
132
133impl SparseSkipList {
134 pub fn new() -> Self {
135 Self::default()
136 }
137
138 pub fn push(
140 &mut self,
141 first_doc: DocId,
142 last_doc: DocId,
143 offset: u32,
144 length: u32,
145 max_weight: f32,
146 ) {
147 self.global_max_weight = self.global_max_weight.max(max_weight);
148 self.entries.push(SparseSkipEntry::new(
149 first_doc, last_doc, offset, length, max_weight,
150 ));
151 }
152
153 pub fn len(&self) -> usize {
155 self.entries.len()
156 }
157
158 pub fn is_empty(&self) -> bool {
159 self.entries.is_empty()
160 }
161
162 pub fn get(&self, index: usize) -> Option<&SparseSkipEntry> {
164 self.entries.get(index)
165 }
166
167 pub fn global_max_weight(&self) -> f32 {
169 self.global_max_weight
170 }
171
172 pub fn find_block(&self, target: DocId) -> Option<usize> {
174 if self.entries.is_empty() {
175 return None;
176 }
177 let idx = self.entries.partition_point(|e| e.last_doc < target);
179 if idx < self.entries.len() {
180 Some(idx)
181 } else {
182 None
183 }
184 }
185
186 pub fn iter(&self) -> impl Iterator<Item = &SparseSkipEntry> {
188 self.entries.iter()
189 }
190
191 pub fn write<W: Write>(&self, writer: &mut W) -> io::Result<()> {
193 writer.write_u32::<LittleEndian>(self.entries.len() as u32)?;
194 writer.write_f32::<LittleEndian>(self.global_max_weight)?;
195 for entry in &self.entries {
196 entry.write(writer)?;
197 }
198 Ok(())
199 }
200
201 pub fn read<R: Read>(reader: &mut R) -> io::Result<Self> {
203 let count = reader.read_u32::<LittleEndian>()? as usize;
204 let global_max_weight = reader.read_f32::<LittleEndian>()?;
205 let mut entries = Vec::with_capacity(count);
206 for _ in 0..count {
207 entries.push(SparseSkipEntry::read(reader)?);
208 }
209 Ok(Self {
210 entries,
211 global_max_weight,
212 })
213 }
214}
215
216#[derive(Debug, Clone)]
222pub struct SparsePostingList {
223 quantization: WeightQuantization,
225 scale: f32,
227 min_val: f32,
229 doc_count: u32,
231 data: Vec<u8>,
233}
234
235impl SparsePostingList {
236 pub fn from_postings(
238 postings: &[(DocId, f32)],
239 quantization: WeightQuantization,
240 ) -> io::Result<Self> {
241 if postings.is_empty() {
242 return Ok(Self {
243 quantization,
244 scale: 1.0,
245 min_val: 0.0,
246 doc_count: 0,
247 data: Vec::new(),
248 });
249 }
250
251 let weights: Vec<f32> = postings.iter().map(|(_, w)| *w).collect();
253 let min_val = weights.iter().cloned().fold(f32::INFINITY, f32::min);
254 let max_val = weights.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
255
256 let (scale, adjusted_min) = match quantization {
257 WeightQuantization::Float32 | WeightQuantization::Float16 => (1.0, 0.0),
258 WeightQuantization::UInt8 => {
259 let range = max_val - min_val;
260 if range < f32::EPSILON {
261 (1.0, min_val)
262 } else {
263 (range / 255.0, min_val)
264 }
265 }
266 WeightQuantization::UInt4 => {
267 let range = max_val - min_val;
268 if range < f32::EPSILON {
269 (1.0, min_val)
270 } else {
271 (range / 15.0, min_val)
272 }
273 }
274 };
275
276 let mut data = Vec::new();
277
278 let mut prev_doc_id = 0u32;
280 for (doc_id, _) in postings {
281 let delta = doc_id - prev_doc_id;
282 write_vint(&mut data, delta as u64)?;
283 prev_doc_id = *doc_id;
284 }
285
286 match quantization {
288 WeightQuantization::Float32 => {
289 for (_, weight) in postings {
290 data.write_f32::<LittleEndian>(*weight)?;
291 }
292 }
293 WeightQuantization::Float16 => {
294 use half::slice::HalfFloatSliceExt;
296 let weights: Vec<f32> = postings.iter().map(|(_, w)| *w).collect();
297 let mut f16_slice: Vec<half::f16> = vec![half::f16::ZERO; weights.len()];
298 f16_slice.convert_from_f32_slice(&weights);
299 for h in f16_slice {
300 data.write_u16::<LittleEndian>(h.to_bits())?;
301 }
302 }
303 WeightQuantization::UInt8 => {
304 for (_, weight) in postings {
305 let quantized = ((*weight - adjusted_min) / scale).round() as u8;
306 data.write_u8(quantized)?;
307 }
308 }
309 WeightQuantization::UInt4 => {
310 let mut i = 0;
312 while i < postings.len() {
313 let q1 = ((postings[i].1 - adjusted_min) / scale).round() as u8 & 0x0F;
314 let q2 = if i + 1 < postings.len() {
315 ((postings[i + 1].1 - adjusted_min) / scale).round() as u8 & 0x0F
316 } else {
317 0
318 };
319 data.write_u8((q2 << 4) | q1)?;
320 i += 2;
321 }
322 }
323 }
324
325 Ok(Self {
326 quantization,
327 scale,
328 min_val: adjusted_min,
329 doc_count: postings.len() as u32,
330 data,
331 })
332 }
333
334 pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
336 writer.write_u8(self.quantization as u8)?;
337 writer.write_f32::<LittleEndian>(self.scale)?;
338 writer.write_f32::<LittleEndian>(self.min_val)?;
339 writer.write_u32::<LittleEndian>(self.doc_count)?;
340 writer.write_u32::<LittleEndian>(self.data.len() as u32)?;
341 writer.write_all(&self.data)?;
342 Ok(())
343 }
344
345 pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
347 let quant_byte = reader.read_u8()?;
348 let quantization = WeightQuantization::from_u8(quant_byte).ok_or_else(|| {
349 io::Error::new(io::ErrorKind::InvalidData, "Invalid quantization type")
350 })?;
351 let scale = reader.read_f32::<LittleEndian>()?;
352 let min_val = reader.read_f32::<LittleEndian>()?;
353 let doc_count = reader.read_u32::<LittleEndian>()?;
354 let data_len = reader.read_u32::<LittleEndian>()? as usize;
355 let mut data = vec![0u8; data_len];
356 reader.read_exact(&mut data)?;
357
358 Ok(Self {
359 quantization,
360 scale,
361 min_val,
362 doc_count,
363 data,
364 })
365 }
366
367 pub fn doc_count(&self) -> u32 {
369 self.doc_count
370 }
371
372 pub fn quantization(&self) -> WeightQuantization {
374 self.quantization
375 }
376
377 pub fn iterator(&self) -> SparsePostingIterator<'_> {
379 SparsePostingIterator::new(self)
380 }
381
382 pub fn decode_all(&self) -> io::Result<Vec<(DocId, f32)>> {
384 let mut result = Vec::with_capacity(self.doc_count as usize);
385 let mut iter = self.iterator();
386
387 while !iter.exhausted {
388 result.push((iter.doc_id, iter.weight));
389 iter.advance();
390 }
391
392 Ok(result)
393 }
394}
395
396pub struct SparsePostingIterator<'a> {
398 posting_list: &'a SparsePostingList,
399 doc_id_offset: usize,
401 weight_offset: usize,
403 index: usize,
405 doc_id: DocId,
407 weight: f32,
409 exhausted: bool,
411}
412
413impl<'a> SparsePostingIterator<'a> {
414 fn new(posting_list: &'a SparsePostingList) -> Self {
415 let mut iter = Self {
416 posting_list,
417 doc_id_offset: 0,
418 weight_offset: 0,
419 index: 0,
420 doc_id: 0,
421 weight: 0.0,
422 exhausted: posting_list.doc_count == 0,
423 };
424
425 if !iter.exhausted {
426 iter.weight_offset = iter.calculate_weight_offset();
428 iter.load_current();
429 }
430
431 iter
432 }
433
434 fn calculate_weight_offset(&self) -> usize {
435 let mut offset = 0;
437 let mut reader = &self.posting_list.data[..];
438
439 for _ in 0..self.posting_list.doc_count {
440 if read_vint(&mut reader).is_ok() {
441 offset = self.posting_list.data.len() - reader.len();
442 }
443 }
444
445 offset
446 }
447
448 fn load_current(&mut self) {
449 if self.index >= self.posting_list.doc_count as usize {
450 self.exhausted = true;
451 return;
452 }
453
454 let mut reader = &self.posting_list.data[self.doc_id_offset..];
456 if let Ok(delta) = read_vint(&mut reader) {
457 self.doc_id = self.doc_id.wrapping_add(delta as u32);
458 self.doc_id_offset = self.posting_list.data.len() - reader.len();
459 }
460
461 let weight_idx = self.index;
463 let pl = self.posting_list;
464
465 self.weight = match pl.quantization {
466 WeightQuantization::Float32 => {
467 let offset = self.weight_offset + weight_idx * 4;
468 if offset + 4 <= pl.data.len() {
469 let bytes = &pl.data[offset..offset + 4];
470 f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]])
471 } else {
472 0.0
473 }
474 }
475 WeightQuantization::Float16 => {
476 let offset = self.weight_offset + weight_idx * 2;
477 if offset + 2 <= pl.data.len() {
478 let bits = u16::from_le_bytes([pl.data[offset], pl.data[offset + 1]]);
479 half::f16::from_bits(bits).to_f32()
480 } else {
481 0.0
482 }
483 }
484 WeightQuantization::UInt8 => {
485 let offset = self.weight_offset + weight_idx;
486 if offset < pl.data.len() {
487 let quantized = pl.data[offset];
488 quantized as f32 * pl.scale + pl.min_val
489 } else {
490 0.0
491 }
492 }
493 WeightQuantization::UInt4 => {
494 let byte_offset = self.weight_offset + weight_idx / 2;
495 if byte_offset < pl.data.len() {
496 let byte = pl.data[byte_offset];
497 let quantized = if weight_idx.is_multiple_of(2) {
498 byte & 0x0F
499 } else {
500 (byte >> 4) & 0x0F
501 };
502 quantized as f32 * pl.scale + pl.min_val
503 } else {
504 0.0
505 }
506 }
507 };
508 }
509
510 pub fn doc(&self) -> DocId {
512 if self.exhausted {
513 super::TERMINATED
514 } else {
515 self.doc_id
516 }
517 }
518
519 pub fn weight(&self) -> f32 {
521 if self.exhausted { 0.0 } else { self.weight }
522 }
523
524 pub fn advance(&mut self) -> DocId {
526 if self.exhausted {
527 return super::TERMINATED;
528 }
529
530 self.index += 1;
531 if self.index >= self.posting_list.doc_count as usize {
532 self.exhausted = true;
533 return super::TERMINATED;
534 }
535
536 self.load_current();
537 self.doc_id
538 }
539
540 pub fn seek(&mut self, target: DocId) -> DocId {
542 while !self.exhausted && self.doc_id < target {
543 self.advance();
544 }
545 self.doc()
546 }
547}
548
549#[cfg(test)]
550mod tests {
551 use super::*;
552
553 #[test]
554 fn test_sparse_vector_dot_product() {
555 let v1 = SparseVector::from_entries(&[0, 2, 5], &[1.0, 2.0, 3.0]);
556 let v2 = SparseVector::from_entries(&[1, 2, 5], &[1.0, 4.0, 2.0]);
557
558 assert!((v1.dot(&v2) - 14.0).abs() < 1e-6);
560 }
561
562 #[test]
563 fn test_sparse_posting_list_float32() {
564 let postings = vec![(0, 1.5), (5, 2.3), (10, 0.8), (100, 3.15)];
565 let pl = SparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
566
567 assert_eq!(pl.doc_count(), 4);
568
569 let mut iter = pl.iterator();
570 assert_eq!(iter.doc(), 0);
571 assert!((iter.weight() - 1.5).abs() < 1e-6);
572
573 iter.advance();
574 assert_eq!(iter.doc(), 5);
575 assert!((iter.weight() - 2.3).abs() < 1e-6);
576
577 iter.advance();
578 assert_eq!(iter.doc(), 10);
579
580 iter.advance();
581 assert_eq!(iter.doc(), 100);
582 assert!((iter.weight() - 3.15).abs() < 1e-6);
583
584 iter.advance();
585 assert_eq!(iter.doc(), super::super::TERMINATED);
586 }
587
588 #[test]
589 fn test_sparse_posting_list_uint8() {
590 let postings = vec![(0, 0.0), (5, 0.5), (10, 1.0)];
591 let pl = SparsePostingList::from_postings(&postings, WeightQuantization::UInt8).unwrap();
592
593 let decoded = pl.decode_all().unwrap();
594 assert_eq!(decoded.len(), 3);
595
596 assert!(decoded[0].1 < decoded[1].1);
598 assert!(decoded[1].1 < decoded[2].1);
599 }
600
601 #[test]
602 fn test_block_sparse_posting_list() {
603 let postings: Vec<(DocId, u16, f32)> =
605 (0..300).map(|i| (i * 2, 0, (i as f32) * 0.1)).collect();
606
607 let pl =
608 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
609
610 assert_eq!(pl.doc_count(), 300);
611 assert!(pl.num_blocks() >= 2);
612
613 let mut iter = pl.iterator();
615 for (expected_doc, _, expected_weight) in &postings {
616 assert_eq!(iter.doc(), *expected_doc);
617 assert!((iter.weight() - expected_weight).abs() < 1e-6);
618 iter.advance();
619 }
620 assert_eq!(iter.doc(), super::super::TERMINATED);
621 }
622
623 #[test]
624 fn test_block_sparse_seek() {
625 let postings: Vec<(DocId, u16, f32)> = (0..500).map(|i| (i * 3, 0, i as f32)).collect();
626
627 let pl =
628 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
629
630 let mut iter = pl.iterator();
631
632 assert_eq!(iter.seek(300), 300);
634
635 assert_eq!(iter.seek(301), 303);
637
638 assert_eq!(iter.seek(2000), super::super::TERMINATED);
640 }
641
642 #[test]
643 fn test_serialization_roundtrip() {
644 let postings: Vec<(DocId, u16, f32)> = vec![(0, 0, 1.0), (10, 0, 2.0), (100, 0, 3.0)];
645
646 for quant in [
647 WeightQuantization::Float32,
648 WeightQuantization::Float16,
649 WeightQuantization::UInt8,
650 ] {
651 let pl = BlockSparsePostingList::from_postings(&postings, quant).unwrap();
652
653 let mut buffer = Vec::new();
654 pl.serialize(&mut buffer).unwrap();
655
656 let pl2 = BlockSparsePostingList::deserialize(&mut &buffer[..]).unwrap();
657
658 assert_eq!(pl.doc_count(), pl2.doc_count());
659
660 let mut iter1 = pl.iterator();
662 let mut iter2 = pl2.iterator();
663
664 while iter1.doc() != super::super::TERMINATED {
665 assert_eq!(iter1.doc(), iter2.doc());
666 assert!((iter1.weight() - iter2.weight()).abs() < 0.1);
667 iter1.advance();
668 iter2.advance();
669 }
670 }
671 }
672
673 #[test]
674 fn test_concatenate() {
675 let postings1: Vec<(DocId, u16, f32)> = vec![(0, 0, 1.0), (5, 1, 2.0)];
676 let postings2: Vec<(DocId, u16, f32)> = vec![(0, 0, 3.0), (10, 1, 4.0)];
677
678 let pl1 =
679 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
680 let pl2 =
681 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
682
683 let mut all: Vec<(DocId, u16, f32)> = pl1.decode_all();
685 for (doc_id, ord, w) in pl2.decode_all() {
686 all.push((doc_id + 100, ord, w));
687 }
688 let merged =
689 BlockSparsePostingList::from_postings(&all, WeightQuantization::Float32).unwrap();
690
691 assert_eq!(merged.doc_count(), 4);
692
693 let decoded = merged.decode_all();
694 assert_eq!(decoded[0], (0, 0, 1.0));
695 assert_eq!(decoded[1], (5, 1, 2.0));
696 assert_eq!(decoded[2], (100, 0, 3.0));
697 assert_eq!(decoded[3], (110, 1, 4.0));
698 }
699
700 #[test]
701 fn test_sparse_vector_config() {
702 let default = SparseVectorConfig::default();
704 assert_eq!(default.index_size, IndexSize::U32);
705 assert_eq!(default.weight_quantization, WeightQuantization::Float32);
706 assert_eq!(default.bytes_per_entry(), 8.0); let splade = SparseVectorConfig::splade();
710 assert_eq!(splade.index_size, IndexSize::U16);
711 assert_eq!(splade.weight_quantization, WeightQuantization::UInt8);
712 assert_eq!(splade.bytes_per_entry(), 3.0); assert_eq!(splade.weight_threshold, 0.01);
714 assert_eq!(splade.posting_list_pruning, Some(0.1));
715 assert!(splade.query_config.is_some());
716 let query_cfg = splade.query_config.as_ref().unwrap();
717 assert_eq!(query_cfg.heap_factor, 0.8);
718 assert_eq!(query_cfg.max_query_dims, Some(20));
719
720 let compact = SparseVectorConfig::compact();
722 assert_eq!(compact.index_size, IndexSize::U16);
723 assert_eq!(compact.weight_quantization, WeightQuantization::UInt4);
724 assert_eq!(compact.bytes_per_entry(), 2.5); let conservative = SparseVectorConfig::conservative();
728 assert_eq!(conservative.index_size, IndexSize::U32);
729 assert_eq!(
730 conservative.weight_quantization,
731 WeightQuantization::Float16
732 );
733 assert_eq!(conservative.weight_threshold, 0.005);
734 assert_eq!(conservative.posting_list_pruning, None);
735
736 let byte = splade.to_byte();
738 let restored = SparseVectorConfig::from_byte(byte).unwrap();
739 assert_eq!(restored.index_size, splade.index_size);
740 assert_eq!(restored.weight_quantization, splade.weight_quantization);
741 }
744
745 #[test]
746 fn test_index_size() {
747 assert_eq!(IndexSize::U16.bytes(), 2);
748 assert_eq!(IndexSize::U32.bytes(), 4);
749 assert_eq!(IndexSize::U16.max_value(), 65535);
750 assert_eq!(IndexSize::U32.max_value(), u32::MAX);
751 }
752
753 #[test]
754 fn test_block_max_weight() {
755 let postings: Vec<(DocId, u16, f32)> = (0..300)
756 .map(|i| (i as DocId, 0, (i as f32) * 0.1))
757 .collect();
758
759 let pl =
760 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
761
762 assert!((pl.global_max_weight() - 29.9).abs() < 0.01);
763 assert!(pl.num_blocks() >= 3);
764
765 let block0_max = pl.block_max_weight(0).unwrap();
766 assert!((block0_max - 12.7).abs() < 0.01);
767
768 let block1_max = pl.block_max_weight(1).unwrap();
769 assert!((block1_max - 25.5).abs() < 0.01);
770
771 let block2_max = pl.block_max_weight(2).unwrap();
772 assert!((block2_max - 29.9).abs() < 0.01);
773
774 let query_weight = 2.0;
776 let mut iter = pl.iterator();
777 assert!((iter.current_block_max_weight() - 12.7).abs() < 0.01);
778 assert!((iter.current_block_max_contribution(query_weight) - 25.4).abs() < 0.1);
779
780 iter.seek(128);
781 assert!((iter.current_block_max_weight() - 25.5).abs() < 0.01);
782 }
783
784 #[test]
785 fn test_sparse_skip_list_serialization() {
786 let mut skip_list = SparseSkipList::new();
787 skip_list.push(0, 127, 0, 50, 12.7);
788 skip_list.push(128, 255, 100, 60, 25.5);
789 skip_list.push(256, 299, 200, 40, 29.9);
790
791 assert_eq!(skip_list.len(), 3);
792 assert!((skip_list.global_max_weight() - 29.9).abs() < 0.01);
793
794 let mut buffer = Vec::new();
796 skip_list.write(&mut buffer).unwrap();
797
798 let restored = SparseSkipList::read(&mut buffer.as_slice()).unwrap();
800
801 assert_eq!(restored.len(), 3);
802 assert!((restored.global_max_weight() - 29.9).abs() < 0.01);
803
804 let e0 = restored.get(0).unwrap();
806 assert_eq!(e0.first_doc, 0);
807 assert_eq!(e0.last_doc, 127);
808 assert!((e0.max_weight - 12.7).abs() < 0.01);
809
810 let e1 = restored.get(1).unwrap();
811 assert_eq!(e1.first_doc, 128);
812 assert!((e1.max_weight - 25.5).abs() < 0.01);
813 }
814}