1mod block;
27mod config;
28
29pub use block::{BlockSparsePostingIterator, BlockSparsePostingList};
30pub use config::{
31 IndexSize, QueryWeighting, SparseEntry, SparseQueryConfig, SparseVector, SparseVectorConfig,
32 WeightQuantization,
33};
34
35use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
36use std::io::{self, Read, Write};
37
38use super::posting_common::{read_vint, write_vint};
39use crate::DocId;
40
41#[derive(Debug, Clone, Copy)]
43pub struct SparsePosting {
44 pub doc_id: DocId,
45 pub weight: f32,
46}
47
48pub const SPARSE_BLOCK_SIZE: usize = 128;
50
51#[derive(Debug, Clone, Copy, PartialEq)]
55pub struct SparseSkipEntry {
56 pub first_doc: DocId,
58 pub last_doc: DocId,
60 pub offset: u32,
62 pub max_weight: f32,
64}
65
66impl SparseSkipEntry {
67 pub fn new(first_doc: DocId, last_doc: DocId, offset: u32, max_weight: f32) -> Self {
68 Self {
69 first_doc,
70 last_doc,
71 offset,
72 max_weight,
73 }
74 }
75
76 #[inline]
81 pub fn block_max_contribution(&self, query_weight: f32) -> f32 {
82 query_weight * self.max_weight
83 }
84
85 pub fn write<W: Write>(&self, writer: &mut W) -> io::Result<()> {
87 writer.write_u32::<LittleEndian>(self.first_doc)?;
88 writer.write_u32::<LittleEndian>(self.last_doc)?;
89 writer.write_u32::<LittleEndian>(self.offset)?;
90 writer.write_f32::<LittleEndian>(self.max_weight)?;
91 Ok(())
92 }
93
94 pub fn read<R: Read>(reader: &mut R) -> io::Result<Self> {
96 let first_doc = reader.read_u32::<LittleEndian>()?;
97 let last_doc = reader.read_u32::<LittleEndian>()?;
98 let offset = reader.read_u32::<LittleEndian>()?;
99 let max_weight = reader.read_f32::<LittleEndian>()?;
100 Ok(Self {
101 first_doc,
102 last_doc,
103 offset,
104 max_weight,
105 })
106 }
107}
108
109#[derive(Debug, Clone, Default)]
111pub struct SparseSkipList {
112 entries: Vec<SparseSkipEntry>,
113 global_max_weight: f32,
115}
116
117impl SparseSkipList {
118 pub fn new() -> Self {
119 Self::default()
120 }
121
122 pub fn push(&mut self, first_doc: DocId, last_doc: DocId, offset: u32, max_weight: f32) {
124 self.global_max_weight = self.global_max_weight.max(max_weight);
125 self.entries.push(SparseSkipEntry::new(
126 first_doc, last_doc, offset, max_weight,
127 ));
128 }
129
130 pub fn len(&self) -> usize {
132 self.entries.len()
133 }
134
135 pub fn is_empty(&self) -> bool {
136 self.entries.is_empty()
137 }
138
139 pub fn get(&self, index: usize) -> Option<&SparseSkipEntry> {
141 self.entries.get(index)
142 }
143
144 pub fn global_max_weight(&self) -> f32 {
146 self.global_max_weight
147 }
148
149 pub fn find_block(&self, target: DocId) -> Option<usize> {
151 self.entries.iter().position(|e| e.last_doc >= target)
152 }
153
154 pub fn iter(&self) -> impl Iterator<Item = &SparseSkipEntry> {
156 self.entries.iter()
157 }
158
159 pub fn write<W: Write>(&self, writer: &mut W) -> io::Result<()> {
161 writer.write_u32::<LittleEndian>(self.entries.len() as u32)?;
162 writer.write_f32::<LittleEndian>(self.global_max_weight)?;
163 for entry in &self.entries {
164 entry.write(writer)?;
165 }
166 Ok(())
167 }
168
169 pub fn read<R: Read>(reader: &mut R) -> io::Result<Self> {
171 let count = reader.read_u32::<LittleEndian>()? as usize;
172 let global_max_weight = reader.read_f32::<LittleEndian>()?;
173 let mut entries = Vec::with_capacity(count);
174 for _ in 0..count {
175 entries.push(SparseSkipEntry::read(reader)?);
176 }
177 Ok(Self {
178 entries,
179 global_max_weight,
180 })
181 }
182}
183
184#[derive(Debug, Clone)]
190pub struct SparsePostingList {
191 quantization: WeightQuantization,
193 scale: f32,
195 min_val: f32,
197 doc_count: u32,
199 data: Vec<u8>,
201}
202
203impl SparsePostingList {
204 pub fn from_postings(
206 postings: &[(DocId, f32)],
207 quantization: WeightQuantization,
208 ) -> io::Result<Self> {
209 if postings.is_empty() {
210 return Ok(Self {
211 quantization,
212 scale: 1.0,
213 min_val: 0.0,
214 doc_count: 0,
215 data: Vec::new(),
216 });
217 }
218
219 let weights: Vec<f32> = postings.iter().map(|(_, w)| *w).collect();
221 let min_val = weights.iter().cloned().fold(f32::INFINITY, f32::min);
222 let max_val = weights.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
223
224 let (scale, adjusted_min) = match quantization {
225 WeightQuantization::Float32 | WeightQuantization::Float16 => (1.0, 0.0),
226 WeightQuantization::UInt8 => {
227 let range = max_val - min_val;
228 if range < f32::EPSILON {
229 (1.0, min_val)
230 } else {
231 (range / 255.0, min_val)
232 }
233 }
234 WeightQuantization::UInt4 => {
235 let range = max_val - min_val;
236 if range < f32::EPSILON {
237 (1.0, min_val)
238 } else {
239 (range / 15.0, min_val)
240 }
241 }
242 };
243
244 let mut data = Vec::new();
245
246 let mut prev_doc_id = 0u32;
248 for (doc_id, _) in postings {
249 let delta = doc_id - prev_doc_id;
250 write_vint(&mut data, delta as u64)?;
251 prev_doc_id = *doc_id;
252 }
253
254 match quantization {
256 WeightQuantization::Float32 => {
257 for (_, weight) in postings {
258 data.write_f32::<LittleEndian>(*weight)?;
259 }
260 }
261 WeightQuantization::Float16 => {
262 use half::slice::HalfFloatSliceExt;
264 let weights: Vec<f32> = postings.iter().map(|(_, w)| *w).collect();
265 let mut f16_slice: Vec<half::f16> = vec![half::f16::ZERO; weights.len()];
266 f16_slice.convert_from_f32_slice(&weights);
267 for h in f16_slice {
268 data.write_u16::<LittleEndian>(h.to_bits())?;
269 }
270 }
271 WeightQuantization::UInt8 => {
272 for (_, weight) in postings {
273 let quantized = ((*weight - adjusted_min) / scale).round() as u8;
274 data.write_u8(quantized)?;
275 }
276 }
277 WeightQuantization::UInt4 => {
278 let mut i = 0;
280 while i < postings.len() {
281 let q1 = ((postings[i].1 - adjusted_min) / scale).round() as u8 & 0x0F;
282 let q2 = if i + 1 < postings.len() {
283 ((postings[i + 1].1 - adjusted_min) / scale).round() as u8 & 0x0F
284 } else {
285 0
286 };
287 data.write_u8((q2 << 4) | q1)?;
288 i += 2;
289 }
290 }
291 }
292
293 Ok(Self {
294 quantization,
295 scale,
296 min_val: adjusted_min,
297 doc_count: postings.len() as u32,
298 data,
299 })
300 }
301
302 pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
304 writer.write_u8(self.quantization as u8)?;
305 writer.write_f32::<LittleEndian>(self.scale)?;
306 writer.write_f32::<LittleEndian>(self.min_val)?;
307 writer.write_u32::<LittleEndian>(self.doc_count)?;
308 writer.write_u32::<LittleEndian>(self.data.len() as u32)?;
309 writer.write_all(&self.data)?;
310 Ok(())
311 }
312
313 pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
315 let quant_byte = reader.read_u8()?;
316 let quantization = WeightQuantization::from_u8(quant_byte).ok_or_else(|| {
317 io::Error::new(io::ErrorKind::InvalidData, "Invalid quantization type")
318 })?;
319 let scale = reader.read_f32::<LittleEndian>()?;
320 let min_val = reader.read_f32::<LittleEndian>()?;
321 let doc_count = reader.read_u32::<LittleEndian>()?;
322 let data_len = reader.read_u32::<LittleEndian>()? as usize;
323 let mut data = vec![0u8; data_len];
324 reader.read_exact(&mut data)?;
325
326 Ok(Self {
327 quantization,
328 scale,
329 min_val,
330 doc_count,
331 data,
332 })
333 }
334
335 pub fn doc_count(&self) -> u32 {
337 self.doc_count
338 }
339
340 pub fn quantization(&self) -> WeightQuantization {
342 self.quantization
343 }
344
345 pub fn iterator(&self) -> SparsePostingIterator<'_> {
347 SparsePostingIterator::new(self)
348 }
349
350 pub fn decode_all(&self) -> io::Result<Vec<(DocId, f32)>> {
352 let mut result = Vec::with_capacity(self.doc_count as usize);
353 let mut iter = self.iterator();
354
355 while !iter.exhausted {
356 result.push((iter.doc_id, iter.weight));
357 iter.advance();
358 }
359
360 Ok(result)
361 }
362}
363
364pub struct SparsePostingIterator<'a> {
366 posting_list: &'a SparsePostingList,
367 doc_id_offset: usize,
369 weight_offset: usize,
371 index: usize,
373 doc_id: DocId,
375 weight: f32,
377 exhausted: bool,
379}
380
381impl<'a> SparsePostingIterator<'a> {
382 fn new(posting_list: &'a SparsePostingList) -> Self {
383 let mut iter = Self {
384 posting_list,
385 doc_id_offset: 0,
386 weight_offset: 0,
387 index: 0,
388 doc_id: 0,
389 weight: 0.0,
390 exhausted: posting_list.doc_count == 0,
391 };
392
393 if !iter.exhausted {
394 iter.weight_offset = iter.calculate_weight_offset();
396 iter.load_current();
397 }
398
399 iter
400 }
401
402 fn calculate_weight_offset(&self) -> usize {
403 let mut offset = 0;
405 let mut reader = &self.posting_list.data[..];
406
407 for _ in 0..self.posting_list.doc_count {
408 if read_vint(&mut reader).is_ok() {
409 offset = self.posting_list.data.len() - reader.len();
410 }
411 }
412
413 offset
414 }
415
416 fn load_current(&mut self) {
417 if self.index >= self.posting_list.doc_count as usize {
418 self.exhausted = true;
419 return;
420 }
421
422 let mut reader = &self.posting_list.data[self.doc_id_offset..];
424 if let Ok(delta) = read_vint(&mut reader) {
425 self.doc_id = self.doc_id.wrapping_add(delta as u32);
426 self.doc_id_offset = self.posting_list.data.len() - reader.len();
427 }
428
429 let weight_idx = self.index;
431 let pl = self.posting_list;
432
433 self.weight = match pl.quantization {
434 WeightQuantization::Float32 => {
435 let offset = self.weight_offset + weight_idx * 4;
436 if offset + 4 <= pl.data.len() {
437 let bytes = &pl.data[offset..offset + 4];
438 f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]])
439 } else {
440 0.0
441 }
442 }
443 WeightQuantization::Float16 => {
444 let offset = self.weight_offset + weight_idx * 2;
445 if offset + 2 <= pl.data.len() {
446 let bits = u16::from_le_bytes([pl.data[offset], pl.data[offset + 1]]);
447 half::f16::from_bits(bits).to_f32()
448 } else {
449 0.0
450 }
451 }
452 WeightQuantization::UInt8 => {
453 let offset = self.weight_offset + weight_idx;
454 if offset < pl.data.len() {
455 let quantized = pl.data[offset];
456 quantized as f32 * pl.scale + pl.min_val
457 } else {
458 0.0
459 }
460 }
461 WeightQuantization::UInt4 => {
462 let byte_offset = self.weight_offset + weight_idx / 2;
463 if byte_offset < pl.data.len() {
464 let byte = pl.data[byte_offset];
465 let quantized = if weight_idx.is_multiple_of(2) {
466 byte & 0x0F
467 } else {
468 (byte >> 4) & 0x0F
469 };
470 quantized as f32 * pl.scale + pl.min_val
471 } else {
472 0.0
473 }
474 }
475 };
476 }
477
478 pub fn doc(&self) -> DocId {
480 if self.exhausted {
481 super::TERMINATED
482 } else {
483 self.doc_id
484 }
485 }
486
487 pub fn weight(&self) -> f32 {
489 if self.exhausted { 0.0 } else { self.weight }
490 }
491
492 pub fn advance(&mut self) -> DocId {
494 if self.exhausted {
495 return super::TERMINATED;
496 }
497
498 self.index += 1;
499 if self.index >= self.posting_list.doc_count as usize {
500 self.exhausted = true;
501 return super::TERMINATED;
502 }
503
504 self.load_current();
505 self.doc_id
506 }
507
508 pub fn seek(&mut self, target: DocId) -> DocId {
510 while !self.exhausted && self.doc_id < target {
511 self.advance();
512 }
513 self.doc()
514 }
515}
516
517#[cfg(test)]
518mod tests {
519 use super::*;
520
521 #[test]
522 fn test_sparse_vector_dot_product() {
523 let v1 = SparseVector::from_entries(&[0, 2, 5], &[1.0, 2.0, 3.0]);
524 let v2 = SparseVector::from_entries(&[1, 2, 5], &[1.0, 4.0, 2.0]);
525
526 assert!((v1.dot(&v2) - 14.0).abs() < 1e-6);
528 }
529
530 #[test]
531 fn test_sparse_posting_list_float32() {
532 let postings = vec![(0, 1.5), (5, 2.3), (10, 0.8), (100, 3.15)];
533 let pl = SparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
534
535 assert_eq!(pl.doc_count(), 4);
536
537 let mut iter = pl.iterator();
538 assert_eq!(iter.doc(), 0);
539 assert!((iter.weight() - 1.5).abs() < 1e-6);
540
541 iter.advance();
542 assert_eq!(iter.doc(), 5);
543 assert!((iter.weight() - 2.3).abs() < 1e-6);
544
545 iter.advance();
546 assert_eq!(iter.doc(), 10);
547
548 iter.advance();
549 assert_eq!(iter.doc(), 100);
550 assert!((iter.weight() - 3.15).abs() < 1e-6);
551
552 iter.advance();
553 assert_eq!(iter.doc(), super::super::TERMINATED);
554 }
555
556 #[test]
557 fn test_sparse_posting_list_uint8() {
558 let postings = vec![(0, 0.0), (5, 0.5), (10, 1.0)];
559 let pl = SparsePostingList::from_postings(&postings, WeightQuantization::UInt8).unwrap();
560
561 let decoded = pl.decode_all().unwrap();
562 assert_eq!(decoded.len(), 3);
563
564 assert!(decoded[0].1 < decoded[1].1);
566 assert!(decoded[1].1 < decoded[2].1);
567 }
568
569 #[test]
570 fn test_block_sparse_posting_list() {
571 let postings: Vec<(DocId, u16, f32)> =
573 (0..300).map(|i| (i * 2, 0, (i as f32) * 0.1)).collect();
574
575 let pl =
576 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
577
578 assert_eq!(pl.doc_count(), 300);
579 assert!(pl.num_blocks() >= 2);
580
581 let mut iter = pl.iterator();
583 for (expected_doc, _, expected_weight) in &postings {
584 assert_eq!(iter.doc(), *expected_doc);
585 assert!((iter.weight() - expected_weight).abs() < 1e-6);
586 iter.advance();
587 }
588 assert_eq!(iter.doc(), super::super::TERMINATED);
589 }
590
591 #[test]
592 fn test_block_sparse_seek() {
593 let postings: Vec<(DocId, u16, f32)> = (0..500).map(|i| (i * 3, 0, i as f32)).collect();
594
595 let pl =
596 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
597
598 let mut iter = pl.iterator();
599
600 assert_eq!(iter.seek(300), 300);
602
603 assert_eq!(iter.seek(301), 303);
605
606 assert_eq!(iter.seek(2000), super::super::TERMINATED);
608 }
609
610 #[test]
611 fn test_serialization_roundtrip() {
612 let postings: Vec<(DocId, u16, f32)> = vec![(0, 0, 1.0), (10, 0, 2.0), (100, 0, 3.0)];
613
614 for quant in [
615 WeightQuantization::Float32,
616 WeightQuantization::Float16,
617 WeightQuantization::UInt8,
618 ] {
619 let pl = BlockSparsePostingList::from_postings(&postings, quant).unwrap();
620
621 let mut buffer = Vec::new();
622 pl.serialize(&mut buffer).unwrap();
623
624 let pl2 = BlockSparsePostingList::deserialize(&mut &buffer[..]).unwrap();
625
626 assert_eq!(pl.doc_count(), pl2.doc_count());
627
628 let mut iter1 = pl.iterator();
630 let mut iter2 = pl2.iterator();
631
632 while iter1.doc() != super::super::TERMINATED {
633 assert_eq!(iter1.doc(), iter2.doc());
634 assert!((iter1.weight() - iter2.weight()).abs() < 0.1);
635 iter1.advance();
636 iter2.advance();
637 }
638 }
639 }
640
641 #[test]
642 fn test_concatenate() {
643 let postings1: Vec<(DocId, u16, f32)> = vec![(0, 0, 1.0), (5, 1, 2.0)];
644 let postings2: Vec<(DocId, u16, f32)> = vec![(0, 0, 3.0), (10, 1, 4.0)];
645
646 let pl1 =
647 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
648 let pl2 =
649 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
650
651 let mut all: Vec<(DocId, u16, f32)> = pl1.decode_all();
653 for (doc_id, ord, w) in pl2.decode_all() {
654 all.push((doc_id + 100, ord, w));
655 }
656 let merged =
657 BlockSparsePostingList::from_postings(&all, WeightQuantization::Float32).unwrap();
658
659 assert_eq!(merged.doc_count(), 4);
660
661 let decoded = merged.decode_all();
662 assert_eq!(decoded[0], (0, 0, 1.0));
663 assert_eq!(decoded[1], (5, 1, 2.0));
664 assert_eq!(decoded[2], (100, 0, 3.0));
665 assert_eq!(decoded[3], (110, 1, 4.0));
666 }
667
668 #[test]
669 fn test_sparse_vector_config() {
670 let default = SparseVectorConfig::default();
672 assert_eq!(default.index_size, IndexSize::U32);
673 assert_eq!(default.weight_quantization, WeightQuantization::Float32);
674 assert_eq!(default.bytes_per_entry(), 8.0); let splade = SparseVectorConfig::splade();
678 assert_eq!(splade.index_size, IndexSize::U16);
679 assert_eq!(splade.weight_quantization, WeightQuantization::UInt8);
680 assert_eq!(splade.bytes_per_entry(), 3.0); let compact = SparseVectorConfig::compact();
684 assert_eq!(compact.index_size, IndexSize::U16);
685 assert_eq!(compact.weight_quantization, WeightQuantization::UInt4);
686 assert_eq!(compact.bytes_per_entry(), 2.5); let byte = splade.to_byte();
690 let restored = SparseVectorConfig::from_byte(byte).unwrap();
691 assert_eq!(restored, splade);
692 }
693
694 #[test]
695 fn test_index_size() {
696 assert_eq!(IndexSize::U16.bytes(), 2);
697 assert_eq!(IndexSize::U32.bytes(), 4);
698 assert_eq!(IndexSize::U16.max_value(), 65535);
699 assert_eq!(IndexSize::U32.max_value(), u32::MAX);
700 }
701
702 #[test]
703 fn test_block_max_weight() {
704 let postings: Vec<(DocId, u16, f32)> = (0..300)
705 .map(|i| (i as DocId, 0, (i as f32) * 0.1))
706 .collect();
707
708 let pl =
709 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
710
711 assert!((pl.global_max_weight() - 29.9).abs() < 0.01);
712 assert!(pl.num_blocks() >= 3);
713
714 let block0_max = pl.block_max_weight(0).unwrap();
715 assert!((block0_max - 12.7).abs() < 0.01);
716
717 let block1_max = pl.block_max_weight(1).unwrap();
718 assert!((block1_max - 25.5).abs() < 0.01);
719
720 let block2_max = pl.block_max_weight(2).unwrap();
721 assert!((block2_max - 29.9).abs() < 0.01);
722
723 let query_weight = 2.0;
725 let mut iter = pl.iterator();
726 assert!((iter.current_block_max_weight() - 12.7).abs() < 0.01);
727 assert!((iter.current_block_max_contribution(query_weight) - 25.4).abs() < 0.1);
728
729 iter.seek(128);
730 assert!((iter.current_block_max_weight() - 25.5).abs() < 0.01);
731 }
732
733 #[test]
734 fn test_sparse_skip_list_serialization() {
735 let mut skip_list = SparseSkipList::new();
736 skip_list.push(0, 127, 0, 12.7);
737 skip_list.push(128, 255, 100, 25.5);
738 skip_list.push(256, 299, 200, 29.9);
739
740 assert_eq!(skip_list.len(), 3);
741 assert!((skip_list.global_max_weight() - 29.9).abs() < 0.01);
742
743 let mut buffer = Vec::new();
745 skip_list.write(&mut buffer).unwrap();
746
747 let restored = SparseSkipList::read(&mut buffer.as_slice()).unwrap();
749
750 assert_eq!(restored.len(), 3);
751 assert!((restored.global_max_weight() - 29.9).abs() < 0.01);
752
753 let e0 = restored.get(0).unwrap();
755 assert_eq!(e0.first_doc, 0);
756 assert_eq!(e0.last_doc, 127);
757 assert!((e0.max_weight - 12.7).abs() < 0.01);
758
759 let e1 = restored.get(1).unwrap();
760 assert_eq!(e1.first_doc, 128);
761 assert!((e1.max_weight - 25.5).abs() < 0.01);
762 }
763}