Skip to main content

llm_tokenizer/
registry.rs

1//! Tokenizer Registry for dynamic tokenizer loading
2//!
3//! Provides thread-safe, deduplicated tokenizer loading for IGW mode where
4//! multiple routers (HTTP and gRPC) need to share tokenizers across workers.
5//!
6//! ## ID vs Name Lookup
7//!
8//! Tokenizers are stored with two keys:
9//! - **ID (UUID)**: Unique identifier generated at registration, immutable
10//! - **Name**: User-provided identifier, must be unique
11//!
12//! Lookup behavior:
13//! - `get(key)`: Tries name first, then ID (backward compatible)
14//! - `get_by_id(id)`: Exact ID match only
15//! - `get_by_name(name)`: Exact name match only
16//! - `remove(name)`: Removes by name
17//! - `remove_by_id(id)`: Removes by ID
18
19use std::sync::Arc;
20
21use dashmap::DashMap;
22use thiserror::Error;
23use tokio::sync::Mutex;
24use tracing::{debug, info};
25use uuid::Uuid;
26
27use crate::traits::Tokenizer;
28
29/// Outcome of a tokenizer load operation
30#[derive(Debug, Clone)]
31pub enum LoadOutcome {
32    /// Tokenizer was newly loaded and registered
33    Loaded { id: String },
34    /// Tokenizer already existed, returning existing ID
35    AlreadyExists { id: String },
36}
37
38impl LoadOutcome {
39    /// Get the ID regardless of outcome
40    pub fn id(&self) -> &str {
41        match self {
42            LoadOutcome::Loaded { id } => id,
43            LoadOutcome::AlreadyExists { id } => id,
44        }
45    }
46
47    /// Returns true if the tokenizer was newly loaded
48    pub fn is_newly_loaded(&self) -> bool {
49        matches!(self, LoadOutcome::Loaded { .. })
50    }
51}
52
53/// Error type for tokenizer loading operations
54#[derive(Debug, Error)]
55pub enum LoadError {
56    /// Name cannot be empty
57    #[error("tokenizer name cannot be empty")]
58    EmptyName,
59    /// Source cannot be empty
60    #[error("tokenizer source cannot be empty")]
61    EmptySource,
62    /// Loading failed
63    #[error("{0}")]
64    LoadFailed(String),
65}
66
67/// Metadata and tokenizer instance for a registered tokenizer
68#[derive(Clone)]
69pub struct TokenizerEntry {
70    /// Unique identifier (UUID)
71    pub id: String,
72    /// User-provided name
73    pub name: String,
74    /// Source path or HuggingFace model ID
75    pub source: String,
76    /// The tokenizer instance
77    pub tokenizer: Arc<dyn Tokenizer>,
78}
79
80/// Registry for managing tokenizers keyed by UUID
81///
82/// Features:
83/// - Thread-safe concurrent access using DashMap
84/// - Per-key locking to prevent duplicate loading
85/// - Lookup by UUID (primary) or name (secondary index)
86pub struct TokenizerRegistry {
87    /// Storage for loaded tokenizers, keyed by UUID
88    tokenizers: DashMap<String, TokenizerEntry>,
89    /// Secondary index: name -> UUID for lookup
90    name_to_id: DashMap<String, String>,
91    /// Per-key locks to prevent duplicate loading
92    loading_locks: DashMap<String, Arc<Mutex<()>>>,
93}
94
95/// RAII guard that removes the loading lock entry on drop.
96/// Ensures cleanup happens on normal completion, early return, or panic.
97struct LoadingLockGuard<'a> {
98    locks: &'a DashMap<String, Arc<Mutex<()>>>,
99    key: String,
100}
101
102impl Drop for LoadingLockGuard<'_> {
103    fn drop(&mut self) {
104        self.locks.remove(&self.key);
105    }
106}
107
108impl TokenizerRegistry {
109    /// Create a new empty registry
110    pub fn new() -> Self {
111        Self {
112            tokenizers: DashMap::new(),
113            name_to_id: DashMap::new(),
114            loading_locks: DashMap::new(),
115        }
116    }
117
118    /// Generate a new UUID for a tokenizer
119    pub fn generate_id() -> String {
120        Uuid::now_v7().to_string()
121    }
122
123    /// Load and register a tokenizer
124    ///
125    /// Validates inputs, handles deduplication, and loads the tokenizer if needed.
126    /// Per-key locking ensures only one load happens per name, preventing race conditions.
127    ///
128    /// # Arguments
129    /// * `id` - Pre-generated UUID (use `generate_id()` to create one)
130    /// * `name` - User-provided name (used for deduplication, must not be empty)
131    /// * `source` - Source path or HuggingFace model ID (must not be empty)
132    /// * `loader` - Async function that loads the tokenizer
133    ///
134    /// # Returns
135    /// * `Ok(LoadOutcome::Loaded { id })` - Tokenizer was newly loaded
136    /// * `Ok(LoadOutcome::AlreadyExists { id })` - Tokenizer already existed
137    /// * `Err(LoadError)` - Validation failed or loading failed
138    pub async fn load<F, Fut>(
139        &self,
140        id: &str,
141        name: &str,
142        source: &str,
143        loader: F,
144    ) -> Result<LoadOutcome, LoadError>
145    where
146        F: FnOnce() -> Fut,
147        Fut: std::future::Future<Output = Result<Arc<dyn Tokenizer>, String>>,
148    {
149        // Validate inputs
150        if name.is_empty() {
151            return Err(LoadError::EmptyName);
152        }
153        if source.is_empty() {
154            return Err(LoadError::EmptySource);
155        }
156
157        // Fast path: already loaded by name
158        if let Some(existing_id) = self.name_to_id.get(name) {
159            debug!("Tokenizer already registered for name: {}", name);
160            return Ok(LoadOutcome::AlreadyExists {
161                id: existing_id.clone(),
162            });
163        }
164
165        debug!("Tokenizer cache miss for name: {}", name);
166
167        // Acquire per-name lock to prevent duplicate loading
168        let lock = self
169            .loading_locks
170            .entry(name.to_string())
171            .or_insert_with(|| Arc::new(Mutex::new(())))
172            .clone();
173
174        let _mutex_guard = lock.lock().await;
175        let _lock_cleanup = LoadingLockGuard {
176            locks: &self.loading_locks,
177            key: name.to_string(),
178        };
179
180        // Double-check after acquiring lock (another thread may have loaded it)
181        if let Some(existing_id) = self.name_to_id.get(name) {
182            debug!("Tokenizer loaded by another thread for name: {}", name);
183            return Ok(LoadOutcome::AlreadyExists {
184                id: existing_id.clone(),
185            });
186        }
187
188        // Load tokenizer
189        info!("Loading tokenizer '{}' from source: {}", name, source);
190        let result = loader().await;
191
192        let tokenizer = result.map_err(LoadError::LoadFailed)?;
193
194        // Create entry with provided ID
195        let entry = TokenizerEntry {
196            id: id.to_string(),
197            name: name.to_string(),
198            source: source.to_string(),
199            tokenizer,
200        };
201
202        // Store in registry
203        self.tokenizers.insert(id.to_string(), entry);
204        self.name_to_id.insert(name.to_string(), id.to_string());
205
206        info!(
207            "Successfully registered tokenizer '{}' with id: {}",
208            name, id
209        );
210
211        Ok(LoadOutcome::Loaded { id: id.to_string() })
212    }
213
214    /// Register a preloaded tokenizer with a pre-generated ID
215    ///
216    /// Atomically inserts a tokenizer into the registry only if no tokenizer
217    /// with the same name exists. Returns the ID if successful.
218    ///
219    /// This is primarily used for testing. Production code should use `load()`.
220    ///
221    /// # Returns
222    /// * `Some(id)` - If the tokenizer was successfully registered
223    /// * `None` - If a tokenizer with this name already existed
224    #[cfg(test)]
225    pub fn register(
226        &self,
227        id: &str,
228        name: &str,
229        source: &str,
230        tokenizer: Arc<dyn Tokenizer>,
231    ) -> Option<String> {
232        use dashmap::mapref::entry::Entry;
233
234        // Check if name already exists
235        match self.name_to_id.entry(name.to_string()) {
236            Entry::Occupied(_) => {
237                debug!(
238                    "Tokenizer already exists for name: {}, skipping registration",
239                    name
240                );
241                None
242            }
243            Entry::Vacant(name_entry) => {
244                let entry = TokenizerEntry {
245                    id: id.to_string(),
246                    name: name.to_string(),
247                    source: source.to_string(),
248                    tokenizer,
249                };
250
251                info!("Registering tokenizer '{}' with id: {}", name, id);
252                self.tokenizers.insert(id.to_string(), entry);
253                name_entry.insert(id.to_string());
254                Some(id.to_string())
255            }
256        }
257    }
258
259    /// Get a tokenizer by UUID
260    pub fn get_by_id(&self, id: &str) -> Option<TokenizerEntry> {
261        self.tokenizers.get(id).map(|e| e.clone())
262    }
263
264    /// Get a tokenizer by name
265    pub fn get_by_name(&self, name: &str) -> Option<TokenizerEntry> {
266        self.name_to_id
267            .get(name)
268            .and_then(|id| self.tokenizers.get(id.as_str()).map(|e| e.clone()))
269    }
270
271    /// Get a tokenizer (for backward compatibility, tries name first then ID)
272    pub fn get(&self, name_or_id: &str) -> Option<Arc<dyn Tokenizer>> {
273        self.get_by_name(name_or_id)
274            .or_else(|| self.get_by_id(name_or_id))
275            .map(|e| e.tokenizer)
276    }
277
278    /// Check if a tokenizer is registered by name
279    pub fn contains(&self, name: &str) -> bool {
280        self.name_to_id.contains_key(name)
281    }
282
283    /// Check if a tokenizer is registered by ID
284    pub fn contains_id(&self, id: &str) -> bool {
285        self.tokenizers.contains_key(id)
286    }
287
288    /// Get the number of loaded tokenizers
289    pub fn len(&self) -> usize {
290        self.tokenizers.len()
291    }
292
293    /// Check if the registry is empty
294    pub fn is_empty(&self) -> bool {
295        self.tokenizers.is_empty()
296    }
297
298    /// List all registered tokenizers
299    pub fn list(&self) -> Vec<TokenizerEntry> {
300        let mut entries: Vec<TokenizerEntry> =
301            self.tokenizers.iter().map(|e| e.value().clone()).collect();
302        entries.sort_by(|a, b| a.name.cmp(&b.name));
303        entries
304    }
305
306    /// Remove a tokenizer by ID
307    ///
308    /// Returns the entry if it was present.
309    pub fn remove_by_id(&self, id: &str) -> Option<TokenizerEntry> {
310        if let Some((_, entry)) = self.tokenizers.remove(id) {
311            self.name_to_id.remove(&entry.name);
312            Some(entry)
313        } else {
314            None
315        }
316    }
317
318    /// Remove a tokenizer by name
319    ///
320    /// Returns the entry if it was present.
321    pub fn remove(&self, name: &str) -> Option<TokenizerEntry> {
322        if let Some((_, id)) = self.name_to_id.remove(name) {
323            self.tokenizers.remove(&id).map(|(_, e)| e)
324        } else {
325            None
326        }
327    }
328
329    /// Clear all tokenizers from the registry
330    pub fn clear(&self) {
331        self.tokenizers.clear();
332        self.name_to_id.clear();
333        self.loading_locks.clear();
334    }
335}
336
337impl Default for TokenizerRegistry {
338    fn default() -> Self {
339        Self::new()
340    }
341}
342
343#[cfg(test)]
344#[expect(
345    clippy::disallowed_methods,
346    reason = "tokio::spawn is fine in unit tests that await all handles"
347)]
348mod tests {
349    use std::{
350        sync::{
351            atomic::{AtomicUsize, Ordering},
352            Arc,
353        },
354        time::Duration,
355    };
356
357    use tokio::time::sleep;
358
359    use crate::{mock::MockTokenizer, traits::Tokenizer, LoadError, TokenizerRegistry};
360
361    #[tokio::test]
362    async fn test_basic_operations() {
363        let registry = TokenizerRegistry::new();
364
365        // Registry starts empty
366        assert!(registry.is_empty());
367        assert_eq!(registry.len(), 0);
368        assert!(!registry.contains("model1"));
369
370        // Load and register a tokenizer
371        let id = TokenizerRegistry::generate_id();
372        let outcome = registry
373            .load(&id, "model1", "path/to/model", || async {
374                Ok(Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>)
375            })
376            .await
377            .unwrap();
378
379        // Verify LoadOutcome::Loaded
380        assert!(outcome.is_newly_loaded());
381        assert_eq!(outcome.id(), id);
382
383        // Verify it's loaded
384        assert!(!registry.is_empty());
385        assert_eq!(registry.len(), 1);
386        assert!(registry.contains("model1"));
387        assert!(registry.contains_id(&id));
388
389        // Get returns the tokenizer
390        let entry = registry.get_by_name("model1").unwrap();
391        assert_eq!(entry.id, id);
392        assert_eq!(entry.name, "model1");
393        assert_eq!(entry.source, "path/to/model");
394
395        // Remove works
396        let removed = registry.remove_by_id(&id);
397        assert!(removed.is_some());
398        assert!(registry.is_empty());
399    }
400
401    #[tokio::test]
402    async fn test_load_returns_already_exists() {
403        let registry = TokenizerRegistry::new();
404        let id1 = TokenizerRegistry::generate_id();
405        let id2 = TokenizerRegistry::generate_id();
406
407        // First load should return Loaded
408        let outcome1 = registry
409            .load(&id1, "model1", "source1", || async {
410                Ok(Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>)
411            })
412            .await
413            .unwrap();
414        assert!(outcome1.is_newly_loaded());
415        assert_eq!(outcome1.id(), id1);
416
417        // Second load with same name should return AlreadyExists with ORIGINAL id
418        let outcome2 = registry
419            .load(&id2, "model1", "source2", || async {
420                panic!("Loader should not be called for duplicate name");
421            })
422            .await
423            .unwrap();
424        assert!(!outcome2.is_newly_loaded());
425        assert_eq!(outcome2.id(), id1); // Returns the original ID, not id2
426
427        // Registry still has only one entry
428        assert_eq!(registry.len(), 1);
429
430        // Original source is preserved
431        let entry = registry.get_by_name("model1").unwrap();
432        assert_eq!(entry.source, "source1");
433    }
434
435    #[tokio::test]
436    async fn test_load_validation() {
437        let registry = TokenizerRegistry::new();
438        let id = TokenizerRegistry::generate_id();
439
440        // Empty name should fail
441        let result = registry
442            .load(&id, "", "source", || async {
443                panic!("Loader should not be called for invalid input");
444            })
445            .await;
446        assert!(matches!(result, Err(LoadError::EmptyName)));
447
448        // Empty source should fail
449        let result = registry
450            .load(&id, "model", "", || async {
451                panic!("Loader should not be called for invalid input");
452            })
453            .await;
454        assert!(matches!(result, Err(LoadError::EmptySource)));
455
456        // Registry should be empty (nothing was loaded)
457        assert!(registry.is_empty());
458    }
459
460    #[tokio::test]
461    async fn test_load_prevents_duplicate_loading() {
462        let registry = Arc::new(TokenizerRegistry::new());
463        let load_count = Arc::new(AtomicUsize::new(0));
464
465        // Spawn multiple tasks trying to load the same tokenizer
466        let mut handles = vec![];
467        for i in 0..10 {
468            let registry = registry.clone();
469            let load_count = load_count.clone();
470            let id = format!("id-{i}");
471            let handle = tokio::spawn(async move {
472                registry
473                    .load(&id, "model1", "source", || async {
474                        // Simulate slow loading
475                        sleep(Duration::from_millis(10)).await;
476                        load_count.fetch_add(1, Ordering::SeqCst);
477                        Ok(Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>)
478                    })
479                    .await
480            });
481            handles.push(handle);
482        }
483
484        // Wait for all tasks
485        for handle in handles {
486            handle.await.unwrap().unwrap();
487        }
488
489        // Verify tokenizer was loaded only once
490        assert_eq!(
491            load_count.load(Ordering::SeqCst),
492            1,
493            "Tokenizer should be loaded exactly once despite concurrent requests"
494        );
495        assert_eq!(registry.len(), 1);
496    }
497
498    #[tokio::test]
499    async fn test_multiple_models() {
500        let registry = TokenizerRegistry::new();
501
502        // Load multiple tokenizers
503        for i in 1..=5 {
504            let model_name = format!("model{i}");
505            let id = TokenizerRegistry::generate_id();
506            registry
507                .load(&id, &model_name, "source", || async {
508                    Ok(Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>)
509                })
510                .await
511                .unwrap();
512        }
513
514        assert_eq!(registry.len(), 5);
515        assert!(registry.contains("model1"));
516        assert!(registry.contains("model5"));
517        assert!(!registry.contains("model6"));
518
519        // List returns all with metadata
520        let entries = registry.list();
521        assert_eq!(entries.len(), 5);
522        assert!(entries.iter().any(|e| e.name == "model1"));
523
524        // Clear all
525        registry.clear();
526        assert!(registry.is_empty());
527    }
528
529    #[tokio::test]
530    async fn test_load_failure() {
531        let registry = TokenizerRegistry::new();
532        let id = TokenizerRegistry::generate_id();
533
534        // Try to load with a failing loader
535        let result = registry
536            .load(&id, "failing_model", "source", || async {
537                Err("Load failed".to_string())
538            })
539            .await;
540
541        assert!(result.is_err());
542        assert!(!registry.contains("failing_model"));
543        assert!(registry.is_empty());
544    }
545
546    #[tokio::test]
547    async fn test_get_by_name_and_id() {
548        let registry = TokenizerRegistry::new();
549        let id = TokenizerRegistry::generate_id();
550
551        registry
552            .load(&id, "my-model", "hf/model", || async {
553                Ok(Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>)
554            })
555            .await
556            .unwrap();
557
558        // Get by name
559        let by_name = registry.get_by_name("my-model");
560        assert!(by_name.is_some());
561        assert_eq!(by_name.as_ref().unwrap().id, id);
562
563        // Get by ID
564        let by_id = registry.get_by_id(&id);
565        assert!(by_id.is_some());
566        assert_eq!(by_id.as_ref().unwrap().name, "my-model");
567
568        // Generic get works with both
569        assert!(registry.get("my-model").is_some());
570        assert!(registry.get(&id).is_some());
571    }
572
573    #[tokio::test]
574    async fn test_register_only_if_absent() {
575        let registry = TokenizerRegistry::new();
576        let id1 = TokenizerRegistry::generate_id();
577        let id2 = TokenizerRegistry::generate_id();
578        let tokenizer1 = Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>;
579        let tokenizer2 = Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>;
580
581        // First registration should succeed
582        let result1 = registry.register(&id1, "model1", "source1", tokenizer1.clone());
583        assert!(result1.is_some());
584        assert_eq!(registry.len(), 1);
585
586        // Second registration with same name should fail
587        let result2 = registry.register(&id2, "model1", "source2", tokenizer2.clone());
588        assert!(result2.is_none());
589        assert_eq!(registry.len(), 1);
590
591        // Original tokenizer should still be there
592        let entry = registry.get_by_name("model1").unwrap();
593        assert_eq!(entry.id, id1);
594        assert_eq!(entry.source, "source1");
595
596        // Registration with different name should succeed
597        let id3 = TokenizerRegistry::generate_id();
598        let result3 = registry.register(&id3, "model2", "source2", tokenizer2);
599        assert!(result3.is_some());
600        assert_eq!(registry.len(), 2);
601    }
602
603    #[tokio::test]
604    async fn test_loading_lock_cleanup_on_panic() {
605        let registry = Arc::new(TokenizerRegistry::new());
606
607        // Spawn a task that will panic during loading
608        let registry_clone = registry.clone();
609        let handle = tokio::spawn(async move {
610            registry_clone
611                .load(
612                    &TokenizerRegistry::generate_id(),
613                    "panic-model",
614                    "source",
615                    || async {
616                        panic!("Simulated panic during tokenizer loading");
617                    },
618                )
619                .await
620        });
621
622        // Wait for the task - it should panic
623        let result = handle.await;
624        assert!(result.is_err(), "Task should have panicked");
625
626        // The RAII guard should have cleaned up the loading lock.
627        // Verify by attempting another load with the same name - it should work,
628        // not deadlock or fail due to stale lock.
629        let id = TokenizerRegistry::generate_id();
630        let outcome = registry
631            .load(&id, "panic-model", "source", || async {
632                Ok(Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>)
633            })
634            .await;
635
636        // Should succeed - the lock was properly cleaned up
637        assert!(outcome.is_ok(), "Load should succeed after panic cleanup");
638        assert!(outcome.unwrap().is_newly_loaded());
639        assert_eq!(registry.len(), 1);
640        assert!(registry.contains("panic-model"));
641    }
642
643    #[tokio::test]
644    async fn test_loading_lock_cleanup_on_early_return() {
645        let registry = Arc::new(TokenizerRegistry::new());
646
647        // Load a tokenizer
648        let id1 = TokenizerRegistry::generate_id();
649        registry
650            .load(&id1, "model1", "source1", || async {
651                Ok(Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>)
652            })
653            .await
654            .unwrap();
655
656        // Now simulate concurrent load attempts where one thread wins
657        // and another thread sees "already exists" in the double-check.
658        // The RAII guard should clean up the lock on early return.
659
660        // First, verify loading_locks is empty after successful load
661        // by checking that we can load a different model without issues
662        let id2 = TokenizerRegistry::generate_id();
663        let outcome = registry
664            .load(&id2, "model2", "source2", || async {
665                Ok(Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>)
666            })
667            .await
668            .unwrap();
669
670        assert!(outcome.is_newly_loaded());
671        assert_eq!(registry.len(), 2);
672
673        // Try to load model1 again - should return AlreadyExists
674        // and the lock should be cleaned up (not leak)
675        let id3 = TokenizerRegistry::generate_id();
676        let outcome = registry
677            .load(&id3, "model1", "source1", || async {
678                panic!("Loader should not be called for existing model");
679            })
680            .await
681            .unwrap();
682
683        assert!(!outcome.is_newly_loaded());
684        assert_eq!(outcome.id(), id1); // Returns original ID
685    }
686}