Skip to main content

hermes_core/tokenizer/
idf_weights.rs

1//! Pre-computed IDF weights from model's `idf.json`
2//!
3//! Neural sparse models (e.g., opensearch-neural-sparse-encoding-multilingual-v1)
4//! ship `idf.json` with IDF values calibrated during training. Using these weights
5//! instead of index-derived IDF produces correct rankings for doc-only models.
6//!
7//! The idf.json format maps **token strings** to IDF values:
8//! `{"[PAD]": 0.607, "hemoglobin": 8.12, "##ing": 1.05, ...}`
9//!
10//! At load time, we resolve token strings to numeric IDs via the model's tokenizer,
11//! then store as a flat `Vec<f32>` for O(1) lookup by token_id.
12
13#[cfg(feature = "native")]
14use std::collections::HashMap;
15#[cfg(feature = "native")]
16use std::path::Path;
17#[cfg(feature = "native")]
18use std::sync::Arc;
19
20#[cfg(feature = "native")]
21use log::{debug, warn};
22#[cfg(feature = "native")]
23use parking_lot::RwLock;
24
25#[cfg(feature = "native")]
26use crate::Result;
27#[cfg(feature = "native")]
28use crate::error::Error;
29
30/// Pre-computed IDF weights indexed by token_id
31///
32/// Stored as a flat `Vec<f32>` for O(1) lookup by token_id.
33/// For mBERT's 105K vocab this uses ~420KB of memory.
34#[cfg(feature = "native")]
35pub struct IdfWeights {
36    weights: Vec<f32>,
37}
38
39#[cfg(feature = "native")]
40impl IdfWeights {
41    /// Get the IDF weight for a token_id
42    ///
43    /// Returns 1.0 for out-of-range token_ids (neutral weight).
44    #[inline]
45    pub fn get(&self, token_id: u32) -> f32 {
46        self.weights.get(token_id as usize).copied().unwrap_or(1.0)
47    }
48
49    /// Load IDF weights from a JSON object, resolving token strings to IDs
50    /// via the provided tokenizer.
51    ///
52    /// The idf.json maps token strings → IDF values. We use `token_to_id`
53    /// to convert each key to a numeric token ID for O(1) lookup.
54    fn from_json_with_tokenizer(
55        json_bytes: &[u8],
56        tokenizer: &tokenizers::Tokenizer,
57    ) -> Result<Self> {
58        let map: HashMap<String, f64> = serde_json::from_slice(json_bytes)
59            .map_err(|e| Error::Tokenizer(format!("Failed to parse idf.json: {}", e)))?;
60
61        if map.is_empty() {
62            return Err(Error::Tokenizer("idf.json is empty".to_string()));
63        }
64
65        // Resolve token strings to IDs and find max ID
66        let mut resolved: Vec<(u32, f32)> = Vec::with_capacity(map.len());
67        let mut missed = 0u32;
68        for (token_str, value) in &map {
69            if let Some(id) = tokenizer.token_to_id(token_str) {
70                resolved.push((id, *value as f32));
71            } else {
72                missed += 1;
73            }
74        }
75
76        if resolved.is_empty() {
77            return Err(Error::Tokenizer(
78                "idf.json: no tokens could be resolved to IDs via tokenizer".to_string(),
79            ));
80        }
81
82        let max_id = resolved.iter().map(|(id, _)| *id).max().unwrap();
83
84        // Initialize with 1.0 (neutral weight) for unmapped tokens
85        let mut weights = vec![1.0f32; (max_id + 1) as usize];
86        for &(id, value) in &resolved {
87            weights[id as usize] = value;
88        }
89
90        debug!(
91            "Loaded {} IDF weights via tokenizer (vec size: {}, unresolved: {})",
92            resolved.len(),
93            weights.len(),
94            missed,
95        );
96
97        Ok(Self { weights })
98    }
99}
100
101/// Global cache for IDF weights, keyed by model name.
102/// Caches both successful loads and failures to avoid repeated download attempts.
103#[cfg(feature = "native")]
104pub struct IdfWeightsCache {
105    cache: RwLock<HashMap<String, Option<Arc<IdfWeights>>>>,
106}
107
108#[cfg(feature = "native")]
109impl Default for IdfWeightsCache {
110    fn default() -> Self {
111        Self::new()
112    }
113}
114
115#[cfg(feature = "native")]
116impl IdfWeightsCache {
117    /// Create a new IDF weights cache
118    pub fn new() -> Self {
119        Self {
120            cache: RwLock::new(HashMap::new()),
121        }
122    }
123
124    /// Get or load IDF weights for a model
125    ///
126    /// Lookup order:
127    /// 1. In-memory cache
128    /// 2. Local file in `cache_dir` (e.g. index directory): `idf_<sanitized_model>.json`
129    /// 3. HuggingFace hub download (saved to `cache_dir` on success)
130    ///
131    /// Returns `None` if `idf.json` is not available (graceful fallback).
132    /// Both successes and failures are cached to avoid repeated attempts.
133    pub fn get_or_load(
134        &self,
135        model_name: &str,
136        cache_dir: Option<&Path>,
137    ) -> Option<Arc<IdfWeights>> {
138        // Check in-memory cache first (covers both success and cached failure)
139        {
140            let cache = self.cache.read();
141            if let Some(entry) = cache.get(model_name) {
142                return entry.as_ref().map(Arc::clone);
143            }
144        }
145
146        // Try local cache file, then HF hub
147        match self.load_with_local_cache(model_name, cache_dir) {
148            Ok(weights) => {
149                let weights = Arc::new(weights);
150                let mut cache = self.cache.write();
151                cache.insert(model_name.to_string(), Some(Arc::clone(&weights)));
152                Some(weights)
153            }
154            Err(e) => {
155                warn!(
156                    "Could not load idf.json for model '{}': {}. Falling back to index-derived IDF.",
157                    model_name, e
158                );
159                let mut cache = self.cache.write();
160                cache.insert(model_name.to_string(), None);
161                None
162            }
163        }
164    }
165
166    /// Sanitize model name for use as a filename component
167    fn sanitized_model_name(model_name: &str) -> String {
168        model_name.replace('/', "--")
169    }
170
171    /// Local cache filename for a model's idf.json
172    fn local_cache_path(cache_dir: &Path, model_name: &str) -> std::path::PathBuf {
173        cache_dir.join(format!(
174            "idf_{}.json",
175            Self::sanitized_model_name(model_name)
176        ))
177    }
178
179    /// Try loading from local cache file first, then fall back to HF hub download.
180    /// On successful HF download, saves a copy to the local cache directory.
181    fn load_with_local_cache(
182        &self,
183        model_name: &str,
184        cache_dir: Option<&Path>,
185    ) -> Result<IdfWeights> {
186        let tokenizer = super::tokenizer_cache().get_or_load(model_name)?;
187
188        // Try local cache first
189        if let Some(dir) = cache_dir {
190            let local_path = Self::local_cache_path(dir, model_name);
191            if local_path.exists() {
192                let json_bytes = std::fs::read(&local_path).map_err(|e| {
193                    Error::Tokenizer(format!(
194                        "Failed to read cached idf.json at {:?}: {}",
195                        local_path, e
196                    ))
197                })?;
198                debug!(
199                    "Loaded idf.json from local cache: {:?} for model '{}'",
200                    local_path, model_name
201                );
202                return IdfWeights::from_json_with_tokenizer(&json_bytes, &tokenizer.tokenizer);
203            }
204        }
205
206        // Download from HF hub
207        let json_bytes = self.download_idf_json(model_name)?;
208
209        // Save to local cache for next time
210        if let Some(dir) = cache_dir {
211            let local_path = Self::local_cache_path(dir, model_name);
212            if let Err(e) = std::fs::write(&local_path, &json_bytes) {
213                warn!(
214                    "Failed to cache idf.json to {:?}: {} (non-fatal)",
215                    local_path, e
216                );
217            } else {
218                debug!(
219                    "Cached idf.json to {:?} for model '{}'",
220                    local_path, model_name
221                );
222            }
223        }
224
225        IdfWeights::from_json_with_tokenizer(&json_bytes, &tokenizer.tokenizer)
226    }
227
228    /// Download raw idf.json bytes from HuggingFace hub
229    fn download_idf_json(&self, model_name: &str) -> Result<Vec<u8>> {
230        let api = hf_hub::api::sync::Api::new()
231            .map_err(|e| Error::Tokenizer(format!("Failed to create HF hub API: {}", e)))?;
232        let repo = api.model(model_name.to_string());
233        let idf_path = repo.get("idf.json").map_err(|e| {
234            Error::Tokenizer(format!(
235                "Failed to download idf.json from '{}': {}",
236                model_name, e
237            ))
238        })?;
239
240        debug!(
241            "Downloaded idf.json from '{}' to {:?}",
242            model_name, idf_path
243        );
244
245        std::fs::read(&idf_path).map_err(|e| {
246            Error::Tokenizer(format!("Failed to read idf.json at {:?}: {}", idf_path, e))
247        })
248    }
249
250    /// Clear the cache
251    pub fn clear(&self) {
252        let mut cache = self.cache.write();
253        cache.clear();
254    }
255}
256
257/// Global IDF weights cache instance
258#[cfg(feature = "native")]
259static IDF_WEIGHTS_CACHE: std::sync::OnceLock<IdfWeightsCache> = std::sync::OnceLock::new();
260
261/// Get the global IDF weights cache
262#[cfg(feature = "native")]
263pub fn idf_weights_cache() -> &'static IdfWeightsCache {
264    IDF_WEIGHTS_CACHE.get_or_init(IdfWeightsCache::new)
265}
266
267#[cfg(test)]
268#[cfg(feature = "native")]
269mod tests {
270    use super::*;
271
272    /// Build a minimal tokenizer with a known vocab for testing
273    fn test_tokenizer() -> tokenizers::Tokenizer {
274        use tokenizers::models::wordpiece::WordPiece;
275        let wp = WordPiece::builder()
276            .vocab([
277                ("[UNK]".to_string(), 0),
278                ("hello".to_string(), 1),
279                ("world".to_string(), 2),
280                ("foo".to_string(), 5),
281                ("bar".to_string(), 100),
282            ])
283            .unk_token("[UNK]".into())
284            .build()
285            .unwrap();
286        tokenizers::Tokenizer::new(wp)
287    }
288
289    #[test]
290    fn test_idf_weights_from_json_with_tokenizer() {
291        let json = br#"{"hello": 1.5, "world": 2.0, "foo": 0.5, "bar": 3.0}"#;
292        let tokenizer = test_tokenizer();
293        let weights = IdfWeights::from_json_with_tokenizer(json, &tokenizer).unwrap();
294
295        // hello=1, world=2, foo=5, bar=100
296        assert!((weights.get(1) - 1.5).abs() < f32::EPSILON);
297        assert!((weights.get(2) - 2.0).abs() < f32::EPSILON);
298        assert!((weights.get(5) - 0.5).abs() < f32::EPSILON);
299        assert!((weights.get(100) - 3.0).abs() < f32::EPSILON);
300
301        // Unmapped tokens get 1.0
302        assert!((weights.get(3) - 1.0).abs() < f32::EPSILON);
303        assert!((weights.get(50) - 1.0).abs() < f32::EPSILON);
304
305        // Out-of-range tokens get 1.0
306        assert!((weights.get(999) - 1.0).abs() < f32::EPSILON);
307    }
308
309    #[test]
310    fn test_idf_weights_unresolvable_tokens_skipped() {
311        // "unknown_xyz" is not in the vocab, should be skipped
312        let json = br#"{"hello": 1.5, "unknown_xyz": 9.9}"#;
313        let tokenizer = test_tokenizer();
314        let weights = IdfWeights::from_json_with_tokenizer(json, &tokenizer).unwrap();
315
316        assert!((weights.get(1) - 1.5).abs() < f32::EPSILON); // hello resolved
317    }
318
319    #[test]
320    fn test_idf_weights_empty_json() {
321        let json = br#"{}"#;
322        let tokenizer = test_tokenizer();
323        assert!(IdfWeights::from_json_with_tokenizer(json, &tokenizer).is_err());
324    }
325
326    #[test]
327    fn test_idf_weights_invalid_json() {
328        let json = br#"not json"#;
329        let tokenizer = test_tokenizer();
330        assert!(IdfWeights::from_json_with_tokenizer(json, &tokenizer).is_err());
331    }
332
333    #[test]
334    fn test_idf_weights_cache_structure() {
335        let cache = IdfWeightsCache::new();
336        assert!(cache.cache.read().is_empty());
337    }
338
339    #[test]
340    fn test_idf_weights_cache_miss_graceful() {
341        let cache = IdfWeightsCache::new();
342        // Non-existent model should return None gracefully
343        let result = cache.get_or_load("nonexistent-model-xyz-12345", None);
344        assert!(result.is_none());
345    }
346}