use crate::types::{Error, Result, ShipEvent, StreamConfig};
use clap::Parser;
use futures::StreamExt;
use serde::Serialize;
use serde::de::DeserializeOwned;
use std::collections::HashMap;
use std::path::PathBuf;
pub fn parse_payload<T: DeserializeOwned>(payload: &[u8]) -> Result<T> {
serde_json::from_slice(payload).map_err(Error::Serialization)
}
pub fn serialize_response<T: Serialize>(response: &T) -> Vec<u8> {
serde_json::to_vec(response).unwrap_or_else(|_| b"{}".to_vec())
}
pub fn error_response(message: &str, code: u16) -> Vec<u8> {
serde_json::to_vec(&ErrorResponse {
error: message.to_string(),
code,
})
.unwrap_or_else(|_| format!("{{\"error\":\"{message}\",\"code\":{code}}}").into_bytes())
}
#[derive(serde::Serialize)]
struct ErrorResponse {
error: String,
code: u16,
}
#[derive(Parser, Debug, Clone)]
pub struct ServiceArgs {
#[arg(long, default_value = ".")]
pub data_dir: PathBuf,
#[arg(long, default_value = "nats://localhost:4222")]
pub nats_url: String,
}
type HandlerFn<S> = Box<dyn Fn(&[u8], &mut S) -> Vec<u8> + Send + Sync>;
type MutationCallback<S> = Box<dyn Fn(&str, &[u8], &S) -> Option<(String, Vec<u8>)> + Send + Sync>;
type ShutdownCallback<S> = Box<dyn FnOnce(&S) + Send>;
type DefaultHandlerFn<S> = Box<dyn Fn(&str, &[u8], &mut S) -> Vec<u8> + Send + Sync>;
pub struct NatsServiceBuilder<S: Send + Sync + 'static> {
nats_url: String,
subject_prefix: String,
state: S,
handlers: HashMap<String, HandlerFn<S>>,
default_handler: Option<DefaultHandlerFn<S>>,
mutation_callback: Option<MutationCallback<S>>,
event_subject_prefix: Option<String>,
event_bus_stream: Option<StreamConfig>,
event_source: Option<String>,
shutdown_callback: Option<ShutdownCallback<S>>,
}
impl<S: Send + Sync + 'static> NatsServiceBuilder<S> {
pub fn new(subject_prefix: impl Into<String>, state: S) -> Self {
Self {
nats_url: "nats://localhost:4222".to_string(),
subject_prefix: subject_prefix.into(),
state,
handlers: HashMap::new(),
default_handler: None,
mutation_callback: None,
event_subject_prefix: None,
event_bus_stream: None,
event_source: None,
shutdown_callback: None,
}
}
pub fn nats_url(mut self, url: &str) -> Self {
self.nats_url = url.to_string();
self
}
pub fn handler<F>(mut self, command: &str, handler: F) -> Self
where
F: Fn(&[u8], &mut S) -> Vec<u8> + Send + Sync + 'static,
{
self.handlers.insert(command.to_string(), Box::new(handler));
self
}
pub fn default_handler<F>(mut self, handler: F) -> Self
where
F: Fn(&str, &[u8], &mut S) -> Vec<u8> + Send + Sync + 'static,
{
self.default_handler = Some(Box::new(handler));
self
}
pub fn mutation_callback<F>(mut self, callback: F) -> Self
where
F: Fn(&str, &[u8], &S) -> Option<(String, Vec<u8>)> + Send + Sync + 'static,
{
self.mutation_callback = Some(Box::new(callback));
self
}
pub fn event_prefix(mut self, prefix: &str) -> Self {
self.event_subject_prefix = Some(prefix.to_string());
self
}
pub fn event_bus_stream(mut self, stream: StreamConfig, source: impl Into<String>) -> Self {
self.event_bus_stream = Some(stream);
self.event_source = Some(source.into());
self
}
pub fn on_shutdown<F>(mut self, callback: F) -> Self
where
F: FnOnce(&S) + Send + 'static,
{
self.shutdown_callback = Some(Box::new(callback));
self
}
pub async fn run(mut self) -> Result<()> {
let mut conn = crate::connection::ConnectionManager::new(crate::types::NatsConfig::new(
&self.nats_url,
));
conn.connect()
.await
.map_err(|e| Error::Connection(e.to_string()))?;
let client = conn
.client()
.map_err(|e| Error::Connection(e.to_string()))?
.clone();
if let Some(ref stream_config) = self.event_bus_stream {
conn.ensure_stream(stream_config).await?;
}
let subscribe_subject = format!("{}.>", self.subject_prefix);
let mut subscriber = client
.subscribe(subscribe_subject)
.await
.map_err(|e| Error::Connection(e.to_string()))?;
tracing::info!(
prefix = %self.subject_prefix,
url = %self.nats_url,
handlers = self.handlers.len(),
has_default = self.default_handler.is_some(),
jetstream = self.event_bus_stream.is_some(),
"service ready"
);
let js = if self.event_bus_stream.is_some() {
Some(
conn.jetstream()
.map_err(|e| Error::Connection(e.to_string()))?,
)
} else {
None
};
let event_source = self.event_source.clone().unwrap_or_default();
loop {
tokio::select! {
msg = subscriber.next() => {
match msg {
Some(msg) => {
let subject = msg.subject.to_string();
let command = self.strip_prefix(&subject);
let response = if self.handlers.contains_key(command) {
self.dispatch(command, &msg.payload)
} else if let Some(ref default) = self.default_handler {
default(&subject, &msg.payload, &mut self.state)
} else {
self.dispatch(command, &msg.payload)
};
if let Some(reply_to) = msg.reply
&& let Err(e) = client.publish(reply_to, response.clone().into()).await
{
tracing::error!(error = %e, subject = %subject, "reply failed");
}
if let Some(ref callback) = self.mutation_callback
&& let Some((event_type, event_bytes)) = callback(command, &response, &self.state)
{
let full_subject = match &self.event_subject_prefix {
Some(prefix) => format!("{}.{}", prefix, event_type),
None => event_type.clone(),
};
if let Some(js) = &js {
let envelope = ShipEvent::new(
&event_type,
&event_source,
serde_json::from_slice(&event_bytes)
.unwrap_or(serde_json::Value::Null),
);
let data = serde_json::to_vec(&envelope)
.unwrap_or_else(|_| event_bytes.clone());
match js.publish(full_subject, data.into()).await {
Ok(ack) => { drop(ack); }
Err(e) => {
tracing::error!(error = %e, "jetstream event publish failed");
}
}
} else {
if let Err(e) = client.publish(full_subject, event_bytes.into()).await {
tracing::error!(error = %e, "event publish failed");
}
}
}
}
None => {
tracing::warn!("NATS subscription closed");
break;
}
}
}
_ = tokio::signal::ctrl_c() => {
tracing::info!("shutting down");
break;
}
}
}
if let Some(callback) = self.shutdown_callback.take() {
callback(&self.state);
}
let _ = conn.disconnect().await;
Ok(())
}
pub fn dispatch(&mut self, command: &str, payload: &[u8]) -> Vec<u8> {
match self.handlers.get(command) {
Some(handler) => handler(payload, &mut self.state),
None => error_response(&format!("UNKNOWN_COMMAND: {}", command), 404),
}
}
fn strip_prefix<'a>(&self, subject: &'a str) -> &'a str {
subject
.strip_prefix(&self.subject_prefix)
.and_then(|s| s.strip_prefix('.'))
.unwrap_or(subject)
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[derive(Default)]
struct TestState {
counter: u64,
}
#[test]
fn dispatch_echo() {
let mut svc = NatsServiceBuilder::new("test.cmd", TestState::default())
.handler("echo", |payload, _state| payload.to_vec());
let response = svc.dispatch("echo", b"hello");
assert_eq!(response, b"hello");
}
#[test]
fn dispatch_with_state() {
let mut svc = NatsServiceBuilder::new("test.cmd", TestState::default()).handler(
"increment",
|_payload, state: &mut TestState| {
state.counter += 1;
serialize_response(&state.counter)
},
);
svc.dispatch("increment", b"");
svc.dispatch("increment", b"");
let response = svc.dispatch("increment", b"");
let count: u64 = serde_json::from_slice(&response).unwrap();
assert_eq!(count, 3);
}
#[test]
fn dispatch_unknown_command() {
let mut svc = NatsServiceBuilder::new("test.cmd", TestState::default());
let response = svc.dispatch("nonexistent", b"");
let parsed: serde_json::Value = serde_json::from_slice(&response).unwrap();
assert!(
parsed["error"]
.as_str()
.unwrap()
.contains("UNKNOWN_COMMAND")
);
assert_eq!(parsed["code"], 404);
}
#[test]
fn strip_prefix() {
let svc = NatsServiceBuilder::new("kanban.cmd", TestState::default());
assert_eq!(svc.strip_prefix("kanban.cmd.create"), "create");
assert_eq!(svc.strip_prefix("kanban.cmd.pr.merge"), "pr.merge");
assert_eq!(svc.strip_prefix("other.subject"), "other.subject");
}
#[test]
fn parse_payload_success() {
#[derive(serde::Deserialize)]
struct Req {
name: String,
}
let payload = serde_json::to_vec(&json!({"name": "test"})).unwrap();
let req: Req = parse_payload(&payload).unwrap();
assert_eq!(req.name, "test");
}
#[test]
fn parse_payload_error() {
#[derive(serde::Deserialize)]
struct Req {
#[allow(dead_code)]
name: String,
}
let result: Result<Req> = parse_payload(b"not json");
assert!(result.is_err());
}
#[test]
fn serialize_response_json() {
let data = json!({"status": "ok", "count": 42});
let bytes = serialize_response(&data);
let back: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
assert_eq!(back["count"], 42);
}
#[test]
fn error_response_format() {
let bytes = error_response("not found", 404);
let parsed: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
assert_eq!(parsed["error"], "not found");
assert_eq!(parsed["code"], 404);
}
#[test]
fn default_handler_catch_all() {
let mut svc = NatsServiceBuilder::new("kanban.cmd", TestState::default()).default_handler(
|subject, _payload, _state| format!("handled: {subject}").into_bytes(),
);
let subject = "kanban.cmd.create";
let command = svc.strip_prefix(subject);
let response = if svc.handlers.contains_key(command) {
svc.dispatch(command, b"")
} else if let Some(ref default) = svc.default_handler {
default(subject, b"", &mut svc.state)
} else {
svc.dispatch(command, b"")
};
assert_eq!(response, b"handled: kanban.cmd.create");
}
#[test]
fn named_handler_takes_precedence_over_default() {
let mut svc = NatsServiceBuilder::new("svc", TestState::default())
.handler("ping", |_, _| b"pong".to_vec())
.default_handler(|_subject, _payload, _state| b"default".to_vec());
let subject = "svc.ping";
let command = svc.strip_prefix(subject);
let response = if svc.handlers.contains_key(command) {
svc.dispatch(command, b"")
} else if let Some(ref default) = svc.default_handler {
default(subject, b"", &mut svc.state)
} else {
svc.dispatch(command, b"")
};
assert_eq!(response, b"pong");
}
#[test]
fn multiple_handlers() {
let mut svc = NatsServiceBuilder::new("svc", TestState::default())
.handler("a", |_, _| b"handler_a".to_vec())
.handler("b", |_, _| b"handler_b".to_vec())
.handler("c", |_, _| b"handler_c".to_vec());
assert_eq!(svc.dispatch("a", b""), b"handler_a");
assert_eq!(svc.dispatch("b", b""), b"handler_b");
assert_eq!(svc.dispatch("c", b""), b"handler_c");
}
#[test]
fn mutation_callback_fires() {
let mut svc = NatsServiceBuilder::new("svc", TestState::default())
.handler("create", |_, _| b"created".to_vec())
.mutation_callback(|cmd, _response, _state| {
if cmd == "create" {
Some(("item.created".to_string(), b"event_data".to_vec()))
} else {
None
}
});
let response = svc.dispatch("create", b"");
assert_eq!(response, b"created");
let callback = svc.mutation_callback.as_ref().unwrap();
let event = callback("create", &response, &svc.state);
assert!(event.is_some());
let (subject, _) = event.unwrap();
assert_eq!(subject, "item.created");
let no_event = callback("query", &[], &svc.state);
assert!(no_event.is_none());
}
#[test]
fn service_args_defaults() {
let args = ServiceArgs {
data_dir: PathBuf::from("."),
nats_url: "nats://localhost:4222".to_string(),
};
assert_eq!(args.nats_url, "nats://localhost:4222");
}
}