dynamo_runtime/pipeline/
registry.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::any::Any;
5use std::collections::HashMap;
6use std::sync::Arc;
7
8/// Registry struct that manages both shared and unique objects.
9///
10/// # Examples
11///
12/// ```
13/// use dynamo_runtime::pipeline::registry::Registry;
14///
15/// let mut registry = Registry::new();
16///
17/// // Insert and retrieve shared objects
18/// registry.insert_shared("shared1", 42);
19/// assert_eq!(*registry.get_shared::<i32>("shared1").unwrap(), 42);
20///
21/// // Insert and take unique objects
22/// registry.insert_unique("unique1", "Hello".to_string());
23/// assert_eq!(registry.take_unique::<String>("unique1").unwrap(), "Hello");
24///
25/// // Taking the same unique again should fail since it's not cloneable
26/// assert!(registry.take_unique::<String>("unique1").is_err());
27///
28/// // Insert and clone unique objects
29/// registry.insert_unique("unique2", "World".to_string());
30/// assert_eq!(registry.clone_unique::<String>("unique2").unwrap(), "World");
31///
32/// // Taking the same cloned unique should is ok
33/// assert!(registry.take_unique::<String>("unique2").is_ok());
34///
35/// ```
36#[derive(Debug, Default)]
37pub struct Registry {
38    shared_storage: HashMap<String, Arc<dyn Any + Send + Sync>>, // Shared objects
39    unique_storage: HashMap<String, Box<dyn Any + Send + Sync>>, // Takable objects
40}
41
42impl Registry {
43    /// Create a new empty registry.
44    pub fn new() -> Self {
45        Registry {
46            shared_storage: HashMap::new(),
47            unique_storage: HashMap::new(),
48        }
49    }
50
51    /// Check if a shared object exists in the registry by key.
52    pub fn contains_shared(&self, key: &str) -> bool {
53        self.shared_storage.contains_key(key)
54    }
55
56    /// Insert a shared object into the registry with a specific key.
57    pub fn insert_shared<K: ToString, U: Send + Sync + 'static>(&mut self, key: K, value: U) {
58        self.shared_storage.insert(
59            key.to_string(),
60            Arc::new(value) as Arc<dyn Any + Send + Sync>,
61        );
62    }
63
64    /// Retrieve a shared object from the registry by key and type.
65    pub fn get_shared<V: Send + Sync + 'static>(&self, key: &str) -> Result<Arc<V>, String> {
66        match self.shared_storage.get(key) {
67            Some(boxed) => boxed.clone().downcast::<V>().map_err(|_| {
68                format!(
69                    "Failed to downcast to the requested type for shared key: {}",
70                    key
71                )
72            }),
73            None => Err(format!("Shared key not found: {}", key)),
74        }
75    }
76
77    /// Check if a unique object exists in the registry by key.
78    pub fn contains_unique(&self, key: &str) -> bool {
79        self.unique_storage.contains_key(key)
80    }
81
82    /// Insert a unique object into the registry with a specific key.
83    pub fn insert_unique<K: ToString, U: Send + Sync + 'static>(&mut self, key: K, value: U) {
84        self.unique_storage.insert(
85            key.to_string(),
86            Box::new(value) as Box<dyn Any + Send + Sync>,
87        );
88    }
89
90    /// Take a unique object from the registry by key and type, removing it from the registry.
91    pub fn take_unique<V: Send + Sync + 'static>(&mut self, key: &str) -> Result<V, String> {
92        match self.unique_storage.remove(key) {
93            Some(boxed) => boxed.downcast::<V>().map(|b| *b).map_err(|_| {
94                format!(
95                    "Failed to downcast to the requested type for unique key: {}",
96                    key
97                )
98            }),
99            None => Err(format!("Takable key not found: {}", key)),
100        }
101    }
102
103    /// Clone a unique object from the registry if it implements `Clone`.
104    pub fn clone_unique<V: Clone + Send + Sync + 'static>(&self, key: &str) -> Result<V, String> {
105        match self.unique_storage.get(key) {
106            Some(boxed) => boxed.downcast_ref::<V>().cloned().ok_or_else(|| {
107                format!(
108                    "Failed to downcast to the requested type for unique key: {}",
109                    key
110                )
111            }),
112            None => Err(format!("Takable key not found: {}", key)),
113        }
114    }
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120
121    #[test]
122    fn test_insert_and_get_shared() {
123        let mut registry = Registry::new();
124        registry.insert_shared("shared1", 42);
125        assert_eq!(*registry.get_shared::<i32>("shared1").unwrap(), 42);
126        assert!(registry.get_shared::<f64>("shared1").is_err()); // Testing a downcast failure
127    }
128
129    #[test]
130    fn test_insert_and_take_unique() {
131        let mut registry = Registry::new();
132        registry.insert_unique("unique1", "Hello".to_string());
133        assert_eq!(registry.take_unique::<String>("unique1").unwrap(), "Hello");
134        assert!(registry.take_unique::<String>("unique1").is_err()); // Key is now missing
135    }
136
137    #[test]
138    fn test_insert_and_clone_then_take_unique() {
139        let mut registry = Registry::new();
140
141        registry.insert_unique("unique2", "World".to_string());
142
143        assert_eq!(registry.clone_unique::<String>("unique2").unwrap(), "World");
144
145        // When cloned, the object should still be available for taking
146        assert!(registry.take_unique::<String>("unique2").is_ok());
147    }
148
149    #[test]
150    fn test_failed_take_after_cloning() {
151        let mut registry = Registry::new();
152
153        registry.insert_unique("unique3", "Another".to_string());
154        assert_eq!(
155            registry.clone_unique::<String>("unique3").unwrap(),
156            "Another"
157        );
158
159        // Cloned, then Take is OK
160        assert_eq!(
161            registry.take_unique::<String>("unique3").unwrap(),
162            "Another"
163        );
164
165        // Take, then Take again should fail
166        assert!(registry.take_unique::<String>("unique3").is_err());
167    }
168}