use async_nats::jetstream::{self, Context};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use thiserror::Error;
pub const DEFAULT_SUBJECT: &str = "noetl.commands";
pub const DEFAULT_STREAM: &str = "noetl_commands";
#[derive(Debug, Error)]
pub enum NatsError {
#[error("NATS connection error: {0}")]
Connection(String),
#[error("JetStream error: {0}")]
JetStream(String),
#[error("Publish error: {0}")]
Publish(String),
#[error("Not connected to NATS")]
NotConnected,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CommandNotification {
pub execution_id: i64,
pub event_id: i64,
pub command_id: String,
pub step: String,
pub server_url: String,
}
#[derive(Clone)]
pub struct NatsPublisher {
js: Context,
subject: String,
}
impl NatsPublisher {
pub async fn new(
client: Arc<async_nats::Client>,
subject: Option<&str>,
stream_name: Option<&str>,
) -> Result<Self, NatsError> {
let subject = subject.unwrap_or(DEFAULT_SUBJECT).to_string();
let stream = stream_name.unwrap_or(DEFAULT_STREAM);
let js = jetstream::new((*client).clone());
Self::ensure_stream(&js, stream, &subject).await?;
Ok(Self { js, subject })
}
async fn ensure_stream(js: &Context, stream: &str, subject: &str) -> Result<(), NatsError> {
match js.get_stream(stream).await {
Ok(_) => {
tracing::debug!(stream = %stream, "Using existing NATS stream");
Ok(())
}
Err(_) => {
let config = jetstream::stream::Config {
name: stream.to_string(),
subjects: vec![subject.to_string()],
max_age: std::time::Duration::from_secs(3600), storage: jetstream::stream::StorageType::File,
..Default::default()
};
js.create_stream(config)
.await
.map_err(|e| NatsError::JetStream(e.to_string()))?;
tracing::info!(stream = %stream, subject = %subject, "Created NATS stream");
Ok(())
}
}
}
pub async fn publish_command(
&self,
notification: CommandNotification,
) -> Result<(), NatsError> {
let payload = serde_json::to_vec(¬ification)
.map_err(|e| NatsError::Publish(format!("Serialization error: {}", e)))?;
self.js
.publish(self.subject.clone(), payload.into())
.await
.map_err(|e| NatsError::Publish(e.to_string()))?
.await
.map_err(|e| NatsError::Publish(e.to_string()))?;
tracing::debug!(
execution_id = notification.execution_id,
event_id = notification.event_id,
command_id = %notification.command_id,
step = %notification.step,
"Published command notification"
);
Ok(())
}
pub async fn publish(
&self,
execution_id: i64,
event_id: i64,
command_id: &str,
step: &str,
server_url: &str,
) -> Result<(), NatsError> {
let notification = CommandNotification {
execution_id,
event_id,
command_id: command_id.to_string(),
step: step.to_string(),
server_url: server_url.to_string(),
};
self.publish_command(notification).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_command_notification_serialization() {
let notification = CommandNotification {
execution_id: 12345,
event_id: 67890,
command_id: "cmd-abc123".to_string(),
step: "process_data".to_string(),
server_url: "http://localhost:8082".to_string(),
};
let json = serde_json::to_string(¬ification).unwrap();
assert!(json.contains("12345"));
assert!(json.contains("67890"));
assert!(json.contains("cmd-abc123"));
assert!(json.contains("process_data"));
}
#[test]
fn test_command_notification_deserialization() {
let json = r#"{
"execution_id": 12345,
"event_id": 67890,
"command_id": "cmd-abc123",
"step": "process_data",
"server_url": "http://localhost:8082"
}"#;
let notification: CommandNotification = serde_json::from_str(json).unwrap();
assert_eq!(notification.execution_id, 12345);
assert_eq!(notification.event_id, 67890);
assert_eq!(notification.command_id, "cmd-abc123");
assert_eq!(notification.step, "process_data");
assert_eq!(notification.server_url, "http://localhost:8082");
}
#[test]
fn test_default_constants() {
assert_eq!(DEFAULT_SUBJECT, "noetl.commands");
assert_eq!(DEFAULT_STREAM, "noetl_commands");
}
}