hermes_core/tokenizer/
idf_weights.rs1#[cfg(feature = "native")]
14use std::collections::HashMap;
15#[cfg(feature = "native")]
16use std::path::Path;
17#[cfg(feature = "native")]
18use std::sync::Arc;
19
20#[cfg(feature = "native")]
21use log::{debug, warn};
22#[cfg(feature = "native")]
23use parking_lot::RwLock;
24
25#[cfg(feature = "native")]
26use crate::Result;
27#[cfg(feature = "native")]
28use crate::error::Error;
29
30#[cfg(feature = "native")]
35pub struct IdfWeights {
36 weights: Vec<f32>,
37}
38
39#[cfg(feature = "native")]
40impl IdfWeights {
41 #[inline]
45 pub fn get(&self, token_id: u32) -> f32 {
46 self.weights.get(token_id as usize).copied().unwrap_or(1.0)
47 }
48
49 fn from_json_with_tokenizer(
55 json_bytes: &[u8],
56 tokenizer: &tokenizers::Tokenizer,
57 ) -> Result<Self> {
58 let map: HashMap<String, f64> = serde_json::from_slice(json_bytes)
59 .map_err(|e| Error::Tokenizer(format!("Failed to parse idf.json: {}", e)))?;
60
61 if map.is_empty() {
62 return Err(Error::Tokenizer("idf.json is empty".to_string()));
63 }
64
65 let mut resolved: Vec<(u32, f32)> = Vec::with_capacity(map.len());
67 let mut missed = 0u32;
68 for (token_str, value) in &map {
69 if let Some(id) = tokenizer.token_to_id(token_str) {
70 resolved.push((id, *value as f32));
71 } else {
72 missed += 1;
73 }
74 }
75
76 if resolved.is_empty() {
77 return Err(Error::Tokenizer(
78 "idf.json: no tokens could be resolved to IDs via tokenizer".to_string(),
79 ));
80 }
81
82 let max_id = resolved.iter().map(|(id, _)| *id).max().unwrap();
83
84 let mut weights = vec![1.0f32; (max_id + 1) as usize];
86 for &(id, value) in &resolved {
87 weights[id as usize] = value;
88 }
89
90 debug!(
91 "Loaded {} IDF weights via tokenizer (vec size: {}, unresolved: {})",
92 resolved.len(),
93 weights.len(),
94 missed,
95 );
96
97 Ok(Self { weights })
98 }
99}
100
101#[cfg(feature = "native")]
104pub struct IdfWeightsCache {
105 cache: RwLock<HashMap<String, Option<Arc<IdfWeights>>>>,
106}
107
108#[cfg(feature = "native")]
109impl Default for IdfWeightsCache {
110 fn default() -> Self {
111 Self::new()
112 }
113}
114
115#[cfg(feature = "native")]
116impl IdfWeightsCache {
117 pub fn new() -> Self {
119 Self {
120 cache: RwLock::new(HashMap::new()),
121 }
122 }
123
124 pub fn get_or_load(
134 &self,
135 model_name: &str,
136 cache_dir: Option<&Path>,
137 ) -> Option<Arc<IdfWeights>> {
138 {
140 let cache = self.cache.read();
141 if let Some(entry) = cache.get(model_name) {
142 return entry.as_ref().map(Arc::clone);
143 }
144 }
145
146 match self.load_with_local_cache(model_name, cache_dir) {
148 Ok(weights) => {
149 let weights = Arc::new(weights);
150 let mut cache = self.cache.write();
151 cache.insert(model_name.to_string(), Some(Arc::clone(&weights)));
152 Some(weights)
153 }
154 Err(e) => {
155 warn!(
156 "Could not load idf.json for model '{}': {}. Falling back to index-derived IDF.",
157 model_name, e
158 );
159 let mut cache = self.cache.write();
160 cache.insert(model_name.to_string(), None);
161 None
162 }
163 }
164 }
165
166 fn sanitized_model_name(model_name: &str) -> String {
168 model_name.replace('/', "--")
169 }
170
171 fn local_cache_path(cache_dir: &Path, model_name: &str) -> std::path::PathBuf {
173 cache_dir.join(format!(
174 "idf_{}.json",
175 Self::sanitized_model_name(model_name)
176 ))
177 }
178
179 fn load_with_local_cache(
182 &self,
183 model_name: &str,
184 cache_dir: Option<&Path>,
185 ) -> Result<IdfWeights> {
186 let tokenizer = super::tokenizer_cache().get_or_load(model_name)?;
187
188 if let Some(dir) = cache_dir {
190 let local_path = Self::local_cache_path(dir, model_name);
191 if local_path.exists() {
192 let json_bytes = std::fs::read(&local_path).map_err(|e| {
193 Error::Tokenizer(format!(
194 "Failed to read cached idf.json at {:?}: {}",
195 local_path, e
196 ))
197 })?;
198 debug!(
199 "Loaded idf.json from local cache: {:?} for model '{}'",
200 local_path, model_name
201 );
202 return IdfWeights::from_json_with_tokenizer(&json_bytes, &tokenizer.tokenizer);
203 }
204 }
205
206 let json_bytes = self.download_idf_json(model_name)?;
208
209 if let Some(dir) = cache_dir {
211 let local_path = Self::local_cache_path(dir, model_name);
212 if let Err(e) = std::fs::write(&local_path, &json_bytes) {
213 warn!(
214 "Failed to cache idf.json to {:?}: {} (non-fatal)",
215 local_path, e
216 );
217 } else {
218 debug!(
219 "Cached idf.json to {:?} for model '{}'",
220 local_path, model_name
221 );
222 }
223 }
224
225 IdfWeights::from_json_with_tokenizer(&json_bytes, &tokenizer.tokenizer)
226 }
227
228 fn download_idf_json(&self, model_name: &str) -> Result<Vec<u8>> {
231 let cache = hf_hub::Cache::from_env();
233 let cache_repo = cache.model(model_name.to_string());
234 if let Some(cached_path) = cache_repo.get("idf.json") {
235 debug!(
236 "Loaded idf.json from HF cache: {:?} for model '{}'",
237 cached_path, model_name
238 );
239 return std::fs::read(&cached_path).map_err(|e| {
240 Error::Tokenizer(format!(
241 "Failed to read cached idf.json at {:?}: {}",
242 cached_path, e
243 ))
244 });
245 }
246
247 let api = hf_hub::api::sync::Api::new()
249 .map_err(|e| Error::Tokenizer(format!("Failed to create HF hub API: {}", e)))?;
250 let repo = api.model(model_name.to_string());
251 let idf_path = repo.get("idf.json").map_err(|e| {
252 Error::Tokenizer(format!(
253 "Failed to download idf.json from '{}': {}",
254 model_name, e
255 ))
256 })?;
257
258 debug!(
259 "Downloaded idf.json from '{}' to {:?}",
260 model_name, idf_path
261 );
262
263 std::fs::read(&idf_path).map_err(|e| {
264 Error::Tokenizer(format!("Failed to read idf.json at {:?}: {}", idf_path, e))
265 })
266 }
267
268 pub fn clear(&self) {
270 let mut cache = self.cache.write();
271 cache.clear();
272 }
273}
274
275#[cfg(feature = "native")]
277static IDF_WEIGHTS_CACHE: std::sync::OnceLock<IdfWeightsCache> = std::sync::OnceLock::new();
278
279#[cfg(feature = "native")]
281pub fn idf_weights_cache() -> &'static IdfWeightsCache {
282 IDF_WEIGHTS_CACHE.get_or_init(IdfWeightsCache::new)
283}
284
285#[cfg(test)]
286#[cfg(feature = "native")]
287mod tests {
288 use super::*;
289
290 fn test_tokenizer() -> tokenizers::Tokenizer {
292 use tokenizers::models::wordpiece::WordPiece;
293 let wp = WordPiece::builder()
294 .vocab([
295 ("[UNK]".to_string(), 0),
296 ("hello".to_string(), 1),
297 ("world".to_string(), 2),
298 ("foo".to_string(), 5),
299 ("bar".to_string(), 100),
300 ])
301 .unk_token("[UNK]".into())
302 .build()
303 .unwrap();
304 tokenizers::Tokenizer::new(wp)
305 }
306
307 #[test]
308 fn test_idf_weights_from_json_with_tokenizer() {
309 let json = br#"{"hello": 1.5, "world": 2.0, "foo": 0.5, "bar": 3.0}"#;
310 let tokenizer = test_tokenizer();
311 let weights = IdfWeights::from_json_with_tokenizer(json, &tokenizer).unwrap();
312
313 assert!((weights.get(1) - 1.5).abs() < f32::EPSILON);
315 assert!((weights.get(2) - 2.0).abs() < f32::EPSILON);
316 assert!((weights.get(5) - 0.5).abs() < f32::EPSILON);
317 assert!((weights.get(100) - 3.0).abs() < f32::EPSILON);
318
319 assert!((weights.get(3) - 1.0).abs() < f32::EPSILON);
321 assert!((weights.get(50) - 1.0).abs() < f32::EPSILON);
322
323 assert!((weights.get(999) - 1.0).abs() < f32::EPSILON);
325 }
326
327 #[test]
328 fn test_idf_weights_unresolvable_tokens_skipped() {
329 let json = br#"{"hello": 1.5, "unknown_xyz": 9.9}"#;
331 let tokenizer = test_tokenizer();
332 let weights = IdfWeights::from_json_with_tokenizer(json, &tokenizer).unwrap();
333
334 assert!((weights.get(1) - 1.5).abs() < f32::EPSILON); }
336
337 #[test]
338 fn test_idf_weights_empty_json() {
339 let json = br#"{}"#;
340 let tokenizer = test_tokenizer();
341 assert!(IdfWeights::from_json_with_tokenizer(json, &tokenizer).is_err());
342 }
343
344 #[test]
345 fn test_idf_weights_invalid_json() {
346 let json = br#"not json"#;
347 let tokenizer = test_tokenizer();
348 assert!(IdfWeights::from_json_with_tokenizer(json, &tokenizer).is_err());
349 }
350
351 #[test]
352 fn test_idf_weights_cache_structure() {
353 let cache = IdfWeightsCache::new();
354 assert!(cache.cache.read().is_empty());
355 }
356
357 #[test]
358 fn test_idf_weights_cache_miss_graceful() {
359 let cache = IdfWeightsCache::new();
360 let result = cache.get_or_load("nonexistent-model-xyz-12345", None);
362 assert!(result.is_none());
363 }
364}