1use anyhow::Result;
5use serde::Deserialize as _;
6use std::collections::{HashMap, HashSet};
7use std::sync::Arc;
8
9use super::{DiscoveryInstance, DiscoveryInstanceId, DiscoveryQuery};
10
11fn deserialize_null_default<'de, D, T>(deserializer: D) -> Result<T, D::Error>
25where
26 D: serde::Deserializer<'de>,
27 T: Default + serde::Deserialize<'de>,
28{
29 Ok(Option::<T>::deserialize(deserializer)?.unwrap_or_default())
30}
31
32#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
35pub struct DiscoveryMetadata {
36 #[serde(default, deserialize_with = "deserialize_null_default")]
38 endpoints: HashMap<String, DiscoveryInstance>,
39 #[serde(default, deserialize_with = "deserialize_null_default")]
41 model_cards: HashMap<String, DiscoveryInstance>,
42 #[serde(default, deserialize_with = "deserialize_null_default")]
44 event_channels: HashMap<String, DiscoveryInstance>,
45}
46
47impl DiscoveryMetadata {
48 pub fn new() -> Self {
50 Self {
51 endpoints: HashMap::new(),
52 model_cards: HashMap::new(),
53 event_channels: HashMap::new(),
54 }
55 }
56
57 pub fn register_endpoint(&mut self, instance: DiscoveryInstance) -> Result<()> {
59 match instance.id() {
60 DiscoveryInstanceId::Endpoint(key) => {
61 self.endpoints.insert(key.to_path(), instance);
62 Ok(())
63 }
64 DiscoveryInstanceId::Model(_) => {
65 anyhow::bail!("Cannot register non-endpoint instance as endpoint")
66 }
67 DiscoveryInstanceId::EventChannel(_) => {
68 anyhow::bail!("Cannot register EventChannel instance as endpoint")
69 }
70 }
71 }
72
73 pub fn register_model_card(&mut self, instance: DiscoveryInstance) -> Result<()> {
75 match instance.id() {
76 DiscoveryInstanceId::Model(key) => {
77 self.model_cards.insert(key.to_path(), instance);
78 Ok(())
79 }
80 DiscoveryInstanceId::Endpoint(_) => {
81 anyhow::bail!("Cannot register non-model-card instance as model card")
82 }
83 DiscoveryInstanceId::EventChannel(_) => {
84 anyhow::bail!("Cannot register EventChannel instance as model card")
85 }
86 }
87 }
88
89 pub fn unregister_endpoint(&mut self, instance: &DiscoveryInstance) -> Result<()> {
91 match instance.id() {
92 DiscoveryInstanceId::Endpoint(key) => {
93 self.endpoints.remove(&key.to_path());
94 Ok(())
95 }
96 DiscoveryInstanceId::Model(_) => {
97 anyhow::bail!("Cannot unregister non-endpoint instance as endpoint")
98 }
99 DiscoveryInstanceId::EventChannel(_) => {
100 anyhow::bail!("Cannot unregister EventChannel instance as endpoint")
101 }
102 }
103 }
104
105 pub fn unregister_model_card(&mut self, instance: &DiscoveryInstance) -> Result<()> {
107 match instance.id() {
108 DiscoveryInstanceId::Model(key) => {
109 self.model_cards.remove(&key.to_path());
110 Ok(())
111 }
112 DiscoveryInstanceId::Endpoint(_) => {
113 anyhow::bail!("Cannot unregister non-model-card instance as model card")
114 }
115 DiscoveryInstanceId::EventChannel(_) => {
116 anyhow::bail!("Cannot unregister EventChannel instance as model card")
117 }
118 }
119 }
120
121 pub fn register_event_channel(&mut self, instance: DiscoveryInstance) -> Result<()> {
123 match instance.id() {
124 DiscoveryInstanceId::EventChannel(key) => {
125 self.event_channels.insert(key.to_path(), instance);
126 Ok(())
127 }
128 DiscoveryInstanceId::Endpoint(_) => {
129 anyhow::bail!("Cannot register Endpoint instance as event channel")
130 }
131 DiscoveryInstanceId::Model(_) => {
132 anyhow::bail!("Cannot register Model instance as event channel")
133 }
134 }
135 }
136
137 pub fn unregister_event_channel(&mut self, instance: &DiscoveryInstance) -> Result<()> {
139 match instance.id() {
140 DiscoveryInstanceId::EventChannel(key) => {
141 self.event_channels.remove(&key.to_path());
142 Ok(())
143 }
144 DiscoveryInstanceId::Endpoint(_) => {
145 anyhow::bail!("Cannot unregister Endpoint instance as event channel")
146 }
147 DiscoveryInstanceId::Model(_) => {
148 anyhow::bail!("Cannot unregister Model instance as event channel")
149 }
150 }
151 }
152
153 pub fn get_all_endpoints(&self) -> Vec<DiscoveryInstance> {
155 self.endpoints.values().cloned().collect()
156 }
157
158 pub fn get_all_model_cards(&self) -> Vec<DiscoveryInstance> {
160 self.model_cards.values().cloned().collect()
161 }
162
163 pub fn get_all_event_channels(&self) -> Vec<DiscoveryInstance> {
165 self.event_channels.values().cloned().collect()
166 }
167
168 pub fn get_all(&self) -> Vec<DiscoveryInstance> {
170 self.endpoints
171 .values()
172 .chain(self.model_cards.values())
173 .chain(self.event_channels.values())
174 .cloned()
175 .collect()
176 }
177
178 pub fn filter(&self, query: &DiscoveryQuery) -> Vec<DiscoveryInstance> {
180 let all_instances = match query {
181 DiscoveryQuery::AllEndpoints
182 | DiscoveryQuery::NamespacedEndpoints { .. }
183 | DiscoveryQuery::ComponentEndpoints { .. }
184 | DiscoveryQuery::Endpoint { .. } => self.get_all_endpoints(),
185
186 DiscoveryQuery::AllModels
187 | DiscoveryQuery::NamespacedModels { .. }
188 | DiscoveryQuery::ComponentModels { .. }
189 | DiscoveryQuery::EndpointModels { .. } => self.get_all_model_cards(),
190
191 DiscoveryQuery::EventChannels(_) => self.get_all_event_channels(),
193 };
194
195 filter_instances(all_instances, query)
196 }
197}
198
199impl Default for DiscoveryMetadata {
200 fn default() -> Self {
201 Self::new()
202 }
203}
204
205fn filter_instances(
207 instances: Vec<DiscoveryInstance>,
208 query: &DiscoveryQuery,
209) -> Vec<DiscoveryInstance> {
210 match query {
211 DiscoveryQuery::AllEndpoints | DiscoveryQuery::AllModels => instances,
212
213 DiscoveryQuery::NamespacedEndpoints { namespace } => instances
214 .into_iter()
215 .filter(|inst| match inst {
216 DiscoveryInstance::Endpoint(i) => &i.namespace == namespace,
217 _ => false,
218 })
219 .collect(),
220
221 DiscoveryQuery::ComponentEndpoints {
222 namespace,
223 component,
224 } => instances
225 .into_iter()
226 .filter(|inst| match inst {
227 DiscoveryInstance::Endpoint(i) => {
228 &i.namespace == namespace && &i.component == component
229 }
230 _ => false,
231 })
232 .collect(),
233
234 DiscoveryQuery::Endpoint {
235 namespace,
236 component,
237 endpoint,
238 } => instances
239 .into_iter()
240 .filter(|inst| match inst {
241 DiscoveryInstance::Endpoint(i) => {
242 &i.namespace == namespace
243 && &i.component == component
244 && &i.endpoint == endpoint
245 }
246 _ => false,
247 })
248 .collect(),
249
250 DiscoveryQuery::NamespacedModels { namespace } => instances
251 .into_iter()
252 .filter(|inst| match inst {
253 DiscoveryInstance::Model { namespace: ns, .. } => ns == namespace,
254 _ => false,
255 })
256 .collect(),
257
258 DiscoveryQuery::ComponentModels {
259 namespace,
260 component,
261 } => instances
262 .into_iter()
263 .filter(|inst| match inst {
264 DiscoveryInstance::Model {
265 namespace: ns,
266 component: comp,
267 ..
268 } => ns == namespace && comp == component,
269 _ => false,
270 })
271 .collect(),
272
273 DiscoveryQuery::EndpointModels {
274 namespace,
275 component,
276 endpoint,
277 } => instances
278 .into_iter()
279 .filter(|inst| match inst {
280 DiscoveryInstance::Model {
281 namespace: ns,
282 component: comp,
283 endpoint: ep,
284 ..
285 } => ns == namespace && comp == component && ep == endpoint,
286 _ => false,
287 })
288 .collect(),
289
290 DiscoveryQuery::EventChannels(query) => instances
292 .into_iter()
293 .filter(|inst| match inst {
294 DiscoveryInstance::EventChannel {
295 namespace: ns,
296 component: comp,
297 topic: t,
298 ..
299 } => {
300 query.namespace.as_ref().is_none_or(|qns| qns == ns)
302 && query.component.as_ref().is_none_or(|qc| qc == comp)
304 && query.topic.as_ref().is_none_or(|qt| qt == t)
306 }
307 _ => false,
308 })
309 .collect(),
310 }
311}
312
313#[derive(Clone, Debug)]
315pub struct MetadataSnapshot {
316 pub instances: HashMap<u64, Arc<DiscoveryMetadata>>,
318 pub generations: HashMap<u64, i64>,
321 pub sequence: u64,
323 pub timestamp: std::time::Instant,
325}
326
327impl MetadataSnapshot {
328 pub fn empty() -> Self {
329 Self {
330 instances: HashMap::new(),
331 generations: HashMap::new(),
332 sequence: 0,
333 timestamp: std::time::Instant::now(),
334 }
335 }
336
337 pub fn has_changes_from(&self, prev: &MetadataSnapshot) -> bool {
341 if self.generations == prev.generations {
342 tracing::trace!(
343 "Snapshot (seq={}): no changes, {} instances",
344 self.sequence,
345 self.instances.len()
346 );
347 return false;
348 }
349
350 let curr_ids: HashSet<u64> = self.generations.keys().copied().collect();
352 let prev_ids: HashSet<u64> = prev.generations.keys().copied().collect();
353
354 let added: Vec<_> = curr_ids
355 .difference(&prev_ids)
356 .map(|id| format!("{:x}", id))
357 .collect();
358 let removed: Vec<_> = prev_ids
359 .difference(&curr_ids)
360 .map(|id| format!("{:x}", id))
361 .collect();
362 let updated: Vec<_> = self
363 .generations
364 .iter()
365 .filter(|(k, v)| prev.generations.get(*k).is_some_and(|pv| pv != *v))
366 .map(|(k, _)| format!("{:x}", k))
367 .collect();
368
369 tracing::info!(
370 "Snapshot (seq={}): {} instances, added={:?}, removed={:?}, updated={:?}",
371 self.sequence,
372 self.instances.len(),
373 added,
374 removed,
375 updated
376 );
377
378 true
379 }
380
381 pub fn filter(&self, query: &DiscoveryQuery) -> Vec<DiscoveryInstance> {
383 self.instances
384 .values()
385 .flat_map(|metadata| metadata.filter(query))
386 .collect()
387 }
388}
389
390#[cfg(test)]
391mod tests {
392 use super::*;
393 use crate::component::{Instance, TransportType};
394 use crate::discovery::EventChannelQuery;
395
396 #[test]
397 fn test_metadata_serde() {
398 let mut metadata = DiscoveryMetadata::new();
399
400 let instance = DiscoveryInstance::Endpoint(Instance {
402 namespace: "test".to_string(),
403 component: "comp1".to_string(),
404 endpoint: "ep1".to_string(),
405 instance_id: 123,
406 transport: TransportType::Nats("nats://localhost:4222".to_string()),
407 device_type: None,
408 });
409
410 metadata.register_endpoint(instance).unwrap();
411
412 let json = serde_json::to_string(&metadata).unwrap();
414
415 let deserialized: DiscoveryMetadata = serde_json::from_str(&json).unwrap();
417
418 assert_eq!(deserialized.endpoints.len(), 1);
419 assert_eq!(deserialized.model_cards.len(), 0);
420 }
421
422 #[tokio::test]
423 async fn test_concurrent_registration() {
424 use tokio::sync::RwLock;
425
426 let metadata = Arc::new(RwLock::new(DiscoveryMetadata::new()));
427
428 let handles: Vec<_> = (0..10)
430 .map(|i| {
431 let metadata = metadata.clone();
432 tokio::spawn(async move {
433 let mut meta = metadata.write().await;
434 let instance = DiscoveryInstance::Endpoint(Instance {
435 namespace: "test".to_string(),
436 component: "comp1".to_string(),
437 endpoint: format!("ep{}", i),
438 instance_id: i,
439 transport: TransportType::Nats("nats://localhost:4222".to_string()),
440 device_type: None,
441 });
442 meta.register_endpoint(instance).unwrap();
443 })
444 })
445 .collect();
446
447 for handle in handles {
449 handle.await.unwrap();
450 }
451
452 let meta = metadata.read().await;
454 assert_eq!(meta.endpoints.len(), 10);
455 }
456
457 #[tokio::test]
458 async fn test_metadata_accessors() {
459 let mut metadata = DiscoveryMetadata::new();
460
461 for i in 0..3 {
463 let instance = DiscoveryInstance::Endpoint(Instance {
464 namespace: "test".to_string(),
465 component: "comp1".to_string(),
466 endpoint: format!("ep{}", i),
467 instance_id: i,
468 transport: TransportType::Nats("nats://localhost:4222".to_string()),
469 device_type: None,
470 });
471 metadata.register_endpoint(instance).unwrap();
472 }
473
474 for i in 0..2 {
476 let instance = DiscoveryInstance::Model {
477 namespace: "test".to_string(),
478 component: "comp1".to_string(),
479 endpoint: format!("ep{}", i),
480 instance_id: i,
481 card_json: serde_json::json!({"model": "test"}),
482 model_suffix: None,
483 };
484 metadata.register_model_card(instance).unwrap();
485 }
486
487 assert_eq!(metadata.get_all_endpoints().len(), 3);
488 assert_eq!(metadata.get_all_model_cards().len(), 2);
489 assert_eq!(metadata.get_all().len(), 5);
490 }
491
492 #[tokio::test]
493 async fn test_event_channel_registration() {
494 use crate::discovery::EventTransport;
495
496 let mut metadata = DiscoveryMetadata::new();
497
498 for i in 0..3 {
500 let instance = DiscoveryInstance::EventChannel {
501 namespace: "test".to_string(),
502 component: "comp1".to_string(),
503 topic: "test-topic".to_string(),
504 instance_id: i,
505 transport: EventTransport::zmq(format!("tcp://localhost:{}", 5000 + i)),
506 };
507 metadata.register_event_channel(instance).unwrap();
508 }
509
510 assert_eq!(metadata.get_all_event_channels().len(), 3);
512
513 assert_eq!(metadata.get_all().len(), 3);
515
516 let filtered = metadata.filter(&DiscoveryQuery::EventChannels(EventChannelQuery::all()));
518 assert_eq!(filtered.len(), 3);
519
520 let filtered = metadata.filter(&DiscoveryQuery::EventChannels(
522 EventChannelQuery::component("test", "comp1"),
523 ));
524 assert_eq!(filtered.len(), 3);
525
526 let filtered = metadata.filter(&DiscoveryQuery::EventChannels(
528 EventChannelQuery::component("other", "comp1"),
529 ));
530 assert_eq!(filtered.len(), 0);
531
532 let instance = DiscoveryInstance::EventChannel {
534 namespace: "test".to_string(),
535 component: "comp1".to_string(),
536 topic: "test-topic".to_string(),
537 instance_id: 0,
538 transport: EventTransport::zmq("tcp://localhost:5000"),
539 };
540 metadata.unregister_event_channel(&instance).unwrap();
541 assert_eq!(metadata.get_all_event_channels().len(), 2);
542 }
543
544 #[tokio::test]
545 async fn test_mixed_instances() {
546 use crate::discovery::EventTransport;
547
548 let mut metadata = DiscoveryMetadata::new();
549
550 let endpoint = DiscoveryInstance::Endpoint(Instance {
552 namespace: "test".to_string(),
553 component: "comp1".to_string(),
554 endpoint: "ep1".to_string(),
555 instance_id: 1,
556 transport: TransportType::Nats("nats://localhost:4222".to_string()),
557 device_type: None,
558 });
559 metadata.register_endpoint(endpoint).unwrap();
560
561 let model = DiscoveryInstance::Model {
562 namespace: "test".to_string(),
563 component: "comp1".to_string(),
564 endpoint: "ep1".to_string(),
565 instance_id: 2,
566 card_json: serde_json::json!({"model": "test"}),
567 model_suffix: None,
568 };
569 metadata.register_model_card(model).unwrap();
570
571 let event_channel = DiscoveryInstance::EventChannel {
572 namespace: "test".to_string(),
573 component: "comp1".to_string(),
574 topic: "test-topic".to_string(),
575 instance_id: 3,
576 transport: EventTransport::zmq("tcp://localhost:5000"),
577 };
578 metadata.register_event_channel(event_channel).unwrap();
579
580 assert_eq!(metadata.get_all().len(), 3);
582 assert_eq!(metadata.get_all_endpoints().len(), 1);
583 assert_eq!(metadata.get_all_model_cards().len(), 1);
584 assert_eq!(metadata.get_all_event_channels().len(), 1);
585 }
586}