Skip to main content

hermes_core/structures/postings/sparse/
config.rs

1//! Configuration types for sparse vector posting lists
2
3use serde::{Deserialize, Serialize};
4
5/// Size of the index (term/dimension ID) in sparse vectors
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
7#[repr(u8)]
8pub enum IndexSize {
9    /// 16-bit index (0-65535), ideal for SPLADE vocabularies
10    U16 = 0,
11    /// 32-bit index (0-4B), for large vocabularies
12    #[default]
13    U32 = 1,
14}
15
16impl IndexSize {
17    /// Bytes per index
18    pub fn bytes(&self) -> usize {
19        match self {
20            IndexSize::U16 => 2,
21            IndexSize::U32 => 4,
22        }
23    }
24
25    /// Maximum value representable
26    pub fn max_value(&self) -> u32 {
27        match self {
28            IndexSize::U16 => u16::MAX as u32,
29            IndexSize::U32 => u32::MAX,
30        }
31    }
32
33    pub(crate) fn from_u8(v: u8) -> Option<Self> {
34        match v {
35            0 => Some(IndexSize::U16),
36            1 => Some(IndexSize::U32),
37            _ => None,
38        }
39    }
40}
41
42/// Quantization format for sparse vector weights
43///
44/// Research-validated compression/effectiveness trade-offs (Pati, 2025):
45/// - **UInt8**: 4x compression, ~1-2% nDCG@10 loss (RECOMMENDED for production)
46/// - **Float16**: 2x compression, <1% nDCG@10 loss
47/// - **Float32**: No compression, baseline effectiveness
48/// - **UInt4**: 8x compression, ~3-5% nDCG@10 loss (experimental)
49#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
50#[repr(u8)]
51pub enum WeightQuantization {
52    /// Full 32-bit float precision
53    #[default]
54    Float32 = 0,
55    /// 16-bit float (half precision) - 2x compression, <1% effectiveness loss
56    Float16 = 1,
57    /// 8-bit unsigned integer with scale factor - 4x compression, ~1-2% effectiveness loss (RECOMMENDED)
58    UInt8 = 2,
59    /// 4-bit unsigned integer with scale factor (packed, 2 per byte) - 8x compression, ~3-5% effectiveness loss
60    UInt4 = 3,
61}
62
63impl WeightQuantization {
64    /// Bytes per weight (approximate for UInt4)
65    pub fn bytes_per_weight(&self) -> f32 {
66        match self {
67            WeightQuantization::Float32 => 4.0,
68            WeightQuantization::Float16 => 2.0,
69            WeightQuantization::UInt8 => 1.0,
70            WeightQuantization::UInt4 => 0.5,
71        }
72    }
73
74    pub(crate) fn from_u8(v: u8) -> Option<Self> {
75        match v {
76            0 => Some(WeightQuantization::Float32),
77            1 => Some(WeightQuantization::Float16),
78            2 => Some(WeightQuantization::UInt8),
79            3 => Some(WeightQuantization::UInt4),
80            _ => None,
81        }
82    }
83}
84
85/// Query-time weighting strategy for sparse vector queries
86#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
87#[serde(rename_all = "snake_case")]
88pub enum QueryWeighting {
89    /// All terms get weight 1.0
90    #[default]
91    One,
92    /// Terms weighted by IDF (inverse document frequency) from global index statistics
93    /// Uses ln(N/df) where N = total docs, df = docs containing dimension
94    Idf,
95    /// Terms weighted by pre-computed IDF from model's idf.json file
96    /// Loaded from HuggingFace model repo. No fallback to global stats.
97    IdfFile,
98}
99
100/// Query-time configuration for sparse vectors
101///
102/// Research-validated query optimization strategies:
103/// - **weight_threshold (0.01-0.05)**: Drop query dimensions with weight below threshold
104///   - Filters low-IDF tokens that add latency without improving relevance
105/// - **max_query_dims (10-20)**: Process only top-k dimensions by weight
106///   - 30-50% latency reduction with <2% nDCG loss (Qiao et al., 2023)
107/// - **heap_factor (0.8)**: Skip blocks with low max score contribution
108///   - ~20% speedup with minor recall loss (SEISMIC-style)
109#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
110pub struct SparseQueryConfig {
111    /// HuggingFace tokenizer path/name for query-time tokenization
112    /// Example: "Alibaba-NLP/gte-Qwen2-1.5B-instruct"
113    #[serde(default, skip_serializing_if = "Option::is_none")]
114    pub tokenizer: Option<String>,
115    /// Weighting strategy for tokenized query terms
116    #[serde(default)]
117    pub weighting: QueryWeighting,
118    /// Heap factor for approximate search (SEISMIC-style optimization)
119    /// A block is skipped if its max possible score < heap_factor * threshold
120    ///
121    /// Research recommendation:
122    /// - 1.0 = exact search (default)
123    /// - 0.8 = approximate, ~20% faster with minor recall loss (RECOMMENDED for production)
124    /// - 0.5 = very approximate, much faster but higher recall loss
125    #[serde(default = "default_heap_factor")]
126    pub heap_factor: f32,
127    /// Minimum weight for query dimensions (query-time pruning)
128    /// Dimensions with abs(weight) below this threshold are dropped before search.
129    /// Useful for filtering low-IDF tokens that add latency without improving relevance.
130    ///
131    /// - 0.0 = no filtering (default)
132    /// - 0.01-0.05 = recommended for SPLADE/learned sparse models
133    #[serde(default)]
134    pub weight_threshold: f32,
135    /// Maximum number of query dimensions to process (query pruning)
136    /// Processes only the top-k dimensions by weight
137    ///
138    /// Research recommendation (Multiple papers 2022-2024):
139    /// - None = process all dimensions (default, exact)
140    /// - Some(10-20) = process top 10-20 dimensions only (RECOMMENDED for SPLADE)
141    ///   - 30-50% latency reduction
142    ///   - <2% nDCG@10 loss
143    #[serde(default, skip_serializing_if = "Option::is_none")]
144    pub max_query_dims: Option<usize>,
145    /// Fraction of query dimensions to keep (0.0-1.0), same semantics as
146    /// indexing-time `pruning`: sort by abs(weight) descending,
147    /// keep top fraction. None or 1.0 = no pruning.
148    #[serde(default, skip_serializing_if = "Option::is_none")]
149    pub pruning: Option<f32>,
150}
151
152fn default_heap_factor() -> f32 {
153    1.0
154}
155
156impl Default for SparseQueryConfig {
157    fn default() -> Self {
158        Self {
159            tokenizer: None,
160            weighting: QueryWeighting::One,
161            heap_factor: 1.0,
162            weight_threshold: 0.0,
163            max_query_dims: None,
164            pruning: None,
165        }
166    }
167}
168
169/// Configuration for sparse vector storage
170///
171/// Research-validated optimizations for learned sparse retrieval (SPLADE, uniCOIL, etc.):
172/// - **Weight threshold (0.01-0.05)**: Removes ~30-50% of postings with minimal nDCG impact
173/// - **Posting list pruning (0.1)**: Keeps top 10% per dimension, 50-70% index reduction, <1% nDCG loss
174/// - **Query pruning (top 10-20 dims)**: 30-50% latency reduction, <2% nDCG loss
175/// - **UInt8 quantization**: 4x compression, 1-2% nDCG loss (optimal trade-off)
176#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
177pub struct SparseVectorConfig {
178    /// Size of dimension/term indices
179    pub index_size: IndexSize,
180    /// Quantization for weights (see WeightQuantization docs for trade-offs)
181    pub weight_quantization: WeightQuantization,
182    /// Minimum weight threshold - weights below this value are not indexed
183    ///
184    /// Research recommendation (Guo et al., 2022; SPLADE v2):
185    /// - 0.01-0.05 for SPLADE models removes ~30-50% of postings
186    /// - Minimal impact on nDCG@10 (<1% loss)
187    /// - Major reduction in index size and query latency
188    #[serde(default)]
189    pub weight_threshold: f32,
190    /// Block size for posting lists (must be power of 2, default 128 for SIMD)
191    /// Larger blocks = better compression, smaller blocks = faster seeks
192    #[serde(default = "default_block_size")]
193    pub block_size: usize,
194    /// Static pruning: fraction of postings to keep per inverted list (SEISMIC-style)
195    /// Lists are sorted by weight descending and truncated to top fraction.
196    ///
197    /// Research recommendation (SPLADE v2, Formal et al., 2021):
198    /// - None = keep all postings (default, exact)
199    /// - Some(0.1) = keep top 10% of postings per dimension
200    ///   - 50-70% index size reduction
201    ///   - <1% nDCG@10 loss
202    ///   - Exploits "concentration of importance" in learned representations
203    ///
204    /// Applied only during initial segment build, not during merge.
205    #[serde(default, skip_serializing_if = "Option::is_none")]
206    pub pruning: Option<f32>,
207    /// Query-time configuration (tokenizer, weighting)
208    #[serde(default, skip_serializing_if = "Option::is_none")]
209    pub query_config: Option<SparseQueryConfig>,
210}
211
212fn default_block_size() -> usize {
213    128
214}
215
216impl Default for SparseVectorConfig {
217    fn default() -> Self {
218        Self {
219            index_size: IndexSize::U32,
220            weight_quantization: WeightQuantization::Float32,
221            weight_threshold: 0.0,
222            block_size: 128,
223            pruning: None,
224            query_config: None,
225        }
226    }
227}
228
229impl SparseVectorConfig {
230    /// SPLADE-optimized config with research-validated defaults
231    ///
232    /// Optimized for SPLADE, uniCOIL, and similar learned sparse retrieval models.
233    /// Based on research findings from:
234    /// - Pati (2025): UInt8 quantization = 4x compression, 1-2% nDCG loss
235    /// - Formal et al. (2021): SPLADE v2 posting list pruning
236    /// - Qiao et al. (2023): Query dimension pruning and approximate search
237    /// - Guo et al. (2022): Weight thresholding for efficiency
238    ///
239    /// Expected performance vs. full precision baseline:
240    /// - Index size: ~15-25% of original (combined effect of all optimizations)
241    /// - Query latency: 40-60% faster
242    /// - Effectiveness: 2-4% nDCG@10 loss (typically acceptable for production)
243    ///
244    /// Vocabulary: ~30K dimensions (fits in u16)
245    pub fn splade() -> Self {
246        Self {
247            index_size: IndexSize::U16,
248            weight_quantization: WeightQuantization::UInt8,
249            weight_threshold: 0.01, // Remove ~30-50% of low-weight postings
250            block_size: 128,
251            pruning: Some(0.1), // Keep top 10% per dimension
252            query_config: Some(SparseQueryConfig {
253                tokenizer: None,
254                weighting: QueryWeighting::One,
255                heap_factor: 0.8,         // 20% faster approximate search
256                weight_threshold: 0.01,   // Drop low-IDF query tokens
257                max_query_dims: Some(20), // Process top 20 query dimensions
258                pruning: Some(0.1),       // Keep top 10% of query dims
259            }),
260        }
261    }
262
263    /// Compact config: Maximum compression (experimental)
264    ///
265    /// Uses aggressive UInt4 quantization for smallest possible index size.
266    /// Expected trade-offs:
267    /// - Index size: ~10-15% of Float32 baseline
268    /// - Effectiveness: ~3-5% nDCG@10 loss
269    ///
270    /// Recommended for: Memory-constrained environments, cache-heavy workloads
271    pub fn compact() -> Self {
272        Self {
273            index_size: IndexSize::U16,
274            weight_quantization: WeightQuantization::UInt4,
275            weight_threshold: 0.02, // Slightly higher threshold for UInt4
276            block_size: 128,
277            pruning: Some(0.15), // Keep top 15% per dimension
278            query_config: Some(SparseQueryConfig {
279                tokenizer: None,
280                weighting: QueryWeighting::One,
281                heap_factor: 0.7,         // More aggressive approximate search
282                weight_threshold: 0.02,   // Drop low-IDF query tokens
283                max_query_dims: Some(15), // Fewer query dimensions
284                pruning: Some(0.15),      // Keep top 15% of query dims
285            }),
286        }
287    }
288
289    /// Full precision config: No compression, baseline effectiveness
290    ///
291    /// Use for: Research baselines, when effectiveness is critical
292    pub fn full_precision() -> Self {
293        Self {
294            index_size: IndexSize::U32,
295            weight_quantization: WeightQuantization::Float32,
296            weight_threshold: 0.0,
297            block_size: 128,
298            pruning: None,
299            query_config: None,
300        }
301    }
302
303    /// Conservative config: Mild optimizations, minimal effectiveness loss
304    ///
305    /// Balances compression and effectiveness with conservative defaults.
306    /// Expected trade-offs:
307    /// - Index size: ~40-50% of Float32 baseline
308    /// - Query latency: ~20-30% faster
309    /// - Effectiveness: <1% nDCG@10 loss
310    ///
311    /// Recommended for: Production deployments prioritizing effectiveness
312    pub fn conservative() -> Self {
313        Self {
314            index_size: IndexSize::U32,
315            weight_quantization: WeightQuantization::Float16,
316            weight_threshold: 0.005, // Minimal pruning
317            block_size: 128,
318            pruning: None, // No posting list pruning
319            query_config: Some(SparseQueryConfig {
320                tokenizer: None,
321                weighting: QueryWeighting::One,
322                heap_factor: 0.9,         // Nearly exact search
323                weight_threshold: 0.005,  // Minimal query pruning
324                max_query_dims: Some(50), // Process more dimensions
325                pruning: None,            // No fraction-based pruning
326            }),
327        }
328    }
329
330    /// Set weight threshold (builder pattern)
331    pub fn with_weight_threshold(mut self, threshold: f32) -> Self {
332        self.weight_threshold = threshold;
333        self
334    }
335
336    /// Set posting list pruning fraction (builder pattern)
337    /// e.g., 0.1 = keep top 10% of postings per dimension
338    pub fn with_pruning(mut self, fraction: f32) -> Self {
339        self.pruning = Some(fraction.clamp(0.0, 1.0));
340        self
341    }
342
343    /// Bytes per entry (index + weight)
344    pub fn bytes_per_entry(&self) -> f32 {
345        self.index_size.bytes() as f32 + self.weight_quantization.bytes_per_weight()
346    }
347
348    /// Serialize config to a single byte
349    pub fn to_byte(&self) -> u8 {
350        ((self.index_size as u8) << 4) | (self.weight_quantization as u8)
351    }
352
353    /// Deserialize config from a single byte
354    /// Note: weight_threshold, block_size and query_config are not serialized in the byte
355    pub fn from_byte(b: u8) -> Option<Self> {
356        let index_size = IndexSize::from_u8(b >> 4)?;
357        let weight_quantization = WeightQuantization::from_u8(b & 0x0F)?;
358        Some(Self {
359            index_size,
360            weight_quantization,
361            weight_threshold: 0.0,
362            block_size: 128,
363            pruning: None,
364            query_config: None,
365        })
366    }
367
368    /// Set block size (builder pattern)
369    /// Must be power of 2, recommended: 64, 128, 256
370    pub fn with_block_size(mut self, size: usize) -> Self {
371        self.block_size = size.next_power_of_two();
372        self
373    }
374
375    /// Set query configuration (builder pattern)
376    pub fn with_query_config(mut self, config: SparseQueryConfig) -> Self {
377        self.query_config = Some(config);
378        self
379    }
380}
381
382/// A sparse vector entry: (dimension_id, weight)
383#[derive(Debug, Clone, Copy, PartialEq)]
384pub struct SparseEntry {
385    pub dim_id: u32,
386    pub weight: f32,
387}
388
389/// Sparse vector representation
390#[derive(Debug, Clone, Default)]
391pub struct SparseVector {
392    pub(super) entries: Vec<SparseEntry>,
393}
394
395impl SparseVector {
396    /// Create a new sparse vector
397    pub fn new() -> Self {
398        Self {
399            entries: Vec::new(),
400        }
401    }
402
403    /// Create with pre-allocated capacity
404    pub fn with_capacity(capacity: usize) -> Self {
405        Self {
406            entries: Vec::with_capacity(capacity),
407        }
408    }
409
410    /// Create from dimension IDs and weights
411    pub fn from_entries(dim_ids: &[u32], weights: &[f32]) -> Self {
412        assert_eq!(dim_ids.len(), weights.len());
413        let mut entries: Vec<SparseEntry> = dim_ids
414            .iter()
415            .zip(weights.iter())
416            .map(|(&dim_id, &weight)| SparseEntry { dim_id, weight })
417            .collect();
418        // Sort by dimension ID for efficient intersection
419        entries.sort_by_key(|e| e.dim_id);
420        Self { entries }
421    }
422
423    /// Add an entry (must maintain sorted order by dim_id)
424    pub fn push(&mut self, dim_id: u32, weight: f32) {
425        debug_assert!(
426            self.entries.is_empty() || self.entries.last().unwrap().dim_id < dim_id,
427            "Entries must be added in sorted order by dim_id"
428        );
429        self.entries.push(SparseEntry { dim_id, weight });
430    }
431
432    /// Number of non-zero entries
433    pub fn len(&self) -> usize {
434        self.entries.len()
435    }
436
437    /// Check if empty
438    pub fn is_empty(&self) -> bool {
439        self.entries.is_empty()
440    }
441
442    /// Iterate over entries
443    pub fn iter(&self) -> impl Iterator<Item = &SparseEntry> {
444        self.entries.iter()
445    }
446
447    /// Sort by dimension ID (required for posting list encoding)
448    pub fn sort_by_dim(&mut self) {
449        self.entries.sort_by_key(|e| e.dim_id);
450    }
451
452    /// Sort by weight descending
453    pub fn sort_by_weight_desc(&mut self) {
454        self.entries.sort_by(|a, b| {
455            b.weight
456                .partial_cmp(&a.weight)
457                .unwrap_or(std::cmp::Ordering::Equal)
458        });
459    }
460
461    /// Get top-k entries by weight
462    pub fn top_k(&self, k: usize) -> Vec<SparseEntry> {
463        let mut sorted = self.entries.clone();
464        sorted.sort_by(|a, b| {
465            b.weight
466                .partial_cmp(&a.weight)
467                .unwrap_or(std::cmp::Ordering::Equal)
468        });
469        sorted.truncate(k);
470        sorted
471    }
472
473    /// Compute dot product with another sparse vector
474    pub fn dot(&self, other: &SparseVector) -> f32 {
475        let mut result = 0.0f32;
476        let mut i = 0;
477        let mut j = 0;
478
479        while i < self.entries.len() && j < other.entries.len() {
480            let a = &self.entries[i];
481            let b = &other.entries[j];
482
483            match a.dim_id.cmp(&b.dim_id) {
484                std::cmp::Ordering::Less => i += 1,
485                std::cmp::Ordering::Greater => j += 1,
486                std::cmp::Ordering::Equal => {
487                    result += a.weight * b.weight;
488                    i += 1;
489                    j += 1;
490                }
491            }
492        }
493
494        result
495    }
496
497    /// L2 norm squared
498    pub fn norm_squared(&self) -> f32 {
499        self.entries.iter().map(|e| e.weight * e.weight).sum()
500    }
501
502    /// L2 norm
503    pub fn norm(&self) -> f32 {
504        self.norm_squared().sqrt()
505    }
506
507    /// Prune dimensions below a weight threshold
508    pub fn filter_by_weight(&self, min_weight: f32) -> Self {
509        let entries: Vec<SparseEntry> = self
510            .entries
511            .iter()
512            .filter(|e| e.weight.abs() >= min_weight)
513            .cloned()
514            .collect();
515        Self { entries }
516    }
517}
518
519impl From<Vec<(u32, f32)>> for SparseVector {
520    fn from(pairs: Vec<(u32, f32)>) -> Self {
521        Self {
522            entries: pairs
523                .into_iter()
524                .map(|(dim_id, weight)| SparseEntry { dim_id, weight })
525                .collect(),
526        }
527    }
528}
529
530impl From<SparseVector> for Vec<(u32, f32)> {
531    fn from(vec: SparseVector) -> Self {
532        vec.entries
533            .into_iter()
534            .map(|e| (e.dim_id, e.weight))
535            .collect()
536    }
537}