hermes_core/tokenizer/
idf_weights.rs1#[cfg(feature = "native")]
14use std::collections::HashMap;
15#[cfg(feature = "native")]
16use std::sync::Arc;
17
18#[cfg(feature = "native")]
19use log::{debug, warn};
20#[cfg(feature = "native")]
21use parking_lot::RwLock;
22
23#[cfg(feature = "native")]
24use crate::Result;
25#[cfg(feature = "native")]
26use crate::error::Error;
27
28#[cfg(feature = "native")]
33pub struct IdfWeights {
34 weights: Vec<f32>,
35}
36
37#[cfg(feature = "native")]
38impl IdfWeights {
39 #[inline]
43 pub fn get(&self, token_id: u32) -> f32 {
44 self.weights.get(token_id as usize).copied().unwrap_or(1.0)
45 }
46
47 fn from_json_with_tokenizer(
53 json_bytes: &[u8],
54 tokenizer: &tokenizers::Tokenizer,
55 ) -> Result<Self> {
56 let map: HashMap<String, f64> = serde_json::from_slice(json_bytes)
57 .map_err(|e| Error::Tokenizer(format!("Failed to parse idf.json: {}", e)))?;
58
59 if map.is_empty() {
60 return Err(Error::Tokenizer("idf.json is empty".to_string()));
61 }
62
63 let mut resolved: Vec<(u32, f32)> = Vec::with_capacity(map.len());
65 let mut missed = 0u32;
66 for (token_str, value) in &map {
67 if let Some(id) = tokenizer.token_to_id(token_str) {
68 resolved.push((id, *value as f32));
69 } else {
70 missed += 1;
71 }
72 }
73
74 if resolved.is_empty() {
75 return Err(Error::Tokenizer(
76 "idf.json: no tokens could be resolved to IDs via tokenizer".to_string(),
77 ));
78 }
79
80 let max_id = resolved.iter().map(|(id, _)| *id).max().unwrap();
81
82 let mut weights = vec![1.0f32; (max_id + 1) as usize];
84 for &(id, value) in &resolved {
85 weights[id as usize] = value;
86 }
87
88 debug!(
89 "Loaded {} IDF weights via tokenizer (vec size: {}, unresolved: {})",
90 resolved.len(),
91 weights.len(),
92 missed,
93 );
94
95 Ok(Self { weights })
96 }
97}
98
99#[cfg(feature = "native")]
101pub struct IdfWeightsCache {
102 cache: RwLock<HashMap<String, Arc<IdfWeights>>>,
103}
104
105#[cfg(feature = "native")]
106impl Default for IdfWeightsCache {
107 fn default() -> Self {
108 Self::new()
109 }
110}
111
112#[cfg(feature = "native")]
113impl IdfWeightsCache {
114 pub fn new() -> Self {
116 Self {
117 cache: RwLock::new(HashMap::new()),
118 }
119 }
120
121 pub fn get_or_load(&self, model_name: &str) -> Option<Arc<IdfWeights>> {
126 {
128 let cache = self.cache.read();
129 if let Some(weights) = cache.get(model_name) {
130 return Some(Arc::clone(weights));
131 }
132 }
133
134 match self.download_and_parse(model_name) {
136 Ok(weights) => {
137 let weights = Arc::new(weights);
138 let mut cache = self.cache.write();
139 cache.insert(model_name.to_string(), Arc::clone(&weights));
140 Some(weights)
141 }
142 Err(e) => {
143 warn!(
144 "Could not load idf.json for model '{}': {}. Falling back to index-derived IDF.",
145 model_name, e
146 );
147 None
148 }
149 }
150 }
151
152 fn download_and_parse(&self, model_name: &str) -> Result<IdfWeights> {
156 let api = hf_hub::api::sync::Api::new()
157 .map_err(|e| Error::Tokenizer(format!("Failed to create HF hub API: {}", e)))?;
158 let repo = api.model(model_name.to_string());
159 let idf_path = repo.get("idf.json").map_err(|e| {
160 Error::Tokenizer(format!(
161 "Failed to download idf.json from '{}': {}",
162 model_name, e
163 ))
164 })?;
165
166 debug!(
167 "Downloaded idf.json from '{}' to {:?}",
168 model_name, idf_path
169 );
170
171 let json_bytes = std::fs::read(&idf_path).map_err(|e| {
172 Error::Tokenizer(format!("Failed to read idf.json at {:?}: {}", idf_path, e))
173 })?;
174
175 let tokenizer = super::tokenizer_cache().get_or_load(model_name)?;
177
178 IdfWeights::from_json_with_tokenizer(&json_bytes, &tokenizer.tokenizer)
179 }
180
181 pub fn clear(&self) {
183 let mut cache = self.cache.write();
184 cache.clear();
185 }
186}
187
188#[cfg(feature = "native")]
190static IDF_WEIGHTS_CACHE: std::sync::OnceLock<IdfWeightsCache> = std::sync::OnceLock::new();
191
192#[cfg(feature = "native")]
194pub fn idf_weights_cache() -> &'static IdfWeightsCache {
195 IDF_WEIGHTS_CACHE.get_or_init(IdfWeightsCache::new)
196}
197
198#[cfg(test)]
199#[cfg(feature = "native")]
200mod tests {
201 use super::*;
202
203 fn test_tokenizer() -> tokenizers::Tokenizer {
205 use tokenizers::models::wordpiece::WordPiece;
206 let wp = WordPiece::builder()
207 .vocab([
208 ("[UNK]".to_string(), 0),
209 ("hello".to_string(), 1),
210 ("world".to_string(), 2),
211 ("foo".to_string(), 5),
212 ("bar".to_string(), 100),
213 ])
214 .unk_token("[UNK]".into())
215 .build()
216 .unwrap();
217 tokenizers::Tokenizer::new(wp)
218 }
219
220 #[test]
221 fn test_idf_weights_from_json_with_tokenizer() {
222 let json = br#"{"hello": 1.5, "world": 2.0, "foo": 0.5, "bar": 3.0}"#;
223 let tokenizer = test_tokenizer();
224 let weights = IdfWeights::from_json_with_tokenizer(json, &tokenizer).unwrap();
225
226 assert!((weights.get(1) - 1.5).abs() < f32::EPSILON);
228 assert!((weights.get(2) - 2.0).abs() < f32::EPSILON);
229 assert!((weights.get(5) - 0.5).abs() < f32::EPSILON);
230 assert!((weights.get(100) - 3.0).abs() < f32::EPSILON);
231
232 assert!((weights.get(3) - 1.0).abs() < f32::EPSILON);
234 assert!((weights.get(50) - 1.0).abs() < f32::EPSILON);
235
236 assert!((weights.get(999) - 1.0).abs() < f32::EPSILON);
238 }
239
240 #[test]
241 fn test_idf_weights_unresolvable_tokens_skipped() {
242 let json = br#"{"hello": 1.5, "unknown_xyz": 9.9}"#;
244 let tokenizer = test_tokenizer();
245 let weights = IdfWeights::from_json_with_tokenizer(json, &tokenizer).unwrap();
246
247 assert!((weights.get(1) - 1.5).abs() < f32::EPSILON); }
249
250 #[test]
251 fn test_idf_weights_empty_json() {
252 let json = br#"{}"#;
253 let tokenizer = test_tokenizer();
254 assert!(IdfWeights::from_json_with_tokenizer(json, &tokenizer).is_err());
255 }
256
257 #[test]
258 fn test_idf_weights_invalid_json() {
259 let json = br#"not json"#;
260 let tokenizer = test_tokenizer();
261 assert!(IdfWeights::from_json_with_tokenizer(json, &tokenizer).is_err());
262 }
263
264 #[test]
265 fn test_idf_weights_cache_structure() {
266 let cache = IdfWeightsCache::new();
267 assert!(cache.cache.read().is_empty());
268 }
269
270 #[test]
271 fn test_idf_weights_cache_miss_graceful() {
272 let cache = IdfWeightsCache::new();
273 let result = cache.get_or_load("nonexistent-model-xyz-12345");
275 assert!(result.is_none());
276 }
277}