pub mod error;
use std::{
collections::HashMap,
fmt::Display,
pin::Pin,
sync::{Arc, Mutex},
time::{Duration, Instant},
};
use bytes::Bytes;
pub mod endpoint;
use futures_util::{
stream::{self, SelectAll},
Future, StreamExt,
};
use regex::Regex;
use serde::{Deserialize, Serialize};
use std::sync::LazyLock;
use time::serde::rfc3339;
use time::OffsetDateTime;
use tokio::{sync::broadcast::Sender, task::JoinHandle};
use tracing::debug;
use crate::{
client::PublishErrorKind, Client, Error, HeaderMap, Message, PublishError, Subscriber,
};
use self::endpoint::Endpoint;
const SERVICE_API_PREFIX: &str = "$SRV";
const DEFAULT_QUEUE_GROUP: &str = "q";
pub const NATS_SERVICE_ERROR: &str = "Nats-Service-Error";
pub const NATS_SERVICE_ERROR_CODE: &str = "Nats-Service-Error-Code";
static SEMVER: LazyLock<Regex> = LazyLock::new(|| {
Regex::new(r"^(0|[1-9]\d*)\.(0|[1-9]\d*)\.(0|[1-9]\d*)(?:-((?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\+([0-9a-zA-Z-]+(?:\.[0-9a-zA-Z-]+)*))?$")
.unwrap()
});
static NAME: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"^[A-Za-z0-9\-_]+$").unwrap());
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct Endpoints {
pub(crate) endpoints: HashMap<String, endpoint::Inner>,
}
#[derive(Serialize, Deserialize)]
pub struct PingResponse {
#[serde(rename = "type")]
pub kind: String,
pub name: String,
pub id: String,
pub version: String,
#[serde(default, deserialize_with = "endpoint::null_meta_as_default")]
pub metadata: HashMap<String, String>,
}
#[derive(Serialize, Deserialize)]
pub struct Stats {
#[serde(rename = "type")]
pub kind: String,
pub name: String,
pub id: String,
pub version: String,
#[serde(with = "rfc3339")]
pub started: OffsetDateTime,
pub endpoints: Vec<endpoint::Stats>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Info {
#[serde(rename = "type")]
pub kind: String,
pub name: String,
pub id: String,
pub description: String,
pub version: String,
#[serde(default, deserialize_with = "endpoint::null_meta_as_default")]
pub metadata: HashMap<String, String>,
pub endpoints: Vec<endpoint::Info>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct Config {
pub name: String,
pub description: Option<String>,
pub version: String,
#[serde(skip)]
pub stats_handler: Option<StatsHandler>,
pub metadata: Option<HashMap<String, String>>,
pub queue_group: Option<String>,
}
pub struct ServiceBuilder {
client: Client,
description: Option<String>,
stats_handler: Option<StatsHandler>,
metadata: Option<HashMap<String, String>>,
queue_group: Option<String>,
}
impl ServiceBuilder {
fn new(client: Client) -> Self {
Self {
client,
description: None,
stats_handler: None,
metadata: None,
queue_group: None,
}
}
pub fn description<S: ToString>(mut self, description: S) -> Self {
self.description = Some(description.to_string());
self
}
pub fn stats_handler<F>(mut self, handler: F) -> Self
where
F: FnMut(String, endpoint::Stats) -> serde_json::Value + Send + Sync + 'static,
{
self.stats_handler = Some(StatsHandler(Box::new(handler)));
self
}
pub fn metadata(mut self, metadata: HashMap<String, String>) -> Self {
self.metadata = Some(metadata);
self
}
pub fn queue_group<S: ToString>(mut self, queue_group: S) -> Self {
self.queue_group = Some(queue_group.to_string());
self
}
pub async fn start<N: ToString, V: ToString>(
self,
name: N,
version: V,
) -> Result<Service, Error> {
Service::add(
self.client,
Config {
name: name.to_string(),
version: version.to_string(),
description: self.description,
stats_handler: self.stats_handler,
metadata: self.metadata,
queue_group: self.queue_group,
},
)
.await
}
}
pub enum Verb {
Ping,
Stats,
Info,
Schema,
}
impl Display for Verb {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Verb::Ping => write!(f, "PING"),
Verb::Stats => write!(f, "STATS"),
Verb::Info => write!(f, "INFO"),
Verb::Schema => write!(f, "SCHEMA"),
}
}
}
pub trait ServiceExt {
type Output: Future<Output = Result<Service, crate::Error>>;
fn add_service(&self, config: Config) -> Self::Output;
fn service_builder(&self) -> ServiceBuilder;
}
impl ServiceExt for Client {
type Output = Pin<Box<dyn Future<Output = Result<Service, crate::Error>> + Send>>;
fn add_service(&self, config: Config) -> Self::Output {
let client = self.clone();
Box::pin(async { Service::add(client, config).await })
}
fn service_builder(&self) -> ServiceBuilder {
ServiceBuilder::new(self.clone())
}
}
#[derive(Debug)]
pub struct Service {
endpoints_state: Arc<Mutex<Endpoints>>,
info: Info,
client: Client,
handle: JoinHandle<Result<(), Error>>,
shutdown_tx: Sender<()>,
subjects: Arc<Mutex<Vec<String>>>,
queue_group: String,
}
impl Service {
async fn add(client: Client, config: Config) -> Result<Service, Error> {
if !SEMVER.is_match(config.version.as_str()) {
return Err(Box::new(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"service version is not a valid semver string",
)));
}
if !NAME.is_match(config.name.as_str()) {
return Err(Box::new(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"service name is not a valid string (only A-Z, a-z, 0-9, _, - are allowed)",
)));
}
let endpoints_state = Arc::new(Mutex::new(Endpoints {
endpoints: HashMap::new(),
}));
let queue_group = config
.queue_group
.unwrap_or(DEFAULT_QUEUE_GROUP.to_string());
let id = crate::id_generator::next();
let started = OffsetDateTime::now_utc();
let subjects = Arc::new(Mutex::new(Vec::new()));
let info = Info {
kind: "io.nats.micro.v1.info_response".to_string(),
name: config.name.clone(),
id: id.clone(),
description: config.description.clone().unwrap_or_default(),
version: config.version.clone(),
metadata: config.metadata.clone().unwrap_or_default(),
endpoints: Vec::new(),
};
let (shutdown_tx, _) = tokio::sync::broadcast::channel(1);
let mut pings =
verb_subscription(client.clone(), Verb::Ping, config.name.clone(), id.clone()).await?;
let mut infos =
verb_subscription(client.clone(), Verb::Info, config.name.clone(), id.clone()).await?;
let mut stats =
verb_subscription(client.clone(), Verb::Stats, config.name.clone(), id.clone()).await?;
let handle = tokio::task::spawn({
let mut stats_callback = config.stats_handler;
let info = info.clone();
let endpoints_state = endpoints_state.clone();
let client = client.clone();
async move {
loop {
tokio::select! {
Some(ping) = pings.next() => {
let pong = serde_json::to_vec(&PingResponse{
kind: "io.nats.micro.v1.ping_response".to_string(),
name: info.name.clone(),
id: info.id.clone(),
version: info.version.clone(),
metadata: info.metadata.clone(),
})?;
client.publish(ping.reply.unwrap(), pong.into()).await?;
},
Some(info_request) = infos.next() => {
let info = info.clone();
let endpoints: Vec<endpoint::Info> = {
endpoints_state.lock().unwrap().endpoints.values().map(|value| {
endpoint::Info {
name: value.name.to_owned(),
subject: value.subject.to_owned(),
queue_group: value.queue_group.to_owned(),
metadata: value.metadata.to_owned()
}
}).collect()
};
let info = Info {
endpoints,
..info
};
let info_json = serde_json::to_vec(&info).map(Bytes::from)?;
client.publish(info_request.reply.unwrap(), info_json.clone()).await?;
},
Some(stats_request) = stats.next() => {
if let Some(stats_callback) = stats_callback.as_mut() {
let mut endpoint_stats_locked = endpoints_state.lock().unwrap();
for (key, value) in &mut endpoint_stats_locked.endpoints {
let data = stats_callback.0(key.to_string(), value.clone().into());
value.data = Some(data);
}
}
let stats = serde_json::to_vec(&Stats {
kind: "io.nats.micro.v1.stats_response".to_string(),
name: info.name.clone(),
id: info.id.clone(),
version: info.version.clone(),
started,
endpoints: endpoints_state.lock().unwrap().endpoints.values().cloned().map(Into::into).collect(),
})?;
client.publish(stats_request.reply.unwrap(), stats.into()).await?;
},
else => break,
}
}
Ok(())
}
});
Ok(Service {
endpoints_state,
info,
client,
handle,
shutdown_tx,
subjects,
queue_group,
})
}
pub async fn stop(self) -> Result<(), Error> {
self.shutdown_tx.send(())?;
self.handle.abort();
Ok(())
}
pub async fn reset(&mut self) {
for value in self.endpoints_state.lock().unwrap().endpoints.values_mut() {
value.errors = 0;
value.processing_time = Duration::default();
value.requests = 0;
value.average_processing_time = Duration::default();
}
}
pub async fn stats(&self) -> HashMap<String, endpoint::Stats> {
self.endpoints_state
.lock()
.unwrap()
.endpoints
.iter()
.map(|(key, value)| (key.to_owned(), value.to_owned().into()))
.collect()
}
pub async fn info(&self) -> Info {
self.info.clone()
}
pub fn group<S: ToString>(&self, prefix: S) -> Group {
self.group_with_queue_group(prefix, self.queue_group.clone())
}
pub fn group_with_queue_group<S: ToString, Z: ToString>(
&self,
prefix: S,
queue_group: Z,
) -> Group {
Group {
subjects: self.subjects.clone(),
prefix: prefix.to_string(),
stats: self.endpoints_state.clone(),
client: self.client.clone(),
shutdown_tx: self.shutdown_tx.clone(),
queue_group: queue_group.to_string(),
}
}
pub fn endpoint_builder(&self) -> EndpointBuilder {
EndpointBuilder::new(
self.client.clone(),
self.endpoints_state.clone(),
self.shutdown_tx.clone(),
self.subjects.clone(),
self.queue_group.clone(),
)
}
pub async fn endpoint<S: ToString>(&self, subject: S) -> Result<Endpoint, Error> {
EndpointBuilder::new(
self.client.clone(),
self.endpoints_state.clone(),
self.shutdown_tx.clone(),
self.subjects.clone(),
self.queue_group.clone(),
)
.add(subject)
.await
}
}
pub struct Group {
prefix: String,
stats: Arc<Mutex<Endpoints>>,
client: Client,
shutdown_tx: Sender<()>,
subjects: Arc<Mutex<Vec<String>>>,
queue_group: String,
}
impl Group {
pub fn group<S: ToString>(&self, prefix: S) -> Group {
self.group_with_queue_group(prefix, self.queue_group.clone())
}
pub fn group_with_queue_group<S: ToString, Z: ToString>(
&self,
prefix: S,
queue_group: Z,
) -> Group {
Group {
prefix: format!("{}.{}", self.prefix, prefix.to_string()),
stats: self.stats.clone(),
client: self.client.clone(),
shutdown_tx: self.shutdown_tx.clone(),
subjects: self.subjects.clone(),
queue_group: queue_group.to_string(),
}
}
pub async fn endpoint<S: ToString>(&self, subject: S) -> Result<Endpoint, Error> {
let endpoint = self.endpoint_builder();
endpoint.add(subject.to_string()).await
}
pub fn endpoint_builder(&self) -> EndpointBuilder {
let mut endpoint = EndpointBuilder::new(
self.client.clone(),
self.stats.clone(),
self.shutdown_tx.clone(),
self.subjects.clone(),
self.queue_group.clone(),
);
endpoint.prefix = Some(self.prefix.clone());
endpoint
}
}
async fn verb_subscription(
client: Client,
verb: Verb,
name: String,
id: String,
) -> Result<stream::Fuse<SelectAll<Subscriber>>, Error> {
let verb_all = client
.subscribe(format!("{SERVICE_API_PREFIX}.{verb}"))
.await?;
let verb_name = client
.subscribe(format!("{SERVICE_API_PREFIX}.{verb}.{name}"))
.await?;
let verb_id = client
.subscribe(format!("{SERVICE_API_PREFIX}.{verb}.{name}.{id}"))
.await?;
Ok(stream::select_all([verb_all, verb_id, verb_name]).fuse())
}
type ShutdownReceiverFuture = Pin<
Box<dyn Future<Output = Result<(), tokio::sync::broadcast::error::RecvError>> + Send + Sync>,
>;
#[derive(Debug)]
pub struct Request {
issued: Instant,
client: Client,
pub message: Message,
endpoint: String,
stats: Arc<Mutex<Endpoints>>,
}
impl Request {
pub async fn respond(&self, response: Result<Bytes, error::Error>) -> Result<(), PublishError> {
self.respond_with_headers(response, HeaderMap::new()).await
}
pub async fn respond_with_headers(
&self,
response: Result<Bytes, error::Error>,
mut headers: HeaderMap,
) -> Result<(), PublishError> {
let reply = match self.message.reply.clone() {
None => {
return Err(PublishError::with_source(
PublishErrorKind::InvalidSubject,
"Request is missing reply subject to respond to",
))
}
Some(subject) => subject,
};
let result = match response {
Ok(payload) => {
if headers.is_empty() {
self.client.publish(reply, payload).await
} else {
self.client
.publish_with_headers(reply, headers, payload)
.await
}
}
Err(err) => {
self.stats
.lock()
.unwrap()
.endpoints
.entry(self.endpoint.clone())
.and_modify(|stats| {
stats.last_error = Some(err.clone());
stats.errors += 1;
})
.or_default();
headers.insert(NATS_SERVICE_ERROR, err.status.as_str());
headers.insert(NATS_SERVICE_ERROR_CODE, err.code.to_string().as_str());
self.client
.publish_with_headers(reply, headers, "".into())
.await
}
};
let elapsed = self.issued.elapsed();
let mut stats = self.stats.lock().unwrap();
let stats = stats.endpoints.get_mut(self.endpoint.as_str()).unwrap();
stats.requests += 1;
stats.processing_time += elapsed;
stats.average_processing_time = {
let avg_nanos = (stats.processing_time.as_nanos() / stats.requests as u128) as u64;
Duration::from_nanos(avg_nanos)
};
result
}
}
#[derive(Debug)]
pub struct EndpointBuilder {
client: Client,
stats: Arc<Mutex<Endpoints>>,
shutdown_tx: Sender<()>,
name: Option<String>,
metadata: Option<HashMap<String, String>>,
subjects: Arc<Mutex<Vec<String>>>,
queue_group: String,
prefix: Option<String>,
}
impl EndpointBuilder {
fn new(
client: Client,
stats: Arc<Mutex<Endpoints>>,
shutdown_tx: Sender<()>,
subjects: Arc<Mutex<Vec<String>>>,
queue_group: String,
) -> EndpointBuilder {
EndpointBuilder {
client,
stats,
subjects,
shutdown_tx,
name: None,
metadata: None,
queue_group,
prefix: None,
}
}
pub fn name<S: ToString>(mut self, name: S) -> EndpointBuilder {
self.name = Some(name.to_string());
self
}
pub fn metadata(mut self, metadata: HashMap<String, String>) -> EndpointBuilder {
self.metadata = Some(metadata);
self
}
pub fn queue_group<S: ToString>(mut self, queue_group: S) -> EndpointBuilder {
self.queue_group = queue_group.to_string();
self
}
pub async fn add<S: ToString>(self, subject: S) -> Result<Endpoint, Error> {
let mut subject = subject.to_string();
if let Some(prefix) = self.prefix {
subject = format!("{prefix}.{subject}");
}
let endpoint_name = self.name.clone().unwrap_or_else(|| subject.clone());
let name = self
.name
.clone()
.unwrap_or_else(|| subject.clone().replace('.', "-"));
let requests = self
.client
.queue_subscribe(subject.to_owned(), self.queue_group.to_string())
.await?;
debug!("created service for endpoint {subject}");
let shutdown_rx = self.shutdown_tx.subscribe();
let mut stats = self.stats.lock().unwrap();
stats
.endpoints
.entry(endpoint_name.clone())
.or_insert(endpoint::Inner {
name,
subject: subject.clone(),
metadata: self.metadata.unwrap_or_default(),
queue_group: self.queue_group.clone(),
..Default::default()
});
self.subjects.lock().unwrap().push(subject.clone());
Ok(Endpoint {
requests,
stats: self.stats.clone(),
client: self.client.clone(),
endpoint: endpoint_name,
shutdown: Some(shutdown_rx),
shutdown_future: None,
})
}
}
pub struct StatsHandler(pub Box<dyn FnMut(String, endpoint::Stats) -> serde_json::Value + Send>);
impl std::fmt::Debug for StatsHandler {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Stats handler")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_group_with_queue_group() {
let server = nats_server::run_basic_server();
let client = crate::connect(server.client_url()).await.unwrap();
let group = Group {
prefix: "test".to_string(),
stats: Arc::new(Mutex::new(Endpoints {
endpoints: HashMap::new(),
})),
client,
shutdown_tx: tokio::sync::broadcast::channel(1).0,
subjects: Arc::new(Mutex::new(vec![])),
queue_group: "default".to_string(),
};
let new_group = group.group_with_queue_group("v1", "custom_queue");
assert_eq!(new_group.prefix, "test.v1");
assert_eq!(new_group.queue_group, "custom_queue");
}
#[tokio::test]
async fn test_respond_with_headers_overrides_error_headers() {
let server = nats_server::run_basic_server();
let client = crate::connect(server.client_url()).await.unwrap();
let service = client
.service_builder()
.start("test-service", "1.0.0")
.await
.unwrap();
let subject = "test.subject";
let mut endpoint = service.endpoint(subject).await.unwrap();
let handler = async {
if let Some(request) = endpoint.next().await {
let mut resp_headers = HeaderMap::new();
resp_headers.insert("x-success", "false");
resp_headers.insert(NATS_SERVICE_ERROR, "user-supplied-value");
resp_headers.insert(NATS_SERVICE_ERROR_CODE, "999");
let err = error::Error {
status: "internal-error".to_string(),
code: 500,
};
request
.respond_with_headers(Err(err), resp_headers)
.await
.expect("failed to send response");
}
};
let requester = crate::connect(server.client_url()).await.unwrap();
let request_fut = async { requester.request(subject, "".into()).await.unwrap() };
let (_, resp) = tokio::join!(handler, request_fut);
let headers = resp.headers.expect("expected headers on reply");
assert_eq!(headers.get("x-success").unwrap().as_str(), "false");
assert_eq!(
headers.get(NATS_SERVICE_ERROR).unwrap().as_str(),
"internal-error"
);
assert_eq!(
headers.get(NATS_SERVICE_ERROR_CODE).unwrap().as_str(),
"500"
);
}
#[tokio::test]
async fn test_respond_with_headers_preserves_headers_on_success() {
let server = nats_server::run_basic_server();
let client = crate::connect(server.client_url()).await.unwrap();
let service = client
.service_builder()
.start("test-service", "1.0.0")
.await
.unwrap();
let subject = "test.subject";
let mut endpoint = service.endpoint(subject).await.unwrap();
let handler = async {
if let Some(request) = endpoint.next().await {
let mut resp_headers = HeaderMap::new();
resp_headers.insert("x-success", "false");
resp_headers.insert("x-request-id", "req-123");
resp_headers.insert(NATS_SERVICE_ERROR, "user-supplied-value");
resp_headers.insert(NATS_SERVICE_ERROR_CODE, "999");
request
.respond_with_headers(Ok("ok".into()), resp_headers)
.await
.unwrap();
}
};
let requester = crate::connect(server.client_url()).await.unwrap();
let request_fut = async { requester.request(subject, "".into()).await.unwrap() };
let (_, resp) = tokio::join!(handler, request_fut);
let headers = resp.headers.expect("expected headers on reply");
assert_eq!(headers.get("x-success").unwrap().as_str(), "false");
assert_eq!(headers.get("x-request-id").unwrap().as_str(), "req-123");
assert_eq!(
headers.get(NATS_SERVICE_ERROR).unwrap().as_str(),
"user-supplied-value"
);
assert_eq!(
headers.get(NATS_SERVICE_ERROR_CODE).unwrap().as_str(),
"999"
);
}
}