1use std::sync::Arc;
20
21use dashmap::DashMap;
22use thiserror::Error;
23use tokio::sync::Mutex;
24use tracing::{debug, info};
25use uuid::Uuid;
26
27use crate::traits::Tokenizer;
28
29#[derive(Debug, Clone)]
31pub enum LoadOutcome {
32 Loaded { id: String },
34 AlreadyExists { id: String },
36}
37
38impl LoadOutcome {
39 pub fn id(&self) -> &str {
41 match self {
42 LoadOutcome::Loaded { id } => id,
43 LoadOutcome::AlreadyExists { id } => id,
44 }
45 }
46
47 pub fn is_newly_loaded(&self) -> bool {
49 matches!(self, LoadOutcome::Loaded { .. })
50 }
51}
52
53#[derive(Debug, Error)]
55pub enum LoadError {
56 #[error("tokenizer name cannot be empty")]
58 EmptyName,
59 #[error("tokenizer source cannot be empty")]
61 EmptySource,
62 #[error("{0}")]
64 LoadFailed(String),
65}
66
67#[derive(Clone)]
69pub struct TokenizerEntry {
70 pub id: String,
72 pub name: String,
74 pub source: String,
76 pub tokenizer: Arc<dyn Tokenizer>,
78}
79
80pub struct TokenizerRegistry {
87 tokenizers: DashMap<String, TokenizerEntry>,
89 name_to_id: DashMap<String, String>,
91 loading_locks: DashMap<String, Arc<Mutex<()>>>,
93}
94
95struct LoadingLockGuard<'a> {
98 locks: &'a DashMap<String, Arc<Mutex<()>>>,
99 key: String,
100}
101
102impl Drop for LoadingLockGuard<'_> {
103 fn drop(&mut self) {
104 self.locks.remove(&self.key);
105 }
106}
107
108impl TokenizerRegistry {
109 pub fn new() -> Self {
111 Self {
112 tokenizers: DashMap::new(),
113 name_to_id: DashMap::new(),
114 loading_locks: DashMap::new(),
115 }
116 }
117
118 pub fn generate_id() -> String {
120 Uuid::now_v7().to_string()
121 }
122
123 pub async fn load<F, Fut>(
139 &self,
140 id: &str,
141 name: &str,
142 source: &str,
143 loader: F,
144 ) -> Result<LoadOutcome, LoadError>
145 where
146 F: FnOnce() -> Fut,
147 Fut: std::future::Future<Output = Result<Arc<dyn Tokenizer>, String>>,
148 {
149 if name.is_empty() {
151 return Err(LoadError::EmptyName);
152 }
153 if source.is_empty() {
154 return Err(LoadError::EmptySource);
155 }
156
157 if let Some(existing_id) = self.name_to_id.get(name) {
159 debug!("Tokenizer already registered for name: {}", name);
160 return Ok(LoadOutcome::AlreadyExists {
161 id: existing_id.clone(),
162 });
163 }
164
165 debug!("Tokenizer cache miss for name: {}", name);
166
167 let lock = self
169 .loading_locks
170 .entry(name.to_string())
171 .or_insert_with(|| Arc::new(Mutex::new(())))
172 .clone();
173
174 let _mutex_guard = lock.lock().await;
175 let _lock_cleanup = LoadingLockGuard {
176 locks: &self.loading_locks,
177 key: name.to_string(),
178 };
179
180 if let Some(existing_id) = self.name_to_id.get(name) {
182 debug!("Tokenizer loaded by another thread for name: {}", name);
183 return Ok(LoadOutcome::AlreadyExists {
184 id: existing_id.clone(),
185 });
186 }
187
188 info!("Loading tokenizer '{}' from source: {}", name, source);
190 let result = loader().await;
191
192 let tokenizer = result.map_err(LoadError::LoadFailed)?;
193
194 let entry = TokenizerEntry {
196 id: id.to_string(),
197 name: name.to_string(),
198 source: source.to_string(),
199 tokenizer,
200 };
201
202 self.tokenizers.insert(id.to_string(), entry);
204 self.name_to_id.insert(name.to_string(), id.to_string());
205
206 info!(
207 "Successfully registered tokenizer '{}' with id: {}",
208 name, id
209 );
210
211 Ok(LoadOutcome::Loaded { id: id.to_string() })
212 }
213
214 #[cfg(test)]
225 pub fn register(
226 &self,
227 id: &str,
228 name: &str,
229 source: &str,
230 tokenizer: Arc<dyn Tokenizer>,
231 ) -> Option<String> {
232 use dashmap::mapref::entry::Entry;
233
234 match self.name_to_id.entry(name.to_string()) {
236 Entry::Occupied(_) => {
237 debug!(
238 "Tokenizer already exists for name: {}, skipping registration",
239 name
240 );
241 None
242 }
243 Entry::Vacant(name_entry) => {
244 let entry = TokenizerEntry {
245 id: id.to_string(),
246 name: name.to_string(),
247 source: source.to_string(),
248 tokenizer,
249 };
250
251 info!("Registering tokenizer '{}' with id: {}", name, id);
252 self.tokenizers.insert(id.to_string(), entry);
253 name_entry.insert(id.to_string());
254 Some(id.to_string())
255 }
256 }
257 }
258
259 pub fn get_by_id(&self, id: &str) -> Option<TokenizerEntry> {
261 self.tokenizers.get(id).map(|e| e.clone())
262 }
263
264 pub fn get_by_name(&self, name: &str) -> Option<TokenizerEntry> {
266 self.name_to_id
267 .get(name)
268 .and_then(|id| self.tokenizers.get(id.as_str()).map(|e| e.clone()))
269 }
270
271 pub fn get(&self, name_or_id: &str) -> Option<Arc<dyn Tokenizer>> {
273 self.get_by_name(name_or_id)
274 .or_else(|| self.get_by_id(name_or_id))
275 .map(|e| e.tokenizer)
276 }
277
278 pub fn contains(&self, name: &str) -> bool {
280 self.name_to_id.contains_key(name)
281 }
282
283 pub fn contains_id(&self, id: &str) -> bool {
285 self.tokenizers.contains_key(id)
286 }
287
288 pub fn len(&self) -> usize {
290 self.tokenizers.len()
291 }
292
293 pub fn is_empty(&self) -> bool {
295 self.tokenizers.is_empty()
296 }
297
298 pub fn list(&self) -> Vec<TokenizerEntry> {
300 let mut entries: Vec<TokenizerEntry> =
301 self.tokenizers.iter().map(|e| e.value().clone()).collect();
302 entries.sort_by(|a, b| a.name.cmp(&b.name));
303 entries
304 }
305
306 pub fn remove_by_id(&self, id: &str) -> Option<TokenizerEntry> {
310 if let Some((_, entry)) = self.tokenizers.remove(id) {
311 self.name_to_id.remove(&entry.name);
312 Some(entry)
313 } else {
314 None
315 }
316 }
317
318 pub fn remove(&self, name: &str) -> Option<TokenizerEntry> {
322 if let Some((_, id)) = self.name_to_id.remove(name) {
323 self.tokenizers.remove(&id).map(|(_, e)| e)
324 } else {
325 None
326 }
327 }
328
329 pub fn clear(&self) {
331 self.tokenizers.clear();
332 self.name_to_id.clear();
333 self.loading_locks.clear();
334 }
335}
336
337impl Default for TokenizerRegistry {
338 fn default() -> Self {
339 Self::new()
340 }
341}
342
343#[cfg(test)]
344#[expect(
345 clippy::disallowed_methods,
346 reason = "tokio::spawn is fine in unit tests that await all handles"
347)]
348mod tests {
349 use std::{
350 sync::{
351 atomic::{AtomicUsize, Ordering},
352 Arc,
353 },
354 time::Duration,
355 };
356
357 use tokio::time::sleep;
358
359 use crate::{mock::MockTokenizer, traits::Tokenizer, LoadError, TokenizerRegistry};
360
361 #[tokio::test]
362 async fn test_basic_operations() {
363 let registry = TokenizerRegistry::new();
364
365 assert!(registry.is_empty());
367 assert_eq!(registry.len(), 0);
368 assert!(!registry.contains("model1"));
369
370 let id = TokenizerRegistry::generate_id();
372 let outcome = registry
373 .load(&id, "model1", "path/to/model", || async {
374 Ok(Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>)
375 })
376 .await
377 .unwrap();
378
379 assert!(outcome.is_newly_loaded());
381 assert_eq!(outcome.id(), id);
382
383 assert!(!registry.is_empty());
385 assert_eq!(registry.len(), 1);
386 assert!(registry.contains("model1"));
387 assert!(registry.contains_id(&id));
388
389 let entry = registry.get_by_name("model1").unwrap();
391 assert_eq!(entry.id, id);
392 assert_eq!(entry.name, "model1");
393 assert_eq!(entry.source, "path/to/model");
394
395 let removed = registry.remove_by_id(&id);
397 assert!(removed.is_some());
398 assert!(registry.is_empty());
399 }
400
401 #[tokio::test]
402 async fn test_load_returns_already_exists() {
403 let registry = TokenizerRegistry::new();
404 let id1 = TokenizerRegistry::generate_id();
405 let id2 = TokenizerRegistry::generate_id();
406
407 let outcome1 = registry
409 .load(&id1, "model1", "source1", || async {
410 Ok(Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>)
411 })
412 .await
413 .unwrap();
414 assert!(outcome1.is_newly_loaded());
415 assert_eq!(outcome1.id(), id1);
416
417 let outcome2 = registry
419 .load(&id2, "model1", "source2", || async {
420 panic!("Loader should not be called for duplicate name");
421 })
422 .await
423 .unwrap();
424 assert!(!outcome2.is_newly_loaded());
425 assert_eq!(outcome2.id(), id1); assert_eq!(registry.len(), 1);
429
430 let entry = registry.get_by_name("model1").unwrap();
432 assert_eq!(entry.source, "source1");
433 }
434
435 #[tokio::test]
436 async fn test_load_validation() {
437 let registry = TokenizerRegistry::new();
438 let id = TokenizerRegistry::generate_id();
439
440 let result = registry
442 .load(&id, "", "source", || async {
443 panic!("Loader should not be called for invalid input");
444 })
445 .await;
446 assert!(matches!(result, Err(LoadError::EmptyName)));
447
448 let result = registry
450 .load(&id, "model", "", || async {
451 panic!("Loader should not be called for invalid input");
452 })
453 .await;
454 assert!(matches!(result, Err(LoadError::EmptySource)));
455
456 assert!(registry.is_empty());
458 }
459
460 #[tokio::test]
461 async fn test_load_prevents_duplicate_loading() {
462 let registry = Arc::new(TokenizerRegistry::new());
463 let load_count = Arc::new(AtomicUsize::new(0));
464
465 let mut handles = vec![];
467 for i in 0..10 {
468 let registry = registry.clone();
469 let load_count = load_count.clone();
470 let id = format!("id-{i}");
471 let handle = tokio::spawn(async move {
472 registry
473 .load(&id, "model1", "source", || async {
474 sleep(Duration::from_millis(10)).await;
476 load_count.fetch_add(1, Ordering::SeqCst);
477 Ok(Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>)
478 })
479 .await
480 });
481 handles.push(handle);
482 }
483
484 for handle in handles {
486 handle.await.unwrap().unwrap();
487 }
488
489 assert_eq!(
491 load_count.load(Ordering::SeqCst),
492 1,
493 "Tokenizer should be loaded exactly once despite concurrent requests"
494 );
495 assert_eq!(registry.len(), 1);
496 }
497
498 #[tokio::test]
499 async fn test_multiple_models() {
500 let registry = TokenizerRegistry::new();
501
502 for i in 1..=5 {
504 let model_name = format!("model{i}");
505 let id = TokenizerRegistry::generate_id();
506 registry
507 .load(&id, &model_name, "source", || async {
508 Ok(Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>)
509 })
510 .await
511 .unwrap();
512 }
513
514 assert_eq!(registry.len(), 5);
515 assert!(registry.contains("model1"));
516 assert!(registry.contains("model5"));
517 assert!(!registry.contains("model6"));
518
519 let entries = registry.list();
521 assert_eq!(entries.len(), 5);
522 assert!(entries.iter().any(|e| e.name == "model1"));
523
524 registry.clear();
526 assert!(registry.is_empty());
527 }
528
529 #[tokio::test]
530 async fn test_load_failure() {
531 let registry = TokenizerRegistry::new();
532 let id = TokenizerRegistry::generate_id();
533
534 let result = registry
536 .load(&id, "failing_model", "source", || async {
537 Err("Load failed".to_string())
538 })
539 .await;
540
541 assert!(result.is_err());
542 assert!(!registry.contains("failing_model"));
543 assert!(registry.is_empty());
544 }
545
546 #[tokio::test]
547 async fn test_get_by_name_and_id() {
548 let registry = TokenizerRegistry::new();
549 let id = TokenizerRegistry::generate_id();
550
551 registry
552 .load(&id, "my-model", "hf/model", || async {
553 Ok(Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>)
554 })
555 .await
556 .unwrap();
557
558 let by_name = registry.get_by_name("my-model");
560 assert!(by_name.is_some());
561 assert_eq!(by_name.as_ref().unwrap().id, id);
562
563 let by_id = registry.get_by_id(&id);
565 assert!(by_id.is_some());
566 assert_eq!(by_id.as_ref().unwrap().name, "my-model");
567
568 assert!(registry.get("my-model").is_some());
570 assert!(registry.get(&id).is_some());
571 }
572
573 #[tokio::test]
574 async fn test_register_only_if_absent() {
575 let registry = TokenizerRegistry::new();
576 let id1 = TokenizerRegistry::generate_id();
577 let id2 = TokenizerRegistry::generate_id();
578 let tokenizer1 = Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>;
579 let tokenizer2 = Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>;
580
581 let result1 = registry.register(&id1, "model1", "source1", tokenizer1.clone());
583 assert!(result1.is_some());
584 assert_eq!(registry.len(), 1);
585
586 let result2 = registry.register(&id2, "model1", "source2", tokenizer2.clone());
588 assert!(result2.is_none());
589 assert_eq!(registry.len(), 1);
590
591 let entry = registry.get_by_name("model1").unwrap();
593 assert_eq!(entry.id, id1);
594 assert_eq!(entry.source, "source1");
595
596 let id3 = TokenizerRegistry::generate_id();
598 let result3 = registry.register(&id3, "model2", "source2", tokenizer2);
599 assert!(result3.is_some());
600 assert_eq!(registry.len(), 2);
601 }
602
603 #[tokio::test]
604 async fn test_loading_lock_cleanup_on_panic() {
605 let registry = Arc::new(TokenizerRegistry::new());
606
607 let registry_clone = registry.clone();
609 let handle = tokio::spawn(async move {
610 registry_clone
611 .load(
612 &TokenizerRegistry::generate_id(),
613 "panic-model",
614 "source",
615 || async {
616 panic!("Simulated panic during tokenizer loading");
617 },
618 )
619 .await
620 });
621
622 let result = handle.await;
624 assert!(result.is_err(), "Task should have panicked");
625
626 let id = TokenizerRegistry::generate_id();
630 let outcome = registry
631 .load(&id, "panic-model", "source", || async {
632 Ok(Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>)
633 })
634 .await;
635
636 assert!(outcome.is_ok(), "Load should succeed after panic cleanup");
638 assert!(outcome.unwrap().is_newly_loaded());
639 assert_eq!(registry.len(), 1);
640 assert!(registry.contains("panic-model"));
641 }
642
643 #[tokio::test]
644 async fn test_loading_lock_cleanup_on_early_return() {
645 let registry = Arc::new(TokenizerRegistry::new());
646
647 let id1 = TokenizerRegistry::generate_id();
649 registry
650 .load(&id1, "model1", "source1", || async {
651 Ok(Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>)
652 })
653 .await
654 .unwrap();
655
656 let id2 = TokenizerRegistry::generate_id();
663 let outcome = registry
664 .load(&id2, "model2", "source2", || async {
665 Ok(Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>)
666 })
667 .await
668 .unwrap();
669
670 assert!(outcome.is_newly_loaded());
671 assert_eq!(registry.len(), 2);
672
673 let id3 = TokenizerRegistry::generate_id();
676 let outcome = registry
677 .load(&id3, "model1", "source1", || async {
678 panic!("Loader should not be called for existing model");
679 })
680 .await
681 .unwrap();
682
683 assert!(!outcome.is_newly_loaded());
684 assert_eq!(outcome.id(), id1); }
686}