use super::*;
use crate::SystemHealth;
use crate::config::HealthStatus;
use crate::pipeline::network::ingress::push_endpoint::PushEndpoint;
use anyhow::Result;
use async_trait::async_trait;
use dashmap::DashMap;
use parking_lot::Mutex;
use std::sync::Arc;
use tokio_util::sync::CancellationToken;
pub struct NatsMultiplexedServer {
nats_client: async_nats::Client,
component_registry: crate::component::Registry,
handlers: Arc<DashMap<String, EndpointTask>>,
cancellation_token: CancellationToken,
}
struct EndpointTask {
cancel_token: CancellationToken,
_endpoint_name: String,
}
impl NatsMultiplexedServer {
pub fn new(
nats_client: async_nats::Client,
component_registry: crate::component::Registry,
cancellation_token: CancellationToken,
) -> Arc<Self> {
Arc::new(Self {
nats_client,
component_registry,
handlers: Arc::new(DashMap::new()),
cancellation_token,
})
}
}
#[async_trait]
impl super::unified_server::RequestPlaneServer for NatsMultiplexedServer {
async fn register_endpoint(
&self,
endpoint_name: String,
service_handler: Arc<dyn PushWorkHandler>,
instance_id: u64,
namespace: String,
component_name: String,
system_health: Arc<Mutex<SystemHealth>>,
) -> Result<()> {
tracing::info!(
endpoint_name = %endpoint_name,
namespace = %namespace,
component = %component_name,
instance_id = instance_id,
"NatsMultiplexedServer::register_endpoint called"
);
use crate::transports::nats::Slug;
let service_name_raw = format!("{}_{}", namespace, component_name);
let service_name = Slug::slugify(&service_name_raw).to_string();
tracing::debug!(
service_name_raw = %service_name_raw,
service_name = %service_name,
"Looking up service group in registry"
);
let registry = self.component_registry.inner.lock().await;
let service_group = registry
.services
.get(&service_name)
.map(|service| service.group(&service_name))
.ok_or_else(|| anyhow::anyhow!("Service '{}' not found in registry", service_name))?;
drop(registry);
tracing::info!("Successfully retrieved service group");
let endpoint_with_id = if self.component_registry.is_static() {
endpoint_name.clone()
} else {
format!("{}-{:x}", endpoint_name, instance_id)
};
let service_endpoint = service_group
.endpoint(&endpoint_with_id)
.await
.map_err(|e| {
anyhow::anyhow!(
"Failed to create NATS endpoint '{}': {}",
endpoint_with_id,
e
)
})?;
tracing::info!(
endpoint_name = %endpoint_name,
endpoint_with_id = %endpoint_with_id,
namespace = %namespace,
component = %component_name,
instance_id = instance_id,
"Registering NATS endpoint"
);
let endpoint_cancel = CancellationToken::new();
let endpoint_cancel_clone = endpoint_cancel.clone();
let push_endpoint = PushEndpoint::builder()
.service_handler(service_handler)
.cancellation_token(endpoint_cancel_clone)
.graceful_shutdown(true)
.build()
.map_err(|e| anyhow::anyhow!("Failed to build NATS push endpoint: {}", e))?;
tracing::info!(
endpoint_name = %endpoint_name,
endpoint_with_id = %endpoint_with_id,
"Starting NATS push endpoint listener (blocking)"
);
let endpoint_name_clone = endpoint_name.clone();
tokio::spawn(async move {
if let Err(e) = push_endpoint
.start(
service_endpoint,
namespace,
component_name,
endpoint_name_clone.clone(),
instance_id,
system_health,
)
.await
{
tracing::error!(
endpoint_name = %endpoint_name_clone,
error = %e,
"NATS endpoint task failed"
);
} else {
tracing::info!(
endpoint_name = %endpoint_name_clone,
"NATS push endpoint listener completed"
);
}
});
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
self.handlers.insert(
endpoint_name.clone(),
EndpointTask {
cancel_token: endpoint_cancel,
_endpoint_name: endpoint_name,
},
);
Ok(())
}
async fn unregister_endpoint(&self, endpoint_name: &str) -> Result<()> {
if let Some((_, task)) = self.handlers.remove(endpoint_name) {
tracing::info!(
endpoint_name = %endpoint_name,
"Unregistering NATS endpoint"
);
task.cancel_token.cancel();
}
Ok(())
}
fn address(&self) -> String {
"nats://connected".to_string()
}
fn transport_name(&self) -> &'static str {
"nats"
}
fn is_healthy(&self) -> bool {
true
}
}