Skip to main content

hermes_core/tokenizer/
idf_weights.rs

1//! Pre-computed IDF weights from model's `idf.json`
2//!
3//! Neural sparse models (e.g., opensearch-neural-sparse-encoding-multilingual-v1)
4//! ship `idf.json` with IDF values calibrated during training. Using these weights
5//! instead of index-derived IDF produces correct rankings for doc-only models.
6//!
7//! The idf.json format maps **token strings** to IDF values:
8//! `{"[PAD]": 0.607, "hemoglobin": 8.12, "##ing": 1.05, ...}`
9//!
10//! At load time, we resolve token strings to numeric IDs via the model's tokenizer,
11//! then store as a flat `Vec<f32>` for O(1) lookup by token_id.
12
13#[cfg(feature = "native")]
14use std::collections::HashMap;
15#[cfg(feature = "native")]
16use std::sync::Arc;
17
18#[cfg(feature = "native")]
19use log::{debug, warn};
20#[cfg(feature = "native")]
21use parking_lot::RwLock;
22
23#[cfg(feature = "native")]
24use crate::Result;
25#[cfg(feature = "native")]
26use crate::error::Error;
27
28/// Pre-computed IDF weights indexed by token_id
29///
30/// Stored as a flat `Vec<f32>` for O(1) lookup by token_id.
31/// For mBERT's 105K vocab this uses ~420KB of memory.
32#[cfg(feature = "native")]
33pub struct IdfWeights {
34    weights: Vec<f32>,
35}
36
37#[cfg(feature = "native")]
38impl IdfWeights {
39    /// Get the IDF weight for a token_id
40    ///
41    /// Returns 1.0 for out-of-range token_ids (neutral weight).
42    #[inline]
43    pub fn get(&self, token_id: u32) -> f32 {
44        self.weights.get(token_id as usize).copied().unwrap_or(1.0)
45    }
46
47    /// Load IDF weights from a JSON object, resolving token strings to IDs
48    /// via the provided tokenizer.
49    ///
50    /// The idf.json maps token strings → IDF values. We use `token_to_id`
51    /// to convert each key to a numeric token ID for O(1) lookup.
52    fn from_json_with_tokenizer(
53        json_bytes: &[u8],
54        tokenizer: &tokenizers::Tokenizer,
55    ) -> Result<Self> {
56        let map: HashMap<String, f64> = serde_json::from_slice(json_bytes)
57            .map_err(|e| Error::Tokenizer(format!("Failed to parse idf.json: {}", e)))?;
58
59        if map.is_empty() {
60            return Err(Error::Tokenizer("idf.json is empty".to_string()));
61        }
62
63        // Resolve token strings to IDs and find max ID
64        let mut resolved: Vec<(u32, f32)> = Vec::with_capacity(map.len());
65        let mut missed = 0u32;
66        for (token_str, value) in &map {
67            if let Some(id) = tokenizer.token_to_id(token_str) {
68                resolved.push((id, *value as f32));
69            } else {
70                missed += 1;
71            }
72        }
73
74        if resolved.is_empty() {
75            return Err(Error::Tokenizer(
76                "idf.json: no tokens could be resolved to IDs via tokenizer".to_string(),
77            ));
78        }
79
80        let max_id = resolved.iter().map(|(id, _)| *id).max().unwrap();
81
82        // Initialize with 1.0 (neutral weight) for unmapped tokens
83        let mut weights = vec![1.0f32; (max_id + 1) as usize];
84        for &(id, value) in &resolved {
85            weights[id as usize] = value;
86        }
87
88        debug!(
89            "Loaded {} IDF weights via tokenizer (vec size: {}, unresolved: {})",
90            resolved.len(),
91            weights.len(),
92            missed,
93        );
94
95        Ok(Self { weights })
96    }
97}
98
99/// Global cache for IDF weights, keyed by model name
100#[cfg(feature = "native")]
101pub struct IdfWeightsCache {
102    cache: RwLock<HashMap<String, Arc<IdfWeights>>>,
103}
104
105#[cfg(feature = "native")]
106impl Default for IdfWeightsCache {
107    fn default() -> Self {
108        Self::new()
109    }
110}
111
112#[cfg(feature = "native")]
113impl IdfWeightsCache {
114    /// Create a new IDF weights cache
115    pub fn new() -> Self {
116        Self {
117            cache: RwLock::new(HashMap::new()),
118        }
119    }
120
121    /// Get or load IDF weights for a model
122    ///
123    /// Downloads `idf.json` from the HuggingFace model repo if not cached.
124    /// Returns `None` if `idf.json` is not available (graceful fallback).
125    pub fn get_or_load(&self, model_name: &str) -> Option<Arc<IdfWeights>> {
126        // Check cache first
127        {
128            let cache = self.cache.read();
129            if let Some(weights) = cache.get(model_name) {
130                return Some(Arc::clone(weights));
131            }
132        }
133
134        // Try to load from HF hub
135        match self.download_and_parse(model_name) {
136            Ok(weights) => {
137                let weights = Arc::new(weights);
138                let mut cache = self.cache.write();
139                cache.insert(model_name.to_string(), Arc::clone(&weights));
140                Some(weights)
141            }
142            Err(e) => {
143                warn!(
144                    "Could not load idf.json for model '{}': {}. Falling back to index-derived IDF.",
145                    model_name, e
146                );
147                None
148            }
149        }
150    }
151
152    /// Download idf.json from HuggingFace hub and parse it
153    ///
154    /// Also loads the model's tokenizer to resolve token strings → IDs.
155    fn download_and_parse(&self, model_name: &str) -> Result<IdfWeights> {
156        let api = hf_hub::api::sync::Api::new()
157            .map_err(|e| Error::Tokenizer(format!("Failed to create HF hub API: {}", e)))?;
158        let repo = api.model(model_name.to_string());
159        let idf_path = repo.get("idf.json").map_err(|e| {
160            Error::Tokenizer(format!(
161                "Failed to download idf.json from '{}': {}",
162                model_name, e
163            ))
164        })?;
165
166        debug!(
167            "Downloaded idf.json from '{}' to {:?}",
168            model_name, idf_path
169        );
170
171        let json_bytes = std::fs::read(&idf_path).map_err(|e| {
172            Error::Tokenizer(format!("Failed to read idf.json at {:?}: {}", idf_path, e))
173        })?;
174
175        // Load tokenizer to resolve token strings → numeric IDs
176        let tokenizer = super::tokenizer_cache().get_or_load(model_name)?;
177
178        IdfWeights::from_json_with_tokenizer(&json_bytes, &tokenizer.tokenizer)
179    }
180
181    /// Clear the cache
182    pub fn clear(&self) {
183        let mut cache = self.cache.write();
184        cache.clear();
185    }
186}
187
188/// Global IDF weights cache instance
189#[cfg(feature = "native")]
190static IDF_WEIGHTS_CACHE: std::sync::OnceLock<IdfWeightsCache> = std::sync::OnceLock::new();
191
192/// Get the global IDF weights cache
193#[cfg(feature = "native")]
194pub fn idf_weights_cache() -> &'static IdfWeightsCache {
195    IDF_WEIGHTS_CACHE.get_or_init(IdfWeightsCache::new)
196}
197
198#[cfg(test)]
199#[cfg(feature = "native")]
200mod tests {
201    use super::*;
202
203    /// Build a minimal tokenizer with a known vocab for testing
204    fn test_tokenizer() -> tokenizers::Tokenizer {
205        use tokenizers::models::wordpiece::WordPiece;
206        let wp = WordPiece::builder()
207            .vocab([
208                ("[UNK]".to_string(), 0),
209                ("hello".to_string(), 1),
210                ("world".to_string(), 2),
211                ("foo".to_string(), 5),
212                ("bar".to_string(), 100),
213            ])
214            .unk_token("[UNK]".into())
215            .build()
216            .unwrap();
217        tokenizers::Tokenizer::new(wp)
218    }
219
220    #[test]
221    fn test_idf_weights_from_json_with_tokenizer() {
222        let json = br#"{"hello": 1.5, "world": 2.0, "foo": 0.5, "bar": 3.0}"#;
223        let tokenizer = test_tokenizer();
224        let weights = IdfWeights::from_json_with_tokenizer(json, &tokenizer).unwrap();
225
226        // hello=1, world=2, foo=5, bar=100
227        assert!((weights.get(1) - 1.5).abs() < f32::EPSILON);
228        assert!((weights.get(2) - 2.0).abs() < f32::EPSILON);
229        assert!((weights.get(5) - 0.5).abs() < f32::EPSILON);
230        assert!((weights.get(100) - 3.0).abs() < f32::EPSILON);
231
232        // Unmapped tokens get 1.0
233        assert!((weights.get(3) - 1.0).abs() < f32::EPSILON);
234        assert!((weights.get(50) - 1.0).abs() < f32::EPSILON);
235
236        // Out-of-range tokens get 1.0
237        assert!((weights.get(999) - 1.0).abs() < f32::EPSILON);
238    }
239
240    #[test]
241    fn test_idf_weights_unresolvable_tokens_skipped() {
242        // "unknown_xyz" is not in the vocab, should be skipped
243        let json = br#"{"hello": 1.5, "unknown_xyz": 9.9}"#;
244        let tokenizer = test_tokenizer();
245        let weights = IdfWeights::from_json_with_tokenizer(json, &tokenizer).unwrap();
246
247        assert!((weights.get(1) - 1.5).abs() < f32::EPSILON); // hello resolved
248    }
249
250    #[test]
251    fn test_idf_weights_empty_json() {
252        let json = br#"{}"#;
253        let tokenizer = test_tokenizer();
254        assert!(IdfWeights::from_json_with_tokenizer(json, &tokenizer).is_err());
255    }
256
257    #[test]
258    fn test_idf_weights_invalid_json() {
259        let json = br#"not json"#;
260        let tokenizer = test_tokenizer();
261        assert!(IdfWeights::from_json_with_tokenizer(json, &tokenizer).is_err());
262    }
263
264    #[test]
265    fn test_idf_weights_cache_structure() {
266        let cache = IdfWeightsCache::new();
267        assert!(cache.cache.read().is_empty());
268    }
269
270    #[test]
271    fn test_idf_weights_cache_miss_graceful() {
272        let cache = IdfWeightsCache::new();
273        // Non-existent model should return None gracefully
274        let result = cache.get_or_load("nonexistent-model-xyz-12345");
275        assert!(result.is_none());
276    }
277}