Skip to main content

hermes_core/structures/postings/sparse/
mod.rs

1//! Sparse vector posting list with quantized weights
2//!
3//! Sparse vectors are stored as inverted index posting lists where:
4//! - Each dimension ID is a "term"
5//! - Each document has a weight for that dimension
6//!
7//! ## Configurable Components
8//!
9//! **Index (term/dimension ID) size:**
10//! - `IndexSize::U16`: 16-bit indices (0-65535), ideal for SPLADE (~30K vocab)
11//! - `IndexSize::U32`: 32-bit indices (0-4B), for large vocabularies
12//!
13//! **Weight quantization:**
14//! - `Float32`: Full precision (4 bytes per weight)
15//! - `Float16`: Half precision (2 bytes per weight)
16//! - `UInt8`: 8-bit quantization with scale factor (1 byte per weight)
17//! - `UInt4`: 4-bit quantization with scale factor (0.5 bytes per weight)
18//!
19//! ## Block Format (v2)
20//!
21//! The block-based format separates data into 3 sub-blocks per 128-entry block:
22//! - **Doc IDs**: Delta-encoded, bit-packed (SIMD-friendly)
23//! - **Ordinals**: Bit-packed small integers (lazy decode, only for results)
24//! - **Weights**: Quantized (f32/f16/u8/u4)
25
26mod block;
27mod config;
28
29pub use block::{BlockSparsePostingIterator, BlockSparsePostingList, SparseBlock};
30pub use config::{
31    IndexSize, QueryWeighting, SparseEntry, SparseQueryConfig, SparseVector, SparseVectorConfig,
32    WeightQuantization,
33};
34
35use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
36use std::io::{self, Read, Write};
37
38use super::posting_common::{read_vint, write_vint};
39use crate::DocId;
40
41/// A sparse posting entry: doc_id with quantized weight
42#[derive(Debug, Clone, Copy)]
43pub struct SparsePosting {
44    pub doc_id: DocId,
45    pub weight: f32,
46}
47
48/// Block size for sparse posting lists (matches OptP4D for SIMD alignment)
49pub const SPARSE_BLOCK_SIZE: usize = 128;
50
51/// Skip entry for sparse posting lists with block-max support
52///
53/// Extends the basic skip entry with `max_weight` for Block-Max WAND optimization.
54/// Used for lazy block loading - only skip list is loaded, blocks loaded on-demand.
55#[derive(Debug, Clone, Copy, PartialEq)]
56pub struct SparseSkipEntry {
57    /// First doc_id in the block (absolute)
58    pub first_doc: DocId,
59    /// Last doc_id in the block
60    pub last_doc: DocId,
61    /// Byte offset to block data (relative to data section start)
62    pub offset: u32,
63    /// Byte length of block data
64    pub length: u32,
65    /// Maximum weight in this block (for Block-Max optimization)
66    pub max_weight: f32,
67}
68
69impl SparseSkipEntry {
70    /// Size in bytes when serialized
71    pub const SIZE: usize = 20; // 4 + 4 + 4 + 4 + 4
72
73    pub fn new(
74        first_doc: DocId,
75        last_doc: DocId,
76        offset: u32,
77        length: u32,
78        max_weight: f32,
79    ) -> Self {
80        Self {
81            first_doc,
82            last_doc,
83            offset,
84            length,
85            max_weight,
86        }
87    }
88
89    /// Compute the maximum possible contribution of this block to a dot product
90    ///
91    /// For a query dimension with weight `query_weight`, the maximum contribution
92    /// from this block is `query_weight * max_weight`.
93    #[inline]
94    pub fn block_max_contribution(&self, query_weight: f32) -> f32 {
95        query_weight * self.max_weight
96    }
97
98    /// Write skip entry to writer
99    pub fn write<W: Write + ?Sized>(&self, writer: &mut W) -> io::Result<()> {
100        writer.write_u32::<LittleEndian>(self.first_doc)?;
101        writer.write_u32::<LittleEndian>(self.last_doc)?;
102        writer.write_u32::<LittleEndian>(self.offset)?;
103        writer.write_u32::<LittleEndian>(self.length)?;
104        writer.write_f32::<LittleEndian>(self.max_weight)?;
105        Ok(())
106    }
107
108    /// Read skip entry from reader
109    pub fn read<R: Read>(reader: &mut R) -> io::Result<Self> {
110        let first_doc = reader.read_u32::<LittleEndian>()?;
111        let last_doc = reader.read_u32::<LittleEndian>()?;
112        let offset = reader.read_u32::<LittleEndian>()?;
113        let length = reader.read_u32::<LittleEndian>()?;
114        let max_weight = reader.read_f32::<LittleEndian>()?;
115        Ok(Self {
116            first_doc,
117            last_doc,
118            offset,
119            length,
120            max_weight,
121        })
122    }
123}
124
125/// Skip list for sparse posting lists with block-max support
126#[derive(Debug, Clone, Default)]
127pub struct SparseSkipList {
128    entries: Vec<SparseSkipEntry>,
129    /// Global maximum weight across all blocks (for MaxScore pruning)
130    global_max_weight: f32,
131}
132
133impl SparseSkipList {
134    pub fn new() -> Self {
135        Self::default()
136    }
137
138    /// Add a skip entry
139    pub fn push(
140        &mut self,
141        first_doc: DocId,
142        last_doc: DocId,
143        offset: u32,
144        length: u32,
145        max_weight: f32,
146    ) {
147        self.global_max_weight = self.global_max_weight.max(max_weight);
148        self.entries.push(SparseSkipEntry::new(
149            first_doc, last_doc, offset, length, max_weight,
150        ));
151    }
152
153    /// Number of blocks
154    pub fn len(&self) -> usize {
155        self.entries.len()
156    }
157
158    pub fn is_empty(&self) -> bool {
159        self.entries.is_empty()
160    }
161
162    /// Get entry by index
163    pub fn get(&self, index: usize) -> Option<&SparseSkipEntry> {
164        self.entries.get(index)
165    }
166
167    /// Global maximum weight across all blocks
168    pub fn global_max_weight(&self) -> f32 {
169        self.global_max_weight
170    }
171
172    /// Find block index containing doc_id >= target (binary search, O(log n))
173    pub fn find_block(&self, target: DocId) -> Option<usize> {
174        if self.entries.is_empty() {
175            return None;
176        }
177        // Binary search: find first entry where last_doc >= target
178        let idx = self.entries.partition_point(|e| e.last_doc < target);
179        if idx < self.entries.len() {
180            Some(idx)
181        } else {
182            None
183        }
184    }
185
186    /// Iterate over entries
187    pub fn iter(&self) -> impl Iterator<Item = &SparseSkipEntry> {
188        self.entries.iter()
189    }
190
191    /// Write skip list to writer
192    pub fn write<W: Write>(&self, writer: &mut W) -> io::Result<()> {
193        writer.write_u32::<LittleEndian>(self.entries.len() as u32)?;
194        writer.write_f32::<LittleEndian>(self.global_max_weight)?;
195        for entry in &self.entries {
196            entry.write(writer)?;
197        }
198        Ok(())
199    }
200
201    /// Read skip list from reader
202    pub fn read<R: Read>(reader: &mut R) -> io::Result<Self> {
203        let count = reader.read_u32::<LittleEndian>()? as usize;
204        let global_max_weight = reader.read_f32::<LittleEndian>()?;
205        let mut entries = Vec::with_capacity(count);
206        for _ in 0..count {
207            entries.push(SparseSkipEntry::read(reader)?);
208        }
209        Ok(Self {
210            entries,
211            global_max_weight,
212        })
213    }
214}
215
216/// Sparse posting list for a single dimension
217///
218/// Stores (doc_id, weight) pairs for all documents that have a non-zero
219/// weight for this dimension. Weights are quantized according to the
220/// specified quantization format.
221#[derive(Debug, Clone)]
222pub struct SparsePostingList {
223    /// Quantization format
224    quantization: WeightQuantization,
225    /// Scale factor for UInt8/UInt4 quantization (weight = quantized * scale)
226    scale: f32,
227    /// Minimum value for UInt8/UInt4 quantization (weight = quantized * scale + min)
228    min_val: f32,
229    /// Number of postings
230    doc_count: u32,
231    /// Compressed data: [doc_ids...][weights...]
232    data: Vec<u8>,
233}
234
235impl SparsePostingList {
236    /// Create from postings with specified quantization
237    pub fn from_postings(
238        postings: &[(DocId, f32)],
239        quantization: WeightQuantization,
240    ) -> io::Result<Self> {
241        if postings.is_empty() {
242            return Ok(Self {
243                quantization,
244                scale: 1.0,
245                min_val: 0.0,
246                doc_count: 0,
247                data: Vec::new(),
248            });
249        }
250
251        // Compute min/max for quantization
252        let weights: Vec<f32> = postings.iter().map(|(_, w)| *w).collect();
253        let min_val = weights.iter().cloned().fold(f32::INFINITY, f32::min);
254        let max_val = weights.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
255
256        let (scale, adjusted_min) = match quantization {
257            WeightQuantization::Float32 | WeightQuantization::Float16 => (1.0, 0.0),
258            WeightQuantization::UInt8 => {
259                let range = max_val - min_val;
260                if range < f32::EPSILON {
261                    (1.0, min_val)
262                } else {
263                    (range / 255.0, min_val)
264                }
265            }
266            WeightQuantization::UInt4 => {
267                let range = max_val - min_val;
268                if range < f32::EPSILON {
269                    (1.0, min_val)
270                } else {
271                    (range / 15.0, min_val)
272                }
273            }
274        };
275
276        let mut data = Vec::new();
277
278        // Write doc IDs with delta encoding
279        let mut prev_doc_id = 0u32;
280        for (doc_id, _) in postings {
281            let delta = doc_id - prev_doc_id;
282            write_vint(&mut data, delta as u64)?;
283            prev_doc_id = *doc_id;
284        }
285
286        // Write weights based on quantization
287        match quantization {
288            WeightQuantization::Float32 => {
289                for (_, weight) in postings {
290                    data.write_f32::<LittleEndian>(*weight)?;
291                }
292            }
293            WeightQuantization::Float16 => {
294                // Use SIMD-accelerated batch conversion via half::slice
295                use half::slice::HalfFloatSliceExt;
296                let weights: Vec<f32> = postings.iter().map(|(_, w)| *w).collect();
297                let mut f16_slice: Vec<half::f16> = vec![half::f16::ZERO; weights.len()];
298                f16_slice.convert_from_f32_slice(&weights);
299                for h in f16_slice {
300                    data.write_u16::<LittleEndian>(h.to_bits())?;
301                }
302            }
303            WeightQuantization::UInt8 => {
304                for (_, weight) in postings {
305                    let quantized = ((*weight - adjusted_min) / scale).round() as u8;
306                    data.write_u8(quantized)?;
307                }
308            }
309            WeightQuantization::UInt4 => {
310                // Pack two 4-bit values per byte
311                let mut i = 0;
312                while i < postings.len() {
313                    let q1 = ((postings[i].1 - adjusted_min) / scale).round() as u8 & 0x0F;
314                    let q2 = if i + 1 < postings.len() {
315                        ((postings[i + 1].1 - adjusted_min) / scale).round() as u8 & 0x0F
316                    } else {
317                        0
318                    };
319                    data.write_u8((q2 << 4) | q1)?;
320                    i += 2;
321                }
322            }
323        }
324
325        Ok(Self {
326            quantization,
327            scale,
328            min_val: adjusted_min,
329            doc_count: postings.len() as u32,
330            data,
331        })
332    }
333
334    /// Serialize to bytes
335    pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
336        writer.write_u8(self.quantization as u8)?;
337        writer.write_f32::<LittleEndian>(self.scale)?;
338        writer.write_f32::<LittleEndian>(self.min_val)?;
339        writer.write_u32::<LittleEndian>(self.doc_count)?;
340        writer.write_u32::<LittleEndian>(self.data.len() as u32)?;
341        writer.write_all(&self.data)?;
342        Ok(())
343    }
344
345    /// Deserialize from bytes
346    pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
347        let quant_byte = reader.read_u8()?;
348        let quantization = WeightQuantization::from_u8(quant_byte).ok_or_else(|| {
349            io::Error::new(io::ErrorKind::InvalidData, "Invalid quantization type")
350        })?;
351        let scale = reader.read_f32::<LittleEndian>()?;
352        let min_val = reader.read_f32::<LittleEndian>()?;
353        let doc_count = reader.read_u32::<LittleEndian>()?;
354        let data_len = reader.read_u32::<LittleEndian>()? as usize;
355        let mut data = vec![0u8; data_len];
356        reader.read_exact(&mut data)?;
357
358        Ok(Self {
359            quantization,
360            scale,
361            min_val,
362            doc_count,
363            data,
364        })
365    }
366
367    /// Number of documents in this posting list
368    pub fn doc_count(&self) -> u32 {
369        self.doc_count
370    }
371
372    /// Get quantization format
373    pub fn quantization(&self) -> WeightQuantization {
374        self.quantization
375    }
376
377    /// Create an iterator
378    pub fn iterator(&self) -> SparsePostingIterator<'_> {
379        SparsePostingIterator::new(self)
380    }
381
382    /// Decode all postings (for merge operations)
383    pub fn decode_all(&self) -> io::Result<Vec<(DocId, f32)>> {
384        let mut result = Vec::with_capacity(self.doc_count as usize);
385        let mut iter = self.iterator();
386
387        while !iter.exhausted {
388            result.push((iter.doc_id, iter.weight));
389            iter.advance();
390        }
391
392        Ok(result)
393    }
394}
395
396/// Iterator over sparse posting list
397pub struct SparsePostingIterator<'a> {
398    posting_list: &'a SparsePostingList,
399    /// Current position in doc_id stream
400    doc_id_offset: usize,
401    /// Current position in weight stream
402    weight_offset: usize,
403    /// Current index
404    index: usize,
405    /// Current doc_id
406    doc_id: DocId,
407    /// Current weight
408    weight: f32,
409    /// Whether iterator is exhausted
410    exhausted: bool,
411}
412
413impl<'a> SparsePostingIterator<'a> {
414    fn new(posting_list: &'a SparsePostingList) -> Self {
415        let mut iter = Self {
416            posting_list,
417            doc_id_offset: 0,
418            weight_offset: 0,
419            index: 0,
420            doc_id: 0,
421            weight: 0.0,
422            exhausted: posting_list.doc_count == 0,
423        };
424
425        if !iter.exhausted {
426            // Calculate weight offset (after all doc_id deltas)
427            iter.weight_offset = iter.calculate_weight_offset();
428            iter.load_current();
429        }
430
431        iter
432    }
433
434    fn calculate_weight_offset(&self) -> usize {
435        // Read through all doc_id deltas to find where weights start
436        let mut offset = 0;
437        let mut reader = &self.posting_list.data[..];
438
439        for _ in 0..self.posting_list.doc_count {
440            if read_vint(&mut reader).is_ok() {
441                offset = self.posting_list.data.len() - reader.len();
442            }
443        }
444
445        offset
446    }
447
448    fn load_current(&mut self) {
449        if self.index >= self.posting_list.doc_count as usize {
450            self.exhausted = true;
451            return;
452        }
453
454        // Read doc_id delta
455        let mut reader = &self.posting_list.data[self.doc_id_offset..];
456        if let Ok(delta) = read_vint(&mut reader) {
457            self.doc_id = self.doc_id.wrapping_add(delta as u32);
458            self.doc_id_offset = self.posting_list.data.len() - reader.len();
459        }
460
461        // Read weight based on quantization
462        let weight_idx = self.index;
463        let pl = self.posting_list;
464
465        self.weight = match pl.quantization {
466            WeightQuantization::Float32 => {
467                let offset = self.weight_offset + weight_idx * 4;
468                if offset + 4 <= pl.data.len() {
469                    let bytes = &pl.data[offset..offset + 4];
470                    f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]])
471                } else {
472                    0.0
473                }
474            }
475            WeightQuantization::Float16 => {
476                let offset = self.weight_offset + weight_idx * 2;
477                if offset + 2 <= pl.data.len() {
478                    let bits = u16::from_le_bytes([pl.data[offset], pl.data[offset + 1]]);
479                    half::f16::from_bits(bits).to_f32()
480                } else {
481                    0.0
482                }
483            }
484            WeightQuantization::UInt8 => {
485                let offset = self.weight_offset + weight_idx;
486                if offset < pl.data.len() {
487                    let quantized = pl.data[offset];
488                    quantized as f32 * pl.scale + pl.min_val
489                } else {
490                    0.0
491                }
492            }
493            WeightQuantization::UInt4 => {
494                let byte_offset = self.weight_offset + weight_idx / 2;
495                if byte_offset < pl.data.len() {
496                    let byte = pl.data[byte_offset];
497                    let quantized = if weight_idx.is_multiple_of(2) {
498                        byte & 0x0F
499                    } else {
500                        (byte >> 4) & 0x0F
501                    };
502                    quantized as f32 * pl.scale + pl.min_val
503                } else {
504                    0.0
505                }
506            }
507        };
508    }
509
510    /// Current document ID
511    pub fn doc(&self) -> DocId {
512        if self.exhausted {
513            super::TERMINATED
514        } else {
515            self.doc_id
516        }
517    }
518
519    /// Current weight
520    pub fn weight(&self) -> f32 {
521        if self.exhausted { 0.0 } else { self.weight }
522    }
523
524    /// Advance to next posting
525    pub fn advance(&mut self) -> DocId {
526        if self.exhausted {
527            return super::TERMINATED;
528        }
529
530        self.index += 1;
531        if self.index >= self.posting_list.doc_count as usize {
532            self.exhausted = true;
533            return super::TERMINATED;
534        }
535
536        self.load_current();
537        self.doc_id
538    }
539
540    /// Seek to first doc_id >= target
541    pub fn seek(&mut self, target: DocId) -> DocId {
542        while !self.exhausted && self.doc_id < target {
543            self.advance();
544        }
545        self.doc()
546    }
547}
548
549#[cfg(test)]
550mod tests {
551    use super::*;
552
553    #[test]
554    fn test_sparse_vector_dot_product() {
555        let v1 = SparseVector::from_entries(&[0, 2, 5], &[1.0, 2.0, 3.0]);
556        let v2 = SparseVector::from_entries(&[1, 2, 5], &[1.0, 4.0, 2.0]);
557
558        // dot = 0 + 2*4 + 3*2 = 14
559        assert!((v1.dot(&v2) - 14.0).abs() < 1e-6);
560    }
561
562    #[test]
563    fn test_sparse_posting_list_float32() {
564        let postings = vec![(0, 1.5), (5, 2.3), (10, 0.8), (100, 3.15)];
565        let pl = SparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
566
567        assert_eq!(pl.doc_count(), 4);
568
569        let mut iter = pl.iterator();
570        assert_eq!(iter.doc(), 0);
571        assert!((iter.weight() - 1.5).abs() < 1e-6);
572
573        iter.advance();
574        assert_eq!(iter.doc(), 5);
575        assert!((iter.weight() - 2.3).abs() < 1e-6);
576
577        iter.advance();
578        assert_eq!(iter.doc(), 10);
579
580        iter.advance();
581        assert_eq!(iter.doc(), 100);
582        assert!((iter.weight() - 3.15).abs() < 1e-6);
583
584        iter.advance();
585        assert_eq!(iter.doc(), super::super::TERMINATED);
586    }
587
588    #[test]
589    fn test_sparse_posting_list_uint8() {
590        let postings = vec![(0, 0.0), (5, 0.5), (10, 1.0)];
591        let pl = SparsePostingList::from_postings(&postings, WeightQuantization::UInt8).unwrap();
592
593        let decoded = pl.decode_all().unwrap();
594        assert_eq!(decoded.len(), 3);
595
596        // UInt8 quantization should preserve relative ordering
597        assert!(decoded[0].1 < decoded[1].1);
598        assert!(decoded[1].1 < decoded[2].1);
599    }
600
601    #[test]
602    fn test_block_sparse_posting_list() {
603        // Create enough postings to span multiple blocks
604        let postings: Vec<(DocId, u16, f32)> =
605            (0..300).map(|i| (i * 2, 0, (i as f32) * 0.1)).collect();
606
607        let pl =
608            BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
609
610        assert_eq!(pl.doc_count(), 300);
611        assert!(pl.num_blocks() >= 2);
612
613        // Test iteration
614        let mut iter = pl.iterator();
615        for (expected_doc, _, expected_weight) in &postings {
616            assert_eq!(iter.doc(), *expected_doc);
617            assert!((iter.weight() - expected_weight).abs() < 1e-6);
618            iter.advance();
619        }
620        assert_eq!(iter.doc(), super::super::TERMINATED);
621    }
622
623    #[test]
624    fn test_block_sparse_seek() {
625        let postings: Vec<(DocId, u16, f32)> = (0..500).map(|i| (i * 3, 0, i as f32)).collect();
626
627        let pl =
628            BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
629
630        let mut iter = pl.iterator();
631
632        // Seek to exact match
633        assert_eq!(iter.seek(300), 300);
634
635        // Seek to non-exact (should find next)
636        assert_eq!(iter.seek(301), 303);
637
638        // Seek beyond end
639        assert_eq!(iter.seek(2000), super::super::TERMINATED);
640    }
641
642    #[test]
643    fn test_serialization_roundtrip() {
644        let postings: Vec<(DocId, u16, f32)> = vec![(0, 0, 1.0), (10, 0, 2.0), (100, 0, 3.0)];
645
646        for quant in [
647            WeightQuantization::Float32,
648            WeightQuantization::Float16,
649            WeightQuantization::UInt8,
650        ] {
651            let pl = BlockSparsePostingList::from_postings(&postings, quant).unwrap();
652
653            let mut buffer = Vec::new();
654            pl.serialize(&mut buffer).unwrap();
655
656            let pl2 = BlockSparsePostingList::deserialize(&mut &buffer[..]).unwrap();
657
658            assert_eq!(pl.doc_count(), pl2.doc_count());
659
660            // Verify iteration produces same results
661            let mut iter1 = pl.iterator();
662            let mut iter2 = pl2.iterator();
663
664            while iter1.doc() != super::super::TERMINATED {
665                assert_eq!(iter1.doc(), iter2.doc());
666                assert!((iter1.weight() - iter2.weight()).abs() < 0.1);
667                iter1.advance();
668                iter2.advance();
669            }
670        }
671    }
672
673    #[test]
674    fn test_concatenate() {
675        let postings1: Vec<(DocId, u16, f32)> = vec![(0, 0, 1.0), (5, 1, 2.0)];
676        let postings2: Vec<(DocId, u16, f32)> = vec![(0, 0, 3.0), (10, 1, 4.0)];
677
678        let pl1 =
679            BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
680        let pl2 =
681            BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
682
683        // Merge manually
684        let mut all: Vec<(DocId, u16, f32)> = pl1.decode_all();
685        for (doc_id, ord, w) in pl2.decode_all() {
686            all.push((doc_id + 100, ord, w));
687        }
688        let merged =
689            BlockSparsePostingList::from_postings(&all, WeightQuantization::Float32).unwrap();
690
691        assert_eq!(merged.doc_count(), 4);
692
693        let decoded = merged.decode_all();
694        assert_eq!(decoded[0], (0, 0, 1.0));
695        assert_eq!(decoded[1], (5, 1, 2.0));
696        assert_eq!(decoded[2], (100, 0, 3.0));
697        assert_eq!(decoded[3], (110, 1, 4.0));
698    }
699
700    #[test]
701    fn test_sparse_vector_config() {
702        // Test default config
703        let default = SparseVectorConfig::default();
704        assert_eq!(default.index_size, IndexSize::U32);
705        assert_eq!(default.weight_quantization, WeightQuantization::Float32);
706        assert_eq!(default.bytes_per_entry(), 8.0); // 4 + 4
707
708        // Test SPLADE config (research-validated defaults)
709        let splade = SparseVectorConfig::splade();
710        assert_eq!(splade.index_size, IndexSize::U16);
711        assert_eq!(splade.weight_quantization, WeightQuantization::UInt8);
712        assert_eq!(splade.bytes_per_entry(), 3.0); // 2 + 1
713        assert_eq!(splade.weight_threshold, 0.01);
714        assert_eq!(splade.posting_list_pruning, Some(0.1));
715        assert!(splade.query_config.is_some());
716        let query_cfg = splade.query_config.as_ref().unwrap();
717        assert_eq!(query_cfg.heap_factor, 0.8);
718        assert_eq!(query_cfg.max_query_dims, Some(20));
719
720        // Test compact config
721        let compact = SparseVectorConfig::compact();
722        assert_eq!(compact.index_size, IndexSize::U16);
723        assert_eq!(compact.weight_quantization, WeightQuantization::UInt4);
724        assert_eq!(compact.bytes_per_entry(), 2.5); // 2 + 0.5
725
726        // Test conservative config
727        let conservative = SparseVectorConfig::conservative();
728        assert_eq!(conservative.index_size, IndexSize::U32);
729        assert_eq!(
730            conservative.weight_quantization,
731            WeightQuantization::Float16
732        );
733        assert_eq!(conservative.weight_threshold, 0.005);
734        assert_eq!(conservative.posting_list_pruning, None);
735
736        // Test byte serialization roundtrip (only index_size and weight_quantization are serialized)
737        let byte = splade.to_byte();
738        let restored = SparseVectorConfig::from_byte(byte).unwrap();
739        assert_eq!(restored.index_size, splade.index_size);
740        assert_eq!(restored.weight_quantization, splade.weight_quantization);
741        // Note: Other fields (weight_threshold, posting_list_pruning, query_config) are not
742        // serialized in the byte format, so they revert to defaults after deserialization
743    }
744
745    #[test]
746    fn test_index_size() {
747        assert_eq!(IndexSize::U16.bytes(), 2);
748        assert_eq!(IndexSize::U32.bytes(), 4);
749        assert_eq!(IndexSize::U16.max_value(), 65535);
750        assert_eq!(IndexSize::U32.max_value(), u32::MAX);
751    }
752
753    #[test]
754    fn test_block_max_weight() {
755        let postings: Vec<(DocId, u16, f32)> = (0..300)
756            .map(|i| (i as DocId, 0, (i as f32) * 0.1))
757            .collect();
758
759        let pl =
760            BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
761
762        assert!((pl.global_max_weight() - 29.9).abs() < 0.01);
763        assert!(pl.num_blocks() >= 3);
764
765        let block0_max = pl.block_max_weight(0).unwrap();
766        assert!((block0_max - 12.7).abs() < 0.01);
767
768        let block1_max = pl.block_max_weight(1).unwrap();
769        assert!((block1_max - 25.5).abs() < 0.01);
770
771        let block2_max = pl.block_max_weight(2).unwrap();
772        assert!((block2_max - 29.9).abs() < 0.01);
773
774        // Test iterator block_max methods
775        let query_weight = 2.0;
776        let mut iter = pl.iterator();
777        assert!((iter.current_block_max_weight() - 12.7).abs() < 0.01);
778        assert!((iter.current_block_max_contribution(query_weight) - 25.4).abs() < 0.1);
779
780        iter.seek(128);
781        assert!((iter.current_block_max_weight() - 25.5).abs() < 0.01);
782    }
783
784    #[test]
785    fn test_sparse_skip_list_serialization() {
786        let mut skip_list = SparseSkipList::new();
787        skip_list.push(0, 127, 0, 50, 12.7);
788        skip_list.push(128, 255, 100, 60, 25.5);
789        skip_list.push(256, 299, 200, 40, 29.9);
790
791        assert_eq!(skip_list.len(), 3);
792        assert!((skip_list.global_max_weight() - 29.9).abs() < 0.01);
793
794        // Serialize
795        let mut buffer = Vec::new();
796        skip_list.write(&mut buffer).unwrap();
797
798        // Deserialize
799        let restored = SparseSkipList::read(&mut buffer.as_slice()).unwrap();
800
801        assert_eq!(restored.len(), 3);
802        assert!((restored.global_max_weight() - 29.9).abs() < 0.01);
803
804        // Verify entries
805        let e0 = restored.get(0).unwrap();
806        assert_eq!(e0.first_doc, 0);
807        assert_eq!(e0.last_doc, 127);
808        assert!((e0.max_weight - 12.7).abs() < 0.01);
809
810        let e1 = restored.get(1).unwrap();
811        assert_eq!(e1.first_doc, 128);
812        assert!((e1.max_weight - 25.5).abs() < 0.01);
813    }
814}