1use std::sync::Arc;
20
21use dashmap::DashMap;
22use thiserror::Error;
23use tokio::sync::Mutex;
24use tracing::{debug, info};
25use uuid::Uuid;
26
27use crate::traits::Tokenizer;
28
29#[derive(Debug, Clone)]
31pub enum LoadOutcome {
32 Loaded { id: String },
34 AlreadyExists { id: String },
36}
37
38impl LoadOutcome {
39 pub fn id(&self) -> &str {
41 match self {
42 LoadOutcome::Loaded { id } => id,
43 LoadOutcome::AlreadyExists { id } => id,
44 }
45 }
46
47 pub fn is_newly_loaded(&self) -> bool {
49 matches!(self, LoadOutcome::Loaded { .. })
50 }
51}
52
53#[derive(Debug, Error)]
55pub enum LoadError {
56 #[error("tokenizer name cannot be empty")]
58 EmptyName,
59 #[error("tokenizer source cannot be empty")]
61 EmptySource,
62 #[error("{0}")]
64 LoadFailed(String),
65}
66
67#[derive(Clone)]
69pub struct TokenizerEntry {
70 pub id: String,
72 pub name: String,
74 pub source: String,
76 pub tokenizer: Arc<dyn Tokenizer>,
78}
79
80pub struct TokenizerRegistry {
87 tokenizers: DashMap<String, TokenizerEntry>,
89 name_to_id: DashMap<String, String>,
91 loading_locks: DashMap<String, Arc<Mutex<()>>>,
93}
94
95struct LoadingLockGuard<'a> {
98 locks: &'a DashMap<String, Arc<Mutex<()>>>,
99 key: String,
100}
101
102impl Drop for LoadingLockGuard<'_> {
103 fn drop(&mut self) {
104 self.locks.remove(&self.key);
105 }
106}
107
108impl TokenizerRegistry {
109 pub fn new() -> Self {
111 Self {
112 tokenizers: DashMap::new(),
113 name_to_id: DashMap::new(),
114 loading_locks: DashMap::new(),
115 }
116 }
117
118 pub fn generate_id() -> String {
120 Uuid::new_v4().to_string()
121 }
122
123 pub async fn load<F, Fut>(
139 &self,
140 id: &str,
141 name: &str,
142 source: &str,
143 loader: F,
144 ) -> Result<LoadOutcome, LoadError>
145 where
146 F: FnOnce() -> Fut,
147 Fut: std::future::Future<Output = Result<Arc<dyn Tokenizer>, String>>,
148 {
149 if name.is_empty() {
151 return Err(LoadError::EmptyName);
152 }
153 if source.is_empty() {
154 return Err(LoadError::EmptySource);
155 }
156
157 if let Some(existing_id) = self.name_to_id.get(name) {
159 debug!("Tokenizer already registered for name: {}", name);
160 return Ok(LoadOutcome::AlreadyExists {
161 id: existing_id.clone(),
162 });
163 }
164
165 debug!("Tokenizer cache miss for name: {}", name);
166
167 let lock = self
169 .loading_locks
170 .entry(name.to_string())
171 .or_insert_with(|| Arc::new(Mutex::new(())))
172 .clone();
173
174 let _mutex_guard = lock.lock().await;
175 let _lock_cleanup = LoadingLockGuard {
176 locks: &self.loading_locks,
177 key: name.to_string(),
178 };
179
180 if let Some(existing_id) = self.name_to_id.get(name) {
182 debug!("Tokenizer loaded by another thread for name: {}", name);
183 return Ok(LoadOutcome::AlreadyExists {
184 id: existing_id.clone(),
185 });
186 }
187
188 info!("Loading tokenizer '{}' from source: {}", name, source);
190 let result = loader().await;
191
192 let tokenizer = result.map_err(LoadError::LoadFailed)?;
193
194 let entry = TokenizerEntry {
196 id: id.to_string(),
197 name: name.to_string(),
198 source: source.to_string(),
199 tokenizer,
200 };
201
202 self.tokenizers.insert(id.to_string(), entry);
204 self.name_to_id.insert(name.to_string(), id.to_string());
205
206 info!(
207 "Successfully registered tokenizer '{}' with id: {}",
208 name, id
209 );
210
211 Ok(LoadOutcome::Loaded { id: id.to_string() })
212 }
213
214 #[cfg(test)]
225 pub fn register(
226 &self,
227 id: &str,
228 name: &str,
229 source: &str,
230 tokenizer: Arc<dyn Tokenizer>,
231 ) -> Option<String> {
232 use dashmap::mapref::entry::Entry;
233
234 match self.name_to_id.entry(name.to_string()) {
236 Entry::Occupied(_) => {
237 debug!(
238 "Tokenizer already exists for name: {}, skipping registration",
239 name
240 );
241 None
242 }
243 Entry::Vacant(name_entry) => {
244 let entry = TokenizerEntry {
245 id: id.to_string(),
246 name: name.to_string(),
247 source: source.to_string(),
248 tokenizer,
249 };
250
251 info!("Registering tokenizer '{}' with id: {}", name, id);
252 self.tokenizers.insert(id.to_string(), entry);
253 name_entry.insert(id.to_string());
254 Some(id.to_string())
255 }
256 }
257 }
258
259 pub fn get_by_id(&self, id: &str) -> Option<TokenizerEntry> {
261 self.tokenizers.get(id).map(|e| e.clone())
262 }
263
264 pub fn get_by_name(&self, name: &str) -> Option<TokenizerEntry> {
266 self.name_to_id
267 .get(name)
268 .and_then(|id| self.tokenizers.get(id.as_str()).map(|e| e.clone()))
269 }
270
271 pub fn get(&self, name_or_id: &str) -> Option<Arc<dyn Tokenizer>> {
273 self.get_by_name(name_or_id)
274 .or_else(|| self.get_by_id(name_or_id))
275 .map(|e| e.tokenizer)
276 }
277
278 pub fn contains(&self, name: &str) -> bool {
280 self.name_to_id.contains_key(name)
281 }
282
283 pub fn contains_id(&self, id: &str) -> bool {
285 self.tokenizers.contains_key(id)
286 }
287
288 pub fn len(&self) -> usize {
290 self.tokenizers.len()
291 }
292
293 pub fn is_empty(&self) -> bool {
295 self.tokenizers.is_empty()
296 }
297
298 pub fn list(&self) -> Vec<TokenizerEntry> {
300 let mut entries: Vec<TokenizerEntry> =
301 self.tokenizers.iter().map(|e| e.value().clone()).collect();
302 entries.sort_by(|a, b| a.name.cmp(&b.name));
303 entries
304 }
305
306 pub fn remove_by_id(&self, id: &str) -> Option<TokenizerEntry> {
310 if let Some((_, entry)) = self.tokenizers.remove(id) {
311 self.name_to_id.remove(&entry.name);
312 Some(entry)
313 } else {
314 None
315 }
316 }
317
318 pub fn remove(&self, name: &str) -> Option<TokenizerEntry> {
322 if let Some((_, id)) = self.name_to_id.remove(name) {
323 self.tokenizers.remove(&id).map(|(_, e)| e)
324 } else {
325 None
326 }
327 }
328
329 pub fn clear(&self) {
331 self.tokenizers.clear();
332 self.name_to_id.clear();
333 self.loading_locks.clear();
334 }
335}
336
337impl Default for TokenizerRegistry {
338 fn default() -> Self {
339 Self::new()
340 }
341}
342
343#[cfg(test)]
344mod tests {
345 use std::{sync::Arc, time::Duration};
346
347 use tokio::time::sleep;
348
349 use crate::{mock::MockTokenizer, traits::Tokenizer, LoadError, TokenizerRegistry};
350
351 #[tokio::test]
352 async fn test_basic_operations() {
353 let registry = TokenizerRegistry::new();
354
355 assert!(registry.is_empty());
357 assert_eq!(registry.len(), 0);
358 assert!(!registry.contains("model1"));
359
360 let id = TokenizerRegistry::generate_id();
362 let outcome = registry
363 .load(&id, "model1", "path/to/model", || async {
364 Ok(Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>)
365 })
366 .await
367 .unwrap();
368
369 assert!(outcome.is_newly_loaded());
371 assert_eq!(outcome.id(), id);
372
373 assert!(!registry.is_empty());
375 assert_eq!(registry.len(), 1);
376 assert!(registry.contains("model1"));
377 assert!(registry.contains_id(&id));
378
379 let entry = registry.get_by_name("model1").unwrap();
381 assert_eq!(entry.id, id);
382 assert_eq!(entry.name, "model1");
383 assert_eq!(entry.source, "path/to/model");
384
385 let removed = registry.remove_by_id(&id);
387 assert!(removed.is_some());
388 assert!(registry.is_empty());
389 }
390
391 #[tokio::test]
392 async fn test_load_returns_already_exists() {
393 let registry = TokenizerRegistry::new();
394 let id1 = TokenizerRegistry::generate_id();
395 let id2 = TokenizerRegistry::generate_id();
396
397 let outcome1 = registry
399 .load(&id1, "model1", "source1", || async {
400 Ok(Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>)
401 })
402 .await
403 .unwrap();
404 assert!(outcome1.is_newly_loaded());
405 assert_eq!(outcome1.id(), id1);
406
407 let outcome2 = registry
409 .load(&id2, "model1", "source2", || async {
410 panic!("Loader should not be called for duplicate name");
411 })
412 .await
413 .unwrap();
414 assert!(!outcome2.is_newly_loaded());
415 assert_eq!(outcome2.id(), id1); assert_eq!(registry.len(), 1);
419
420 let entry = registry.get_by_name("model1").unwrap();
422 assert_eq!(entry.source, "source1");
423 }
424
425 #[tokio::test]
426 async fn test_load_validation() {
427 let registry = TokenizerRegistry::new();
428 let id = TokenizerRegistry::generate_id();
429
430 let result = registry
432 .load(&id, "", "source", || async {
433 panic!("Loader should not be called for invalid input");
434 })
435 .await;
436 assert!(matches!(result, Err(LoadError::EmptyName)));
437
438 let result = registry
440 .load(&id, "model", "", || async {
441 panic!("Loader should not be called for invalid input");
442 })
443 .await;
444 assert!(matches!(result, Err(LoadError::EmptySource)));
445
446 assert!(registry.is_empty());
448 }
449
450 #[tokio::test]
451 async fn test_load_prevents_duplicate_loading() {
452 let registry = Arc::new(TokenizerRegistry::new());
453 let load_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
454
455 let mut handles = vec![];
457 for i in 0..10 {
458 let registry = registry.clone();
459 let load_count = load_count.clone();
460 let id = format!("id-{}", i);
461 let handle = tokio::spawn(async move {
462 registry
463 .load(&id, "model1", "source", || async {
464 sleep(Duration::from_millis(10)).await;
466 load_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
467 Ok(Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>)
468 })
469 .await
470 });
471 handles.push(handle);
472 }
473
474 for handle in handles {
476 handle.await.unwrap().unwrap();
477 }
478
479 assert_eq!(
481 load_count.load(std::sync::atomic::Ordering::SeqCst),
482 1,
483 "Tokenizer should be loaded exactly once despite concurrent requests"
484 );
485 assert_eq!(registry.len(), 1);
486 }
487
488 #[tokio::test]
489 async fn test_multiple_models() {
490 let registry = TokenizerRegistry::new();
491
492 for i in 1..=5 {
494 let model_name = format!("model{}", i);
495 let id = TokenizerRegistry::generate_id();
496 registry
497 .load(&id, &model_name, "source", || async {
498 Ok(Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>)
499 })
500 .await
501 .unwrap();
502 }
503
504 assert_eq!(registry.len(), 5);
505 assert!(registry.contains("model1"));
506 assert!(registry.contains("model5"));
507 assert!(!registry.contains("model6"));
508
509 let entries = registry.list();
511 assert_eq!(entries.len(), 5);
512 assert!(entries.iter().any(|e| e.name == "model1"));
513
514 registry.clear();
516 assert!(registry.is_empty());
517 }
518
519 #[tokio::test]
520 async fn test_load_failure() {
521 let registry = TokenizerRegistry::new();
522 let id = TokenizerRegistry::generate_id();
523
524 let result = registry
526 .load(&id, "failing_model", "source", || async {
527 Err("Load failed".to_string())
528 })
529 .await;
530
531 assert!(result.is_err());
532 assert!(!registry.contains("failing_model"));
533 assert!(registry.is_empty());
534 }
535
536 #[tokio::test]
537 async fn test_get_by_name_and_id() {
538 let registry = TokenizerRegistry::new();
539 let id = TokenizerRegistry::generate_id();
540
541 registry
542 .load(&id, "my-model", "hf/model", || async {
543 Ok(Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>)
544 })
545 .await
546 .unwrap();
547
548 let by_name = registry.get_by_name("my-model");
550 assert!(by_name.is_some());
551 assert_eq!(by_name.as_ref().unwrap().id, id);
552
553 let by_id = registry.get_by_id(&id);
555 assert!(by_id.is_some());
556 assert_eq!(by_id.as_ref().unwrap().name, "my-model");
557
558 assert!(registry.get("my-model").is_some());
560 assert!(registry.get(&id).is_some());
561 }
562
563 #[tokio::test]
564 async fn test_register_only_if_absent() {
565 let registry = TokenizerRegistry::new();
566 let id1 = TokenizerRegistry::generate_id();
567 let id2 = TokenizerRegistry::generate_id();
568 let tokenizer1 = Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>;
569 let tokenizer2 = Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>;
570
571 let result1 = registry.register(&id1, "model1", "source1", tokenizer1.clone());
573 assert!(result1.is_some());
574 assert_eq!(registry.len(), 1);
575
576 let result2 = registry.register(&id2, "model1", "source2", tokenizer2.clone());
578 assert!(result2.is_none());
579 assert_eq!(registry.len(), 1);
580
581 let entry = registry.get_by_name("model1").unwrap();
583 assert_eq!(entry.id, id1);
584 assert_eq!(entry.source, "source1");
585
586 let id3 = TokenizerRegistry::generate_id();
588 let result3 = registry.register(&id3, "model2", "source2", tokenizer2);
589 assert!(result3.is_some());
590 assert_eq!(registry.len(), 2);
591 }
592
593 #[tokio::test]
594 async fn test_loading_lock_cleanup_on_panic() {
595 let registry = Arc::new(TokenizerRegistry::new());
596
597 let registry_clone = registry.clone();
599 let handle = tokio::spawn(async move {
600 registry_clone
601 .load(
602 &TokenizerRegistry::generate_id(),
603 "panic-model",
604 "source",
605 || async {
606 panic!("Simulated panic during tokenizer loading");
607 },
608 )
609 .await
610 });
611
612 let result = handle.await;
614 assert!(result.is_err(), "Task should have panicked");
615
616 let id = TokenizerRegistry::generate_id();
620 let outcome = registry
621 .load(&id, "panic-model", "source", || async {
622 Ok(Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>)
623 })
624 .await;
625
626 assert!(outcome.is_ok(), "Load should succeed after panic cleanup");
628 assert!(outcome.unwrap().is_newly_loaded());
629 assert_eq!(registry.len(), 1);
630 assert!(registry.contains("panic-model"));
631 }
632
633 #[tokio::test]
634 async fn test_loading_lock_cleanup_on_early_return() {
635 let registry = Arc::new(TokenizerRegistry::new());
636
637 let id1 = TokenizerRegistry::generate_id();
639 registry
640 .load(&id1, "model1", "source1", || async {
641 Ok(Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>)
642 })
643 .await
644 .unwrap();
645
646 let id2 = TokenizerRegistry::generate_id();
653 let outcome = registry
654 .load(&id2, "model2", "source2", || async {
655 Ok(Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>)
656 })
657 .await
658 .unwrap();
659
660 assert!(outcome.is_newly_loaded());
661 assert_eq!(registry.len(), 2);
662
663 let id3 = TokenizerRegistry::generate_id();
666 let outcome = registry
667 .load(&id3, "model1", "source1", || async {
668 panic!("Loader should not be called for existing model");
669 })
670 .await
671 .unwrap();
672
673 assert!(!outcome.is_newly_loaded());
674 assert_eq!(outcome.id(), id1); }
676}