use anyhow::Result;
use std::collections::HashMap;
use std::sync::Arc;
use super::{DiscoveryInstance, DiscoveryQuery};
fn make_endpoint_key(namespace: &str, component: &str, endpoint: &str) -> String {
format!("{namespace}/{component}/{endpoint}")
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct DiscoveryMetadata {
endpoints: HashMap<String, DiscoveryInstance>,
model_cards: HashMap<String, DiscoveryInstance>,
}
impl DiscoveryMetadata {
pub fn new() -> Self {
Self {
endpoints: HashMap::new(),
model_cards: HashMap::new(),
}
}
pub fn register_endpoint(&mut self, instance: DiscoveryInstance) -> Result<()> {
if let DiscoveryInstance::Endpoint(ref inst) = instance {
let key = make_endpoint_key(&inst.namespace, &inst.component, &inst.endpoint);
self.endpoints.insert(key, instance);
Ok(())
} else {
anyhow::bail!("Cannot register non-endpoint instance as endpoint")
}
}
pub fn register_model_card(&mut self, instance: DiscoveryInstance) -> Result<()> {
if let DiscoveryInstance::Model {
ref namespace,
ref component,
ref endpoint,
..
} = instance
{
let key = make_endpoint_key(namespace, component, endpoint);
self.model_cards.insert(key, instance);
Ok(())
} else {
anyhow::bail!("Cannot register non-model-card instance as model card")
}
}
pub fn get_all_endpoints(&self) -> Vec<DiscoveryInstance> {
self.endpoints.values().cloned().collect()
}
pub fn get_all_model_cards(&self) -> Vec<DiscoveryInstance> {
self.model_cards.values().cloned().collect()
}
pub fn get_all(&self) -> Vec<DiscoveryInstance> {
self.endpoints
.values()
.chain(self.model_cards.values())
.cloned()
.collect()
}
pub fn filter(&self, query: &DiscoveryQuery) -> Vec<DiscoveryInstance> {
let all_instances = match query {
DiscoveryQuery::AllEndpoints
| DiscoveryQuery::NamespacedEndpoints { .. }
| DiscoveryQuery::ComponentEndpoints { .. }
| DiscoveryQuery::Endpoint { .. } => self.get_all_endpoints(),
DiscoveryQuery::AllModels
| DiscoveryQuery::NamespacedModels { .. }
| DiscoveryQuery::ComponentModels { .. }
| DiscoveryQuery::EndpointModels { .. } => self.get_all_model_cards(),
};
filter_instances(all_instances, query)
}
}
impl Default for DiscoveryMetadata {
fn default() -> Self {
Self::new()
}
}
fn filter_instances(
instances: Vec<DiscoveryInstance>,
query: &DiscoveryQuery,
) -> Vec<DiscoveryInstance> {
match query {
DiscoveryQuery::AllEndpoints | DiscoveryQuery::AllModels => instances,
DiscoveryQuery::NamespacedEndpoints { namespace } => instances
.into_iter()
.filter(|inst| match inst {
DiscoveryInstance::Endpoint(i) => &i.namespace == namespace,
_ => false,
})
.collect(),
DiscoveryQuery::ComponentEndpoints {
namespace,
component,
} => instances
.into_iter()
.filter(|inst| match inst {
DiscoveryInstance::Endpoint(i) => {
&i.namespace == namespace && &i.component == component
}
_ => false,
})
.collect(),
DiscoveryQuery::Endpoint {
namespace,
component,
endpoint,
} => instances
.into_iter()
.filter(|inst| match inst {
DiscoveryInstance::Endpoint(i) => {
&i.namespace == namespace
&& &i.component == component
&& &i.endpoint == endpoint
}
_ => false,
})
.collect(),
DiscoveryQuery::NamespacedModels { namespace } => instances
.into_iter()
.filter(|inst| match inst {
DiscoveryInstance::Model { namespace: ns, .. } => ns == namespace,
_ => false,
})
.collect(),
DiscoveryQuery::ComponentModels {
namespace,
component,
} => instances
.into_iter()
.filter(|inst| match inst {
DiscoveryInstance::Model {
namespace: ns,
component: comp,
..
} => ns == namespace && comp == component,
_ => false,
})
.collect(),
DiscoveryQuery::EndpointModels {
namespace,
component,
endpoint,
} => instances
.into_iter()
.filter(|inst| match inst {
DiscoveryInstance::Model {
namespace: ns,
component: comp,
endpoint: ep,
..
} => ns == namespace && comp == component && ep == endpoint,
_ => false,
})
.collect(),
}
}
#[derive(Clone, Debug)]
pub struct MetadataSnapshot {
pub instances: HashMap<u64, Arc<DiscoveryMetadata>>,
pub sequence: u64,
pub timestamp: std::time::Instant,
}
impl MetadataSnapshot {
pub fn empty() -> Self {
Self {
instances: HashMap::new(),
sequence: 0,
timestamp: std::time::Instant::now(),
}
}
pub fn filter(&self, query: &DiscoveryQuery) -> Vec<DiscoveryInstance> {
self.instances
.values()
.flat_map(|metadata| metadata.filter(query))
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::component::{Instance, TransportType};
#[test]
fn test_metadata_serde() {
let mut metadata = DiscoveryMetadata::new();
let instance = DiscoveryInstance::Endpoint(Instance {
namespace: "test".to_string(),
component: "comp1".to_string(),
endpoint: "ep1".to_string(),
instance_id: 123,
transport: TransportType::Nats("nats://localhost:4222".to_string()),
});
metadata.register_endpoint(instance).unwrap();
let json = serde_json::to_string(&metadata).unwrap();
let deserialized: DiscoveryMetadata = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.endpoints.len(), 1);
assert_eq!(deserialized.model_cards.len(), 0);
}
#[tokio::test]
async fn test_concurrent_registration() {
use tokio::sync::RwLock;
let metadata = Arc::new(RwLock::new(DiscoveryMetadata::new()));
let handles: Vec<_> = (0..10)
.map(|i| {
let metadata = metadata.clone();
tokio::spawn(async move {
let mut meta = metadata.write().await;
let instance = DiscoveryInstance::Endpoint(Instance {
namespace: "test".to_string(),
component: "comp1".to_string(),
endpoint: format!("ep{}", i),
instance_id: i,
transport: TransportType::Nats("nats://localhost:4222".to_string()),
});
meta.register_endpoint(instance).unwrap();
})
})
.collect();
for handle in handles {
handle.await.unwrap();
}
let meta = metadata.read().await;
assert_eq!(meta.endpoints.len(), 10);
}
#[tokio::test]
async fn test_metadata_accessors() {
let mut metadata = DiscoveryMetadata::new();
for i in 0..3 {
let instance = DiscoveryInstance::Endpoint(Instance {
namespace: "test".to_string(),
component: "comp1".to_string(),
endpoint: format!("ep{}", i),
instance_id: i,
transport: TransportType::Nats("nats://localhost:4222".to_string()),
});
metadata.register_endpoint(instance).unwrap();
}
for i in 0..2 {
let instance = DiscoveryInstance::Model {
namespace: "test".to_string(),
component: "comp1".to_string(),
endpoint: format!("ep{}", i),
instance_id: i,
card_json: serde_json::json!({"model": "test"}),
};
metadata.register_model_card(instance).unwrap();
}
assert_eq!(metadata.get_all_endpoints().len(), 3);
assert_eq!(metadata.get_all_model_cards().len(), 2);
assert_eq!(metadata.get_all().len(), 5);
}
}