1mod block;
27mod config;
28
29pub use block::{BlockSparsePostingIterator, BlockSparsePostingList, SparseBlock};
30pub use config::{
31 IndexSize, QueryWeighting, SparseEntry, SparseQueryConfig, SparseVector, SparseVectorConfig,
32 WeightQuantization,
33};
34
35use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
36use std::io::{self, Read, Write};
37
38use super::posting_common::{read_vint, write_vint};
39use crate::DocId;
40
41#[derive(Debug, Clone, Copy)]
43pub struct SparsePosting {
44 pub doc_id: DocId,
45 pub weight: f32,
46}
47
48pub const SPARSE_BLOCK_SIZE: usize = 128;
50
51#[derive(Debug, Clone, Copy, PartialEq)]
56pub struct SparseSkipEntry {
57 pub first_doc: DocId,
59 pub last_doc: DocId,
61 pub offset: u32,
63 pub length: u32,
65 pub max_weight: f32,
67}
68
69impl SparseSkipEntry {
70 pub const SIZE: usize = 20; pub fn new(
74 first_doc: DocId,
75 last_doc: DocId,
76 offset: u32,
77 length: u32,
78 max_weight: f32,
79 ) -> Self {
80 Self {
81 first_doc,
82 last_doc,
83 offset,
84 length,
85 max_weight,
86 }
87 }
88
89 #[inline]
94 pub fn block_max_contribution(&self, query_weight: f32) -> f32 {
95 query_weight * self.max_weight
96 }
97
98 pub fn write<W: Write>(&self, writer: &mut W) -> io::Result<()> {
100 writer.write_u32::<LittleEndian>(self.first_doc)?;
101 writer.write_u32::<LittleEndian>(self.last_doc)?;
102 writer.write_u32::<LittleEndian>(self.offset)?;
103 writer.write_u32::<LittleEndian>(self.length)?;
104 writer.write_f32::<LittleEndian>(self.max_weight)?;
105 Ok(())
106 }
107
108 pub fn read<R: Read>(reader: &mut R) -> io::Result<Self> {
110 let first_doc = reader.read_u32::<LittleEndian>()?;
111 let last_doc = reader.read_u32::<LittleEndian>()?;
112 let offset = reader.read_u32::<LittleEndian>()?;
113 let length = reader.read_u32::<LittleEndian>()?;
114 let max_weight = reader.read_f32::<LittleEndian>()?;
115 Ok(Self {
116 first_doc,
117 last_doc,
118 offset,
119 length,
120 max_weight,
121 })
122 }
123}
124
125#[derive(Debug, Clone, Default)]
127pub struct SparseSkipList {
128 entries: Vec<SparseSkipEntry>,
129 global_max_weight: f32,
131}
132
133impl SparseSkipList {
134 pub fn new() -> Self {
135 Self::default()
136 }
137
138 pub fn push(
140 &mut self,
141 first_doc: DocId,
142 last_doc: DocId,
143 offset: u32,
144 length: u32,
145 max_weight: f32,
146 ) {
147 self.global_max_weight = self.global_max_weight.max(max_weight);
148 self.entries.push(SparseSkipEntry::new(
149 first_doc, last_doc, offset, length, max_weight,
150 ));
151 }
152
153 pub fn len(&self) -> usize {
155 self.entries.len()
156 }
157
158 pub fn is_empty(&self) -> bool {
159 self.entries.is_empty()
160 }
161
162 pub fn get(&self, index: usize) -> Option<&SparseSkipEntry> {
164 self.entries.get(index)
165 }
166
167 pub fn global_max_weight(&self) -> f32 {
169 self.global_max_weight
170 }
171
172 pub fn find_block(&self, target: DocId) -> Option<usize> {
174 self.entries.iter().position(|e| e.last_doc >= target)
175 }
176
177 pub fn iter(&self) -> impl Iterator<Item = &SparseSkipEntry> {
179 self.entries.iter()
180 }
181
182 pub fn write<W: Write>(&self, writer: &mut W) -> io::Result<()> {
184 writer.write_u32::<LittleEndian>(self.entries.len() as u32)?;
185 writer.write_f32::<LittleEndian>(self.global_max_weight)?;
186 for entry in &self.entries {
187 entry.write(writer)?;
188 }
189 Ok(())
190 }
191
192 pub fn read<R: Read>(reader: &mut R) -> io::Result<Self> {
194 let count = reader.read_u32::<LittleEndian>()? as usize;
195 let global_max_weight = reader.read_f32::<LittleEndian>()?;
196 let mut entries = Vec::with_capacity(count);
197 for _ in 0..count {
198 entries.push(SparseSkipEntry::read(reader)?);
199 }
200 Ok(Self {
201 entries,
202 global_max_weight,
203 })
204 }
205}
206
207#[derive(Debug, Clone)]
213pub struct SparsePostingList {
214 quantization: WeightQuantization,
216 scale: f32,
218 min_val: f32,
220 doc_count: u32,
222 data: Vec<u8>,
224}
225
226impl SparsePostingList {
227 pub fn from_postings(
229 postings: &[(DocId, f32)],
230 quantization: WeightQuantization,
231 ) -> io::Result<Self> {
232 if postings.is_empty() {
233 return Ok(Self {
234 quantization,
235 scale: 1.0,
236 min_val: 0.0,
237 doc_count: 0,
238 data: Vec::new(),
239 });
240 }
241
242 let weights: Vec<f32> = postings.iter().map(|(_, w)| *w).collect();
244 let min_val = weights.iter().cloned().fold(f32::INFINITY, f32::min);
245 let max_val = weights.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
246
247 let (scale, adjusted_min) = match quantization {
248 WeightQuantization::Float32 | WeightQuantization::Float16 => (1.0, 0.0),
249 WeightQuantization::UInt8 => {
250 let range = max_val - min_val;
251 if range < f32::EPSILON {
252 (1.0, min_val)
253 } else {
254 (range / 255.0, min_val)
255 }
256 }
257 WeightQuantization::UInt4 => {
258 let range = max_val - min_val;
259 if range < f32::EPSILON {
260 (1.0, min_val)
261 } else {
262 (range / 15.0, min_val)
263 }
264 }
265 };
266
267 let mut data = Vec::new();
268
269 let mut prev_doc_id = 0u32;
271 for (doc_id, _) in postings {
272 let delta = doc_id - prev_doc_id;
273 write_vint(&mut data, delta as u64)?;
274 prev_doc_id = *doc_id;
275 }
276
277 match quantization {
279 WeightQuantization::Float32 => {
280 for (_, weight) in postings {
281 data.write_f32::<LittleEndian>(*weight)?;
282 }
283 }
284 WeightQuantization::Float16 => {
285 use half::slice::HalfFloatSliceExt;
287 let weights: Vec<f32> = postings.iter().map(|(_, w)| *w).collect();
288 let mut f16_slice: Vec<half::f16> = vec![half::f16::ZERO; weights.len()];
289 f16_slice.convert_from_f32_slice(&weights);
290 for h in f16_slice {
291 data.write_u16::<LittleEndian>(h.to_bits())?;
292 }
293 }
294 WeightQuantization::UInt8 => {
295 for (_, weight) in postings {
296 let quantized = ((*weight - adjusted_min) / scale).round() as u8;
297 data.write_u8(quantized)?;
298 }
299 }
300 WeightQuantization::UInt4 => {
301 let mut i = 0;
303 while i < postings.len() {
304 let q1 = ((postings[i].1 - adjusted_min) / scale).round() as u8 & 0x0F;
305 let q2 = if i + 1 < postings.len() {
306 ((postings[i + 1].1 - adjusted_min) / scale).round() as u8 & 0x0F
307 } else {
308 0
309 };
310 data.write_u8((q2 << 4) | q1)?;
311 i += 2;
312 }
313 }
314 }
315
316 Ok(Self {
317 quantization,
318 scale,
319 min_val: adjusted_min,
320 doc_count: postings.len() as u32,
321 data,
322 })
323 }
324
325 pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
327 writer.write_u8(self.quantization as u8)?;
328 writer.write_f32::<LittleEndian>(self.scale)?;
329 writer.write_f32::<LittleEndian>(self.min_val)?;
330 writer.write_u32::<LittleEndian>(self.doc_count)?;
331 writer.write_u32::<LittleEndian>(self.data.len() as u32)?;
332 writer.write_all(&self.data)?;
333 Ok(())
334 }
335
336 pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
338 let quant_byte = reader.read_u8()?;
339 let quantization = WeightQuantization::from_u8(quant_byte).ok_or_else(|| {
340 io::Error::new(io::ErrorKind::InvalidData, "Invalid quantization type")
341 })?;
342 let scale = reader.read_f32::<LittleEndian>()?;
343 let min_val = reader.read_f32::<LittleEndian>()?;
344 let doc_count = reader.read_u32::<LittleEndian>()?;
345 let data_len = reader.read_u32::<LittleEndian>()? as usize;
346 let mut data = vec![0u8; data_len];
347 reader.read_exact(&mut data)?;
348
349 Ok(Self {
350 quantization,
351 scale,
352 min_val,
353 doc_count,
354 data,
355 })
356 }
357
358 pub fn doc_count(&self) -> u32 {
360 self.doc_count
361 }
362
363 pub fn quantization(&self) -> WeightQuantization {
365 self.quantization
366 }
367
368 pub fn iterator(&self) -> SparsePostingIterator<'_> {
370 SparsePostingIterator::new(self)
371 }
372
373 pub fn decode_all(&self) -> io::Result<Vec<(DocId, f32)>> {
375 let mut result = Vec::with_capacity(self.doc_count as usize);
376 let mut iter = self.iterator();
377
378 while !iter.exhausted {
379 result.push((iter.doc_id, iter.weight));
380 iter.advance();
381 }
382
383 Ok(result)
384 }
385}
386
387pub struct SparsePostingIterator<'a> {
389 posting_list: &'a SparsePostingList,
390 doc_id_offset: usize,
392 weight_offset: usize,
394 index: usize,
396 doc_id: DocId,
398 weight: f32,
400 exhausted: bool,
402}
403
404impl<'a> SparsePostingIterator<'a> {
405 fn new(posting_list: &'a SparsePostingList) -> Self {
406 let mut iter = Self {
407 posting_list,
408 doc_id_offset: 0,
409 weight_offset: 0,
410 index: 0,
411 doc_id: 0,
412 weight: 0.0,
413 exhausted: posting_list.doc_count == 0,
414 };
415
416 if !iter.exhausted {
417 iter.weight_offset = iter.calculate_weight_offset();
419 iter.load_current();
420 }
421
422 iter
423 }
424
425 fn calculate_weight_offset(&self) -> usize {
426 let mut offset = 0;
428 let mut reader = &self.posting_list.data[..];
429
430 for _ in 0..self.posting_list.doc_count {
431 if read_vint(&mut reader).is_ok() {
432 offset = self.posting_list.data.len() - reader.len();
433 }
434 }
435
436 offset
437 }
438
439 fn load_current(&mut self) {
440 if self.index >= self.posting_list.doc_count as usize {
441 self.exhausted = true;
442 return;
443 }
444
445 let mut reader = &self.posting_list.data[self.doc_id_offset..];
447 if let Ok(delta) = read_vint(&mut reader) {
448 self.doc_id = self.doc_id.wrapping_add(delta as u32);
449 self.doc_id_offset = self.posting_list.data.len() - reader.len();
450 }
451
452 let weight_idx = self.index;
454 let pl = self.posting_list;
455
456 self.weight = match pl.quantization {
457 WeightQuantization::Float32 => {
458 let offset = self.weight_offset + weight_idx * 4;
459 if offset + 4 <= pl.data.len() {
460 let bytes = &pl.data[offset..offset + 4];
461 f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]])
462 } else {
463 0.0
464 }
465 }
466 WeightQuantization::Float16 => {
467 let offset = self.weight_offset + weight_idx * 2;
468 if offset + 2 <= pl.data.len() {
469 let bits = u16::from_le_bytes([pl.data[offset], pl.data[offset + 1]]);
470 half::f16::from_bits(bits).to_f32()
471 } else {
472 0.0
473 }
474 }
475 WeightQuantization::UInt8 => {
476 let offset = self.weight_offset + weight_idx;
477 if offset < pl.data.len() {
478 let quantized = pl.data[offset];
479 quantized as f32 * pl.scale + pl.min_val
480 } else {
481 0.0
482 }
483 }
484 WeightQuantization::UInt4 => {
485 let byte_offset = self.weight_offset + weight_idx / 2;
486 if byte_offset < pl.data.len() {
487 let byte = pl.data[byte_offset];
488 let quantized = if weight_idx.is_multiple_of(2) {
489 byte & 0x0F
490 } else {
491 (byte >> 4) & 0x0F
492 };
493 quantized as f32 * pl.scale + pl.min_val
494 } else {
495 0.0
496 }
497 }
498 };
499 }
500
501 pub fn doc(&self) -> DocId {
503 if self.exhausted {
504 super::TERMINATED
505 } else {
506 self.doc_id
507 }
508 }
509
510 pub fn weight(&self) -> f32 {
512 if self.exhausted { 0.0 } else { self.weight }
513 }
514
515 pub fn advance(&mut self) -> DocId {
517 if self.exhausted {
518 return super::TERMINATED;
519 }
520
521 self.index += 1;
522 if self.index >= self.posting_list.doc_count as usize {
523 self.exhausted = true;
524 return super::TERMINATED;
525 }
526
527 self.load_current();
528 self.doc_id
529 }
530
531 pub fn seek(&mut self, target: DocId) -> DocId {
533 while !self.exhausted && self.doc_id < target {
534 self.advance();
535 }
536 self.doc()
537 }
538}
539
540#[cfg(test)]
541mod tests {
542 use super::*;
543
544 #[test]
545 fn test_sparse_vector_dot_product() {
546 let v1 = SparseVector::from_entries(&[0, 2, 5], &[1.0, 2.0, 3.0]);
547 let v2 = SparseVector::from_entries(&[1, 2, 5], &[1.0, 4.0, 2.0]);
548
549 assert!((v1.dot(&v2) - 14.0).abs() < 1e-6);
551 }
552
553 #[test]
554 fn test_sparse_posting_list_float32() {
555 let postings = vec![(0, 1.5), (5, 2.3), (10, 0.8), (100, 3.15)];
556 let pl = SparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
557
558 assert_eq!(pl.doc_count(), 4);
559
560 let mut iter = pl.iterator();
561 assert_eq!(iter.doc(), 0);
562 assert!((iter.weight() - 1.5).abs() < 1e-6);
563
564 iter.advance();
565 assert_eq!(iter.doc(), 5);
566 assert!((iter.weight() - 2.3).abs() < 1e-6);
567
568 iter.advance();
569 assert_eq!(iter.doc(), 10);
570
571 iter.advance();
572 assert_eq!(iter.doc(), 100);
573 assert!((iter.weight() - 3.15).abs() < 1e-6);
574
575 iter.advance();
576 assert_eq!(iter.doc(), super::super::TERMINATED);
577 }
578
579 #[test]
580 fn test_sparse_posting_list_uint8() {
581 let postings = vec![(0, 0.0), (5, 0.5), (10, 1.0)];
582 let pl = SparsePostingList::from_postings(&postings, WeightQuantization::UInt8).unwrap();
583
584 let decoded = pl.decode_all().unwrap();
585 assert_eq!(decoded.len(), 3);
586
587 assert!(decoded[0].1 < decoded[1].1);
589 assert!(decoded[1].1 < decoded[2].1);
590 }
591
592 #[test]
593 fn test_block_sparse_posting_list() {
594 let postings: Vec<(DocId, u16, f32)> =
596 (0..300).map(|i| (i * 2, 0, (i as f32) * 0.1)).collect();
597
598 let pl =
599 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
600
601 assert_eq!(pl.doc_count(), 300);
602 assert!(pl.num_blocks() >= 2);
603
604 let mut iter = pl.iterator();
606 for (expected_doc, _, expected_weight) in &postings {
607 assert_eq!(iter.doc(), *expected_doc);
608 assert!((iter.weight() - expected_weight).abs() < 1e-6);
609 iter.advance();
610 }
611 assert_eq!(iter.doc(), super::super::TERMINATED);
612 }
613
614 #[test]
615 fn test_block_sparse_seek() {
616 let postings: Vec<(DocId, u16, f32)> = (0..500).map(|i| (i * 3, 0, i as f32)).collect();
617
618 let pl =
619 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
620
621 let mut iter = pl.iterator();
622
623 assert_eq!(iter.seek(300), 300);
625
626 assert_eq!(iter.seek(301), 303);
628
629 assert_eq!(iter.seek(2000), super::super::TERMINATED);
631 }
632
633 #[test]
634 fn test_serialization_roundtrip() {
635 let postings: Vec<(DocId, u16, f32)> = vec![(0, 0, 1.0), (10, 0, 2.0), (100, 0, 3.0)];
636
637 for quant in [
638 WeightQuantization::Float32,
639 WeightQuantization::Float16,
640 WeightQuantization::UInt8,
641 ] {
642 let pl = BlockSparsePostingList::from_postings(&postings, quant).unwrap();
643
644 let mut buffer = Vec::new();
645 pl.serialize(&mut buffer).unwrap();
646
647 let pl2 = BlockSparsePostingList::deserialize(&mut &buffer[..]).unwrap();
648
649 assert_eq!(pl.doc_count(), pl2.doc_count());
650
651 let mut iter1 = pl.iterator();
653 let mut iter2 = pl2.iterator();
654
655 while iter1.doc() != super::super::TERMINATED {
656 assert_eq!(iter1.doc(), iter2.doc());
657 assert!((iter1.weight() - iter2.weight()).abs() < 0.1);
658 iter1.advance();
659 iter2.advance();
660 }
661 }
662 }
663
664 #[test]
665 fn test_concatenate() {
666 let postings1: Vec<(DocId, u16, f32)> = vec![(0, 0, 1.0), (5, 1, 2.0)];
667 let postings2: Vec<(DocId, u16, f32)> = vec![(0, 0, 3.0), (10, 1, 4.0)];
668
669 let pl1 =
670 BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
671 let pl2 =
672 BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
673
674 let mut all: Vec<(DocId, u16, f32)> = pl1.decode_all();
676 for (doc_id, ord, w) in pl2.decode_all() {
677 all.push((doc_id + 100, ord, w));
678 }
679 let merged =
680 BlockSparsePostingList::from_postings(&all, WeightQuantization::Float32).unwrap();
681
682 assert_eq!(merged.doc_count(), 4);
683
684 let decoded = merged.decode_all();
685 assert_eq!(decoded[0], (0, 0, 1.0));
686 assert_eq!(decoded[1], (5, 1, 2.0));
687 assert_eq!(decoded[2], (100, 0, 3.0));
688 assert_eq!(decoded[3], (110, 1, 4.0));
689 }
690
691 #[test]
692 fn test_sparse_vector_config() {
693 let default = SparseVectorConfig::default();
695 assert_eq!(default.index_size, IndexSize::U32);
696 assert_eq!(default.weight_quantization, WeightQuantization::Float32);
697 assert_eq!(default.bytes_per_entry(), 8.0); let splade = SparseVectorConfig::splade();
701 assert_eq!(splade.index_size, IndexSize::U16);
702 assert_eq!(splade.weight_quantization, WeightQuantization::UInt8);
703 assert_eq!(splade.bytes_per_entry(), 3.0); let compact = SparseVectorConfig::compact();
707 assert_eq!(compact.index_size, IndexSize::U16);
708 assert_eq!(compact.weight_quantization, WeightQuantization::UInt4);
709 assert_eq!(compact.bytes_per_entry(), 2.5); let byte = splade.to_byte();
713 let restored = SparseVectorConfig::from_byte(byte).unwrap();
714 assert_eq!(restored, splade);
715 }
716
717 #[test]
718 fn test_index_size() {
719 assert_eq!(IndexSize::U16.bytes(), 2);
720 assert_eq!(IndexSize::U32.bytes(), 4);
721 assert_eq!(IndexSize::U16.max_value(), 65535);
722 assert_eq!(IndexSize::U32.max_value(), u32::MAX);
723 }
724
725 #[test]
726 fn test_block_max_weight() {
727 let postings: Vec<(DocId, u16, f32)> = (0..300)
728 .map(|i| (i as DocId, 0, (i as f32) * 0.1))
729 .collect();
730
731 let pl =
732 BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
733
734 assert!((pl.global_max_weight() - 29.9).abs() < 0.01);
735 assert!(pl.num_blocks() >= 3);
736
737 let block0_max = pl.block_max_weight(0).unwrap();
738 assert!((block0_max - 12.7).abs() < 0.01);
739
740 let block1_max = pl.block_max_weight(1).unwrap();
741 assert!((block1_max - 25.5).abs() < 0.01);
742
743 let block2_max = pl.block_max_weight(2).unwrap();
744 assert!((block2_max - 29.9).abs() < 0.01);
745
746 let query_weight = 2.0;
748 let mut iter = pl.iterator();
749 assert!((iter.current_block_max_weight() - 12.7).abs() < 0.01);
750 assert!((iter.current_block_max_contribution(query_weight) - 25.4).abs() < 0.1);
751
752 iter.seek(128);
753 assert!((iter.current_block_max_weight() - 25.5).abs() < 0.01);
754 }
755
756 #[test]
757 fn test_sparse_skip_list_serialization() {
758 let mut skip_list = SparseSkipList::new();
759 skip_list.push(0, 127, 0, 50, 12.7);
760 skip_list.push(128, 255, 100, 60, 25.5);
761 skip_list.push(256, 299, 200, 40, 29.9);
762
763 assert_eq!(skip_list.len(), 3);
764 assert!((skip_list.global_max_weight() - 29.9).abs() < 0.01);
765
766 let mut buffer = Vec::new();
768 skip_list.write(&mut buffer).unwrap();
769
770 let restored = SparseSkipList::read(&mut buffer.as_slice()).unwrap();
772
773 assert_eq!(restored.len(), 3);
774 assert!((restored.global_max_weight() - 29.9).abs() < 0.01);
775
776 let e0 = restored.get(0).unwrap();
778 assert_eq!(e0.first_doc, 0);
779 assert_eq!(e0.last_doc, 127);
780 assert!((e0.max_weight - 12.7).abs() < 0.01);
781
782 let e1 = restored.get(1).unwrap();
783 assert_eq!(e1.first_doc, 128);
784 assert!((e1.max_weight - 25.5).abs() < 0.01);
785 }
786}