use crate::horizontal_adapter::{BroadcastMessage, RequestBody, ResponseBody};
use crate::horizontal_transport::{HorizontalTransport, TransportConfig, TransportHandlers};
use async_nats::{Client as NatsClient, ConnectOptions as NatsOptions, Subject};
use async_trait::async_trait;
use bytes::Bytes;
use futures::StreamExt;
use sockudo_core::error::{Error, Result};
use sockudo_core::metrics::MetricsInterface;
use sockudo_core::options::NatsAdapterConfig;
use std::sync::Arc;
use std::sync::OnceLock;
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
use std::time::Duration;
use tokio::sync::Notify;
use tracing::{debug, error, info, warn};
const NATS_SYS_SERVER_PING_SUBJECT: &str = "$SYS.REQ.SERVER.PING";
#[derive(Debug, Default)]
pub struct TransportMetrics {
pub messages_received: AtomicU64,
pub messages_processed: AtomicU64,
pub messages_dropped_parse_error: AtomicU64,
}
impl TransportMetrics {
pub fn new() -> Self {
Self::default()
}
#[inline]
pub fn record_received(&self) {
self.messages_received.fetch_add(1, Ordering::Relaxed);
}
#[inline]
pub fn record_processed(&self) {
self.messages_processed.fetch_add(1, Ordering::Relaxed);
}
#[inline]
pub fn record_parse_error(&self) {
self.messages_dropped_parse_error
.fetch_add(1, Ordering::Relaxed);
}
pub fn snapshot(&self) -> (u64, u64, u64) {
(
self.messages_received.load(Ordering::Relaxed),
self.messages_processed.load(Ordering::Relaxed),
self.messages_dropped_parse_error.load(Ordering::Relaxed),
)
}
}
pub struct NatsTransport {
client: NatsClient,
broadcast_subject: String,
request_subject: String,
response_subject: String,
config: NatsAdapterConfig,
metrics: Arc<TransportMetrics>,
metrics_driver: Arc<OnceLock<Arc<dyn MetricsInterface + Send + Sync>>>,
shutdown: Arc<Notify>,
is_running: Arc<AtomicBool>,
owner_count: Arc<AtomicUsize>,
}
impl NatsTransport {
pub fn get_metrics(&self) -> Arc<TransportMetrics> {
self.metrics.clone()
}
async fn discover_node_count_via_system_ping(&self) -> Result<Option<usize>> {
let reply_subject = self.client.new_inbox();
let mut responses = self
.client
.subscribe(reply_subject.clone())
.await
.map_err(|e| Error::Internal(format!("Failed to subscribe for NATS discovery: {e}")))?;
self.client
.publish_with_reply(
Subject::from(NATS_SYS_SERVER_PING_SUBJECT),
reply_subject,
Bytes::new(),
)
.await
.map_err(|e| {
Error::Internal(format!("Failed to publish NATS discovery request: {e}"))
})?;
let max_wait_ms = self.config.request_timeout_ms.clamp(
self.config.discovery_idle_wait_ms,
self.config.discovery_max_wait_ms,
);
let max_wait = Duration::from_millis(max_wait_ms);
let idle_wait = Duration::from_millis(self.config.discovery_idle_wait_ms);
let start = tokio::time::Instant::now();
let mut count = 0usize;
loop {
let elapsed = start.elapsed();
if elapsed >= max_wait {
break;
}
let remaining = max_wait.saturating_sub(elapsed);
let wait_for = if count == 0 {
remaining
} else {
remaining.min(idle_wait)
};
match tokio::time::timeout(wait_for, responses.next()).await {
Ok(Some(_message)) => count += 1,
Ok(None) | Err(_) => break,
}
}
if count > 0 {
debug!("Detected {} NATS server(s) via system ping", count);
Ok(Some(count))
} else {
debug!(
"NATS system discovery returned no responses; falling back to configured/default node count"
);
Ok(None)
}
}
}
impl TransportConfig for NatsAdapterConfig {
fn request_timeout_ms(&self) -> u64 {
self.request_timeout_ms
}
fn prefix(&self) -> &str {
&self.prefix
}
}
#[async_trait]
impl HorizontalTransport for NatsTransport {
type Config = NatsAdapterConfig;
async fn new(config: Self::Config) -> Result<Self> {
info!(
"NATS transport config: servers={:?}, prefix={}, request_timeout={}ms, connection_timeout={}ms",
config.servers, config.prefix, config.request_timeout_ms, config.connection_timeout_ms
);
debug!(
"NATS transport credentials: username={:?}, password={:?}, token={:?}",
config.username, config.password, config.token
);
let mut nats_options = NatsOptions::new();
if let (Some(username), Some(password)) =
(config.username.as_deref(), config.password.as_deref())
{
nats_options =
nats_options.user_and_password(username.to_string(), password.to_string());
} else if let Some(token) = config.token.as_deref() {
nats_options = nats_options.token(token.to_string());
}
nats_options =
nats_options.connection_timeout(Duration::from_millis(config.connection_timeout_ms));
let client = nats_options
.connect(&config.servers)
.await
.map_err(|e| Error::Internal(format!("Failed to connect to NATS: {e}")))?;
let broadcast_subject = format!("{}.broadcast", config.prefix);
let request_subject = format!("{}.requests", config.prefix);
let response_subject = format!("{}.responses", config.prefix);
Ok(Self {
client,
broadcast_subject,
request_subject,
response_subject,
config,
metrics: Arc::new(TransportMetrics::new()),
metrics_driver: Arc::new(OnceLock::new()),
shutdown: Arc::new(Notify::new()),
is_running: Arc::new(AtomicBool::new(true)),
owner_count: Arc::new(AtomicUsize::new(1)),
})
}
async fn publish_broadcast(&self, message: &BroadcastMessage) -> Result<()> {
let message_data = sonic_rs::to_vec(message)
.map_err(|e| Error::Other(format!("Failed to serialize broadcast message: {e}")))?;
self.client
.publish(
Subject::from(self.broadcast_subject.clone()),
message_data.into(),
)
.await
.map_err(|e| Error::Internal(format!("Failed to publish broadcast: {e}")))?;
debug!("Published broadcast message via NATS");
Ok(())
}
async fn publish_request(&self, request: &RequestBody) -> Result<()> {
let request_data = sonic_rs::to_vec(request)
.map_err(|e| Error::Other(format!("Failed to serialize request: {e}")))?;
self.client
.publish(
Subject::from(self.request_subject.clone()),
request_data.into(),
)
.await
.map_err(|e| Error::Internal(format!("Failed to publish request: {e}")))?;
debug!("Broadcasted request {} via NATS", request.request_id);
Ok(())
}
async fn publish_response(&self, response: &ResponseBody) -> Result<()> {
let response_data = sonic_rs::to_vec(response)
.map_err(|e| Error::Other(format!("Failed to serialize response: {e}")))?;
self.client
.publish(
Subject::from(self.response_subject.clone()),
response_data.into(),
)
.await
.map_err(|e| Error::Internal(format!("Failed to publish response: {e}")))?;
debug!("Published response via NATS");
Ok(())
}
async fn start_listeners(&self, handlers: TransportHandlers) -> Result<()> {
let client = self.client.clone();
let broadcast_subject = self.broadcast_subject.clone();
let request_subject = self.request_subject.clone();
let response_subject = self.response_subject.clone();
let response_client = self.client.clone();
let metrics_broadcast = self.metrics.clone();
let metrics_request = self.metrics.clone();
let metrics_response = self.metrics.clone();
let metrics_driver_broadcast = self.metrics_driver.clone();
let metrics_driver_request = self.metrics_driver.clone();
let metrics_driver_response = self.metrics_driver.clone();
let shutdown_broadcast = self.shutdown.clone();
let shutdown_request = self.shutdown.clone();
let shutdown_response = self.shutdown.clone();
let running_broadcast = self.is_running.clone();
let running_request = self.is_running.clone();
let running_response = self.is_running.clone();
let mut broadcast_subscription = client
.subscribe(Subject::from(broadcast_subject.clone()))
.await
.map_err(|e| {
Error::Internal(format!("Failed to subscribe to broadcast subject: {e}"))
})?;
let mut request_subscription = client
.subscribe(Subject::from(request_subject.clone()))
.await
.map_err(|e| Error::Internal(format!("Failed to subscribe to request subject: {e}")))?;
let mut response_subscription = client
.subscribe(Subject::from(response_subject.clone()))
.await
.map_err(|e| {
Error::Internal(format!("Failed to subscribe to response subject: {e}"))
})?;
info!(
"NATS transport listening on subjects: {}, {}, {}",
broadcast_subject, request_subject, response_subject
);
let broadcast_handler = handlers.on_broadcast.clone();
tokio::spawn(async move {
loop {
if !running_broadcast.load(Ordering::Relaxed) {
break;
}
let msg = tokio::select! {
_ = shutdown_broadcast.notified() => break,
msg = broadcast_subscription.next() => msg,
};
let Some(msg) = msg else {
break;
};
metrics_broadcast.record_received();
match sonic_rs::from_slice::<BroadcastMessage>(&msg.payload) {
Ok(broadcast) => {
broadcast_handler(broadcast).await;
metrics_broadcast.record_processed();
}
Err(e) => {
metrics_broadcast.record_parse_error();
if let Some(metrics) = metrics_driver_broadcast.get() {
metrics.mark_horizontal_transport_message_dropped("nats");
}
let payload_preview =
String::from_utf8_lossy(&msg.payload[..msg.payload.len().min(200)]);
error!(
"Failed to parse broadcast message: {} - payload preview: {}",
e, payload_preview
);
}
}
}
warn!("Broadcast subscription ended unexpectedly");
});
let request_handler = handlers.on_request.clone();
tokio::spawn(async move {
loop {
if !running_request.load(Ordering::Relaxed) {
break;
}
let msg = tokio::select! {
_ = shutdown_request.notified() => break,
msg = request_subscription.next() => msg,
};
let Some(msg) = msg else {
break;
};
metrics_request.record_received();
match sonic_rs::from_slice::<RequestBody>(&msg.payload) {
Ok(request) => {
let response_result = request_handler(request).await;
if let Ok(response) = response_result
&& let Ok(response_data) = sonic_rs::to_vec(&response)
&& let Err(e) = response_client
.publish(
Subject::from(response_subject.clone()),
response_data.into(),
)
.await
{
warn!("Failed to publish response: {}", e);
}
metrics_request.record_processed();
}
Err(e) => {
metrics_request.record_parse_error();
if let Some(metrics) = metrics_driver_request.get() {
metrics.mark_horizontal_transport_message_dropped("nats");
}
let payload_preview =
String::from_utf8_lossy(&msg.payload[..msg.payload.len().min(200)]);
error!(
"Failed to parse request message: {} - payload preview: {}",
e, payload_preview
);
}
}
}
warn!("Request subscription ended unexpectedly");
});
let response_handler = handlers.on_response.clone();
tokio::spawn(async move {
loop {
if !running_response.load(Ordering::Relaxed) {
break;
}
let msg = tokio::select! {
_ = shutdown_response.notified() => break,
msg = response_subscription.next() => msg,
};
let Some(msg) = msg else {
break;
};
metrics_response.record_received();
match sonic_rs::from_slice::<ResponseBody>(&msg.payload) {
Ok(response) => {
response_handler(response).await;
metrics_response.record_processed();
}
Err(e) => {
metrics_response.record_parse_error();
if let Some(metrics) = metrics_driver_response.get() {
metrics.mark_horizontal_transport_message_dropped("nats");
}
let payload_preview =
String::from_utf8_lossy(&msg.payload[..msg.payload.len().min(200)]);
error!(
"Failed to parse response message: {} - payload preview: {}",
e, payload_preview
);
}
}
}
warn!("Response subscription ended unexpectedly");
});
Ok(())
}
async fn get_node_count(&self) -> Result<usize> {
if let Some(nodes) = self.config.nodes_number {
return Ok(nodes as usize);
}
match self.discover_node_count_via_system_ping().await {
Ok(Some(nodes)) => Ok(nodes.max(1)),
Ok(None) => Ok(1),
Err(error) => {
warn!(
"NATS node discovery via system ping failed: {}. Falling back to 1 node",
error
);
Ok(1)
}
}
}
async fn check_health(&self) -> Result<()> {
let state = self.client.connection_state();
match state {
async_nats::connection::State::Connected => Ok(()),
async_nats::connection::State::Disconnected => Err(
sockudo_core::error::Error::Connection("NATS client is disconnected".to_string()),
),
other_state => Err(sockudo_core::error::Error::Connection(format!(
"NATS client in transitional state: {other_state:?}"
))),
}
}
fn set_metrics(&self, metrics: Arc<dyn MetricsInterface + Send + Sync>) {
let _ = self.metrics_driver.set(metrics);
}
}
impl Drop for NatsTransport {
fn drop(&mut self) {
if self.owner_count.fetch_sub(1, Ordering::AcqRel) == 1 {
self.is_running.store(false, Ordering::Relaxed);
self.shutdown.notify_waiters();
}
}
}
impl Clone for NatsTransport {
fn clone(&self) -> Self {
self.owner_count.fetch_add(1, Ordering::Relaxed);
Self {
client: self.client.clone(),
broadcast_subject: self.broadcast_subject.clone(),
request_subject: self.request_subject.clone(),
response_subject: self.response_subject.clone(),
config: self.config.clone(),
metrics: self.metrics.clone(),
metrics_driver: self.metrics_driver.clone(),
shutdown: self.shutdown.clone(),
is_running: self.is_running.clone(),
owner_count: self.owner_count.clone(),
}
}
}