dynamo_runtime/discovery/
metadata.rs1use anyhow::Result;
5use std::collections::HashMap;
6use std::sync::Arc;
7
8use super::{DiscoveryInstance, DiscoveryQuery};
9
10fn make_endpoint_key(namespace: &str, component: &str, endpoint: &str) -> String {
13 format!("{namespace}/{component}/{endpoint}")
14}
15
16#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
19pub struct DiscoveryMetadata {
20 endpoints: HashMap<String, DiscoveryInstance>,
22 model_cards: HashMap<String, DiscoveryInstance>,
24}
25
26impl DiscoveryMetadata {
27 pub fn new() -> Self {
29 Self {
30 endpoints: HashMap::new(),
31 model_cards: HashMap::new(),
32 }
33 }
34
35 pub fn register_endpoint(&mut self, instance: DiscoveryInstance) -> Result<()> {
37 if let DiscoveryInstance::Endpoint(ref inst) = instance {
38 let key = make_endpoint_key(&inst.namespace, &inst.component, &inst.endpoint);
39 self.endpoints.insert(key, instance);
40 Ok(())
41 } else {
42 anyhow::bail!("Cannot register non-endpoint instance as endpoint")
43 }
44 }
45
46 pub fn register_model_card(&mut self, instance: DiscoveryInstance) -> Result<()> {
48 if let DiscoveryInstance::Model {
49 ref namespace,
50 ref component,
51 ref endpoint,
52 ..
53 } = instance
54 {
55 let key = make_endpoint_key(namespace, component, endpoint);
56 self.model_cards.insert(key, instance);
57 Ok(())
58 } else {
59 anyhow::bail!("Cannot register non-model-card instance as model card")
60 }
61 }
62
63 pub fn get_all_endpoints(&self) -> Vec<DiscoveryInstance> {
65 self.endpoints.values().cloned().collect()
66 }
67
68 pub fn get_all_model_cards(&self) -> Vec<DiscoveryInstance> {
70 self.model_cards.values().cloned().collect()
71 }
72
73 pub fn get_all(&self) -> Vec<DiscoveryInstance> {
75 self.endpoints
76 .values()
77 .chain(self.model_cards.values())
78 .cloned()
79 .collect()
80 }
81
82 pub fn filter(&self, query: &DiscoveryQuery) -> Vec<DiscoveryInstance> {
84 let all_instances = match query {
85 DiscoveryQuery::AllEndpoints
86 | DiscoveryQuery::NamespacedEndpoints { .. }
87 | DiscoveryQuery::ComponentEndpoints { .. }
88 | DiscoveryQuery::Endpoint { .. } => self.get_all_endpoints(),
89
90 DiscoveryQuery::AllModels
91 | DiscoveryQuery::NamespacedModels { .. }
92 | DiscoveryQuery::ComponentModels { .. }
93 | DiscoveryQuery::EndpointModels { .. } => self.get_all_model_cards(),
94 };
95
96 filter_instances(all_instances, query)
97 }
98}
99
100impl Default for DiscoveryMetadata {
101 fn default() -> Self {
102 Self::new()
103 }
104}
105
106fn filter_instances(
108 instances: Vec<DiscoveryInstance>,
109 query: &DiscoveryQuery,
110) -> Vec<DiscoveryInstance> {
111 match query {
112 DiscoveryQuery::AllEndpoints | DiscoveryQuery::AllModels => instances,
113
114 DiscoveryQuery::NamespacedEndpoints { namespace } => instances
115 .into_iter()
116 .filter(|inst| match inst {
117 DiscoveryInstance::Endpoint(i) => &i.namespace == namespace,
118 _ => false,
119 })
120 .collect(),
121
122 DiscoveryQuery::ComponentEndpoints {
123 namespace,
124 component,
125 } => instances
126 .into_iter()
127 .filter(|inst| match inst {
128 DiscoveryInstance::Endpoint(i) => {
129 &i.namespace == namespace && &i.component == component
130 }
131 _ => false,
132 })
133 .collect(),
134
135 DiscoveryQuery::Endpoint {
136 namespace,
137 component,
138 endpoint,
139 } => instances
140 .into_iter()
141 .filter(|inst| match inst {
142 DiscoveryInstance::Endpoint(i) => {
143 &i.namespace == namespace
144 && &i.component == component
145 && &i.endpoint == endpoint
146 }
147 _ => false,
148 })
149 .collect(),
150
151 DiscoveryQuery::NamespacedModels { namespace } => instances
152 .into_iter()
153 .filter(|inst| match inst {
154 DiscoveryInstance::Model { namespace: ns, .. } => ns == namespace,
155 _ => false,
156 })
157 .collect(),
158
159 DiscoveryQuery::ComponentModels {
160 namespace,
161 component,
162 } => instances
163 .into_iter()
164 .filter(|inst| match inst {
165 DiscoveryInstance::Model {
166 namespace: ns,
167 component: comp,
168 ..
169 } => ns == namespace && comp == component,
170 _ => false,
171 })
172 .collect(),
173
174 DiscoveryQuery::EndpointModels {
175 namespace,
176 component,
177 endpoint,
178 } => instances
179 .into_iter()
180 .filter(|inst| match inst {
181 DiscoveryInstance::Model {
182 namespace: ns,
183 component: comp,
184 endpoint: ep,
185 ..
186 } => ns == namespace && comp == component && ep == endpoint,
187 _ => false,
188 })
189 .collect(),
190 }
191}
192
193#[derive(Clone, Debug)]
195pub struct MetadataSnapshot {
196 pub instances: HashMap<u64, Arc<DiscoveryMetadata>>,
198 pub sequence: u64,
200 pub timestamp: std::time::Instant,
202}
203
204impl MetadataSnapshot {
205 pub fn empty() -> Self {
206 Self {
207 instances: HashMap::new(),
208 sequence: 0,
209 timestamp: std::time::Instant::now(),
210 }
211 }
212
213 pub fn filter(&self, query: &DiscoveryQuery) -> Vec<DiscoveryInstance> {
215 self.instances
216 .values()
217 .flat_map(|metadata| metadata.filter(query))
218 .collect()
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225 use crate::component::{Instance, TransportType};
226
227 #[test]
228 fn test_metadata_serde() {
229 let mut metadata = DiscoveryMetadata::new();
230
231 let instance = DiscoveryInstance::Endpoint(Instance {
233 namespace: "test".to_string(),
234 component: "comp1".to_string(),
235 endpoint: "ep1".to_string(),
236 instance_id: 123,
237 transport: TransportType::Nats("nats://localhost:4222".to_string()),
238 });
239
240 metadata.register_endpoint(instance).unwrap();
241
242 let json = serde_json::to_string(&metadata).unwrap();
244
245 let deserialized: DiscoveryMetadata = serde_json::from_str(&json).unwrap();
247
248 assert_eq!(deserialized.endpoints.len(), 1);
249 assert_eq!(deserialized.model_cards.len(), 0);
250 }
251
252 #[tokio::test]
253 async fn test_concurrent_registration() {
254 use tokio::sync::RwLock;
255
256 let metadata = Arc::new(RwLock::new(DiscoveryMetadata::new()));
257
258 let handles: Vec<_> = (0..10)
260 .map(|i| {
261 let metadata = metadata.clone();
262 tokio::spawn(async move {
263 let mut meta = metadata.write().await;
264 let instance = DiscoveryInstance::Endpoint(Instance {
265 namespace: "test".to_string(),
266 component: "comp1".to_string(),
267 endpoint: format!("ep{}", i),
268 instance_id: i,
269 transport: TransportType::Nats("nats://localhost:4222".to_string()),
270 });
271 meta.register_endpoint(instance).unwrap();
272 })
273 })
274 .collect();
275
276 for handle in handles {
278 handle.await.unwrap();
279 }
280
281 let meta = metadata.read().await;
283 assert_eq!(meta.endpoints.len(), 10);
284 }
285
286 #[tokio::test]
287 async fn test_metadata_accessors() {
288 let mut metadata = DiscoveryMetadata::new();
289
290 for i in 0..3 {
292 let instance = DiscoveryInstance::Endpoint(Instance {
293 namespace: "test".to_string(),
294 component: "comp1".to_string(),
295 endpoint: format!("ep{}", i),
296 instance_id: i,
297 transport: TransportType::Nats("nats://localhost:4222".to_string()),
298 });
299 metadata.register_endpoint(instance).unwrap();
300 }
301
302 for i in 0..2 {
304 let instance = DiscoveryInstance::Model {
305 namespace: "test".to_string(),
306 component: "comp1".to_string(),
307 endpoint: format!("ep{}", i),
308 instance_id: i,
309 card_json: serde_json::json!({"model": "test"}),
310 };
311 metadata.register_model_card(instance).unwrap();
312 }
313
314 assert_eq!(metadata.get_all_endpoints().len(), 3);
315 assert_eq!(metadata.get_all_model_cards().len(), 2);
316 assert_eq!(metadata.get_all().len(), 5);
317 }
318}