hermes_core/tokenizer/
idf_weights.rs1#[cfg(feature = "native")]
8use std::collections::HashMap;
9#[cfg(feature = "native")]
10use std::sync::Arc;
11
12#[cfg(feature = "native")]
13use log::{debug, warn};
14#[cfg(feature = "native")]
15use parking_lot::RwLock;
16
17#[cfg(feature = "native")]
18use crate::Result;
19#[cfg(feature = "native")]
20use crate::error::Error;
21
22#[cfg(feature = "native")]
27pub struct IdfWeights {
28 weights: Vec<f32>,
29}
30
31#[cfg(feature = "native")]
32impl IdfWeights {
33 #[inline]
37 pub fn get(&self, token_id: u32) -> f32 {
38 self.weights.get(token_id as usize).copied().unwrap_or(1.0)
39 }
40
41 fn from_json(json_bytes: &[u8]) -> Result<Self> {
43 let map: HashMap<String, f64> = serde_json::from_slice(json_bytes)
44 .map_err(|e| Error::Tokenizer(format!("Failed to parse idf.json: {}", e)))?;
45
46 if map.is_empty() {
47 return Err(Error::Tokenizer("idf.json is empty".to_string()));
48 }
49
50 let max_id = map
52 .keys()
53 .filter_map(|k| k.parse::<u32>().ok())
54 .max()
55 .ok_or_else(|| Error::Tokenizer("idf.json contains no valid token IDs".to_string()))?;
56
57 let mut weights = vec![1.0f32; (max_id + 1) as usize];
59 let mut count = 0;
60 for (key, value) in &map {
61 if let Ok(id) = key.parse::<u32>() {
62 weights[id as usize] = *value as f32;
63 count += 1;
64 }
65 }
66
67 debug!("Loaded {} IDF weights (vec size: {})", count, weights.len());
68
69 Ok(Self { weights })
70 }
71}
72
73#[cfg(feature = "native")]
75pub struct IdfWeightsCache {
76 cache: RwLock<HashMap<String, Arc<IdfWeights>>>,
77}
78
79#[cfg(feature = "native")]
80impl Default for IdfWeightsCache {
81 fn default() -> Self {
82 Self::new()
83 }
84}
85
86#[cfg(feature = "native")]
87impl IdfWeightsCache {
88 pub fn new() -> Self {
90 Self {
91 cache: RwLock::new(HashMap::new()),
92 }
93 }
94
95 pub fn get_or_load(&self, model_name: &str) -> Option<Arc<IdfWeights>> {
100 {
102 let cache = self.cache.read();
103 if let Some(weights) = cache.get(model_name) {
104 return Some(Arc::clone(weights));
105 }
106 }
107
108 match self.download_and_parse(model_name) {
110 Ok(weights) => {
111 let weights = Arc::new(weights);
112 let mut cache = self.cache.write();
113 cache.insert(model_name.to_string(), Arc::clone(&weights));
114 Some(weights)
115 }
116 Err(e) => {
117 warn!(
118 "Could not load idf.json for model '{}': {}. Falling back to index-derived IDF.",
119 model_name, e
120 );
121 None
122 }
123 }
124 }
125
126 fn download_and_parse(&self, model_name: &str) -> Result<IdfWeights> {
128 let api = hf_hub::api::sync::Api::new()
129 .map_err(|e| Error::Tokenizer(format!("Failed to create HF hub API: {}", e)))?;
130 let repo = api.model(model_name.to_string());
131 let idf_path = repo.get("idf.json").map_err(|e| {
132 Error::Tokenizer(format!(
133 "Failed to download idf.json from '{}': {}",
134 model_name, e
135 ))
136 })?;
137
138 debug!(
139 "Downloaded idf.json from '{}' to {:?}",
140 model_name, idf_path
141 );
142
143 let json_bytes = std::fs::read(&idf_path).map_err(|e| {
144 Error::Tokenizer(format!("Failed to read idf.json at {:?}: {}", idf_path, e))
145 })?;
146
147 IdfWeights::from_json(&json_bytes)
148 }
149
150 pub fn clear(&self) {
152 let mut cache = self.cache.write();
153 cache.clear();
154 }
155}
156
157#[cfg(feature = "native")]
159static IDF_WEIGHTS_CACHE: std::sync::OnceLock<IdfWeightsCache> = std::sync::OnceLock::new();
160
161#[cfg(feature = "native")]
163pub fn idf_weights_cache() -> &'static IdfWeightsCache {
164 IDF_WEIGHTS_CACHE.get_or_init(IdfWeightsCache::new)
165}
166
167#[cfg(test)]
168#[cfg(feature = "native")]
169mod tests {
170 use super::*;
171
172 #[test]
173 fn test_idf_weights_from_json() {
174 let json = br#"{"0": 1.5, "1": 2.0, "5": 0.5, "100": 3.0}"#;
175 let weights = IdfWeights::from_json(json).unwrap();
176
177 assert!((weights.get(0) - 1.5).abs() < f32::EPSILON);
178 assert!((weights.get(1) - 2.0).abs() < f32::EPSILON);
179 assert!((weights.get(5) - 0.5).abs() < f32::EPSILON);
180 assert!((weights.get(100) - 3.0).abs() < f32::EPSILON);
181
182 assert!((weights.get(2) - 1.0).abs() < f32::EPSILON);
184 assert!((weights.get(50) - 1.0).abs() < f32::EPSILON);
185
186 assert!((weights.get(999) - 1.0).abs() < f32::EPSILON);
188 }
189
190 #[test]
191 fn test_idf_weights_empty_json() {
192 let json = br#"{}"#;
193 assert!(IdfWeights::from_json(json).is_err());
194 }
195
196 #[test]
197 fn test_idf_weights_invalid_json() {
198 let json = br#"not json"#;
199 assert!(IdfWeights::from_json(json).is_err());
200 }
201
202 #[test]
203 fn test_idf_weights_cache_structure() {
204 let cache = IdfWeightsCache::new();
205 assert!(cache.cache.read().is_empty());
206 }
207
208 #[test]
209 fn test_idf_weights_cache_miss_graceful() {
210 let cache = IdfWeightsCache::new();
211 let result = cache.get_or_load("nonexistent-model-xyz-12345");
213 assert!(result.is_none());
214 }
215}