1use anyhow::Result;
5use std::collections::{HashMap, HashSet};
6use std::sync::Arc;
7
8use super::{DiscoveryInstance, DiscoveryInstanceId, DiscoveryQuery};
9
10#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
13pub struct DiscoveryMetadata {
14 endpoints: HashMap<String, DiscoveryInstance>,
16 model_cards: HashMap<String, DiscoveryInstance>,
18 event_channels: HashMap<String, DiscoveryInstance>,
20}
21
22impl DiscoveryMetadata {
23 pub fn new() -> Self {
25 Self {
26 endpoints: HashMap::new(),
27 model_cards: HashMap::new(),
28 event_channels: HashMap::new(),
29 }
30 }
31
32 pub fn register_endpoint(&mut self, instance: DiscoveryInstance) -> Result<()> {
34 match instance.id() {
35 DiscoveryInstanceId::Endpoint(key) => {
36 self.endpoints.insert(key.to_path(), instance);
37 Ok(())
38 }
39 DiscoveryInstanceId::Model(_) => {
40 anyhow::bail!("Cannot register non-endpoint instance as endpoint")
41 }
42 DiscoveryInstanceId::EventChannel(_) => {
43 anyhow::bail!("Cannot register EventChannel instance as endpoint")
44 }
45 }
46 }
47
48 pub fn register_model_card(&mut self, instance: DiscoveryInstance) -> Result<()> {
50 match instance.id() {
51 DiscoveryInstanceId::Model(key) => {
52 self.model_cards.insert(key.to_path(), instance);
53 Ok(())
54 }
55 DiscoveryInstanceId::Endpoint(_) => {
56 anyhow::bail!("Cannot register non-model-card instance as model card")
57 }
58 DiscoveryInstanceId::EventChannel(_) => {
59 anyhow::bail!("Cannot register EventChannel instance as model card")
60 }
61 }
62 }
63
64 pub fn unregister_endpoint(&mut self, instance: &DiscoveryInstance) -> Result<()> {
66 match instance.id() {
67 DiscoveryInstanceId::Endpoint(key) => {
68 self.endpoints.remove(&key.to_path());
69 Ok(())
70 }
71 DiscoveryInstanceId::Model(_) => {
72 anyhow::bail!("Cannot unregister non-endpoint instance as endpoint")
73 }
74 DiscoveryInstanceId::EventChannel(_) => {
75 anyhow::bail!("Cannot unregister EventChannel instance as endpoint")
76 }
77 }
78 }
79
80 pub fn unregister_model_card(&mut self, instance: &DiscoveryInstance) -> Result<()> {
82 match instance.id() {
83 DiscoveryInstanceId::Model(key) => {
84 self.model_cards.remove(&key.to_path());
85 Ok(())
86 }
87 DiscoveryInstanceId::Endpoint(_) => {
88 anyhow::bail!("Cannot unregister non-model-card instance as model card")
89 }
90 DiscoveryInstanceId::EventChannel(_) => {
91 anyhow::bail!("Cannot unregister EventChannel instance as model card")
92 }
93 }
94 }
95
96 pub fn register_event_channel(&mut self, instance: DiscoveryInstance) -> Result<()> {
98 match instance.id() {
99 DiscoveryInstanceId::EventChannel(key) => {
100 self.event_channels.insert(key.to_path(), instance);
101 Ok(())
102 }
103 DiscoveryInstanceId::Endpoint(_) => {
104 anyhow::bail!("Cannot register Endpoint instance as event channel")
105 }
106 DiscoveryInstanceId::Model(_) => {
107 anyhow::bail!("Cannot register Model instance as event channel")
108 }
109 }
110 }
111
112 pub fn unregister_event_channel(&mut self, instance: &DiscoveryInstance) -> Result<()> {
114 match instance.id() {
115 DiscoveryInstanceId::EventChannel(key) => {
116 self.event_channels.remove(&key.to_path());
117 Ok(())
118 }
119 DiscoveryInstanceId::Endpoint(_) => {
120 anyhow::bail!("Cannot unregister Endpoint instance as event channel")
121 }
122 DiscoveryInstanceId::Model(_) => {
123 anyhow::bail!("Cannot unregister Model instance as event channel")
124 }
125 }
126 }
127
128 pub fn get_all_endpoints(&self) -> Vec<DiscoveryInstance> {
130 self.endpoints.values().cloned().collect()
131 }
132
133 pub fn get_all_model_cards(&self) -> Vec<DiscoveryInstance> {
135 self.model_cards.values().cloned().collect()
136 }
137
138 pub fn get_all_event_channels(&self) -> Vec<DiscoveryInstance> {
140 self.event_channels.values().cloned().collect()
141 }
142
143 pub fn get_all(&self) -> Vec<DiscoveryInstance> {
145 self.endpoints
146 .values()
147 .chain(self.model_cards.values())
148 .chain(self.event_channels.values())
149 .cloned()
150 .collect()
151 }
152
153 pub fn filter(&self, query: &DiscoveryQuery) -> Vec<DiscoveryInstance> {
155 let all_instances = match query {
156 DiscoveryQuery::AllEndpoints
157 | DiscoveryQuery::NamespacedEndpoints { .. }
158 | DiscoveryQuery::ComponentEndpoints { .. }
159 | DiscoveryQuery::Endpoint { .. } => self.get_all_endpoints(),
160
161 DiscoveryQuery::AllModels
162 | DiscoveryQuery::NamespacedModels { .. }
163 | DiscoveryQuery::ComponentModels { .. }
164 | DiscoveryQuery::EndpointModels { .. } => self.get_all_model_cards(),
165
166 DiscoveryQuery::EventChannels(_) => self.get_all_event_channels(),
168 };
169
170 filter_instances(all_instances, query)
171 }
172}
173
174impl Default for DiscoveryMetadata {
175 fn default() -> Self {
176 Self::new()
177 }
178}
179
180fn filter_instances(
182 instances: Vec<DiscoveryInstance>,
183 query: &DiscoveryQuery,
184) -> Vec<DiscoveryInstance> {
185 match query {
186 DiscoveryQuery::AllEndpoints | DiscoveryQuery::AllModels => instances,
187
188 DiscoveryQuery::NamespacedEndpoints { namespace } => instances
189 .into_iter()
190 .filter(|inst| match inst {
191 DiscoveryInstance::Endpoint(i) => &i.namespace == namespace,
192 _ => false,
193 })
194 .collect(),
195
196 DiscoveryQuery::ComponentEndpoints {
197 namespace,
198 component,
199 } => instances
200 .into_iter()
201 .filter(|inst| match inst {
202 DiscoveryInstance::Endpoint(i) => {
203 &i.namespace == namespace && &i.component == component
204 }
205 _ => false,
206 })
207 .collect(),
208
209 DiscoveryQuery::Endpoint {
210 namespace,
211 component,
212 endpoint,
213 } => instances
214 .into_iter()
215 .filter(|inst| match inst {
216 DiscoveryInstance::Endpoint(i) => {
217 &i.namespace == namespace
218 && &i.component == component
219 && &i.endpoint == endpoint
220 }
221 _ => false,
222 })
223 .collect(),
224
225 DiscoveryQuery::NamespacedModels { namespace } => instances
226 .into_iter()
227 .filter(|inst| match inst {
228 DiscoveryInstance::Model { namespace: ns, .. } => ns == namespace,
229 _ => false,
230 })
231 .collect(),
232
233 DiscoveryQuery::ComponentModels {
234 namespace,
235 component,
236 } => instances
237 .into_iter()
238 .filter(|inst| match inst {
239 DiscoveryInstance::Model {
240 namespace: ns,
241 component: comp,
242 ..
243 } => ns == namespace && comp == component,
244 _ => false,
245 })
246 .collect(),
247
248 DiscoveryQuery::EndpointModels {
249 namespace,
250 component,
251 endpoint,
252 } => instances
253 .into_iter()
254 .filter(|inst| match inst {
255 DiscoveryInstance::Model {
256 namespace: ns,
257 component: comp,
258 endpoint: ep,
259 ..
260 } => ns == namespace && comp == component && ep == endpoint,
261 _ => false,
262 })
263 .collect(),
264
265 DiscoveryQuery::EventChannels(query) => instances
267 .into_iter()
268 .filter(|inst| match inst {
269 DiscoveryInstance::EventChannel {
270 namespace: ns,
271 component: comp,
272 topic: t,
273 ..
274 } => {
275 query.namespace.as_ref().is_none_or(|qns| qns == ns)
277 && query.component.as_ref().is_none_or(|qc| qc == comp)
279 && query.topic.as_ref().is_none_or(|qt| qt == t)
281 }
282 _ => false,
283 })
284 .collect(),
285 }
286}
287
288#[derive(Clone, Debug)]
290pub struct MetadataSnapshot {
291 pub instances: HashMap<u64, Arc<DiscoveryMetadata>>,
293 pub generations: HashMap<u64, i64>,
296 pub sequence: u64,
298 pub timestamp: std::time::Instant,
300}
301
302impl MetadataSnapshot {
303 pub fn empty() -> Self {
304 Self {
305 instances: HashMap::new(),
306 generations: HashMap::new(),
307 sequence: 0,
308 timestamp: std::time::Instant::now(),
309 }
310 }
311
312 pub fn has_changes_from(&self, prev: &MetadataSnapshot) -> bool {
316 if self.generations == prev.generations {
317 tracing::trace!(
318 "Snapshot (seq={}): no changes, {} instances",
319 self.sequence,
320 self.instances.len()
321 );
322 return false;
323 }
324
325 let curr_ids: HashSet<u64> = self.generations.keys().copied().collect();
327 let prev_ids: HashSet<u64> = prev.generations.keys().copied().collect();
328
329 let added: Vec<_> = curr_ids
330 .difference(&prev_ids)
331 .map(|id| format!("{:x}", id))
332 .collect();
333 let removed: Vec<_> = prev_ids
334 .difference(&curr_ids)
335 .map(|id| format!("{:x}", id))
336 .collect();
337 let updated: Vec<_> = self
338 .generations
339 .iter()
340 .filter(|(k, v)| prev.generations.get(*k).is_some_and(|pv| pv != *v))
341 .map(|(k, _)| format!("{:x}", k))
342 .collect();
343
344 tracing::info!(
345 "Snapshot (seq={}): {} instances, added={:?}, removed={:?}, updated={:?}",
346 self.sequence,
347 self.instances.len(),
348 added,
349 removed,
350 updated
351 );
352
353 true
354 }
355
356 pub fn filter(&self, query: &DiscoveryQuery) -> Vec<DiscoveryInstance> {
358 self.instances
359 .values()
360 .flat_map(|metadata| metadata.filter(query))
361 .collect()
362 }
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368 use crate::component::{Instance, TransportType};
369 use crate::discovery::EventChannelQuery;
370
371 #[test]
372 fn test_metadata_serde() {
373 let mut metadata = DiscoveryMetadata::new();
374
375 let instance = DiscoveryInstance::Endpoint(Instance {
377 namespace: "test".to_string(),
378 component: "comp1".to_string(),
379 endpoint: "ep1".to_string(),
380 instance_id: 123,
381 transport: TransportType::Nats("nats://localhost:4222".to_string()),
382 });
383
384 metadata.register_endpoint(instance).unwrap();
385
386 let json = serde_json::to_string(&metadata).unwrap();
388
389 let deserialized: DiscoveryMetadata = serde_json::from_str(&json).unwrap();
391
392 assert_eq!(deserialized.endpoints.len(), 1);
393 assert_eq!(deserialized.model_cards.len(), 0);
394 }
395
396 #[tokio::test]
397 async fn test_concurrent_registration() {
398 use tokio::sync::RwLock;
399
400 let metadata = Arc::new(RwLock::new(DiscoveryMetadata::new()));
401
402 let handles: Vec<_> = (0..10)
404 .map(|i| {
405 let metadata = metadata.clone();
406 tokio::spawn(async move {
407 let mut meta = metadata.write().await;
408 let instance = DiscoveryInstance::Endpoint(Instance {
409 namespace: "test".to_string(),
410 component: "comp1".to_string(),
411 endpoint: format!("ep{}", i),
412 instance_id: i,
413 transport: TransportType::Nats("nats://localhost:4222".to_string()),
414 });
415 meta.register_endpoint(instance).unwrap();
416 })
417 })
418 .collect();
419
420 for handle in handles {
422 handle.await.unwrap();
423 }
424
425 let meta = metadata.read().await;
427 assert_eq!(meta.endpoints.len(), 10);
428 }
429
430 #[tokio::test]
431 async fn test_metadata_accessors() {
432 let mut metadata = DiscoveryMetadata::new();
433
434 for i in 0..3 {
436 let instance = DiscoveryInstance::Endpoint(Instance {
437 namespace: "test".to_string(),
438 component: "comp1".to_string(),
439 endpoint: format!("ep{}", i),
440 instance_id: i,
441 transport: TransportType::Nats("nats://localhost:4222".to_string()),
442 });
443 metadata.register_endpoint(instance).unwrap();
444 }
445
446 for i in 0..2 {
448 let instance = DiscoveryInstance::Model {
449 namespace: "test".to_string(),
450 component: "comp1".to_string(),
451 endpoint: format!("ep{}", i),
452 instance_id: i,
453 card_json: serde_json::json!({"model": "test"}),
454 model_suffix: None,
455 };
456 metadata.register_model_card(instance).unwrap();
457 }
458
459 assert_eq!(metadata.get_all_endpoints().len(), 3);
460 assert_eq!(metadata.get_all_model_cards().len(), 2);
461 assert_eq!(metadata.get_all().len(), 5);
462 }
463
464 #[tokio::test]
465 async fn test_event_channel_registration() {
466 use crate::discovery::EventTransport;
467
468 let mut metadata = DiscoveryMetadata::new();
469
470 for i in 0..3 {
472 let instance = DiscoveryInstance::EventChannel {
473 namespace: "test".to_string(),
474 component: "comp1".to_string(),
475 topic: "test-topic".to_string(),
476 instance_id: i,
477 transport: EventTransport::zmq(format!("tcp://localhost:{}", 5000 + i)),
478 };
479 metadata.register_event_channel(instance).unwrap();
480 }
481
482 assert_eq!(metadata.get_all_event_channels().len(), 3);
484
485 assert_eq!(metadata.get_all().len(), 3);
487
488 let filtered = metadata.filter(&DiscoveryQuery::EventChannels(EventChannelQuery::all()));
490 assert_eq!(filtered.len(), 3);
491
492 let filtered = metadata.filter(&DiscoveryQuery::EventChannels(
494 EventChannelQuery::component("test", "comp1"),
495 ));
496 assert_eq!(filtered.len(), 3);
497
498 let filtered = metadata.filter(&DiscoveryQuery::EventChannels(
500 EventChannelQuery::component("other", "comp1"),
501 ));
502 assert_eq!(filtered.len(), 0);
503
504 let instance = DiscoveryInstance::EventChannel {
506 namespace: "test".to_string(),
507 component: "comp1".to_string(),
508 topic: "test-topic".to_string(),
509 instance_id: 0,
510 transport: EventTransport::zmq("tcp://localhost:5000"),
511 };
512 metadata.unregister_event_channel(&instance).unwrap();
513 assert_eq!(metadata.get_all_event_channels().len(), 2);
514 }
515
516 #[tokio::test]
517 async fn test_mixed_instances() {
518 use crate::discovery::EventTransport;
519
520 let mut metadata = DiscoveryMetadata::new();
521
522 let endpoint = DiscoveryInstance::Endpoint(Instance {
524 namespace: "test".to_string(),
525 component: "comp1".to_string(),
526 endpoint: "ep1".to_string(),
527 instance_id: 1,
528 transport: TransportType::Nats("nats://localhost:4222".to_string()),
529 });
530 metadata.register_endpoint(endpoint).unwrap();
531
532 let model = DiscoveryInstance::Model {
533 namespace: "test".to_string(),
534 component: "comp1".to_string(),
535 endpoint: "ep1".to_string(),
536 instance_id: 2,
537 card_json: serde_json::json!({"model": "test"}),
538 model_suffix: None,
539 };
540 metadata.register_model_card(model).unwrap();
541
542 let event_channel = DiscoveryInstance::EventChannel {
543 namespace: "test".to_string(),
544 component: "comp1".to_string(),
545 topic: "test-topic".to_string(),
546 instance_id: 3,
547 transport: EventTransport::zmq("tcp://localhost:5000"),
548 };
549 metadata.register_event_channel(event_channel).unwrap();
550
551 assert_eq!(metadata.get_all().len(), 3);
553 assert_eq!(metadata.get_all_endpoints().len(), 1);
554 assert_eq!(metadata.get_all_model_cards().len(), 1);
555 assert_eq!(metadata.get_all_event_channels().len(), 1);
556 }
557}