Skip to main content

hermes_core/structures/postings/
sparse_vector.rs

1//! Sparse vector posting list with quantized weights
2//!
3//! Sparse vectors are stored as inverted index posting lists where:
4//! - Each dimension ID is a "term"
5//! - Each document has a weight for that dimension
6//!
7//! ## Configurable Components
8//!
9//! **Index (term/dimension ID) size:**
10//! - `IndexSize::U16`: 16-bit indices (0-65535), ideal for SPLADE (~30K vocab)
11//! - `IndexSize::U32`: 32-bit indices (0-4B), for large vocabularies
12//!
13//! **Weight quantization:**
14//! - `Float32`: Full precision (4 bytes per weight)
15//! - `Float16`: Half precision (2 bytes per weight)
16//! - `UInt8`: 8-bit quantization with scale factor (1 byte per weight)
17//! - `UInt4`: 4-bit quantization with scale factor (0.5 bytes per weight)
18
19use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
20use serde::{Deserialize, Serialize};
21use std::io::{self, Read, Write};
22
23use super::posting_common::{
24    RoundedBitWidth, pack_deltas_fixed, read_vint, unpack_deltas_fixed, write_vint,
25};
26use crate::DocId;
27
28/// Size of the index (term/dimension ID) in sparse vectors
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
30#[repr(u8)]
31pub enum IndexSize {
32    /// 16-bit index (0-65535), ideal for SPLADE vocabularies
33    U16 = 0,
34    /// 32-bit index (0-4B), for large vocabularies
35    #[default]
36    U32 = 1,
37}
38
39impl IndexSize {
40    /// Bytes per index
41    pub fn bytes(&self) -> usize {
42        match self {
43            IndexSize::U16 => 2,
44            IndexSize::U32 => 4,
45        }
46    }
47
48    /// Maximum value representable
49    pub fn max_value(&self) -> u32 {
50        match self {
51            IndexSize::U16 => u16::MAX as u32,
52            IndexSize::U32 => u32::MAX,
53        }
54    }
55
56    fn from_u8(v: u8) -> Option<Self> {
57        match v {
58            0 => Some(IndexSize::U16),
59            1 => Some(IndexSize::U32),
60            _ => None,
61        }
62    }
63}
64
65/// Quantization format for sparse vector weights
66#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
67#[repr(u8)]
68pub enum WeightQuantization {
69    /// Full 32-bit float precision
70    #[default]
71    Float32 = 0,
72    /// 16-bit float (half precision)
73    Float16 = 1,
74    /// 8-bit unsigned integer with scale factor
75    UInt8 = 2,
76    /// 4-bit unsigned integer with scale factor (packed, 2 per byte)
77    UInt4 = 3,
78}
79
80impl WeightQuantization {
81    /// Bytes per weight (approximate for UInt4)
82    pub fn bytes_per_weight(&self) -> f32 {
83        match self {
84            WeightQuantization::Float32 => 4.0,
85            WeightQuantization::Float16 => 2.0,
86            WeightQuantization::UInt8 => 1.0,
87            WeightQuantization::UInt4 => 0.5,
88        }
89    }
90
91    fn from_u8(v: u8) -> Option<Self> {
92        match v {
93            0 => Some(WeightQuantization::Float32),
94            1 => Some(WeightQuantization::Float16),
95            2 => Some(WeightQuantization::UInt8),
96            3 => Some(WeightQuantization::UInt4),
97            _ => None,
98        }
99    }
100}
101
102/// Query-time weighting strategy for sparse vector queries
103#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
104pub enum QueryWeighting {
105    /// All terms get weight 1.0
106    #[default]
107    One,
108    /// Terms weighted by IDF (inverse document frequency) from the index
109    Idf,
110}
111
112/// Query-time configuration for sparse vectors
113#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
114pub struct SparseQueryConfig {
115    /// HuggingFace tokenizer path/name for query-time tokenization
116    /// Example: "Alibaba-NLP/gte-Qwen2-1.5B-instruct"
117    #[serde(default, skip_serializing_if = "Option::is_none")]
118    pub tokenizer: Option<String>,
119    /// Weighting strategy for tokenized query terms
120    #[serde(default)]
121    pub weighting: QueryWeighting,
122    /// Heap factor for approximate search (SEISMIC-style optimization)
123    /// A block is skipped if its max possible score < heap_factor * threshold
124    /// - 1.0 = exact search (default)
125    /// - 0.8 = approximate, ~20% faster with minor recall loss
126    /// - 0.5 = very approximate, much faster
127    #[serde(default = "default_heap_factor")]
128    pub heap_factor: f32,
129    /// Maximum number of query dimensions to process (query pruning)
130    /// Processes only the top-k dimensions by weight
131    /// - None = process all dimensions (default)
132    /// - Some(10) = process top 10 dimensions only
133    #[serde(default, skip_serializing_if = "Option::is_none")]
134    pub max_query_dims: Option<usize>,
135}
136
137fn default_heap_factor() -> f32 {
138    1.0
139}
140
141impl Default for SparseQueryConfig {
142    fn default() -> Self {
143        Self {
144            tokenizer: None,
145            weighting: QueryWeighting::One,
146            heap_factor: 1.0,
147            max_query_dims: None,
148        }
149    }
150}
151
152/// Configuration for sparse vector storage
153#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
154pub struct SparseVectorConfig {
155    /// Size of dimension/term indices
156    pub index_size: IndexSize,
157    /// Quantization for weights
158    pub weight_quantization: WeightQuantization,
159    /// Minimum weight threshold - weights below this value are not indexed
160    /// This reduces index size and can improve query speed at the cost of recall
161    #[serde(default)]
162    pub weight_threshold: f32,
163    /// Static pruning: fraction of postings to keep per inverted list (SEISMIC-style)
164    /// Lists are sorted by weight descending and truncated to top fraction.
165    /// - None = keep all postings (default, exact)
166    /// - Some(0.1) = keep top 10% of postings per dimension
167    ///
168    /// Applied only during initial segment build, not during merge.
169    /// This exploits "concentration of importance" - top entries preserve most of inner product.
170    #[serde(default, skip_serializing_if = "Option::is_none")]
171    pub posting_list_pruning: Option<f32>,
172    /// Query-time configuration (tokenizer, weighting)
173    #[serde(default, skip_serializing_if = "Option::is_none")]
174    pub query_config: Option<SparseQueryConfig>,
175}
176
177impl Default for SparseVectorConfig {
178    fn default() -> Self {
179        Self {
180            index_size: IndexSize::U32,
181            weight_quantization: WeightQuantization::Float32,
182            weight_threshold: 0.0,
183            posting_list_pruning: None,
184            query_config: None,
185        }
186    }
187}
188
189impl SparseVectorConfig {
190    /// SPLADE-optimized config: u16 indices, int8 weights
191    pub fn splade() -> Self {
192        Self {
193            index_size: IndexSize::U16,
194            weight_quantization: WeightQuantization::UInt8,
195            weight_threshold: 0.0,
196            posting_list_pruning: None,
197            query_config: None,
198        }
199    }
200
201    /// Compact config: u16 indices, 4-bit weights
202    pub fn compact() -> Self {
203        Self {
204            index_size: IndexSize::U16,
205            weight_quantization: WeightQuantization::UInt4,
206            weight_threshold: 0.0,
207            posting_list_pruning: None,
208            query_config: None,
209        }
210    }
211
212    /// Full precision config
213    pub fn full_precision() -> Self {
214        Self {
215            index_size: IndexSize::U32,
216            weight_quantization: WeightQuantization::Float32,
217            weight_threshold: 0.0,
218            posting_list_pruning: None,
219            query_config: None,
220        }
221    }
222
223    /// Set weight threshold (builder pattern)
224    pub fn with_weight_threshold(mut self, threshold: f32) -> Self {
225        self.weight_threshold = threshold;
226        self
227    }
228
229    /// Set posting list pruning fraction (builder pattern)
230    /// e.g., 0.1 = keep top 10% of postings per dimension
231    pub fn with_pruning(mut self, fraction: f32) -> Self {
232        self.posting_list_pruning = Some(fraction.clamp(0.0, 1.0));
233        self
234    }
235
236    /// Bytes per entry (index + weight)
237    pub fn bytes_per_entry(&self) -> f32 {
238        self.index_size.bytes() as f32 + self.weight_quantization.bytes_per_weight()
239    }
240
241    /// Serialize config to a single byte
242    pub fn to_byte(&self) -> u8 {
243        ((self.index_size as u8) << 4) | (self.weight_quantization as u8)
244    }
245
246    /// Deserialize config from a single byte
247    /// Note: weight_threshold and query_config are not serialized in the byte
248    pub fn from_byte(b: u8) -> Option<Self> {
249        let index_size = IndexSize::from_u8(b >> 4)?;
250        let weight_quantization = WeightQuantization::from_u8(b & 0x0F)?;
251        Some(Self {
252            index_size,
253            weight_quantization,
254            weight_threshold: 0.0,
255            posting_list_pruning: None,
256            query_config: None,
257        })
258    }
259
260    /// Set query configuration (builder pattern)
261    pub fn with_query_config(mut self, config: SparseQueryConfig) -> Self {
262        self.query_config = Some(config);
263        self
264    }
265}
266
267/// A sparse vector entry: (dimension_id, weight)
268#[derive(Debug, Clone, Copy, PartialEq)]
269pub struct SparseEntry {
270    pub dim_id: u32,
271    pub weight: f32,
272}
273
274/// Sparse vector representation
275#[derive(Debug, Clone, Default)]
276pub struct SparseVector {
277    entries: Vec<SparseEntry>,
278}
279
280impl SparseVector {
281    pub fn new() -> Self {
282        Self::default()
283    }
284
285    pub fn with_capacity(capacity: usize) -> Self {
286        Self {
287            entries: Vec::with_capacity(capacity),
288        }
289    }
290
291    /// Create from dimension IDs and weights
292    pub fn from_entries(dim_ids: &[u32], weights: &[f32]) -> Self {
293        assert_eq!(dim_ids.len(), weights.len());
294        let mut entries: Vec<SparseEntry> = dim_ids
295            .iter()
296            .zip(weights.iter())
297            .map(|(&dim_id, &weight)| SparseEntry { dim_id, weight })
298            .collect();
299        // Sort by dimension ID for efficient intersection
300        entries.sort_by_key(|e| e.dim_id);
301        Self { entries }
302    }
303
304    /// Add an entry (must maintain sorted order by dim_id)
305    pub fn push(&mut self, dim_id: u32, weight: f32) {
306        debug_assert!(
307            self.entries.is_empty() || self.entries.last().unwrap().dim_id < dim_id,
308            "Entries must be added in sorted order by dim_id"
309        );
310        self.entries.push(SparseEntry { dim_id, weight });
311    }
312
313    /// Number of non-zero dimensions
314    pub fn len(&self) -> usize {
315        self.entries.len()
316    }
317
318    pub fn is_empty(&self) -> bool {
319        self.entries.is_empty()
320    }
321
322    pub fn iter(&self) -> impl Iterator<Item = &SparseEntry> {
323        self.entries.iter()
324    }
325
326    /// Compute dot product with another sparse vector
327    pub fn dot(&self, other: &SparseVector) -> f32 {
328        let mut result = 0.0f32;
329        let mut i = 0;
330        let mut j = 0;
331
332        while i < self.entries.len() && j < other.entries.len() {
333            let a = &self.entries[i];
334            let b = &other.entries[j];
335
336            match a.dim_id.cmp(&b.dim_id) {
337                std::cmp::Ordering::Less => i += 1,
338                std::cmp::Ordering::Greater => j += 1,
339                std::cmp::Ordering::Equal => {
340                    result += a.weight * b.weight;
341                    i += 1;
342                    j += 1;
343                }
344            }
345        }
346
347        result
348    }
349
350    /// L2 norm squared
351    pub fn norm_squared(&self) -> f32 {
352        self.entries.iter().map(|e| e.weight * e.weight).sum()
353    }
354
355    /// L2 norm
356    pub fn norm(&self) -> f32 {
357        self.norm_squared().sqrt()
358    }
359
360    /// Prune to top-k dimensions by weight (SEISMIC query_cut optimization)
361    ///
362    /// Returns a new SparseVector with only the top-k dimensions by absolute weight.
363    /// This exploits "concentration of importance" - top dimensions preserve most of inner product.
364    pub fn top_k(&self, k: usize) -> Self {
365        if self.entries.len() <= k {
366            return self.clone();
367        }
368
369        // Sort by weight descending, take top-k, re-sort by dim_id
370        let mut sorted: Vec<SparseEntry> = self.entries.clone();
371        sorted.sort_by(|a, b| {
372            b.weight
373                .abs()
374                .partial_cmp(&a.weight.abs())
375                .unwrap_or(std::cmp::Ordering::Equal)
376        });
377        sorted.truncate(k);
378        sorted.sort_by_key(|e| e.dim_id);
379
380        Self { entries: sorted }
381    }
382
383    /// Prune dimensions below a weight threshold
384    pub fn filter_by_weight(&self, min_weight: f32) -> Self {
385        let entries: Vec<SparseEntry> = self
386            .entries
387            .iter()
388            .filter(|e| e.weight.abs() >= min_weight)
389            .cloned()
390            .collect();
391        Self { entries }
392    }
393}
394
395/// A sparse posting entry: doc_id with quantized weight
396#[derive(Debug, Clone, Copy)]
397pub struct SparsePosting {
398    pub doc_id: DocId,
399    pub weight: f32,
400}
401
402/// Block size for sparse posting lists (matches OptP4D for SIMD alignment)
403pub const SPARSE_BLOCK_SIZE: usize = 128;
404
405/// Skip entry for sparse posting lists with block-max support
406///
407/// Extends the basic skip entry with `max_weight` for Block-Max WAND optimization.
408#[derive(Debug, Clone, Copy, PartialEq)]
409pub struct SparseSkipEntry {
410    /// First doc_id in the block (absolute)
411    pub first_doc: DocId,
412    /// Last doc_id in the block
413    pub last_doc: DocId,
414    /// Byte offset to block data
415    pub offset: u32,
416    /// Maximum weight in this block (for Block-Max optimization)
417    pub max_weight: f32,
418}
419
420impl SparseSkipEntry {
421    pub fn new(first_doc: DocId, last_doc: DocId, offset: u32, max_weight: f32) -> Self {
422        Self {
423            first_doc,
424            last_doc,
425            offset,
426            max_weight,
427        }
428    }
429
430    /// Compute the maximum possible contribution of this block to a dot product
431    ///
432    /// For a query dimension with weight `query_weight`, the maximum contribution
433    /// from this block is `query_weight * max_weight`.
434    #[inline]
435    pub fn block_max_contribution(&self, query_weight: f32) -> f32 {
436        query_weight * self.max_weight
437    }
438
439    /// Write skip entry to writer
440    pub fn write<W: Write>(&self, writer: &mut W) -> io::Result<()> {
441        writer.write_u32::<LittleEndian>(self.first_doc)?;
442        writer.write_u32::<LittleEndian>(self.last_doc)?;
443        writer.write_u32::<LittleEndian>(self.offset)?;
444        writer.write_f32::<LittleEndian>(self.max_weight)?;
445        Ok(())
446    }
447
448    /// Read skip entry from reader
449    pub fn read<R: Read>(reader: &mut R) -> io::Result<Self> {
450        let first_doc = reader.read_u32::<LittleEndian>()?;
451        let last_doc = reader.read_u32::<LittleEndian>()?;
452        let offset = reader.read_u32::<LittleEndian>()?;
453        let max_weight = reader.read_f32::<LittleEndian>()?;
454        Ok(Self {
455            first_doc,
456            last_doc,
457            offset,
458            max_weight,
459        })
460    }
461}
462
463/// Skip list for sparse posting lists with block-max support
464#[derive(Debug, Clone, Default)]
465pub struct SparseSkipList {
466    entries: Vec<SparseSkipEntry>,
467    /// Global maximum weight across all blocks (for MaxScore pruning)
468    global_max_weight: f32,
469}
470
471impl SparseSkipList {
472    pub fn new() -> Self {
473        Self::default()
474    }
475
476    /// Add a skip entry
477    pub fn push(&mut self, first_doc: DocId, last_doc: DocId, offset: u32, max_weight: f32) {
478        self.global_max_weight = self.global_max_weight.max(max_weight);
479        self.entries.push(SparseSkipEntry::new(
480            first_doc, last_doc, offset, max_weight,
481        ));
482    }
483
484    /// Number of blocks
485    pub fn len(&self) -> usize {
486        self.entries.len()
487    }
488
489    pub fn is_empty(&self) -> bool {
490        self.entries.is_empty()
491    }
492
493    /// Get entry by index
494    pub fn get(&self, index: usize) -> Option<&SparseSkipEntry> {
495        self.entries.get(index)
496    }
497
498    /// Global maximum weight across all blocks
499    pub fn global_max_weight(&self) -> f32 {
500        self.global_max_weight
501    }
502
503    /// Find block index containing doc_id >= target
504    pub fn find_block(&self, target: DocId) -> Option<usize> {
505        self.entries.iter().position(|e| e.last_doc >= target)
506    }
507
508    /// Iterate over entries
509    pub fn iter(&self) -> impl Iterator<Item = &SparseSkipEntry> {
510        self.entries.iter()
511    }
512
513    /// Write skip list to writer
514    pub fn write<W: Write>(&self, writer: &mut W) -> io::Result<()> {
515        writer.write_u32::<LittleEndian>(self.entries.len() as u32)?;
516        writer.write_f32::<LittleEndian>(self.global_max_weight)?;
517        for entry in &self.entries {
518            entry.write(writer)?;
519        }
520        Ok(())
521    }
522
523    /// Read skip list from reader
524    pub fn read<R: Read>(reader: &mut R) -> io::Result<Self> {
525        let count = reader.read_u32::<LittleEndian>()? as usize;
526        let global_max_weight = reader.read_f32::<LittleEndian>()?;
527        let mut entries = Vec::with_capacity(count);
528        for _ in 0..count {
529            entries.push(SparseSkipEntry::read(reader)?);
530        }
531        Ok(Self {
532            entries,
533            global_max_weight,
534        })
535    }
536}
537
538/// Sparse posting list for a single dimension
539///
540/// Stores (doc_id, weight) pairs for all documents that have a non-zero
541/// weight for this dimension. Weights are quantized according to the
542/// specified quantization format.
543#[derive(Debug, Clone)]
544pub struct SparsePostingList {
545    /// Quantization format
546    quantization: WeightQuantization,
547    /// Scale factor for UInt8/UInt4 quantization (weight = quantized * scale)
548    scale: f32,
549    /// Minimum value for UInt8/UInt4 quantization (weight = quantized * scale + min)
550    min_val: f32,
551    /// Number of postings
552    doc_count: u32,
553    /// Compressed data: [doc_ids...][weights...]
554    data: Vec<u8>,
555}
556
557impl SparsePostingList {
558    /// Create from postings with specified quantization
559    pub fn from_postings(
560        postings: &[(DocId, f32)],
561        quantization: WeightQuantization,
562    ) -> io::Result<Self> {
563        if postings.is_empty() {
564            return Ok(Self {
565                quantization,
566                scale: 1.0,
567                min_val: 0.0,
568                doc_count: 0,
569                data: Vec::new(),
570            });
571        }
572
573        // Compute min/max for quantization
574        let weights: Vec<f32> = postings.iter().map(|(_, w)| *w).collect();
575        let min_val = weights.iter().cloned().fold(f32::INFINITY, f32::min);
576        let max_val = weights.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
577
578        let (scale, adjusted_min) = match quantization {
579            WeightQuantization::Float32 | WeightQuantization::Float16 => (1.0, 0.0),
580            WeightQuantization::UInt8 => {
581                let range = max_val - min_val;
582                if range < f32::EPSILON {
583                    (1.0, min_val)
584                } else {
585                    (range / 255.0, min_val)
586                }
587            }
588            WeightQuantization::UInt4 => {
589                let range = max_val - min_val;
590                if range < f32::EPSILON {
591                    (1.0, min_val)
592                } else {
593                    (range / 15.0, min_val)
594                }
595            }
596        };
597
598        let mut data = Vec::new();
599
600        // Write doc IDs with delta encoding
601        let mut prev_doc_id = 0u32;
602        for (doc_id, _) in postings {
603            let delta = doc_id - prev_doc_id;
604            write_vint(&mut data, delta as u64)?;
605            prev_doc_id = *doc_id;
606        }
607
608        // Write weights based on quantization
609        match quantization {
610            WeightQuantization::Float32 => {
611                for (_, weight) in postings {
612                    data.write_f32::<LittleEndian>(*weight)?;
613                }
614            }
615            WeightQuantization::Float16 => {
616                // Use SIMD-accelerated batch conversion via half::slice
617                use half::slice::HalfFloatSliceExt;
618                let weights: Vec<f32> = postings.iter().map(|(_, w)| *w).collect();
619                let mut f16_slice: Vec<half::f16> = vec![half::f16::ZERO; weights.len()];
620                f16_slice.convert_from_f32_slice(&weights);
621                for h in f16_slice {
622                    data.write_u16::<LittleEndian>(h.to_bits())?;
623                }
624            }
625            WeightQuantization::UInt8 => {
626                for (_, weight) in postings {
627                    let quantized = ((*weight - adjusted_min) / scale).round() as u8;
628                    data.write_u8(quantized)?;
629                }
630            }
631            WeightQuantization::UInt4 => {
632                // Pack two 4-bit values per byte
633                let mut i = 0;
634                while i < postings.len() {
635                    let q1 = ((postings[i].1 - adjusted_min) / scale).round() as u8 & 0x0F;
636                    let q2 = if i + 1 < postings.len() {
637                        ((postings[i + 1].1 - adjusted_min) / scale).round() as u8 & 0x0F
638                    } else {
639                        0
640                    };
641                    data.write_u8((q2 << 4) | q1)?;
642                    i += 2;
643                }
644            }
645        }
646
647        Ok(Self {
648            quantization,
649            scale,
650            min_val: adjusted_min,
651            doc_count: postings.len() as u32,
652            data,
653        })
654    }
655
656    /// Serialize to bytes
657    pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
658        writer.write_u8(self.quantization as u8)?;
659        writer.write_f32::<LittleEndian>(self.scale)?;
660        writer.write_f32::<LittleEndian>(self.min_val)?;
661        writer.write_u32::<LittleEndian>(self.doc_count)?;
662        writer.write_u32::<LittleEndian>(self.data.len() as u32)?;
663        writer.write_all(&self.data)?;
664        Ok(())
665    }
666
667    /// Deserialize from bytes
668    pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
669        let quant_byte = reader.read_u8()?;
670        let quantization = WeightQuantization::from_u8(quant_byte).ok_or_else(|| {
671            io::Error::new(io::ErrorKind::InvalidData, "Invalid quantization type")
672        })?;
673        let scale = reader.read_f32::<LittleEndian>()?;
674        let min_val = reader.read_f32::<LittleEndian>()?;
675        let doc_count = reader.read_u32::<LittleEndian>()?;
676        let data_len = reader.read_u32::<LittleEndian>()? as usize;
677        let mut data = vec![0u8; data_len];
678        reader.read_exact(&mut data)?;
679
680        Ok(Self {
681            quantization,
682            scale,
683            min_val,
684            doc_count,
685            data,
686        })
687    }
688
689    /// Number of documents in this posting list
690    pub fn doc_count(&self) -> u32 {
691        self.doc_count
692    }
693
694    /// Get quantization format
695    pub fn quantization(&self) -> WeightQuantization {
696        self.quantization
697    }
698
699    /// Create an iterator
700    pub fn iterator(&self) -> SparsePostingIterator<'_> {
701        SparsePostingIterator::new(self)
702    }
703
704    /// Decode all postings (for merge operations)
705    pub fn decode_all(&self) -> io::Result<Vec<(DocId, f32)>> {
706        let mut result = Vec::with_capacity(self.doc_count as usize);
707        let mut iter = self.iterator();
708
709        while !iter.exhausted {
710            result.push((iter.doc_id, iter.weight));
711            iter.advance();
712        }
713
714        Ok(result)
715    }
716}
717
718/// Iterator over sparse posting list
719pub struct SparsePostingIterator<'a> {
720    posting_list: &'a SparsePostingList,
721    /// Current position in doc_id stream
722    doc_id_offset: usize,
723    /// Current position in weight stream
724    weight_offset: usize,
725    /// Current index
726    index: usize,
727    /// Current doc_id
728    doc_id: DocId,
729    /// Current weight
730    weight: f32,
731    /// Whether iterator is exhausted
732    exhausted: bool,
733}
734
735impl<'a> SparsePostingIterator<'a> {
736    fn new(posting_list: &'a SparsePostingList) -> Self {
737        let mut iter = Self {
738            posting_list,
739            doc_id_offset: 0,
740            weight_offset: 0,
741            index: 0,
742            doc_id: 0,
743            weight: 0.0,
744            exhausted: posting_list.doc_count == 0,
745        };
746
747        if !iter.exhausted {
748            // Calculate weight offset (after all doc_id deltas)
749            iter.weight_offset = iter.calculate_weight_offset();
750            iter.load_current();
751        }
752
753        iter
754    }
755
756    fn calculate_weight_offset(&self) -> usize {
757        // Read through all doc_id deltas to find where weights start
758        let mut offset = 0;
759        let mut reader = &self.posting_list.data[..];
760
761        for _ in 0..self.posting_list.doc_count {
762            if read_vint(&mut reader).is_ok() {
763                offset = self.posting_list.data.len() - reader.len();
764            }
765        }
766
767        offset
768    }
769
770    fn load_current(&mut self) {
771        if self.index >= self.posting_list.doc_count as usize {
772            self.exhausted = true;
773            return;
774        }
775
776        // Read doc_id delta
777        let mut reader = &self.posting_list.data[self.doc_id_offset..];
778        if let Ok(delta) = read_vint(&mut reader) {
779            self.doc_id = self.doc_id.wrapping_add(delta as u32);
780            self.doc_id_offset = self.posting_list.data.len() - reader.len();
781        }
782
783        // Read weight based on quantization
784        let weight_idx = self.index;
785        let pl = self.posting_list;
786
787        self.weight = match pl.quantization {
788            WeightQuantization::Float32 => {
789                let offset = self.weight_offset + weight_idx * 4;
790                if offset + 4 <= pl.data.len() {
791                    let bytes = &pl.data[offset..offset + 4];
792                    f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]])
793                } else {
794                    0.0
795                }
796            }
797            WeightQuantization::Float16 => {
798                let offset = self.weight_offset + weight_idx * 2;
799                if offset + 2 <= pl.data.len() {
800                    let bits = u16::from_le_bytes([pl.data[offset], pl.data[offset + 1]]);
801                    half::f16::from_bits(bits).to_f32()
802                } else {
803                    0.0
804                }
805            }
806            WeightQuantization::UInt8 => {
807                let offset = self.weight_offset + weight_idx;
808                if offset < pl.data.len() {
809                    let quantized = pl.data[offset];
810                    quantized as f32 * pl.scale + pl.min_val
811                } else {
812                    0.0
813                }
814            }
815            WeightQuantization::UInt4 => {
816                let byte_offset = self.weight_offset + weight_idx / 2;
817                if byte_offset < pl.data.len() {
818                    let byte = pl.data[byte_offset];
819                    let quantized = if weight_idx.is_multiple_of(2) {
820                        byte & 0x0F
821                    } else {
822                        (byte >> 4) & 0x0F
823                    };
824                    quantized as f32 * pl.scale + pl.min_val
825                } else {
826                    0.0
827                }
828            }
829        };
830    }
831
832    /// Current document ID
833    pub fn doc(&self) -> DocId {
834        if self.exhausted {
835            super::TERMINATED
836        } else {
837            self.doc_id
838        }
839    }
840
841    /// Current weight
842    pub fn weight(&self) -> f32 {
843        if self.exhausted { 0.0 } else { self.weight }
844    }
845
846    /// Advance to next posting
847    pub fn advance(&mut self) -> DocId {
848        if self.exhausted {
849            return super::TERMINATED;
850        }
851
852        self.index += 1;
853        if self.index >= self.posting_list.doc_count as usize {
854            self.exhausted = true;
855            return super::TERMINATED;
856        }
857
858        self.load_current();
859        self.doc_id
860    }
861
862    /// Seek to first doc_id >= target
863    pub fn seek(&mut self, target: DocId) -> DocId {
864        while !self.exhausted && self.doc_id < target {
865            self.advance();
866        }
867        self.doc()
868    }
869}
870
871/// Block-based sparse posting list for skip-list style access
872///
873/// Similar to BlockPostingList but stores quantized weights.
874/// Includes block-max metadata for Block-Max WAND optimization.
875#[derive(Debug, Clone)]
876pub struct BlockSparsePostingList {
877    /// Quantization format
878    quantization: WeightQuantization,
879    /// Global scale factor for UInt8/UInt4
880    scale: f32,
881    /// Global minimum value for UInt8/UInt4
882    min_val: f32,
883    /// Skip list with block-max support
884    skip_list: SparseSkipList,
885    /// Compressed block data
886    data: Vec<u8>,
887    /// Total number of postings
888    doc_count: u32,
889}
890
891impl BlockSparsePostingList {
892    /// Build from postings with specified quantization
893    pub fn from_postings(
894        postings: &[(DocId, f32)],
895        quantization: WeightQuantization,
896    ) -> io::Result<Self> {
897        Self::from_postings_with_pruning(postings, quantization, None)
898    }
899
900    /// Build from postings with static pruning (SEISMIC-style optimization)
901    ///
902    /// If `pruning_fraction` is Some(f), only the top f*len postings by weight are kept.
903    /// This exploits "concentration of importance" - top entries preserve most of inner product.
904    /// Applied only during initial segment build, not during merge.
905    pub fn from_postings_with_pruning(
906        postings: &[(DocId, f32)],
907        quantization: WeightQuantization,
908        pruning_fraction: Option<f32>,
909    ) -> io::Result<Self> {
910        if postings.is_empty() {
911            return Ok(Self {
912                quantization,
913                scale: 1.0,
914                min_val: 0.0,
915                skip_list: SparseSkipList::new(),
916                data: Vec::new(),
917                doc_count: 0,
918            });
919        }
920
921        // Apply static pruning if fraction is set (SEISMIC-style)
922        // Sort by weight descending, take top fraction, then re-sort by doc_id
923        let postings: std::borrow::Cow<'_, [(DocId, f32)]> = if let Some(fraction) =
924            pruning_fraction
925        {
926            let max = ((postings.len() as f32 * fraction).ceil() as usize).max(1);
927            if postings.len() > max {
928                let mut sorted: Vec<(DocId, f32)> = postings.to_vec();
929                // Sort by weight descending
930                sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
931                // Take top fraction
932                sorted.truncate(max);
933                // Re-sort by doc_id for efficient iteration
934                sorted.sort_by_key(|(doc_id, _)| *doc_id);
935                std::borrow::Cow::Owned(sorted)
936            } else {
937                std::borrow::Cow::Borrowed(postings)
938            }
939        } else {
940            std::borrow::Cow::Borrowed(postings)
941        };
942
943        // Compute global min/max for quantization
944        let weights: Vec<f32> = postings.iter().map(|(_, w)| *w).collect();
945        let min_val = weights.iter().cloned().fold(f32::INFINITY, f32::min);
946        let max_val = weights.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
947
948        let (scale, adjusted_min) = match quantization {
949            WeightQuantization::Float32 | WeightQuantization::Float16 => (1.0, 0.0),
950            WeightQuantization::UInt8 => {
951                let range = max_val - min_val;
952                if range < f32::EPSILON {
953                    (1.0, min_val)
954                } else {
955                    (range / 255.0, min_val)
956                }
957            }
958            WeightQuantization::UInt4 => {
959                let range = max_val - min_val;
960                if range < f32::EPSILON {
961                    (1.0, min_val)
962                } else {
963                    (range / 15.0, min_val)
964                }
965            }
966        };
967
968        let mut skip_list = SparseSkipList::new();
969        let mut data = Vec::new();
970
971        let mut i = 0;
972        while i < postings.len() {
973            let block_end = (i + SPARSE_BLOCK_SIZE).min(postings.len());
974            let block = &postings[i..block_end];
975
976            let first_doc_id = block.first().unwrap().0;
977            let last_doc_id = block.last().unwrap().0;
978
979            // Compute max weight in this block for Block-Max optimization
980            let block_max_weight = block
981                .iter()
982                .map(|(_, w)| *w)
983                .fold(f32::NEG_INFINITY, f32::max);
984
985            // Pack doc IDs with fixed-width delta encoding (SIMD-friendly)
986            let block_doc_ids: Vec<DocId> = block.iter().map(|(d, _)| *d).collect();
987            let (doc_bit_width, packed_doc_ids) = pack_deltas_fixed(&block_doc_ids);
988
989            // Block header: [count: u16][doc_bit_width: u8][packed_doc_ids...][weights...]
990            let block_start = data.len() as u32;
991            skip_list.push(first_doc_id, last_doc_id, block_start, block_max_weight);
992
993            data.write_u16::<LittleEndian>(block.len() as u16)?;
994            data.write_u8(doc_bit_width as u8)?;
995            data.extend_from_slice(&packed_doc_ids);
996
997            // Write weights based on quantization
998            match quantization {
999                WeightQuantization::Float32 => {
1000                    for (_, weight) in block {
1001                        data.write_f32::<LittleEndian>(*weight)?;
1002                    }
1003                }
1004                WeightQuantization::Float16 => {
1005                    // Use SIMD-accelerated batch conversion via half::slice
1006                    use half::slice::HalfFloatSliceExt;
1007                    let weights: Vec<f32> = block.iter().map(|(_, w)| *w).collect();
1008                    let mut f16_slice: Vec<half::f16> = vec![half::f16::ZERO; weights.len()];
1009                    f16_slice.convert_from_f32_slice(&weights);
1010                    for h in f16_slice {
1011                        data.write_u16::<LittleEndian>(h.to_bits())?;
1012                    }
1013                }
1014                WeightQuantization::UInt8 => {
1015                    for (_, weight) in block {
1016                        let quantized = ((*weight - adjusted_min) / scale).round() as u8;
1017                        data.write_u8(quantized)?;
1018                    }
1019                }
1020                WeightQuantization::UInt4 => {
1021                    let mut j = 0;
1022                    while j < block.len() {
1023                        let q1 = ((block[j].1 - adjusted_min) / scale).round() as u8 & 0x0F;
1024                        let q2 = if j + 1 < block.len() {
1025                            ((block[j + 1].1 - adjusted_min) / scale).round() as u8 & 0x0F
1026                        } else {
1027                            0
1028                        };
1029                        data.write_u8((q2 << 4) | q1)?;
1030                        j += 2;
1031                    }
1032                }
1033            }
1034
1035            i = block_end;
1036        }
1037
1038        Ok(Self {
1039            quantization,
1040            scale,
1041            min_val: adjusted_min,
1042            skip_list,
1043            data,
1044            doc_count: postings.len() as u32,
1045        })
1046    }
1047
1048    /// Serialize to bytes
1049    pub fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
1050        writer.write_u8(self.quantization as u8)?;
1051        writer.write_f32::<LittleEndian>(self.scale)?;
1052        writer.write_f32::<LittleEndian>(self.min_val)?;
1053        writer.write_u32::<LittleEndian>(self.doc_count)?;
1054
1055        // Write skip list with block-max support
1056        self.skip_list.write(writer)?;
1057
1058        // Write data
1059        writer.write_u32::<LittleEndian>(self.data.len() as u32)?;
1060        writer.write_all(&self.data)?;
1061
1062        Ok(())
1063    }
1064
1065    /// Deserialize from bytes
1066    pub fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
1067        let quant_byte = reader.read_u8()?;
1068        let quantization = WeightQuantization::from_u8(quant_byte).ok_or_else(|| {
1069            io::Error::new(io::ErrorKind::InvalidData, "Invalid quantization type")
1070        })?;
1071        let scale = reader.read_f32::<LittleEndian>()?;
1072        let min_val = reader.read_f32::<LittleEndian>()?;
1073        let doc_count = reader.read_u32::<LittleEndian>()?;
1074
1075        // Read skip list with block-max support
1076        let skip_list = SparseSkipList::read(reader)?;
1077
1078        let data_len = reader.read_u32::<LittleEndian>()? as usize;
1079        let mut data = vec![0u8; data_len];
1080        reader.read_exact(&mut data)?;
1081
1082        Ok(Self {
1083            quantization,
1084            scale,
1085            min_val,
1086            skip_list,
1087            data,
1088            doc_count,
1089        })
1090    }
1091
1092    /// Number of documents
1093    pub fn doc_count(&self) -> u32 {
1094        self.doc_count
1095    }
1096
1097    /// Number of blocks
1098    pub fn num_blocks(&self) -> usize {
1099        self.skip_list.len()
1100    }
1101
1102    /// Get quantization format
1103    pub fn quantization(&self) -> WeightQuantization {
1104        self.quantization
1105    }
1106
1107    /// Global maximum weight across all blocks (for MaxScore pruning)
1108    pub fn global_max_weight(&self) -> f32 {
1109        self.skip_list.global_max_weight()
1110    }
1111
1112    /// Get block-max weight for a specific block
1113    pub fn block_max_weight(&self, block_idx: usize) -> Option<f32> {
1114        self.skip_list.get(block_idx).map(|e| e.max_weight)
1115    }
1116
1117    /// Compute maximum possible contribution to dot product with given query weight
1118    ///
1119    /// This is used for MaxScore pruning: if `query_weight * global_max_weight < threshold`,
1120    /// this entire dimension can be skipped.
1121    #[inline]
1122    pub fn max_contribution(&self, query_weight: f32) -> f32 {
1123        query_weight * self.skip_list.global_max_weight()
1124    }
1125
1126    /// Create an iterator
1127    pub fn iterator(&self) -> BlockSparsePostingIterator<'_> {
1128        BlockSparsePostingIterator::new(self)
1129    }
1130
1131    /// Approximate size in bytes
1132    pub fn size_bytes(&self) -> usize {
1133        // Header: quantization (1) + scale (4) + min_val (4) + doc_count (4) = 13
1134        // Skip list: count (4) + global_max (4) + entries * (first_doc + last_doc + offset + max_weight) = 4 + 4 + n * 16
1135        // Data: data.len()
1136        13 + 8 + self.skip_list.len() * 16 + self.data.len()
1137    }
1138
1139    /// Concatenate multiple posting lists with doc_id remapping
1140    pub fn concatenate(
1141        sources: &[(BlockSparsePostingList, u32)],
1142        target_quantization: WeightQuantization,
1143    ) -> io::Result<Self> {
1144        // Decode all postings and merge
1145        let mut all_postings: Vec<(DocId, f32)> = Vec::new();
1146
1147        for (source, doc_offset) in sources {
1148            let decoded = source.decode_all()?;
1149            for (doc_id, weight) in decoded {
1150                all_postings.push((doc_id + doc_offset, weight));
1151            }
1152        }
1153
1154        // Re-encode with target quantization
1155        Self::from_postings(&all_postings, target_quantization)
1156    }
1157
1158    /// Decode all postings
1159    pub fn decode_all(&self) -> io::Result<Vec<(DocId, f32)>> {
1160        let mut result = Vec::with_capacity(self.doc_count as usize);
1161        let mut iter = self.iterator();
1162
1163        while iter.doc() != super::TERMINATED {
1164            result.push((iter.doc(), iter.weight()));
1165            iter.advance();
1166        }
1167
1168        Ok(result)
1169    }
1170}
1171
1172/// Iterator over block sparse posting list
1173pub struct BlockSparsePostingIterator<'a> {
1174    posting_list: &'a BlockSparsePostingList,
1175    current_block: usize,
1176    block_postings: Vec<(DocId, f32)>,
1177    position_in_block: usize,
1178    exhausted: bool,
1179}
1180
1181impl<'a> BlockSparsePostingIterator<'a> {
1182    fn new(posting_list: &'a BlockSparsePostingList) -> Self {
1183        let exhausted = posting_list.skip_list.is_empty();
1184        let mut iter = Self {
1185            posting_list,
1186            current_block: 0,
1187            block_postings: Vec::new(),
1188            position_in_block: 0,
1189            exhausted,
1190        };
1191
1192        if !iter.exhausted {
1193            iter.load_block(0);
1194        }
1195
1196        iter
1197    }
1198
1199    fn load_block(&mut self, block_idx: usize) {
1200        let entry = match self.posting_list.skip_list.get(block_idx) {
1201            Some(e) => e,
1202            None => {
1203                self.exhausted = true;
1204                return;
1205            }
1206        };
1207
1208        self.current_block = block_idx;
1209        self.position_in_block = 0;
1210        self.block_postings.clear();
1211
1212        let offset = entry.offset as usize;
1213        let first_doc_id = entry.first_doc;
1214        let data = &self.posting_list.data[offset..];
1215
1216        // Read block header: [count: u16][doc_bit_width: u8]
1217        if data.len() < 3 {
1218            self.exhausted = true;
1219            return;
1220        }
1221        let count = u16::from_le_bytes([data[0], data[1]]) as usize;
1222        let doc_bit_width = RoundedBitWidth::from_u8(data[2]).unwrap_or(RoundedBitWidth::Zero);
1223
1224        // Unpack doc IDs with SIMD-accelerated delta decoding
1225        let doc_bytes = doc_bit_width.bytes_per_value() * count.saturating_sub(1);
1226        let doc_data = &data[3..3 + doc_bytes];
1227        let mut doc_ids = vec![0u32; count];
1228        unpack_deltas_fixed(doc_data, doc_bit_width, first_doc_id, count, &mut doc_ids);
1229
1230        // Weight data starts after doc IDs
1231        let weight_offset = 3 + doc_bytes;
1232        let weight_data = &data[weight_offset..];
1233        let pl = self.posting_list;
1234
1235        // Decode weights based on quantization (batch SIMD where possible)
1236        let weights: Vec<f32> = match pl.quantization {
1237            WeightQuantization::Float32 => {
1238                let mut weights = Vec::with_capacity(count);
1239                let mut reader = weight_data;
1240                for _ in 0..count {
1241                    if reader.len() >= 4 {
1242                        weights.push((&mut reader).read_f32::<LittleEndian>().unwrap_or(0.0));
1243                    } else {
1244                        weights.push(0.0);
1245                    }
1246                }
1247                weights
1248            }
1249            WeightQuantization::Float16 => {
1250                // Use SIMD-accelerated batch conversion via half::slice
1251                use half::slice::HalfFloatSliceExt;
1252                let mut f16_slice: Vec<half::f16> = Vec::with_capacity(count);
1253                for i in 0..count {
1254                    let offset = i * 2;
1255                    if offset + 2 <= weight_data.len() {
1256                        let bits =
1257                            u16::from_le_bytes([weight_data[offset], weight_data[offset + 1]]);
1258                        f16_slice.push(half::f16::from_bits(bits));
1259                    } else {
1260                        f16_slice.push(half::f16::ZERO);
1261                    }
1262                }
1263                let mut weights = vec![0.0f32; count];
1264                f16_slice.convert_to_f32_slice(&mut weights);
1265                weights
1266            }
1267            WeightQuantization::UInt8 => {
1268                let mut weights = Vec::with_capacity(count);
1269                for i in 0..count {
1270                    if i < weight_data.len() {
1271                        weights.push(weight_data[i] as f32 * pl.scale + pl.min_val);
1272                    } else {
1273                        weights.push(0.0);
1274                    }
1275                }
1276                weights
1277            }
1278            WeightQuantization::UInt4 => {
1279                let mut weights = Vec::with_capacity(count);
1280                for i in 0..count {
1281                    let byte_idx = i / 2;
1282                    if byte_idx < weight_data.len() {
1283                        let byte = weight_data[byte_idx];
1284                        let quantized = if i % 2 == 0 {
1285                            byte & 0x0F
1286                        } else {
1287                            (byte >> 4) & 0x0F
1288                        };
1289                        weights.push(quantized as f32 * pl.scale + pl.min_val);
1290                    } else {
1291                        weights.push(0.0);
1292                    }
1293                }
1294                weights
1295            }
1296        };
1297
1298        // Combine doc_ids and weights into block_postings
1299        for (doc_id, weight) in doc_ids.into_iter().zip(weights.into_iter()) {
1300            self.block_postings.push((doc_id, weight));
1301        }
1302    }
1303
1304    /// Check if iterator is exhausted
1305    #[inline]
1306    pub fn is_exhausted(&self) -> bool {
1307        self.exhausted
1308    }
1309
1310    /// Current document ID
1311    pub fn doc(&self) -> DocId {
1312        if self.exhausted {
1313            super::TERMINATED
1314        } else if self.position_in_block < self.block_postings.len() {
1315            self.block_postings[self.position_in_block].0
1316        } else {
1317            super::TERMINATED
1318        }
1319    }
1320
1321    /// Current weight
1322    pub fn weight(&self) -> f32 {
1323        if self.exhausted || self.position_in_block >= self.block_postings.len() {
1324            0.0
1325        } else {
1326            self.block_postings[self.position_in_block].1
1327        }
1328    }
1329
1330    /// Get current block's maximum weight (for Block-Max optimization)
1331    ///
1332    /// Returns the maximum weight of any posting in the current block.
1333    /// Used to compute upper bound contribution: `query_weight * current_block_max_weight()`.
1334    #[inline]
1335    pub fn current_block_max_weight(&self) -> f32 {
1336        if self.exhausted {
1337            0.0
1338        } else {
1339            self.posting_list
1340                .skip_list
1341                .get(self.current_block)
1342                .map(|e| e.max_weight)
1343                .unwrap_or(0.0)
1344        }
1345    }
1346
1347    /// Compute maximum possible contribution from current block
1348    ///
1349    /// For Block-Max WAND: if this value < threshold, skip the entire block.
1350    #[inline]
1351    pub fn current_block_max_contribution(&self, query_weight: f32) -> f32 {
1352        query_weight * self.current_block_max_weight()
1353    }
1354
1355    /// Advance to next posting
1356    pub fn advance(&mut self) -> DocId {
1357        if self.exhausted {
1358            return super::TERMINATED;
1359        }
1360
1361        self.position_in_block += 1;
1362        if self.position_in_block >= self.block_postings.len() {
1363            self.load_block(self.current_block + 1);
1364        }
1365
1366        self.doc()
1367    }
1368
1369    /// Seek to first doc_id >= target
1370    pub fn seek(&mut self, target: DocId) -> DocId {
1371        if self.exhausted {
1372            return super::TERMINATED;
1373        }
1374
1375        // Find target block using shared SkipList
1376        if let Some(block_idx) = self.posting_list.skip_list.find_block(target) {
1377            if block_idx != self.current_block {
1378                self.load_block(block_idx);
1379            }
1380
1381            // Linear search within block
1382            while self.position_in_block < self.block_postings.len() {
1383                if self.block_postings[self.position_in_block].0 >= target {
1384                    return self.doc();
1385                }
1386                self.position_in_block += 1;
1387            }
1388
1389            // Try next block
1390            self.load_block(self.current_block + 1);
1391            self.seek(target)
1392        } else {
1393            self.exhausted = true;
1394            super::TERMINATED
1395        }
1396    }
1397}
1398
1399#[cfg(test)]
1400mod tests {
1401    use super::*;
1402
1403    #[test]
1404    fn test_sparse_vector_dot_product() {
1405        let v1 = SparseVector::from_entries(&[0, 2, 5], &[1.0, 2.0, 3.0]);
1406        let v2 = SparseVector::from_entries(&[1, 2, 5], &[1.0, 4.0, 2.0]);
1407
1408        // dot = 0 + 2*4 + 3*2 = 14
1409        assert!((v1.dot(&v2) - 14.0).abs() < 1e-6);
1410    }
1411
1412    #[test]
1413    fn test_sparse_posting_list_float32() {
1414        let postings = vec![(0, 1.5), (5, 2.3), (10, 0.8), (100, 3.15)];
1415        let pl = SparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1416
1417        assert_eq!(pl.doc_count(), 4);
1418
1419        let mut iter = pl.iterator();
1420        assert_eq!(iter.doc(), 0);
1421        assert!((iter.weight() - 1.5).abs() < 1e-6);
1422
1423        iter.advance();
1424        assert_eq!(iter.doc(), 5);
1425        assert!((iter.weight() - 2.3).abs() < 1e-6);
1426
1427        iter.advance();
1428        assert_eq!(iter.doc(), 10);
1429
1430        iter.advance();
1431        assert_eq!(iter.doc(), 100);
1432        assert!((iter.weight() - 3.15).abs() < 1e-6);
1433
1434        iter.advance();
1435        assert_eq!(iter.doc(), super::super::TERMINATED);
1436    }
1437
1438    #[test]
1439    fn test_sparse_posting_list_uint8() {
1440        let postings = vec![(0, 0.0), (5, 0.5), (10, 1.0)];
1441        let pl = SparsePostingList::from_postings(&postings, WeightQuantization::UInt8).unwrap();
1442
1443        let decoded = pl.decode_all().unwrap();
1444        assert_eq!(decoded.len(), 3);
1445
1446        // UInt8 quantization should preserve relative ordering
1447        assert!(decoded[0].1 < decoded[1].1);
1448        assert!(decoded[1].1 < decoded[2].1);
1449    }
1450
1451    #[test]
1452    fn test_block_sparse_posting_list() {
1453        // Create enough postings to span multiple blocks
1454        let postings: Vec<(DocId, f32)> = (0..300).map(|i| (i * 2, (i as f32) * 0.1)).collect();
1455
1456        let pl =
1457            BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1458
1459        assert_eq!(pl.doc_count(), 300);
1460        assert!(pl.num_blocks() >= 2);
1461
1462        // Test iteration
1463        let mut iter = pl.iterator();
1464        for (expected_doc, expected_weight) in &postings {
1465            assert_eq!(iter.doc(), *expected_doc);
1466            assert!((iter.weight() - expected_weight).abs() < 1e-6);
1467            iter.advance();
1468        }
1469        assert_eq!(iter.doc(), super::super::TERMINATED);
1470    }
1471
1472    #[test]
1473    fn test_block_sparse_seek() {
1474        let postings: Vec<(DocId, f32)> = (0..500).map(|i| (i * 3, i as f32)).collect();
1475
1476        let pl =
1477            BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1478
1479        let mut iter = pl.iterator();
1480
1481        // Seek to exact match
1482        assert_eq!(iter.seek(300), 300);
1483
1484        // Seek to non-exact (should find next)
1485        assert_eq!(iter.seek(301), 303);
1486
1487        // Seek beyond end
1488        assert_eq!(iter.seek(2000), super::super::TERMINATED);
1489    }
1490
1491    #[test]
1492    fn test_serialization_roundtrip() {
1493        let postings: Vec<(DocId, f32)> = vec![(0, 1.0), (10, 2.0), (100, 3.0)];
1494
1495        for quant in [
1496            WeightQuantization::Float32,
1497            WeightQuantization::Float16,
1498            WeightQuantization::UInt8,
1499        ] {
1500            let pl = BlockSparsePostingList::from_postings(&postings, quant).unwrap();
1501
1502            let mut buffer = Vec::new();
1503            pl.serialize(&mut buffer).unwrap();
1504
1505            let pl2 = BlockSparsePostingList::deserialize(&mut &buffer[..]).unwrap();
1506
1507            assert_eq!(pl.doc_count(), pl2.doc_count());
1508            assert_eq!(pl.quantization(), pl2.quantization());
1509
1510            // Verify iteration produces same results
1511            let mut iter1 = pl.iterator();
1512            let mut iter2 = pl2.iterator();
1513
1514            while iter1.doc() != super::super::TERMINATED {
1515                assert_eq!(iter1.doc(), iter2.doc());
1516                // Allow some tolerance for quantization
1517                assert!((iter1.weight() - iter2.weight()).abs() < 0.1);
1518                iter1.advance();
1519                iter2.advance();
1520            }
1521        }
1522    }
1523
1524    #[test]
1525    fn test_concatenate() {
1526        let postings1: Vec<(DocId, f32)> = vec![(0, 1.0), (5, 2.0)];
1527        let postings2: Vec<(DocId, f32)> = vec![(0, 3.0), (10, 4.0)];
1528
1529        let pl1 =
1530            BlockSparsePostingList::from_postings(&postings1, WeightQuantization::Float32).unwrap();
1531        let pl2 =
1532            BlockSparsePostingList::from_postings(&postings2, WeightQuantization::Float32).unwrap();
1533
1534        // Merge with doc_offset for second list
1535        let merged = BlockSparsePostingList::concatenate(
1536            &[(pl1, 0), (pl2, 100)],
1537            WeightQuantization::Float32,
1538        )
1539        .unwrap();
1540
1541        assert_eq!(merged.doc_count(), 4);
1542
1543        let decoded = merged.decode_all().unwrap();
1544        assert_eq!(decoded[0], (0, 1.0));
1545        assert_eq!(decoded[1], (5, 2.0));
1546        assert_eq!(decoded[2], (100, 3.0)); // 0 + 100 offset
1547        assert_eq!(decoded[3], (110, 4.0)); // 10 + 100 offset
1548    }
1549
1550    #[test]
1551    fn test_sparse_vector_config() {
1552        // Test default config
1553        let default = SparseVectorConfig::default();
1554        assert_eq!(default.index_size, IndexSize::U32);
1555        assert_eq!(default.weight_quantization, WeightQuantization::Float32);
1556        assert_eq!(default.bytes_per_entry(), 8.0); // 4 + 4
1557
1558        // Test SPLADE config
1559        let splade = SparseVectorConfig::splade();
1560        assert_eq!(splade.index_size, IndexSize::U16);
1561        assert_eq!(splade.weight_quantization, WeightQuantization::UInt8);
1562        assert_eq!(splade.bytes_per_entry(), 3.0); // 2 + 1
1563
1564        // Test compact config
1565        let compact = SparseVectorConfig::compact();
1566        assert_eq!(compact.index_size, IndexSize::U16);
1567        assert_eq!(compact.weight_quantization, WeightQuantization::UInt4);
1568        assert_eq!(compact.bytes_per_entry(), 2.5); // 2 + 0.5
1569
1570        // Test serialization roundtrip
1571        let byte = splade.to_byte();
1572        let restored = SparseVectorConfig::from_byte(byte).unwrap();
1573        assert_eq!(restored, splade);
1574    }
1575
1576    #[test]
1577    fn test_index_size() {
1578        assert_eq!(IndexSize::U16.bytes(), 2);
1579        assert_eq!(IndexSize::U32.bytes(), 4);
1580        assert_eq!(IndexSize::U16.max_value(), 65535);
1581        assert_eq!(IndexSize::U32.max_value(), u32::MAX);
1582    }
1583
1584    #[test]
1585    fn test_block_max_weight() {
1586        // Create postings with known max weights per block
1587        // Block 0: docs 0-127, weights 0.0-12.7, max = 12.7
1588        // Block 1: docs 128-255, weights 12.8-25.5, max = 25.5
1589        // Block 2: docs 256-299, weights 25.6-29.9, max = 29.9
1590        let postings: Vec<(DocId, f32)> =
1591            (0..300).map(|i| (i as DocId, (i as f32) * 0.1)).collect();
1592
1593        let pl =
1594            BlockSparsePostingList::from_postings(&postings, WeightQuantization::Float32).unwrap();
1595
1596        // Verify global max weight
1597        assert!((pl.global_max_weight() - 29.9).abs() < 0.01);
1598
1599        // Verify block max weights
1600        assert!(pl.num_blocks() >= 3);
1601
1602        // Block 0: max should be around 12.7 (index 127 * 0.1)
1603        let block0_max = pl.block_max_weight(0).unwrap();
1604        assert!((block0_max - 12.7).abs() < 0.01);
1605
1606        // Block 1: max should be around 25.5 (index 255 * 0.1)
1607        let block1_max = pl.block_max_weight(1).unwrap();
1608        assert!((block1_max - 25.5).abs() < 0.01);
1609
1610        // Block 2: max should be around 29.9 (index 299 * 0.1)
1611        let block2_max = pl.block_max_weight(2).unwrap();
1612        assert!((block2_max - 29.9).abs() < 0.01);
1613
1614        // Test max_contribution
1615        let query_weight = 2.0;
1616        assert!((pl.max_contribution(query_weight) - 59.8).abs() < 0.1);
1617
1618        // Test iterator block_max methods
1619        let mut iter = pl.iterator();
1620        assert!((iter.current_block_max_weight() - 12.7).abs() < 0.01);
1621        assert!((iter.current_block_max_contribution(query_weight) - 25.4).abs() < 0.1);
1622
1623        // Seek to block 1
1624        iter.seek(128);
1625        assert!((iter.current_block_max_weight() - 25.5).abs() < 0.01);
1626    }
1627
1628    #[test]
1629    fn test_sparse_skip_list_serialization() {
1630        let mut skip_list = SparseSkipList::new();
1631        skip_list.push(0, 127, 0, 12.7);
1632        skip_list.push(128, 255, 100, 25.5);
1633        skip_list.push(256, 299, 200, 29.9);
1634
1635        assert_eq!(skip_list.len(), 3);
1636        assert!((skip_list.global_max_weight() - 29.9).abs() < 0.01);
1637
1638        // Serialize
1639        let mut buffer = Vec::new();
1640        skip_list.write(&mut buffer).unwrap();
1641
1642        // Deserialize
1643        let restored = SparseSkipList::read(&mut buffer.as_slice()).unwrap();
1644
1645        assert_eq!(restored.len(), 3);
1646        assert!((restored.global_max_weight() - 29.9).abs() < 0.01);
1647
1648        // Verify entries
1649        let e0 = restored.get(0).unwrap();
1650        assert_eq!(e0.first_doc, 0);
1651        assert_eq!(e0.last_doc, 127);
1652        assert!((e0.max_weight - 12.7).abs() < 0.01);
1653
1654        let e1 = restored.get(1).unwrap();
1655        assert_eq!(e1.first_doc, 128);
1656        assert!((e1.max_weight - 25.5).abs() < 0.01);
1657    }
1658}