use anyhow::Result;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use super::{DiscoveryInstance, DiscoveryInstanceId, DiscoveryQuery};
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct DiscoveryMetadata {
endpoints: HashMap<String, DiscoveryInstance>,
model_cards: HashMap<String, DiscoveryInstance>,
event_channels: HashMap<String, DiscoveryInstance>,
}
impl DiscoveryMetadata {
pub fn new() -> Self {
Self {
endpoints: HashMap::new(),
model_cards: HashMap::new(),
event_channels: HashMap::new(),
}
}
pub fn register_endpoint(&mut self, instance: DiscoveryInstance) -> Result<()> {
match instance.id() {
DiscoveryInstanceId::Endpoint(key) => {
self.endpoints.insert(key.to_path(), instance);
Ok(())
}
DiscoveryInstanceId::Model(_) => {
anyhow::bail!("Cannot register non-endpoint instance as endpoint")
}
DiscoveryInstanceId::EventChannel(_) => {
anyhow::bail!("Cannot register EventChannel instance as endpoint")
}
}
}
pub fn register_model_card(&mut self, instance: DiscoveryInstance) -> Result<()> {
match instance.id() {
DiscoveryInstanceId::Model(key) => {
self.model_cards.insert(key.to_path(), instance);
Ok(())
}
DiscoveryInstanceId::Endpoint(_) => {
anyhow::bail!("Cannot register non-model-card instance as model card")
}
DiscoveryInstanceId::EventChannel(_) => {
anyhow::bail!("Cannot register EventChannel instance as model card")
}
}
}
pub fn unregister_endpoint(&mut self, instance: &DiscoveryInstance) -> Result<()> {
match instance.id() {
DiscoveryInstanceId::Endpoint(key) => {
self.endpoints.remove(&key.to_path());
Ok(())
}
DiscoveryInstanceId::Model(_) => {
anyhow::bail!("Cannot unregister non-endpoint instance as endpoint")
}
DiscoveryInstanceId::EventChannel(_) => {
anyhow::bail!("Cannot unregister EventChannel instance as endpoint")
}
}
}
pub fn unregister_model_card(&mut self, instance: &DiscoveryInstance) -> Result<()> {
match instance.id() {
DiscoveryInstanceId::Model(key) => {
self.model_cards.remove(&key.to_path());
Ok(())
}
DiscoveryInstanceId::Endpoint(_) => {
anyhow::bail!("Cannot unregister non-model-card instance as model card")
}
DiscoveryInstanceId::EventChannel(_) => {
anyhow::bail!("Cannot unregister EventChannel instance as model card")
}
}
}
pub fn register_event_channel(&mut self, instance: DiscoveryInstance) -> Result<()> {
match instance.id() {
DiscoveryInstanceId::EventChannel(key) => {
self.event_channels.insert(key.to_path(), instance);
Ok(())
}
DiscoveryInstanceId::Endpoint(_) => {
anyhow::bail!("Cannot register Endpoint instance as event channel")
}
DiscoveryInstanceId::Model(_) => {
anyhow::bail!("Cannot register Model instance as event channel")
}
}
}
pub fn unregister_event_channel(&mut self, instance: &DiscoveryInstance) -> Result<()> {
match instance.id() {
DiscoveryInstanceId::EventChannel(key) => {
self.event_channels.remove(&key.to_path());
Ok(())
}
DiscoveryInstanceId::Endpoint(_) => {
anyhow::bail!("Cannot unregister Endpoint instance as event channel")
}
DiscoveryInstanceId::Model(_) => {
anyhow::bail!("Cannot unregister Model instance as event channel")
}
}
}
pub fn get_all_endpoints(&self) -> Vec<DiscoveryInstance> {
self.endpoints.values().cloned().collect()
}
pub fn get_all_model_cards(&self) -> Vec<DiscoveryInstance> {
self.model_cards.values().cloned().collect()
}
pub fn get_all_event_channels(&self) -> Vec<DiscoveryInstance> {
self.event_channels.values().cloned().collect()
}
pub fn get_all(&self) -> Vec<DiscoveryInstance> {
self.endpoints
.values()
.chain(self.model_cards.values())
.chain(self.event_channels.values())
.cloned()
.collect()
}
pub fn filter(&self, query: &DiscoveryQuery) -> Vec<DiscoveryInstance> {
let all_instances = match query {
DiscoveryQuery::AllEndpoints
| DiscoveryQuery::NamespacedEndpoints { .. }
| DiscoveryQuery::ComponentEndpoints { .. }
| DiscoveryQuery::Endpoint { .. } => self.get_all_endpoints(),
DiscoveryQuery::AllModels
| DiscoveryQuery::NamespacedModels { .. }
| DiscoveryQuery::ComponentModels { .. }
| DiscoveryQuery::EndpointModels { .. } => self.get_all_model_cards(),
DiscoveryQuery::EventChannels(_) => self.get_all_event_channels(),
};
filter_instances(all_instances, query)
}
}
impl Default for DiscoveryMetadata {
fn default() -> Self {
Self::new()
}
}
fn filter_instances(
instances: Vec<DiscoveryInstance>,
query: &DiscoveryQuery,
) -> Vec<DiscoveryInstance> {
match query {
DiscoveryQuery::AllEndpoints | DiscoveryQuery::AllModels => instances,
DiscoveryQuery::NamespacedEndpoints { namespace } => instances
.into_iter()
.filter(|inst| match inst {
DiscoveryInstance::Endpoint(i) => &i.namespace == namespace,
_ => false,
})
.collect(),
DiscoveryQuery::ComponentEndpoints {
namespace,
component,
} => instances
.into_iter()
.filter(|inst| match inst {
DiscoveryInstance::Endpoint(i) => {
&i.namespace == namespace && &i.component == component
}
_ => false,
})
.collect(),
DiscoveryQuery::Endpoint {
namespace,
component,
endpoint,
} => instances
.into_iter()
.filter(|inst| match inst {
DiscoveryInstance::Endpoint(i) => {
&i.namespace == namespace
&& &i.component == component
&& &i.endpoint == endpoint
}
_ => false,
})
.collect(),
DiscoveryQuery::NamespacedModels { namespace } => instances
.into_iter()
.filter(|inst| match inst {
DiscoveryInstance::Model { namespace: ns, .. } => ns == namespace,
_ => false,
})
.collect(),
DiscoveryQuery::ComponentModels {
namespace,
component,
} => instances
.into_iter()
.filter(|inst| match inst {
DiscoveryInstance::Model {
namespace: ns,
component: comp,
..
} => ns == namespace && comp == component,
_ => false,
})
.collect(),
DiscoveryQuery::EndpointModels {
namespace,
component,
endpoint,
} => instances
.into_iter()
.filter(|inst| match inst {
DiscoveryInstance::Model {
namespace: ns,
component: comp,
endpoint: ep,
..
} => ns == namespace && comp == component && ep == endpoint,
_ => false,
})
.collect(),
DiscoveryQuery::EventChannels(query) => instances
.into_iter()
.filter(|inst| match inst {
DiscoveryInstance::EventChannel {
namespace: ns,
component: comp,
topic: t,
..
} => {
query.namespace.as_ref().is_none_or(|qns| qns == ns)
&& query.component.as_ref().is_none_or(|qc| qc == comp)
&& query.topic.as_ref().is_none_or(|qt| qt == t)
}
_ => false,
})
.collect(),
}
}
#[derive(Clone, Debug)]
pub struct MetadataSnapshot {
pub instances: HashMap<u64, Arc<DiscoveryMetadata>>,
pub generations: HashMap<u64, i64>,
pub sequence: u64,
pub timestamp: std::time::Instant,
}
impl MetadataSnapshot {
pub fn empty() -> Self {
Self {
instances: HashMap::new(),
generations: HashMap::new(),
sequence: 0,
timestamp: std::time::Instant::now(),
}
}
pub fn has_changes_from(&self, prev: &MetadataSnapshot) -> bool {
if self.generations == prev.generations {
tracing::trace!(
"Snapshot (seq={}): no changes, {} instances",
self.sequence,
self.instances.len()
);
return false;
}
let curr_ids: HashSet<u64> = self.generations.keys().copied().collect();
let prev_ids: HashSet<u64> = prev.generations.keys().copied().collect();
let added: Vec<_> = curr_ids
.difference(&prev_ids)
.map(|id| format!("{:x}", id))
.collect();
let removed: Vec<_> = prev_ids
.difference(&curr_ids)
.map(|id| format!("{:x}", id))
.collect();
let updated: Vec<_> = self
.generations
.iter()
.filter(|(k, v)| prev.generations.get(*k).is_some_and(|pv| pv != *v))
.map(|(k, _)| format!("{:x}", k))
.collect();
tracing::info!(
"Snapshot (seq={}): {} instances, added={:?}, removed={:?}, updated={:?}",
self.sequence,
self.instances.len(),
added,
removed,
updated
);
true
}
pub fn filter(&self, query: &DiscoveryQuery) -> Vec<DiscoveryInstance> {
self.instances
.values()
.flat_map(|metadata| metadata.filter(query))
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::component::{Instance, TransportType};
use crate::discovery::EventChannelQuery;
#[test]
fn test_metadata_serde() {
let mut metadata = DiscoveryMetadata::new();
let instance = DiscoveryInstance::Endpoint(Instance {
namespace: "test".to_string(),
component: "comp1".to_string(),
endpoint: "ep1".to_string(),
instance_id: 123,
transport: TransportType::Nats("nats://localhost:4222".to_string()),
});
metadata.register_endpoint(instance).unwrap();
let json = serde_json::to_string(&metadata).unwrap();
let deserialized: DiscoveryMetadata = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.endpoints.len(), 1);
assert_eq!(deserialized.model_cards.len(), 0);
}
#[tokio::test]
async fn test_concurrent_registration() {
use tokio::sync::RwLock;
let metadata = Arc::new(RwLock::new(DiscoveryMetadata::new()));
let handles: Vec<_> = (0..10)
.map(|i| {
let metadata = metadata.clone();
tokio::spawn(async move {
let mut meta = metadata.write().await;
let instance = DiscoveryInstance::Endpoint(Instance {
namespace: "test".to_string(),
component: "comp1".to_string(),
endpoint: format!("ep{}", i),
instance_id: i,
transport: TransportType::Nats("nats://localhost:4222".to_string()),
});
meta.register_endpoint(instance).unwrap();
})
})
.collect();
for handle in handles {
handle.await.unwrap();
}
let meta = metadata.read().await;
assert_eq!(meta.endpoints.len(), 10);
}
#[tokio::test]
async fn test_metadata_accessors() {
let mut metadata = DiscoveryMetadata::new();
for i in 0..3 {
let instance = DiscoveryInstance::Endpoint(Instance {
namespace: "test".to_string(),
component: "comp1".to_string(),
endpoint: format!("ep{}", i),
instance_id: i,
transport: TransportType::Nats("nats://localhost:4222".to_string()),
});
metadata.register_endpoint(instance).unwrap();
}
for i in 0..2 {
let instance = DiscoveryInstance::Model {
namespace: "test".to_string(),
component: "comp1".to_string(),
endpoint: format!("ep{}", i),
instance_id: i,
card_json: serde_json::json!({"model": "test"}),
model_suffix: None,
};
metadata.register_model_card(instance).unwrap();
}
assert_eq!(metadata.get_all_endpoints().len(), 3);
assert_eq!(metadata.get_all_model_cards().len(), 2);
assert_eq!(metadata.get_all().len(), 5);
}
#[tokio::test]
async fn test_event_channel_registration() {
use crate::discovery::EventTransport;
let mut metadata = DiscoveryMetadata::new();
for i in 0..3 {
let instance = DiscoveryInstance::EventChannel {
namespace: "test".to_string(),
component: "comp1".to_string(),
topic: "test-topic".to_string(),
instance_id: i,
transport: EventTransport::zmq(format!("tcp://localhost:{}", 5000 + i)),
};
metadata.register_event_channel(instance).unwrap();
}
assert_eq!(metadata.get_all_event_channels().len(), 3);
assert_eq!(metadata.get_all().len(), 3);
let filtered = metadata.filter(&DiscoveryQuery::EventChannels(EventChannelQuery::all()));
assert_eq!(filtered.len(), 3);
let filtered = metadata.filter(&DiscoveryQuery::EventChannels(
EventChannelQuery::component("test", "comp1"),
));
assert_eq!(filtered.len(), 3);
let filtered = metadata.filter(&DiscoveryQuery::EventChannels(
EventChannelQuery::component("other", "comp1"),
));
assert_eq!(filtered.len(), 0);
let instance = DiscoveryInstance::EventChannel {
namespace: "test".to_string(),
component: "comp1".to_string(),
topic: "test-topic".to_string(),
instance_id: 0,
transport: EventTransport::zmq("tcp://localhost:5000"),
};
metadata.unregister_event_channel(&instance).unwrap();
assert_eq!(metadata.get_all_event_channels().len(), 2);
}
#[tokio::test]
async fn test_mixed_instances() {
use crate::discovery::EventTransport;
let mut metadata = DiscoveryMetadata::new();
let endpoint = DiscoveryInstance::Endpoint(Instance {
namespace: "test".to_string(),
component: "comp1".to_string(),
endpoint: "ep1".to_string(),
instance_id: 1,
transport: TransportType::Nats("nats://localhost:4222".to_string()),
});
metadata.register_endpoint(endpoint).unwrap();
let model = DiscoveryInstance::Model {
namespace: "test".to_string(),
component: "comp1".to_string(),
endpoint: "ep1".to_string(),
instance_id: 2,
card_json: serde_json::json!({"model": "test"}),
model_suffix: None,
};
metadata.register_model_card(model).unwrap();
let event_channel = DiscoveryInstance::EventChannel {
namespace: "test".to_string(),
component: "comp1".to_string(),
topic: "test-topic".to_string(),
instance_id: 3,
transport: EventTransport::zmq("tcp://localhost:5000"),
};
metadata.register_event_channel(event_channel).unwrap();
assert_eq!(metadata.get_all().len(), 3);
assert_eq!(metadata.get_all_endpoints().len(), 1);
assert_eq!(metadata.get_all_model_cards().len(), 1);
assert_eq!(metadata.get_all_event_channels().len(), 1);
}
}