use ahash::AHashMap;
use std::sync::{Arc, RwLock};
use once_cell::sync::Lazy;
use crate::KreuzbergError;
static TOKENIZER_CACHE: Lazy<RwLock<AHashMap<String, Arc<tokenizers::Tokenizer>>>> =
Lazy::new(|| RwLock::new(AHashMap::new()));
pub(crate) fn get_or_init_tokenizer(model: &str) -> crate::Result<Arc<tokenizers::Tokenizer>> {
{
let cache = TOKENIZER_CACHE
.read()
.map_err(|e| KreuzbergError::Other(format!("Tokenizer cache read lock poisoned: {}", e)))?;
if let Some(tok) = cache.get(model) {
return Ok(Arc::clone(tok));
}
}
let mut cache = TOKENIZER_CACHE
.write()
.map_err(|e| KreuzbergError::Other(format!("Tokenizer cache write lock poisoned: {}", e)))?;
if let Some(tok) = cache.get(model) {
return Ok(Arc::clone(tok));
}
let tokenizer = tokenizers::Tokenizer::from_pretrained(model, None)
.map_err(|e| KreuzbergError::validation(format!("Failed to load tokenizer '{}': {}", model, e)))?;
let arc = Arc::new(tokenizer);
cache.insert(model.to_string(), Arc::clone(&arc));
Ok(arc)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_returns_same_instance() {
if std::env::var("CI").is_ok() {
return;
}
let model = "bert-base-uncased";
let tok1 = get_or_init_tokenizer(model).unwrap();
let tok2 = get_or_init_tokenizer(model).unwrap();
assert!(Arc::ptr_eq(&tok1, &tok2));
}
}