Skip to main content

hermes_core/tokenizer/
idf_weights.rs

1//! Pre-computed IDF weights from model's `idf.json`
2//!
3//! Neural sparse models (e.g., opensearch-neural-sparse-encoding-multilingual-v1)
4//! ship `idf.json` with IDF values calibrated during training. Using these weights
5//! instead of index-derived IDF produces correct rankings for doc-only models.
6
7#[cfg(feature = "native")]
8use std::collections::HashMap;
9#[cfg(feature = "native")]
10use std::sync::Arc;
11
12#[cfg(feature = "native")]
13use log::{debug, warn};
14#[cfg(feature = "native")]
15use parking_lot::RwLock;
16
17#[cfg(feature = "native")]
18use crate::Result;
19#[cfg(feature = "native")]
20use crate::error::Error;
21
22/// Pre-computed IDF weights indexed by token_id
23///
24/// Stored as a flat `Vec<f32>` for O(1) lookup by token_id.
25/// For mBERT's 105K vocab this uses ~420KB of memory.
26#[cfg(feature = "native")]
27pub struct IdfWeights {
28    weights: Vec<f32>,
29}
30
31#[cfg(feature = "native")]
32impl IdfWeights {
33    /// Get the IDF weight for a token_id
34    ///
35    /// Returns 1.0 for out-of-range token_ids (neutral weight).
36    #[inline]
37    pub fn get(&self, token_id: u32) -> f32 {
38        self.weights.get(token_id as usize).copied().unwrap_or(1.0)
39    }
40
41    /// Load IDF weights from a JSON object `{"token_id_str": float, ...}`
42    fn from_json(json_bytes: &[u8]) -> Result<Self> {
43        let map: HashMap<String, f64> = serde_json::from_slice(json_bytes)
44            .map_err(|e| Error::Tokenizer(format!("Failed to parse idf.json: {}", e)))?;
45
46        if map.is_empty() {
47            return Err(Error::Tokenizer("idf.json is empty".to_string()));
48        }
49
50        // Find max token_id to size the vector
51        let max_id = map
52            .keys()
53            .filter_map(|k| k.parse::<u32>().ok())
54            .max()
55            .ok_or_else(|| Error::Tokenizer("idf.json contains no valid token IDs".to_string()))?;
56
57        // Initialize with 1.0 (neutral weight) for unmapped tokens
58        let mut weights = vec![1.0f32; (max_id + 1) as usize];
59        let mut count = 0;
60        for (key, value) in &map {
61            if let Ok(id) = key.parse::<u32>() {
62                weights[id as usize] = *value as f32;
63                count += 1;
64            }
65        }
66
67        debug!("Loaded {} IDF weights (vec size: {})", count, weights.len());
68
69        Ok(Self { weights })
70    }
71}
72
73/// Global cache for IDF weights, keyed by model name
74#[cfg(feature = "native")]
75pub struct IdfWeightsCache {
76    cache: RwLock<HashMap<String, Arc<IdfWeights>>>,
77}
78
79#[cfg(feature = "native")]
80impl Default for IdfWeightsCache {
81    fn default() -> Self {
82        Self::new()
83    }
84}
85
86#[cfg(feature = "native")]
87impl IdfWeightsCache {
88    /// Create a new IDF weights cache
89    pub fn new() -> Self {
90        Self {
91            cache: RwLock::new(HashMap::new()),
92        }
93    }
94
95    /// Get or load IDF weights for a model
96    ///
97    /// Downloads `idf.json` from the HuggingFace model repo if not cached.
98    /// Returns `None` if `idf.json` is not available (graceful fallback).
99    pub fn get_or_load(&self, model_name: &str) -> Option<Arc<IdfWeights>> {
100        // Check cache first
101        {
102            let cache = self.cache.read();
103            if let Some(weights) = cache.get(model_name) {
104                return Some(Arc::clone(weights));
105            }
106        }
107
108        // Try to load from HF hub
109        match self.download_and_parse(model_name) {
110            Ok(weights) => {
111                let weights = Arc::new(weights);
112                let mut cache = self.cache.write();
113                cache.insert(model_name.to_string(), Arc::clone(&weights));
114                Some(weights)
115            }
116            Err(e) => {
117                warn!(
118                    "Could not load idf.json for model '{}': {}. Falling back to index-derived IDF.",
119                    model_name, e
120                );
121                None
122            }
123        }
124    }
125
126    /// Download idf.json from HuggingFace hub and parse it
127    fn download_and_parse(&self, model_name: &str) -> Result<IdfWeights> {
128        let api = hf_hub::api::sync::Api::new()
129            .map_err(|e| Error::Tokenizer(format!("Failed to create HF hub API: {}", e)))?;
130        let repo = api.model(model_name.to_string());
131        let idf_path = repo.get("idf.json").map_err(|e| {
132            Error::Tokenizer(format!(
133                "Failed to download idf.json from '{}': {}",
134                model_name, e
135            ))
136        })?;
137
138        debug!(
139            "Downloaded idf.json from '{}' to {:?}",
140            model_name, idf_path
141        );
142
143        let json_bytes = std::fs::read(&idf_path).map_err(|e| {
144            Error::Tokenizer(format!("Failed to read idf.json at {:?}: {}", idf_path, e))
145        })?;
146
147        IdfWeights::from_json(&json_bytes)
148    }
149
150    /// Clear the cache
151    pub fn clear(&self) {
152        let mut cache = self.cache.write();
153        cache.clear();
154    }
155}
156
157/// Global IDF weights cache instance
158#[cfg(feature = "native")]
159static IDF_WEIGHTS_CACHE: std::sync::OnceLock<IdfWeightsCache> = std::sync::OnceLock::new();
160
161/// Get the global IDF weights cache
162#[cfg(feature = "native")]
163pub fn idf_weights_cache() -> &'static IdfWeightsCache {
164    IDF_WEIGHTS_CACHE.get_or_init(IdfWeightsCache::new)
165}
166
167#[cfg(test)]
168#[cfg(feature = "native")]
169mod tests {
170    use super::*;
171
172    #[test]
173    fn test_idf_weights_from_json() {
174        let json = br#"{"0": 1.5, "1": 2.0, "5": 0.5, "100": 3.0}"#;
175        let weights = IdfWeights::from_json(json).unwrap();
176
177        assert!((weights.get(0) - 1.5).abs() < f32::EPSILON);
178        assert!((weights.get(1) - 2.0).abs() < f32::EPSILON);
179        assert!((weights.get(5) - 0.5).abs() < f32::EPSILON);
180        assert!((weights.get(100) - 3.0).abs() < f32::EPSILON);
181
182        // Unmapped tokens get 1.0
183        assert!((weights.get(2) - 1.0).abs() < f32::EPSILON);
184        assert!((weights.get(50) - 1.0).abs() < f32::EPSILON);
185
186        // Out-of-range tokens get 1.0
187        assert!((weights.get(999) - 1.0).abs() < f32::EPSILON);
188    }
189
190    #[test]
191    fn test_idf_weights_empty_json() {
192        let json = br#"{}"#;
193        assert!(IdfWeights::from_json(json).is_err());
194    }
195
196    #[test]
197    fn test_idf_weights_invalid_json() {
198        let json = br#"not json"#;
199        assert!(IdfWeights::from_json(json).is_err());
200    }
201
202    #[test]
203    fn test_idf_weights_cache_structure() {
204        let cache = IdfWeightsCache::new();
205        assert!(cache.cache.read().is_empty());
206    }
207
208    #[test]
209    fn test_idf_weights_cache_miss_graceful() {
210        let cache = IdfWeightsCache::new();
211        // Non-existent model should return None gracefully
212        let result = cache.get_or_load("nonexistent-model-xyz-12345");
213        assert!(result.is_none());
214    }
215}