hermes_core/structures/postings/sparse/
config.rs1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
7#[repr(u8)]
8pub enum IndexSize {
9 U16 = 0,
11 #[default]
13 U32 = 1,
14}
15
16impl IndexSize {
17 pub fn bytes(&self) -> usize {
19 match self {
20 IndexSize::U16 => 2,
21 IndexSize::U32 => 4,
22 }
23 }
24
25 pub fn max_value(&self) -> u32 {
27 match self {
28 IndexSize::U16 => u16::MAX as u32,
29 IndexSize::U32 => u32::MAX,
30 }
31 }
32
33 pub(crate) fn from_u8(v: u8) -> Option<Self> {
34 match v {
35 0 => Some(IndexSize::U16),
36 1 => Some(IndexSize::U32),
37 _ => None,
38 }
39 }
40}
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
44#[repr(u8)]
45pub enum WeightQuantization {
46 #[default]
48 Float32 = 0,
49 Float16 = 1,
51 UInt8 = 2,
53 UInt4 = 3,
55}
56
57impl WeightQuantization {
58 pub fn bytes_per_weight(&self) -> f32 {
60 match self {
61 WeightQuantization::Float32 => 4.0,
62 WeightQuantization::Float16 => 2.0,
63 WeightQuantization::UInt8 => 1.0,
64 WeightQuantization::UInt4 => 0.5,
65 }
66 }
67
68 pub(crate) fn from_u8(v: u8) -> Option<Self> {
69 match v {
70 0 => Some(WeightQuantization::Float32),
71 1 => Some(WeightQuantization::Float16),
72 2 => Some(WeightQuantization::UInt8),
73 3 => Some(WeightQuantization::UInt4),
74 _ => None,
75 }
76 }
77}
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
81pub enum QueryWeighting {
82 #[default]
84 One,
85 Idf,
87}
88
89#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
91pub struct SparseQueryConfig {
92 #[serde(default, skip_serializing_if = "Option::is_none")]
95 pub tokenizer: Option<String>,
96 #[serde(default)]
98 pub weighting: QueryWeighting,
99 #[serde(default = "default_heap_factor")]
105 pub heap_factor: f32,
106 #[serde(default, skip_serializing_if = "Option::is_none")]
111 pub max_query_dims: Option<usize>,
112}
113
114fn default_heap_factor() -> f32 {
115 1.0
116}
117
118impl Default for SparseQueryConfig {
119 fn default() -> Self {
120 Self {
121 tokenizer: None,
122 weighting: QueryWeighting::One,
123 heap_factor: 1.0,
124 max_query_dims: None,
125 }
126 }
127}
128
129#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
131pub struct SparseVectorConfig {
132 pub index_size: IndexSize,
134 pub weight_quantization: WeightQuantization,
136 #[serde(default)]
139 pub weight_threshold: f32,
140 #[serde(default = "default_block_size")]
143 pub block_size: usize,
144 #[serde(default, skip_serializing_if = "Option::is_none")]
152 pub posting_list_pruning: Option<f32>,
153 #[serde(default, skip_serializing_if = "Option::is_none")]
155 pub query_config: Option<SparseQueryConfig>,
156}
157
158fn default_block_size() -> usize {
159 128
160}
161
162impl Default for SparseVectorConfig {
163 fn default() -> Self {
164 Self {
165 index_size: IndexSize::U32,
166 weight_quantization: WeightQuantization::Float32,
167 weight_threshold: 0.0,
168 block_size: 128,
169 posting_list_pruning: None,
170 query_config: None,
171 }
172 }
173}
174
175impl SparseVectorConfig {
176 pub fn splade() -> Self {
178 Self {
179 index_size: IndexSize::U16,
180 weight_quantization: WeightQuantization::UInt8,
181 weight_threshold: 0.0,
182 block_size: 128,
183 posting_list_pruning: None,
184 query_config: None,
185 }
186 }
187
188 pub fn compact() -> Self {
190 Self {
191 index_size: IndexSize::U16,
192 weight_quantization: WeightQuantization::UInt4,
193 weight_threshold: 0.0,
194 block_size: 128,
195 posting_list_pruning: None,
196 query_config: None,
197 }
198 }
199
200 pub fn full_precision() -> Self {
202 Self {
203 index_size: IndexSize::U32,
204 weight_quantization: WeightQuantization::Float32,
205 weight_threshold: 0.0,
206 block_size: 128,
207 posting_list_pruning: None,
208 query_config: None,
209 }
210 }
211
212 pub fn with_weight_threshold(mut self, threshold: f32) -> Self {
214 self.weight_threshold = threshold;
215 self
216 }
217
218 pub fn with_pruning(mut self, fraction: f32) -> Self {
221 self.posting_list_pruning = Some(fraction.clamp(0.0, 1.0));
222 self
223 }
224
225 pub fn bytes_per_entry(&self) -> f32 {
227 self.index_size.bytes() as f32 + self.weight_quantization.bytes_per_weight()
228 }
229
230 pub fn to_byte(&self) -> u8 {
232 ((self.index_size as u8) << 4) | (self.weight_quantization as u8)
233 }
234
235 pub fn from_byte(b: u8) -> Option<Self> {
238 let index_size = IndexSize::from_u8(b >> 4)?;
239 let weight_quantization = WeightQuantization::from_u8(b & 0x0F)?;
240 Some(Self {
241 index_size,
242 weight_quantization,
243 weight_threshold: 0.0,
244 block_size: 128,
245 posting_list_pruning: None,
246 query_config: None,
247 })
248 }
249
250 pub fn with_block_size(mut self, size: usize) -> Self {
253 self.block_size = size.next_power_of_two();
254 self
255 }
256
257 pub fn with_query_config(mut self, config: SparseQueryConfig) -> Self {
259 self.query_config = Some(config);
260 self
261 }
262}
263
264#[derive(Debug, Clone, Copy, PartialEq)]
266pub struct SparseEntry {
267 pub dim_id: u32,
268 pub weight: f32,
269}
270
271#[derive(Debug, Clone, Default)]
273pub struct SparseVector {
274 pub(super) entries: Vec<SparseEntry>,
275}
276
277impl SparseVector {
278 pub fn new() -> Self {
280 Self {
281 entries: Vec::new(),
282 }
283 }
284
285 pub fn with_capacity(capacity: usize) -> Self {
287 Self {
288 entries: Vec::with_capacity(capacity),
289 }
290 }
291
292 pub fn from_entries(dim_ids: &[u32], weights: &[f32]) -> Self {
294 assert_eq!(dim_ids.len(), weights.len());
295 let mut entries: Vec<SparseEntry> = dim_ids
296 .iter()
297 .zip(weights.iter())
298 .map(|(&dim_id, &weight)| SparseEntry { dim_id, weight })
299 .collect();
300 entries.sort_by_key(|e| e.dim_id);
302 Self { entries }
303 }
304
305 pub fn push(&mut self, dim_id: u32, weight: f32) {
307 debug_assert!(
308 self.entries.is_empty() || self.entries.last().unwrap().dim_id < dim_id,
309 "Entries must be added in sorted order by dim_id"
310 );
311 self.entries.push(SparseEntry { dim_id, weight });
312 }
313
314 pub fn len(&self) -> usize {
316 self.entries.len()
317 }
318
319 pub fn is_empty(&self) -> bool {
321 self.entries.is_empty()
322 }
323
324 pub fn iter(&self) -> impl Iterator<Item = &SparseEntry> {
326 self.entries.iter()
327 }
328
329 pub fn sort_by_dim(&mut self) {
331 self.entries.sort_by_key(|e| e.dim_id);
332 }
333
334 pub fn sort_by_weight_desc(&mut self) {
336 self.entries.sort_by(|a, b| {
337 b.weight
338 .partial_cmp(&a.weight)
339 .unwrap_or(std::cmp::Ordering::Equal)
340 });
341 }
342
343 pub fn top_k(&self, k: usize) -> Vec<SparseEntry> {
345 let mut sorted = self.entries.clone();
346 sorted.sort_by(|a, b| {
347 b.weight
348 .partial_cmp(&a.weight)
349 .unwrap_or(std::cmp::Ordering::Equal)
350 });
351 sorted.truncate(k);
352 sorted
353 }
354
355 pub fn dot(&self, other: &SparseVector) -> f32 {
357 let mut result = 0.0f32;
358 let mut i = 0;
359 let mut j = 0;
360
361 while i < self.entries.len() && j < other.entries.len() {
362 let a = &self.entries[i];
363 let b = &other.entries[j];
364
365 match a.dim_id.cmp(&b.dim_id) {
366 std::cmp::Ordering::Less => i += 1,
367 std::cmp::Ordering::Greater => j += 1,
368 std::cmp::Ordering::Equal => {
369 result += a.weight * b.weight;
370 i += 1;
371 j += 1;
372 }
373 }
374 }
375
376 result
377 }
378
379 pub fn norm_squared(&self) -> f32 {
381 self.entries.iter().map(|e| e.weight * e.weight).sum()
382 }
383
384 pub fn norm(&self) -> f32 {
386 self.norm_squared().sqrt()
387 }
388
389 pub fn filter_by_weight(&self, min_weight: f32) -> Self {
391 let entries: Vec<SparseEntry> = self
392 .entries
393 .iter()
394 .filter(|e| e.weight.abs() >= min_weight)
395 .cloned()
396 .collect();
397 Self { entries }
398 }
399}
400
401impl From<Vec<(u32, f32)>> for SparseVector {
402 fn from(pairs: Vec<(u32, f32)>) -> Self {
403 Self {
404 entries: pairs
405 .into_iter()
406 .map(|(dim_id, weight)| SparseEntry { dim_id, weight })
407 .collect(),
408 }
409 }
410}
411
412impl From<SparseVector> for Vec<(u32, f32)> {
413 fn from(vec: SparseVector) -> Self {
414 vec.entries
415 .into_iter()
416 .map(|e| (e.dim_id, e.weight))
417 .collect()
418 }
419}