Skip to main content

hermes_core/structures/postings/sparse/
config.rs

1//! Configuration types for sparse vector posting lists
2
3use serde::{Deserialize, Serialize};
4
5/// Size of the index (term/dimension ID) in sparse vectors
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
7#[repr(u8)]
8pub enum IndexSize {
9    /// 16-bit index (0-65535), ideal for SPLADE vocabularies
10    U16 = 0,
11    /// 32-bit index (0-4B), for large vocabularies
12    #[default]
13    U32 = 1,
14}
15
16impl IndexSize {
17    /// Bytes per index
18    pub fn bytes(&self) -> usize {
19        match self {
20            IndexSize::U16 => 2,
21            IndexSize::U32 => 4,
22        }
23    }
24
25    /// Maximum value representable
26    pub fn max_value(&self) -> u32 {
27        match self {
28            IndexSize::U16 => u16::MAX as u32,
29            IndexSize::U32 => u32::MAX,
30        }
31    }
32
33    pub(crate) fn from_u8(v: u8) -> Option<Self> {
34        match v {
35            0 => Some(IndexSize::U16),
36            1 => Some(IndexSize::U32),
37            _ => None,
38        }
39    }
40}
41
42/// Quantization format for sparse vector weights
43#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
44#[repr(u8)]
45pub enum WeightQuantization {
46    /// Full 32-bit float precision
47    #[default]
48    Float32 = 0,
49    /// 16-bit float (half precision)
50    Float16 = 1,
51    /// 8-bit unsigned integer with scale factor
52    UInt8 = 2,
53    /// 4-bit unsigned integer with scale factor (packed, 2 per byte)
54    UInt4 = 3,
55}
56
57impl WeightQuantization {
58    /// Bytes per weight (approximate for UInt4)
59    pub fn bytes_per_weight(&self) -> f32 {
60        match self {
61            WeightQuantization::Float32 => 4.0,
62            WeightQuantization::Float16 => 2.0,
63            WeightQuantization::UInt8 => 1.0,
64            WeightQuantization::UInt4 => 0.5,
65        }
66    }
67
68    pub(crate) fn from_u8(v: u8) -> Option<Self> {
69        match v {
70            0 => Some(WeightQuantization::Float32),
71            1 => Some(WeightQuantization::Float16),
72            2 => Some(WeightQuantization::UInt8),
73            3 => Some(WeightQuantization::UInt4),
74            _ => None,
75        }
76    }
77}
78
79/// Query-time weighting strategy for sparse vector queries
80#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
81pub enum QueryWeighting {
82    /// All terms get weight 1.0
83    #[default]
84    One,
85    /// Terms weighted by IDF (inverse document frequency) from the index
86    Idf,
87}
88
89/// Query-time configuration for sparse vectors
90#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
91pub struct SparseQueryConfig {
92    /// HuggingFace tokenizer path/name for query-time tokenization
93    /// Example: "Alibaba-NLP/gte-Qwen2-1.5B-instruct"
94    #[serde(default, skip_serializing_if = "Option::is_none")]
95    pub tokenizer: Option<String>,
96    /// Weighting strategy for tokenized query terms
97    #[serde(default)]
98    pub weighting: QueryWeighting,
99    /// Heap factor for approximate search (SEISMIC-style optimization)
100    /// A block is skipped if its max possible score < heap_factor * threshold
101    /// - 1.0 = exact search (default)
102    /// - 0.8 = approximate, ~20% faster with minor recall loss
103    /// - 0.5 = very approximate, much faster
104    #[serde(default = "default_heap_factor")]
105    pub heap_factor: f32,
106    /// Maximum number of query dimensions to process (query pruning)
107    /// Processes only the top-k dimensions by weight
108    /// - None = process all dimensions (default)
109    /// - Some(10) = process top 10 dimensions only
110    #[serde(default, skip_serializing_if = "Option::is_none")]
111    pub max_query_dims: Option<usize>,
112}
113
114fn default_heap_factor() -> f32 {
115    1.0
116}
117
118impl Default for SparseQueryConfig {
119    fn default() -> Self {
120        Self {
121            tokenizer: None,
122            weighting: QueryWeighting::One,
123            heap_factor: 1.0,
124            max_query_dims: None,
125        }
126    }
127}
128
129/// Configuration for sparse vector storage
130#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
131pub struct SparseVectorConfig {
132    /// Size of dimension/term indices
133    pub index_size: IndexSize,
134    /// Quantization for weights
135    pub weight_quantization: WeightQuantization,
136    /// Minimum weight threshold - weights below this value are not indexed
137    /// This reduces index size and can improve query speed at the cost of recall
138    #[serde(default)]
139    pub weight_threshold: f32,
140    /// Block size for posting lists (must be power of 2, default 128 for SIMD)
141    /// Larger blocks = better compression, smaller blocks = faster seeks
142    #[serde(default = "default_block_size")]
143    pub block_size: usize,
144    /// Static pruning: fraction of postings to keep per inverted list (SEISMIC-style)
145    /// Lists are sorted by weight descending and truncated to top fraction.
146    /// - None = keep all postings (default, exact)
147    /// - Some(0.1) = keep top 10% of postings per dimension
148    ///
149    /// Applied only during initial segment build, not during merge.
150    /// This exploits "concentration of importance" - top entries preserve most of inner product.
151    #[serde(default, skip_serializing_if = "Option::is_none")]
152    pub posting_list_pruning: Option<f32>,
153    /// Query-time configuration (tokenizer, weighting)
154    #[serde(default, skip_serializing_if = "Option::is_none")]
155    pub query_config: Option<SparseQueryConfig>,
156}
157
158fn default_block_size() -> usize {
159    128
160}
161
162impl Default for SparseVectorConfig {
163    fn default() -> Self {
164        Self {
165            index_size: IndexSize::U32,
166            weight_quantization: WeightQuantization::Float32,
167            weight_threshold: 0.0,
168            block_size: 128,
169            posting_list_pruning: None,
170            query_config: None,
171        }
172    }
173}
174
175impl SparseVectorConfig {
176    /// SPLADE-optimized config: u16 indices, int8 weights
177    pub fn splade() -> Self {
178        Self {
179            index_size: IndexSize::U16,
180            weight_quantization: WeightQuantization::UInt8,
181            weight_threshold: 0.0,
182            block_size: 128,
183            posting_list_pruning: None,
184            query_config: None,
185        }
186    }
187
188    /// Compact config: u16 indices, 4-bit weights
189    pub fn compact() -> Self {
190        Self {
191            index_size: IndexSize::U16,
192            weight_quantization: WeightQuantization::UInt4,
193            weight_threshold: 0.0,
194            block_size: 128,
195            posting_list_pruning: None,
196            query_config: None,
197        }
198    }
199
200    /// Full precision config
201    pub fn full_precision() -> Self {
202        Self {
203            index_size: IndexSize::U32,
204            weight_quantization: WeightQuantization::Float32,
205            weight_threshold: 0.0,
206            block_size: 128,
207            posting_list_pruning: None,
208            query_config: None,
209        }
210    }
211
212    /// Set weight threshold (builder pattern)
213    pub fn with_weight_threshold(mut self, threshold: f32) -> Self {
214        self.weight_threshold = threshold;
215        self
216    }
217
218    /// Set posting list pruning fraction (builder pattern)
219    /// e.g., 0.1 = keep top 10% of postings per dimension
220    pub fn with_pruning(mut self, fraction: f32) -> Self {
221        self.posting_list_pruning = Some(fraction.clamp(0.0, 1.0));
222        self
223    }
224
225    /// Bytes per entry (index + weight)
226    pub fn bytes_per_entry(&self) -> f32 {
227        self.index_size.bytes() as f32 + self.weight_quantization.bytes_per_weight()
228    }
229
230    /// Serialize config to a single byte
231    pub fn to_byte(&self) -> u8 {
232        ((self.index_size as u8) << 4) | (self.weight_quantization as u8)
233    }
234
235    /// Deserialize config from a single byte
236    /// Note: weight_threshold, block_size and query_config are not serialized in the byte
237    pub fn from_byte(b: u8) -> Option<Self> {
238        let index_size = IndexSize::from_u8(b >> 4)?;
239        let weight_quantization = WeightQuantization::from_u8(b & 0x0F)?;
240        Some(Self {
241            index_size,
242            weight_quantization,
243            weight_threshold: 0.0,
244            block_size: 128,
245            posting_list_pruning: None,
246            query_config: None,
247        })
248    }
249
250    /// Set block size (builder pattern)
251    /// Must be power of 2, recommended: 64, 128, 256
252    pub fn with_block_size(mut self, size: usize) -> Self {
253        self.block_size = size.next_power_of_two();
254        self
255    }
256
257    /// Set query configuration (builder pattern)
258    pub fn with_query_config(mut self, config: SparseQueryConfig) -> Self {
259        self.query_config = Some(config);
260        self
261    }
262}
263
264/// A sparse vector entry: (dimension_id, weight)
265#[derive(Debug, Clone, Copy, PartialEq)]
266pub struct SparseEntry {
267    pub dim_id: u32,
268    pub weight: f32,
269}
270
271/// Sparse vector representation
272#[derive(Debug, Clone, Default)]
273pub struct SparseVector {
274    pub(super) entries: Vec<SparseEntry>,
275}
276
277impl SparseVector {
278    /// Create a new sparse vector
279    pub fn new() -> Self {
280        Self {
281            entries: Vec::new(),
282        }
283    }
284
285    /// Create with pre-allocated capacity
286    pub fn with_capacity(capacity: usize) -> Self {
287        Self {
288            entries: Vec::with_capacity(capacity),
289        }
290    }
291
292    /// Create from dimension IDs and weights
293    pub fn from_entries(dim_ids: &[u32], weights: &[f32]) -> Self {
294        assert_eq!(dim_ids.len(), weights.len());
295        let mut entries: Vec<SparseEntry> = dim_ids
296            .iter()
297            .zip(weights.iter())
298            .map(|(&dim_id, &weight)| SparseEntry { dim_id, weight })
299            .collect();
300        // Sort by dimension ID for efficient intersection
301        entries.sort_by_key(|e| e.dim_id);
302        Self { entries }
303    }
304
305    /// Add an entry (must maintain sorted order by dim_id)
306    pub fn push(&mut self, dim_id: u32, weight: f32) {
307        debug_assert!(
308            self.entries.is_empty() || self.entries.last().unwrap().dim_id < dim_id,
309            "Entries must be added in sorted order by dim_id"
310        );
311        self.entries.push(SparseEntry { dim_id, weight });
312    }
313
314    /// Number of non-zero entries
315    pub fn len(&self) -> usize {
316        self.entries.len()
317    }
318
319    /// Check if empty
320    pub fn is_empty(&self) -> bool {
321        self.entries.is_empty()
322    }
323
324    /// Iterate over entries
325    pub fn iter(&self) -> impl Iterator<Item = &SparseEntry> {
326        self.entries.iter()
327    }
328
329    /// Sort by dimension ID (required for posting list encoding)
330    pub fn sort_by_dim(&mut self) {
331        self.entries.sort_by_key(|e| e.dim_id);
332    }
333
334    /// Sort by weight descending
335    pub fn sort_by_weight_desc(&mut self) {
336        self.entries.sort_by(|a, b| {
337            b.weight
338                .partial_cmp(&a.weight)
339                .unwrap_or(std::cmp::Ordering::Equal)
340        });
341    }
342
343    /// Get top-k entries by weight
344    pub fn top_k(&self, k: usize) -> Vec<SparseEntry> {
345        let mut sorted = self.entries.clone();
346        sorted.sort_by(|a, b| {
347            b.weight
348                .partial_cmp(&a.weight)
349                .unwrap_or(std::cmp::Ordering::Equal)
350        });
351        sorted.truncate(k);
352        sorted
353    }
354
355    /// Compute dot product with another sparse vector
356    pub fn dot(&self, other: &SparseVector) -> f32 {
357        let mut result = 0.0f32;
358        let mut i = 0;
359        let mut j = 0;
360
361        while i < self.entries.len() && j < other.entries.len() {
362            let a = &self.entries[i];
363            let b = &other.entries[j];
364
365            match a.dim_id.cmp(&b.dim_id) {
366                std::cmp::Ordering::Less => i += 1,
367                std::cmp::Ordering::Greater => j += 1,
368                std::cmp::Ordering::Equal => {
369                    result += a.weight * b.weight;
370                    i += 1;
371                    j += 1;
372                }
373            }
374        }
375
376        result
377    }
378
379    /// L2 norm squared
380    pub fn norm_squared(&self) -> f32 {
381        self.entries.iter().map(|e| e.weight * e.weight).sum()
382    }
383
384    /// L2 norm
385    pub fn norm(&self) -> f32 {
386        self.norm_squared().sqrt()
387    }
388
389    /// Prune dimensions below a weight threshold
390    pub fn filter_by_weight(&self, min_weight: f32) -> Self {
391        let entries: Vec<SparseEntry> = self
392            .entries
393            .iter()
394            .filter(|e| e.weight.abs() >= min_weight)
395            .cloned()
396            .collect();
397        Self { entries }
398    }
399}
400
401impl From<Vec<(u32, f32)>> for SparseVector {
402    fn from(pairs: Vec<(u32, f32)>) -> Self {
403        Self {
404            entries: pairs
405                .into_iter()
406                .map(|(dim_id, weight)| SparseEntry { dim_id, weight })
407                .collect(),
408        }
409    }
410}
411
412impl From<SparseVector> for Vec<(u32, f32)> {
413    fn from(vec: SparseVector) -> Self {
414        vec.entries
415            .into_iter()
416            .map(|e| (e.dim_id, e.weight))
417            .collect()
418    }
419}