#[cfg(feature = "native")]
use std::collections::HashMap;
#[cfg(feature = "native")]
use std::path::Path;
#[cfg(feature = "native")]
use std::sync::Arc;
#[cfg(feature = "native")]
use log::{debug, warn};
#[cfg(feature = "native")]
use parking_lot::RwLock;
#[cfg(feature = "native")]
use crate::Result;
#[cfg(feature = "native")]
use crate::error::Error;
#[cfg(feature = "native")]
pub struct IdfWeights {
weights: Vec<f32>,
}
#[cfg(feature = "native")]
impl IdfWeights {
#[inline]
pub fn get(&self, token_id: u32) -> f32 {
self.weights.get(token_id as usize).copied().unwrap_or(1.0)
}
fn from_json_with_tokenizer(
json_bytes: &[u8],
tokenizer: &tokenizers::Tokenizer,
) -> Result<Self> {
let map: HashMap<String, f64> = serde_json::from_slice(json_bytes)
.map_err(|e| Error::Tokenizer(format!("Failed to parse idf.json: {}", e)))?;
if map.is_empty() {
return Err(Error::Tokenizer("idf.json is empty".to_string()));
}
let mut resolved: Vec<(u32, f32)> = Vec::with_capacity(map.len());
let mut missed = 0u32;
for (token_str, value) in &map {
if let Some(id) = tokenizer.token_to_id(token_str) {
resolved.push((id, *value as f32));
} else {
missed += 1;
}
}
if resolved.is_empty() {
return Err(Error::Tokenizer(
"idf.json: no tokens could be resolved to IDs via tokenizer".to_string(),
));
}
let max_id = resolved.iter().map(|(id, _)| *id).max().unwrap();
let mut weights = vec![1.0f32; (max_id + 1) as usize];
for &(id, value) in &resolved {
weights[id as usize] = value;
}
debug!(
"Loaded {} IDF weights via tokenizer (vec size: {}, unresolved: {})",
resolved.len(),
weights.len(),
missed,
);
Ok(Self { weights })
}
}
#[cfg(feature = "native")]
pub struct IdfWeightsCache {
cache: RwLock<HashMap<String, Option<Arc<IdfWeights>>>>,
}
#[cfg(feature = "native")]
impl Default for IdfWeightsCache {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "native")]
impl IdfWeightsCache {
pub fn new() -> Self {
Self {
cache: RwLock::new(HashMap::new()),
}
}
pub fn get_or_load(
&self,
model_name: &str,
cache_dir: Option<&Path>,
) -> Option<Arc<IdfWeights>> {
{
let cache = self.cache.read();
if let Some(entry) = cache.get(model_name) {
return entry.as_ref().map(Arc::clone);
}
}
match self.load_with_local_cache(model_name, cache_dir) {
Ok(weights) => {
let weights = Arc::new(weights);
let mut cache = self.cache.write();
cache.insert(model_name.to_string(), Some(Arc::clone(&weights)));
Some(weights)
}
Err(e) => {
warn!(
"Could not load idf.json for model '{}': {}. Falling back to index-derived IDF.",
model_name, e
);
let mut cache = self.cache.write();
cache.insert(model_name.to_string(), None);
None
}
}
}
fn sanitized_model_name(model_name: &str) -> String {
model_name.replace('/', "--")
}
fn local_cache_path(cache_dir: &Path, model_name: &str) -> std::path::PathBuf {
cache_dir.join(format!(
"idf_{}.json",
Self::sanitized_model_name(model_name)
))
}
fn load_with_local_cache(
&self,
model_name: &str,
cache_dir: Option<&Path>,
) -> Result<IdfWeights> {
let tokenizer = super::tokenizer_cache().get_or_load(model_name)?;
if let Some(dir) = cache_dir {
let local_path = Self::local_cache_path(dir, model_name);
if local_path.exists() {
let json_bytes = std::fs::read(&local_path).map_err(|e| {
Error::Tokenizer(format!(
"Failed to read cached idf.json at {:?}: {}",
local_path, e
))
})?;
debug!(
"Loaded idf.json from local cache: {:?} for model '{}'",
local_path, model_name
);
return IdfWeights::from_json_with_tokenizer(&json_bytes, &tokenizer.tokenizer);
}
}
let json_bytes = self.download_idf_json(model_name)?;
if let Some(dir) = cache_dir {
let local_path = Self::local_cache_path(dir, model_name);
if let Err(e) = std::fs::write(&local_path, &json_bytes) {
warn!(
"Failed to cache idf.json to {:?}: {} (non-fatal)",
local_path, e
);
} else {
debug!(
"Cached idf.json to {:?} for model '{}'",
local_path, model_name
);
}
}
IdfWeights::from_json_with_tokenizer(&json_bytes, &tokenizer.tokenizer)
}
fn download_idf_json(&self, model_name: &str) -> Result<Vec<u8>> {
let cache = hf_hub::Cache::from_env();
let cache_repo = cache.model(model_name.to_string());
if let Some(cached_path) = cache_repo.get("idf.json") {
debug!(
"Loaded idf.json from HF cache: {:?} for model '{}'",
cached_path, model_name
);
return std::fs::read(&cached_path).map_err(|e| {
Error::Tokenizer(format!(
"Failed to read cached idf.json at {:?}: {}",
cached_path, e
))
});
}
let api = hf_hub::api::sync::Api::new()
.map_err(|e| Error::Tokenizer(format!("Failed to create HF hub API: {}", e)))?;
let repo = api.model(model_name.to_string());
let idf_path = repo.get("idf.json").map_err(|e| {
Error::Tokenizer(format!(
"Failed to download idf.json from '{}': {}",
model_name, e
))
})?;
debug!(
"Downloaded idf.json from '{}' to {:?}",
model_name, idf_path
);
std::fs::read(&idf_path).map_err(|e| {
Error::Tokenizer(format!("Failed to read idf.json at {:?}: {}", idf_path, e))
})
}
pub fn clear(&self) {
let mut cache = self.cache.write();
cache.clear();
}
}
#[cfg(feature = "native")]
static IDF_WEIGHTS_CACHE: std::sync::OnceLock<IdfWeightsCache> = std::sync::OnceLock::new();
#[cfg(feature = "native")]
pub fn idf_weights_cache() -> &'static IdfWeightsCache {
IDF_WEIGHTS_CACHE.get_or_init(IdfWeightsCache::new)
}
#[cfg(test)]
#[cfg(feature = "native")]
mod tests {
use super::*;
fn test_tokenizer() -> tokenizers::Tokenizer {
use tokenizers::models::wordpiece::WordPiece;
let wp = WordPiece::builder()
.vocab([
("[UNK]".to_string(), 0),
("hello".to_string(), 1),
("world".to_string(), 2),
("foo".to_string(), 5),
("bar".to_string(), 100),
])
.unk_token("[UNK]".into())
.build()
.unwrap();
tokenizers::Tokenizer::new(wp)
}
#[test]
fn test_idf_weights_from_json_with_tokenizer() {
let json = br#"{"hello": 1.5, "world": 2.0, "foo": 0.5, "bar": 3.0}"#;
let tokenizer = test_tokenizer();
let weights = IdfWeights::from_json_with_tokenizer(json, &tokenizer).unwrap();
assert!((weights.get(1) - 1.5).abs() < f32::EPSILON);
assert!((weights.get(2) - 2.0).abs() < f32::EPSILON);
assert!((weights.get(5) - 0.5).abs() < f32::EPSILON);
assert!((weights.get(100) - 3.0).abs() < f32::EPSILON);
assert!((weights.get(3) - 1.0).abs() < f32::EPSILON);
assert!((weights.get(50) - 1.0).abs() < f32::EPSILON);
assert!((weights.get(999) - 1.0).abs() < f32::EPSILON);
}
#[test]
fn test_idf_weights_unresolvable_tokens_skipped() {
let json = br#"{"hello": 1.5, "unknown_xyz": 9.9}"#;
let tokenizer = test_tokenizer();
let weights = IdfWeights::from_json_with_tokenizer(json, &tokenizer).unwrap();
assert!((weights.get(1) - 1.5).abs() < f32::EPSILON); }
#[test]
fn test_idf_weights_empty_json() {
let json = br#"{}"#;
let tokenizer = test_tokenizer();
assert!(IdfWeights::from_json_with_tokenizer(json, &tokenizer).is_err());
}
#[test]
fn test_idf_weights_invalid_json() {
let json = br#"not json"#;
let tokenizer = test_tokenizer();
assert!(IdfWeights::from_json_with_tokenizer(json, &tokenizer).is_err());
}
#[test]
fn test_idf_weights_cache_structure() {
let cache = IdfWeightsCache::new();
assert!(cache.cache.read().is_empty());
}
#[test]
fn test_idf_weights_cache_miss_graceful() {
let cache = IdfWeightsCache::new();
let result = cache.get_or_load("nonexistent-model-xyz-12345", None);
assert!(result.is_none());
}
}