llm_tokenizer/
registry.rs

1//! Tokenizer Registry for dynamic tokenizer loading
2//!
3//! Provides thread-safe, deduplicated tokenizer loading for IGW mode where
4//! multiple routers (HTTP and gRPC) need to share tokenizers across workers.
5//!
6//! ## ID vs Name Lookup
7//!
8//! Tokenizers are stored with two keys:
9//! - **ID (UUID)**: Unique identifier generated at registration, immutable
10//! - **Name**: User-provided identifier, must be unique
11//!
12//! Lookup behavior:
13//! - `get(key)`: Tries name first, then ID (backward compatible)
14//! - `get_by_id(id)`: Exact ID match only
15//! - `get_by_name(name)`: Exact name match only
16//! - `remove(name)`: Removes by name
17//! - `remove_by_id(id)`: Removes by ID
18
19use std::sync::Arc;
20
21use dashmap::DashMap;
22use thiserror::Error;
23use tokio::sync::Mutex;
24use tracing::{debug, info};
25use uuid::Uuid;
26
27use crate::traits::Tokenizer;
28
29/// Outcome of a tokenizer load operation
30#[derive(Debug, Clone)]
31pub enum LoadOutcome {
32    /// Tokenizer was newly loaded and registered
33    Loaded { id: String },
34    /// Tokenizer already existed, returning existing ID
35    AlreadyExists { id: String },
36}
37
38impl LoadOutcome {
39    /// Get the ID regardless of outcome
40    pub fn id(&self) -> &str {
41        match self {
42            LoadOutcome::Loaded { id } => id,
43            LoadOutcome::AlreadyExists { id } => id,
44        }
45    }
46
47    /// Returns true if the tokenizer was newly loaded
48    pub fn is_newly_loaded(&self) -> bool {
49        matches!(self, LoadOutcome::Loaded { .. })
50    }
51}
52
53/// Error type for tokenizer loading operations
54#[derive(Debug, Error)]
55pub enum LoadError {
56    /// Name cannot be empty
57    #[error("tokenizer name cannot be empty")]
58    EmptyName,
59    /// Source cannot be empty
60    #[error("tokenizer source cannot be empty")]
61    EmptySource,
62    /// Loading failed
63    #[error("{0}")]
64    LoadFailed(String),
65}
66
67/// Metadata and tokenizer instance for a registered tokenizer
68#[derive(Clone)]
69pub struct TokenizerEntry {
70    /// Unique identifier (UUID)
71    pub id: String,
72    /// User-provided name
73    pub name: String,
74    /// Source path or HuggingFace model ID
75    pub source: String,
76    /// The tokenizer instance
77    pub tokenizer: Arc<dyn Tokenizer>,
78}
79
80/// Registry for managing tokenizers keyed by UUID
81///
82/// Features:
83/// - Thread-safe concurrent access using DashMap
84/// - Per-key locking to prevent duplicate loading
85/// - Lookup by UUID (primary) or name (secondary index)
86pub struct TokenizerRegistry {
87    /// Storage for loaded tokenizers, keyed by UUID
88    tokenizers: DashMap<String, TokenizerEntry>,
89    /// Secondary index: name -> UUID for lookup
90    name_to_id: DashMap<String, String>,
91    /// Per-key locks to prevent duplicate loading
92    loading_locks: DashMap<String, Arc<Mutex<()>>>,
93}
94
95/// RAII guard that removes the loading lock entry on drop.
96/// Ensures cleanup happens on normal completion, early return, or panic.
97struct LoadingLockGuard<'a> {
98    locks: &'a DashMap<String, Arc<Mutex<()>>>,
99    key: String,
100}
101
102impl Drop for LoadingLockGuard<'_> {
103    fn drop(&mut self) {
104        self.locks.remove(&self.key);
105    }
106}
107
108impl TokenizerRegistry {
109    /// Create a new empty registry
110    pub fn new() -> Self {
111        Self {
112            tokenizers: DashMap::new(),
113            name_to_id: DashMap::new(),
114            loading_locks: DashMap::new(),
115        }
116    }
117
118    /// Generate a new UUID for a tokenizer
119    pub fn generate_id() -> String {
120        Uuid::new_v4().to_string()
121    }
122
123    /// Load and register a tokenizer
124    ///
125    /// Validates inputs, handles deduplication, and loads the tokenizer if needed.
126    /// Per-key locking ensures only one load happens per name, preventing race conditions.
127    ///
128    /// # Arguments
129    /// * `id` - Pre-generated UUID (use `generate_id()` to create one)
130    /// * `name` - User-provided name (used for deduplication, must not be empty)
131    /// * `source` - Source path or HuggingFace model ID (must not be empty)
132    /// * `loader` - Async function that loads the tokenizer
133    ///
134    /// # Returns
135    /// * `Ok(LoadOutcome::Loaded { id })` - Tokenizer was newly loaded
136    /// * `Ok(LoadOutcome::AlreadyExists { id })` - Tokenizer already existed
137    /// * `Err(LoadError)` - Validation failed or loading failed
138    pub async fn load<F, Fut>(
139        &self,
140        id: &str,
141        name: &str,
142        source: &str,
143        loader: F,
144    ) -> Result<LoadOutcome, LoadError>
145    where
146        F: FnOnce() -> Fut,
147        Fut: std::future::Future<Output = Result<Arc<dyn Tokenizer>, String>>,
148    {
149        // Validate inputs
150        if name.is_empty() {
151            return Err(LoadError::EmptyName);
152        }
153        if source.is_empty() {
154            return Err(LoadError::EmptySource);
155        }
156
157        // Fast path: already loaded by name
158        if let Some(existing_id) = self.name_to_id.get(name) {
159            debug!("Tokenizer already registered for name: {}", name);
160            return Ok(LoadOutcome::AlreadyExists {
161                id: existing_id.clone(),
162            });
163        }
164
165        debug!("Tokenizer cache miss for name: {}", name);
166
167        // Acquire per-name lock to prevent duplicate loading
168        let lock = self
169            .loading_locks
170            .entry(name.to_string())
171            .or_insert_with(|| Arc::new(Mutex::new(())))
172            .clone();
173
174        let _mutex_guard = lock.lock().await;
175        let _lock_cleanup = LoadingLockGuard {
176            locks: &self.loading_locks,
177            key: name.to_string(),
178        };
179
180        // Double-check after acquiring lock (another thread may have loaded it)
181        if let Some(existing_id) = self.name_to_id.get(name) {
182            debug!("Tokenizer loaded by another thread for name: {}", name);
183            return Ok(LoadOutcome::AlreadyExists {
184                id: existing_id.clone(),
185            });
186        }
187
188        // Load tokenizer
189        info!("Loading tokenizer '{}' from source: {}", name, source);
190        let result = loader().await;
191
192        let tokenizer = result.map_err(LoadError::LoadFailed)?;
193
194        // Create entry with provided ID
195        let entry = TokenizerEntry {
196            id: id.to_string(),
197            name: name.to_string(),
198            source: source.to_string(),
199            tokenizer,
200        };
201
202        // Store in registry
203        self.tokenizers.insert(id.to_string(), entry);
204        self.name_to_id.insert(name.to_string(), id.to_string());
205
206        info!(
207            "Successfully registered tokenizer '{}' with id: {}",
208            name, id
209        );
210
211        Ok(LoadOutcome::Loaded { id: id.to_string() })
212    }
213
214    /// Register a preloaded tokenizer with a pre-generated ID
215    ///
216    /// Atomically inserts a tokenizer into the registry only if no tokenizer
217    /// with the same name exists. Returns the ID if successful.
218    ///
219    /// This is primarily used for testing. Production code should use `load()`.
220    ///
221    /// # Returns
222    /// * `Some(id)` - If the tokenizer was successfully registered
223    /// * `None` - If a tokenizer with this name already existed
224    #[cfg(test)]
225    pub fn register(
226        &self,
227        id: &str,
228        name: &str,
229        source: &str,
230        tokenizer: Arc<dyn Tokenizer>,
231    ) -> Option<String> {
232        use dashmap::mapref::entry::Entry;
233
234        // Check if name already exists
235        match self.name_to_id.entry(name.to_string()) {
236            Entry::Occupied(_) => {
237                debug!(
238                    "Tokenizer already exists for name: {}, skipping registration",
239                    name
240                );
241                None
242            }
243            Entry::Vacant(name_entry) => {
244                let entry = TokenizerEntry {
245                    id: id.to_string(),
246                    name: name.to_string(),
247                    source: source.to_string(),
248                    tokenizer,
249                };
250
251                info!("Registering tokenizer '{}' with id: {}", name, id);
252                self.tokenizers.insert(id.to_string(), entry);
253                name_entry.insert(id.to_string());
254                Some(id.to_string())
255            }
256        }
257    }
258
259    /// Get a tokenizer by UUID
260    pub fn get_by_id(&self, id: &str) -> Option<TokenizerEntry> {
261        self.tokenizers.get(id).map(|e| e.clone())
262    }
263
264    /// Get a tokenizer by name
265    pub fn get_by_name(&self, name: &str) -> Option<TokenizerEntry> {
266        self.name_to_id
267            .get(name)
268            .and_then(|id| self.tokenizers.get(id.as_str()).map(|e| e.clone()))
269    }
270
271    /// Get a tokenizer (for backward compatibility, tries name first then ID)
272    pub fn get(&self, name_or_id: &str) -> Option<Arc<dyn Tokenizer>> {
273        self.get_by_name(name_or_id)
274            .or_else(|| self.get_by_id(name_or_id))
275            .map(|e| e.tokenizer)
276    }
277
278    /// Check if a tokenizer is registered by name
279    pub fn contains(&self, name: &str) -> bool {
280        self.name_to_id.contains_key(name)
281    }
282
283    /// Check if a tokenizer is registered by ID
284    pub fn contains_id(&self, id: &str) -> bool {
285        self.tokenizers.contains_key(id)
286    }
287
288    /// Get the number of loaded tokenizers
289    pub fn len(&self) -> usize {
290        self.tokenizers.len()
291    }
292
293    /// Check if the registry is empty
294    pub fn is_empty(&self) -> bool {
295        self.tokenizers.is_empty()
296    }
297
298    /// List all registered tokenizers
299    pub fn list(&self) -> Vec<TokenizerEntry> {
300        let mut entries: Vec<TokenizerEntry> =
301            self.tokenizers.iter().map(|e| e.value().clone()).collect();
302        entries.sort_by(|a, b| a.name.cmp(&b.name));
303        entries
304    }
305
306    /// Remove a tokenizer by ID
307    ///
308    /// Returns the entry if it was present.
309    pub fn remove_by_id(&self, id: &str) -> Option<TokenizerEntry> {
310        if let Some((_, entry)) = self.tokenizers.remove(id) {
311            self.name_to_id.remove(&entry.name);
312            Some(entry)
313        } else {
314            None
315        }
316    }
317
318    /// Remove a tokenizer by name
319    ///
320    /// Returns the entry if it was present.
321    pub fn remove(&self, name: &str) -> Option<TokenizerEntry> {
322        if let Some((_, id)) = self.name_to_id.remove(name) {
323            self.tokenizers.remove(&id).map(|(_, e)| e)
324        } else {
325            None
326        }
327    }
328
329    /// Clear all tokenizers from the registry
330    pub fn clear(&self) {
331        self.tokenizers.clear();
332        self.name_to_id.clear();
333        self.loading_locks.clear();
334    }
335}
336
337impl Default for TokenizerRegistry {
338    fn default() -> Self {
339        Self::new()
340    }
341}
342
343#[cfg(test)]
344mod tests {
345    use std::{sync::Arc, time::Duration};
346
347    use tokio::time::sleep;
348
349    use crate::{mock::MockTokenizer, traits::Tokenizer, LoadError, TokenizerRegistry};
350
351    #[tokio::test]
352    async fn test_basic_operations() {
353        let registry = TokenizerRegistry::new();
354
355        // Registry starts empty
356        assert!(registry.is_empty());
357        assert_eq!(registry.len(), 0);
358        assert!(!registry.contains("model1"));
359
360        // Load and register a tokenizer
361        let id = TokenizerRegistry::generate_id();
362        let outcome = registry
363            .load(&id, "model1", "path/to/model", || async {
364                Ok(Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>)
365            })
366            .await
367            .unwrap();
368
369        // Verify LoadOutcome::Loaded
370        assert!(outcome.is_newly_loaded());
371        assert_eq!(outcome.id(), id);
372
373        // Verify it's loaded
374        assert!(!registry.is_empty());
375        assert_eq!(registry.len(), 1);
376        assert!(registry.contains("model1"));
377        assert!(registry.contains_id(&id));
378
379        // Get returns the tokenizer
380        let entry = registry.get_by_name("model1").unwrap();
381        assert_eq!(entry.id, id);
382        assert_eq!(entry.name, "model1");
383        assert_eq!(entry.source, "path/to/model");
384
385        // Remove works
386        let removed = registry.remove_by_id(&id);
387        assert!(removed.is_some());
388        assert!(registry.is_empty());
389    }
390
391    #[tokio::test]
392    async fn test_load_returns_already_exists() {
393        let registry = TokenizerRegistry::new();
394        let id1 = TokenizerRegistry::generate_id();
395        let id2 = TokenizerRegistry::generate_id();
396
397        // First load should return Loaded
398        let outcome1 = registry
399            .load(&id1, "model1", "source1", || async {
400                Ok(Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>)
401            })
402            .await
403            .unwrap();
404        assert!(outcome1.is_newly_loaded());
405        assert_eq!(outcome1.id(), id1);
406
407        // Second load with same name should return AlreadyExists with ORIGINAL id
408        let outcome2 = registry
409            .load(&id2, "model1", "source2", || async {
410                panic!("Loader should not be called for duplicate name");
411            })
412            .await
413            .unwrap();
414        assert!(!outcome2.is_newly_loaded());
415        assert_eq!(outcome2.id(), id1); // Returns the original ID, not id2
416
417        // Registry still has only one entry
418        assert_eq!(registry.len(), 1);
419
420        // Original source is preserved
421        let entry = registry.get_by_name("model1").unwrap();
422        assert_eq!(entry.source, "source1");
423    }
424
425    #[tokio::test]
426    async fn test_load_validation() {
427        let registry = TokenizerRegistry::new();
428        let id = TokenizerRegistry::generate_id();
429
430        // Empty name should fail
431        let result = registry
432            .load(&id, "", "source", || async {
433                panic!("Loader should not be called for invalid input");
434            })
435            .await;
436        assert!(matches!(result, Err(LoadError::EmptyName)));
437
438        // Empty source should fail
439        let result = registry
440            .load(&id, "model", "", || async {
441                panic!("Loader should not be called for invalid input");
442            })
443            .await;
444        assert!(matches!(result, Err(LoadError::EmptySource)));
445
446        // Registry should be empty (nothing was loaded)
447        assert!(registry.is_empty());
448    }
449
450    #[tokio::test]
451    async fn test_load_prevents_duplicate_loading() {
452        let registry = Arc::new(TokenizerRegistry::new());
453        let load_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
454
455        // Spawn multiple tasks trying to load the same tokenizer
456        let mut handles = vec![];
457        for i in 0..10 {
458            let registry = registry.clone();
459            let load_count = load_count.clone();
460            let id = format!("id-{}", i);
461            let handle = tokio::spawn(async move {
462                registry
463                    .load(&id, "model1", "source", || async {
464                        // Simulate slow loading
465                        sleep(Duration::from_millis(10)).await;
466                        load_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
467                        Ok(Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>)
468                    })
469                    .await
470            });
471            handles.push(handle);
472        }
473
474        // Wait for all tasks
475        for handle in handles {
476            handle.await.unwrap().unwrap();
477        }
478
479        // Verify tokenizer was loaded only once
480        assert_eq!(
481            load_count.load(std::sync::atomic::Ordering::SeqCst),
482            1,
483            "Tokenizer should be loaded exactly once despite concurrent requests"
484        );
485        assert_eq!(registry.len(), 1);
486    }
487
488    #[tokio::test]
489    async fn test_multiple_models() {
490        let registry = TokenizerRegistry::new();
491
492        // Load multiple tokenizers
493        for i in 1..=5 {
494            let model_name = format!("model{}", i);
495            let id = TokenizerRegistry::generate_id();
496            registry
497                .load(&id, &model_name, "source", || async {
498                    Ok(Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>)
499                })
500                .await
501                .unwrap();
502        }
503
504        assert_eq!(registry.len(), 5);
505        assert!(registry.contains("model1"));
506        assert!(registry.contains("model5"));
507        assert!(!registry.contains("model6"));
508
509        // List returns all with metadata
510        let entries = registry.list();
511        assert_eq!(entries.len(), 5);
512        assert!(entries.iter().any(|e| e.name == "model1"));
513
514        // Clear all
515        registry.clear();
516        assert!(registry.is_empty());
517    }
518
519    #[tokio::test]
520    async fn test_load_failure() {
521        let registry = TokenizerRegistry::new();
522        let id = TokenizerRegistry::generate_id();
523
524        // Try to load with a failing loader
525        let result = registry
526            .load(&id, "failing_model", "source", || async {
527                Err("Load failed".to_string())
528            })
529            .await;
530
531        assert!(result.is_err());
532        assert!(!registry.contains("failing_model"));
533        assert!(registry.is_empty());
534    }
535
536    #[tokio::test]
537    async fn test_get_by_name_and_id() {
538        let registry = TokenizerRegistry::new();
539        let id = TokenizerRegistry::generate_id();
540
541        registry
542            .load(&id, "my-model", "hf/model", || async {
543                Ok(Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>)
544            })
545            .await
546            .unwrap();
547
548        // Get by name
549        let by_name = registry.get_by_name("my-model");
550        assert!(by_name.is_some());
551        assert_eq!(by_name.as_ref().unwrap().id, id);
552
553        // Get by ID
554        let by_id = registry.get_by_id(&id);
555        assert!(by_id.is_some());
556        assert_eq!(by_id.as_ref().unwrap().name, "my-model");
557
558        // Generic get works with both
559        assert!(registry.get("my-model").is_some());
560        assert!(registry.get(&id).is_some());
561    }
562
563    #[tokio::test]
564    async fn test_register_only_if_absent() {
565        let registry = TokenizerRegistry::new();
566        let id1 = TokenizerRegistry::generate_id();
567        let id2 = TokenizerRegistry::generate_id();
568        let tokenizer1 = Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>;
569        let tokenizer2 = Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>;
570
571        // First registration should succeed
572        let result1 = registry.register(&id1, "model1", "source1", tokenizer1.clone());
573        assert!(result1.is_some());
574        assert_eq!(registry.len(), 1);
575
576        // Second registration with same name should fail
577        let result2 = registry.register(&id2, "model1", "source2", tokenizer2.clone());
578        assert!(result2.is_none());
579        assert_eq!(registry.len(), 1);
580
581        // Original tokenizer should still be there
582        let entry = registry.get_by_name("model1").unwrap();
583        assert_eq!(entry.id, id1);
584        assert_eq!(entry.source, "source1");
585
586        // Registration with different name should succeed
587        let id3 = TokenizerRegistry::generate_id();
588        let result3 = registry.register(&id3, "model2", "source2", tokenizer2);
589        assert!(result3.is_some());
590        assert_eq!(registry.len(), 2);
591    }
592
593    #[tokio::test]
594    async fn test_loading_lock_cleanup_on_panic() {
595        let registry = Arc::new(TokenizerRegistry::new());
596
597        // Spawn a task that will panic during loading
598        let registry_clone = registry.clone();
599        let handle = tokio::spawn(async move {
600            registry_clone
601                .load(
602                    &TokenizerRegistry::generate_id(),
603                    "panic-model",
604                    "source",
605                    || async {
606                        panic!("Simulated panic during tokenizer loading");
607                    },
608                )
609                .await
610        });
611
612        // Wait for the task - it should panic
613        let result = handle.await;
614        assert!(result.is_err(), "Task should have panicked");
615
616        // The RAII guard should have cleaned up the loading lock.
617        // Verify by attempting another load with the same name - it should work,
618        // not deadlock or fail due to stale lock.
619        let id = TokenizerRegistry::generate_id();
620        let outcome = registry
621            .load(&id, "panic-model", "source", || async {
622                Ok(Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>)
623            })
624            .await;
625
626        // Should succeed - the lock was properly cleaned up
627        assert!(outcome.is_ok(), "Load should succeed after panic cleanup");
628        assert!(outcome.unwrap().is_newly_loaded());
629        assert_eq!(registry.len(), 1);
630        assert!(registry.contains("panic-model"));
631    }
632
633    #[tokio::test]
634    async fn test_loading_lock_cleanup_on_early_return() {
635        let registry = Arc::new(TokenizerRegistry::new());
636
637        // Load a tokenizer
638        let id1 = TokenizerRegistry::generate_id();
639        registry
640            .load(&id1, "model1", "source1", || async {
641                Ok(Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>)
642            })
643            .await
644            .unwrap();
645
646        // Now simulate concurrent load attempts where one thread wins
647        // and another thread sees "already exists" in the double-check.
648        // The RAII guard should clean up the lock on early return.
649
650        // First, verify loading_locks is empty after successful load
651        // by checking that we can load a different model without issues
652        let id2 = TokenizerRegistry::generate_id();
653        let outcome = registry
654            .load(&id2, "model2", "source2", || async {
655                Ok(Arc::new(MockTokenizer::default()) as Arc<dyn Tokenizer>)
656            })
657            .await
658            .unwrap();
659
660        assert!(outcome.is_newly_loaded());
661        assert_eq!(registry.len(), 2);
662
663        // Try to load model1 again - should return AlreadyExists
664        // and the lock should be cleaned up (not leak)
665        let id3 = TokenizerRegistry::generate_id();
666        let outcome = registry
667            .load(&id3, "model1", "source1", || async {
668                panic!("Loader should not be called for existing model");
669            })
670            .await
671            .unwrap();
672
673        assert!(!outcome.is_newly_loaded());
674        assert_eq!(outcome.id(), id1); // Returns original ID
675    }
676}