use std::sync::Arc;
use dashmap::DashMap;
use thiserror::Error;
use tokio::sync::Mutex;
use tracing::{debug, info};
use uuid::Uuid;
use crate::traits::Tokenizer;
#[derive(Debug, Clone)]
pub enum LoadOutcome {
Loaded { id: String },
AlreadyExists { id: String },
}
impl LoadOutcome {
pub fn id(&self) -> &str {
match self {
LoadOutcome::Loaded { id } => id,
LoadOutcome::AlreadyExists { id } => id,
}
}
pub fn is_newly_loaded(&self) -> bool {
matches!(self, LoadOutcome::Loaded { .. })
}
}
#[derive(Debug, Error)]
pub enum LoadError {
#[error("tokenizer name cannot be empty")]
EmptyName,
#[error("tokenizer source cannot be empty")]
EmptySource,
#[error("{0}")]
LoadFailed(String),
}
#[derive(Clone)]
pub struct TokenizerEntry {
pub id: String,
pub name: String,
pub source: String,
pub tokenizer: Arc<dyn Tokenizer>,
}
pub struct TokenizerRegistry {
tokenizers: DashMap<String, TokenizerEntry>,
name_to_id: DashMap<String, String>,
loading_locks: DashMap<String, Arc<Mutex<()>>>,
}
struct LoadingLockGuard<'a> {
locks: &'a DashMap<String, Arc<Mutex<()>>>,
key: String,
}
impl Drop for LoadingLockGuard<'_> {
fn drop(&mut self) {
self.locks.remove(&self.key);
}
}
impl TokenizerRegistry {
pub fn new() -> Self {
Self {
tokenizers: DashMap::new(),
name_to_id: DashMap::new(),
loading_locks: DashMap::new(),
}
}
pub fn generate_id() -> String {
Uuid::now_v7().to_string()
}
pub async fn load<F, Fut>(
&self,
id: &str,
name: &str,
source: &str,
loader: F,
) -> Result<LoadOutcome, LoadError>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = Result<Arc<dyn Tokenizer>, String>>,
{
if name.is_empty() {
return Err(LoadError::EmptyName);
}
if source.is_empty() {
return Err(LoadError::EmptySource);
}
if let Some(existing_id) = self.name_to_id.get(name) {
debug!("Tokenizer already registered for name: {}", name);
return Ok(LoadOutcome::AlreadyExists {
id: existing_id.clone(),
});
}
debug!("Tokenizer cache miss for name: {}", name);
let lock = self
.loading_locks
.entry(name.to_string())
.or_insert_with(|| Arc::new(Mutex::new(())))
.clone();
let _mutex_guard = lock.lock().await;
let _lock_cleanup = LoadingLockGuard {
locks: &self.loading_locks,
key: name.to_string(),
};
if let Some(existing_id) = self.name_to_id.get(name) {
debug!("Tokenizer loaded by another thread for name: {}", name);
return Ok(LoadOutcome::AlreadyExists {
id: existing_id.clone(),
});
}
info!("Loading tokenizer '{}' from source: {}", name, source);
let result = loader().await;
let tokenizer = result.map_err(LoadError::LoadFailed)?;
let entry = TokenizerEntry {
id: id.to_string(),
name: name.to_string(),
source: source.to_string(),
tokenizer,
};
self.tokenizers.insert(id.to_string(), entry);
self.name_to_id.insert(name.to_string(), id.to_string());
info!(
"Successfully registered tokenizer '{}' with id: {}",
name, id
);
Ok(LoadOutcome::Loaded { id: id.to_string() })
}
#[cfg(test)]
pub fn register(
&self,
id: &str,
name: &str,
source: &str,
tokenizer: Arc<dyn Tokenizer>,
) -> Option<String> {
use dashmap::mapref::entry::Entry;
match self.name_to_id.entry(name.to_string()) {
Entry::Occupied(_) => {
debug!(
"Tokenizer already exists for name: {}, skipping registration",
name
);
None
}
Entry::Vacant(name_entry) => {
let entry = TokenizerEntry {
id: id.to_string(),
name: name.to_string(),
source: source.to_string(),
tokenizer,
};
info!("Registering tokenizer '{}' with id: {}", name, id);
self.tokenizers.insert(id.to_string(), entry);
name_entry.insert(id.to_string());
Some(id.to_string())
}
}
}
pub fn get_by_id(&self, id: &str) -> Option<TokenizerEntry> {
self.tokenizers.get(id).map(|e| e.clone())
}
pub fn get_by_name(&self, name: &str) -> Option<TokenizerEntry> {
self.name_to_id
.get(name)
.and_then(|id| self.tokenizers.get(id.as_str()).map(|e| e.clone()))
}
pub fn get(&self, name_or_id: &str) -> Option<Arc<dyn Tokenizer>> {
self.get_by_name(name_or_id)
.or_else(|| self.get_by_id(name_or_id))
.map(|e| e.tokenizer)
}
pub fn contains(&self, name: &str) -> bool {
self.name_to_id.contains_key(name)
}
pub fn contains_id(&self, id: &str) -> bool {
self.tokenizers.contains_key(id)
}
pub fn len(&self) -> usize {
self.tokenizers.len()
}
pub fn is_empty(&self) -> bool {
self.tokenizers.is_empty()
}
pub fn list(&self) -> Vec<TokenizerEntry> {
let mut entries: Vec<TokenizerEntry> =
self.tokenizers.iter().map(|e| e.value().clone()).collect();
entries.sort_by(|a, b| a.name.cmp(&b.name));
entries
}
pub fn remove_by_id(&self, id: &str) -> Option<TokenizerEntry> {
if let Some((_, entry)) = self.tokenizers.remove(id) {
self.name_to_id.remove(&entry.name);
Some(entry)
} else {
None
}
}
pub fn remove(&self, name: &str) -> Option<TokenizerEntry> {
if let Some((_, id)) = self.name_to_id.remove(name) {
self.tokenizers.remove(&id).map(|(_, e)| e)
} else {
None
}
}
pub fn clear(&self) {
self.tokenizers.clear();
self.name_to_id.clear();
self.loading_locks.clear();
}
}
impl Default for TokenizerRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
#[expect(
clippy::disallowed_methods,
reason = "tokio::spawn is fine in unit tests that await all handles"
)]
mod tests {
use std::{
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
time::Duration,
};
use tokio::time::sleep;
use crate::{mock::MockTokenizer, traits::Tokenizer, LoadError, TokenizerRegistry};
#[tokio::test]
async fn test_basic_operations() {
let registry = TokenizerRegistry::new();
assert!(registry.is_empty());
assert_eq!(registry.len(), 0);
assert!(!registry.contains("model1"));
let id = TokenizerRegistry::generate_id();
let outcome = registry
.load(&id, "model1", "path/to/model", || async {
Ok(Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>)
})
.await
.unwrap();
assert!(outcome.is_newly_loaded());
assert_eq!(outcome.id(), id);
assert!(!registry.is_empty());
assert_eq!(registry.len(), 1);
assert!(registry.contains("model1"));
assert!(registry.contains_id(&id));
let entry = registry.get_by_name("model1").unwrap();
assert_eq!(entry.id, id);
assert_eq!(entry.name, "model1");
assert_eq!(entry.source, "path/to/model");
let removed = registry.remove_by_id(&id);
assert!(removed.is_some());
assert!(registry.is_empty());
}
#[tokio::test]
async fn test_load_returns_already_exists() {
let registry = TokenizerRegistry::new();
let id1 = TokenizerRegistry::generate_id();
let id2 = TokenizerRegistry::generate_id();
let outcome1 = registry
.load(&id1, "model1", "source1", || async {
Ok(Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>)
})
.await
.unwrap();
assert!(outcome1.is_newly_loaded());
assert_eq!(outcome1.id(), id1);
let outcome2 = registry
.load(&id2, "model1", "source2", || async {
panic!("Loader should not be called for duplicate name");
})
.await
.unwrap();
assert!(!outcome2.is_newly_loaded());
assert_eq!(outcome2.id(), id1);
assert_eq!(registry.len(), 1);
let entry = registry.get_by_name("model1").unwrap();
assert_eq!(entry.source, "source1");
}
#[tokio::test]
async fn test_load_validation() {
let registry = TokenizerRegistry::new();
let id = TokenizerRegistry::generate_id();
let result = registry
.load(&id, "", "source", || async {
panic!("Loader should not be called for invalid input");
})
.await;
assert!(matches!(result, Err(LoadError::EmptyName)));
let result = registry
.load(&id, "model", "", || async {
panic!("Loader should not be called for invalid input");
})
.await;
assert!(matches!(result, Err(LoadError::EmptySource)));
assert!(registry.is_empty());
}
#[tokio::test]
async fn test_load_prevents_duplicate_loading() {
let registry = Arc::new(TokenizerRegistry::new());
let load_count = Arc::new(AtomicUsize::new(0));
let mut handles = vec![];
for i in 0..10 {
let registry = registry.clone();
let load_count = load_count.clone();
let id = format!("id-{i}");
let handle = tokio::spawn(async move {
registry
.load(&id, "model1", "source", || async {
sleep(Duration::from_millis(10)).await;
load_count.fetch_add(1, Ordering::SeqCst);
Ok(Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>)
})
.await
});
handles.push(handle);
}
for handle in handles {
handle.await.unwrap().unwrap();
}
assert_eq!(
load_count.load(Ordering::SeqCst),
1,
"Tokenizer should be loaded exactly once despite concurrent requests"
);
assert_eq!(registry.len(), 1);
}
#[tokio::test]
async fn test_multiple_models() {
let registry = TokenizerRegistry::new();
for i in 1..=5 {
let model_name = format!("model{i}");
let id = TokenizerRegistry::generate_id();
registry
.load(&id, &model_name, "source", || async {
Ok(Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>)
})
.await
.unwrap();
}
assert_eq!(registry.len(), 5);
assert!(registry.contains("model1"));
assert!(registry.contains("model5"));
assert!(!registry.contains("model6"));
let entries = registry.list();
assert_eq!(entries.len(), 5);
assert!(entries.iter().any(|e| e.name == "model1"));
registry.clear();
assert!(registry.is_empty());
}
#[tokio::test]
async fn test_load_failure() {
let registry = TokenizerRegistry::new();
let id = TokenizerRegistry::generate_id();
let result = registry
.load(&id, "failing_model", "source", || async {
Err("Load failed".to_string())
})
.await;
assert!(result.is_err());
assert!(!registry.contains("failing_model"));
assert!(registry.is_empty());
}
#[tokio::test]
async fn test_get_by_name_and_id() {
let registry = TokenizerRegistry::new();
let id = TokenizerRegistry::generate_id();
registry
.load(&id, "my-model", "hf/model", || async {
Ok(Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>)
})
.await
.unwrap();
let by_name = registry.get_by_name("my-model");
assert!(by_name.is_some());
assert_eq!(by_name.as_ref().unwrap().id, id);
let by_id = registry.get_by_id(&id);
assert!(by_id.is_some());
assert_eq!(by_id.as_ref().unwrap().name, "my-model");
assert!(registry.get("my-model").is_some());
assert!(registry.get(&id).is_some());
}
#[tokio::test]
async fn test_register_only_if_absent() {
let registry = TokenizerRegistry::new();
let id1 = TokenizerRegistry::generate_id();
let id2 = TokenizerRegistry::generate_id();
let tokenizer1 = Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>;
let tokenizer2 = Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>;
let result1 = registry.register(&id1, "model1", "source1", tokenizer1.clone());
assert!(result1.is_some());
assert_eq!(registry.len(), 1);
let result2 = registry.register(&id2, "model1", "source2", tokenizer2.clone());
assert!(result2.is_none());
assert_eq!(registry.len(), 1);
let entry = registry.get_by_name("model1").unwrap();
assert_eq!(entry.id, id1);
assert_eq!(entry.source, "source1");
let id3 = TokenizerRegistry::generate_id();
let result3 = registry.register(&id3, "model2", "source2", tokenizer2);
assert!(result3.is_some());
assert_eq!(registry.len(), 2);
}
#[tokio::test]
async fn test_loading_lock_cleanup_on_panic() {
let registry = Arc::new(TokenizerRegistry::new());
let registry_clone = registry.clone();
let handle = tokio::spawn(async move {
registry_clone
.load(
&TokenizerRegistry::generate_id(),
"panic-model",
"source",
|| async {
panic!("Simulated panic during tokenizer loading");
},
)
.await
});
let result = handle.await;
assert!(result.is_err(), "Task should have panicked");
let id = TokenizerRegistry::generate_id();
let outcome = registry
.load(&id, "panic-model", "source", || async {
Ok(Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>)
})
.await;
assert!(outcome.is_ok(), "Load should succeed after panic cleanup");
assert!(outcome.unwrap().is_newly_loaded());
assert_eq!(registry.len(), 1);
assert!(registry.contains("panic-model"));
}
#[tokio::test]
async fn test_loading_lock_cleanup_on_early_return() {
let registry = Arc::new(TokenizerRegistry::new());
let id1 = TokenizerRegistry::generate_id();
registry
.load(&id1, "model1", "source1", || async {
Ok(Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>)
})
.await
.unwrap();
let id2 = TokenizerRegistry::generate_id();
let outcome = registry
.load(&id2, "model2", "source2", || async {
Ok(Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>)
})
.await
.unwrap();
assert!(outcome.is_newly_loaded());
assert_eq!(registry.len(), 2);
let id3 = TokenizerRegistry::generate_id();
let outcome = registry
.load(&id3, "model1", "source1", || async {
panic!("Loader should not be called for existing model");
})
.await
.unwrap();
assert!(!outcome.is_newly_loaded());
assert_eq!(outcome.id(), id1); }
}