1pub mod error;
15
16use std::{
17 collections::HashMap,
18 fmt::Display,
19 pin::Pin,
20 sync::{Arc, Mutex},
21 time::{Duration, Instant},
22};
23
24use bytes::Bytes;
25pub mod endpoint;
26use futures_util::{
27 stream::{self, SelectAll},
28 Future, StreamExt,
29};
30use regex::Regex;
31use serde::{Deserialize, Serialize};
32use std::sync::LazyLock;
33use time::serde::rfc3339;
34use time::OffsetDateTime;
35use tokio::{sync::broadcast::Sender, task::JoinHandle};
36use tracing::debug;
37
38use crate::{
39 client::PublishErrorKind, Client, Error, HeaderMap, Message, PublishError, Subscriber,
40};
41
42use self::endpoint::Endpoint;
43
44const SERVICE_API_PREFIX: &str = "$SRV";
45const DEFAULT_QUEUE_GROUP: &str = "q";
46pub const NATS_SERVICE_ERROR: &str = "Nats-Service-Error";
47pub const NATS_SERVICE_ERROR_CODE: &str = "Nats-Service-Error-Code";
48
49static SEMVER: LazyLock<Regex> = LazyLock::new(|| {
52 Regex::new(r"^(0|[1-9]\d*)\.(0|[1-9]\d*)\.(0|[1-9]\d*)(?:-((?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\+([0-9a-zA-Z-]+(?:\.[0-9a-zA-Z-]+)*))?$")
53 .unwrap()
54});
55static NAME: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"^[A-Za-z0-9\-_]+$").unwrap());
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
60pub(crate) struct Endpoints {
61 pub(crate) endpoints: HashMap<String, endpoint::Inner>,
62}
63
64#[derive(Serialize, Deserialize)]
66pub struct PingResponse {
67 #[serde(rename = "type")]
69 pub kind: String,
70 pub name: String,
72 pub id: String,
74 pub version: String,
76 #[serde(default, deserialize_with = "endpoint::null_meta_as_default")]
78 pub metadata: HashMap<String, String>,
79}
80
81#[derive(Serialize, Deserialize)]
83pub struct Stats {
84 #[serde(rename = "type")]
86 pub kind: String,
87 pub name: String,
89 pub id: String,
91 pub version: String,
93 #[serde(with = "rfc3339")]
94 pub started: OffsetDateTime,
95 pub endpoints: Vec<endpoint::Stats>,
97}
98
99#[derive(Serialize, Deserialize, Debug, Clone)]
102pub struct Info {
103 #[serde(rename = "type")]
105 pub kind: String,
106 pub name: String,
108 pub id: String,
110 pub description: String,
112 pub version: String,
114 #[serde(default, deserialize_with = "endpoint::null_meta_as_default")]
116 pub metadata: HashMap<String, String>,
117 pub endpoints: Vec<endpoint::Info>,
119}
120
121#[derive(Serialize, Deserialize, Debug)]
123pub struct Config {
124 pub name: String,
127 pub description: Option<String>,
129 pub version: String,
131 #[serde(skip)]
133 pub stats_handler: Option<StatsHandler>,
134 pub metadata: Option<HashMap<String, String>>,
136 pub queue_group: Option<String>,
138}
139
140pub struct ServiceBuilder {
141 client: Client,
142 description: Option<String>,
143 stats_handler: Option<StatsHandler>,
144 metadata: Option<HashMap<String, String>>,
145 queue_group: Option<String>,
146}
147
148impl ServiceBuilder {
149 fn new(client: Client) -> Self {
150 Self {
151 client,
152 description: None,
153 stats_handler: None,
154 metadata: None,
155 queue_group: None,
156 }
157 }
158
159 pub fn description<S: ToString>(mut self, description: S) -> Self {
161 self.description = Some(description.to_string());
162 self
163 }
164
165 pub fn stats_handler<F>(mut self, handler: F) -> Self
167 where
168 F: FnMut(String, endpoint::Stats) -> serde_json::Value + Send + Sync + 'static,
169 {
170 self.stats_handler = Some(StatsHandler(Box::new(handler)));
171 self
172 }
173
174 pub fn metadata(mut self, metadata: HashMap<String, String>) -> Self {
176 self.metadata = Some(metadata);
177 self
178 }
179
180 pub fn queue_group<S: ToString>(mut self, queue_group: S) -> Self {
182 self.queue_group = Some(queue_group.to_string());
183 self
184 }
185
186 pub async fn start<N: ToString, V: ToString>(
188 self,
189 name: N,
190 version: V,
191 ) -> Result<Service, Error> {
192 Service::add(
193 self.client,
194 Config {
195 name: name.to_string(),
196 version: version.to_string(),
197 description: self.description,
198 stats_handler: self.stats_handler,
199 metadata: self.metadata,
200 queue_group: self.queue_group,
201 },
202 )
203 .await
204 }
205}
206
207pub enum Verb {
209 Ping,
210 Stats,
211 Info,
212 Schema,
213}
214
215impl Display for Verb {
216 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
217 match self {
218 Verb::Ping => write!(f, "PING"),
219 Verb::Stats => write!(f, "STATS"),
220 Verb::Info => write!(f, "INFO"),
221 Verb::Schema => write!(f, "SCHEMA"),
222 }
223 }
224}
225
226pub trait ServiceExt {
227 type Output: Future<Output = Result<Service, crate::Error>>;
228
229 fn add_service(&self, config: Config) -> Self::Output;
260
261 fn service_builder(&self) -> ServiceBuilder;
287}
288
289impl ServiceExt for Client {
290 type Output = Pin<Box<dyn Future<Output = Result<Service, crate::Error>> + Send>>;
291
292 fn add_service(&self, config: Config) -> Self::Output {
293 let client = self.clone();
294 Box::pin(async { Service::add(client, config).await })
295 }
296
297 fn service_builder(&self) -> ServiceBuilder {
298 ServiceBuilder::new(self.clone())
299 }
300}
301
302#[derive(Debug)]
323pub struct Service {
324 endpoints_state: Arc<Mutex<Endpoints>>,
325 info: Info,
326 client: Client,
327 handle: JoinHandle<Result<(), Error>>,
328 shutdown_tx: Sender<()>,
329 subjects: Arc<Mutex<Vec<String>>>,
330 queue_group: String,
331}
332
333impl Service {
334 async fn add(client: Client, config: Config) -> Result<Service, Error> {
335 if !SEMVER.is_match(config.version.as_str()) {
337 return Err(Box::new(std::io::Error::new(
338 std::io::ErrorKind::InvalidInput,
339 "service version is not a valid semver string",
340 )));
341 }
342 if !NAME.is_match(config.name.as_str()) {
344 return Err(Box::new(std::io::Error::new(
345 std::io::ErrorKind::InvalidInput,
346 "service name is not a valid string (only A-Z, a-z, 0-9, _, - are allowed)",
347 )));
348 }
349 let endpoints_state = Arc::new(Mutex::new(Endpoints {
350 endpoints: HashMap::new(),
351 }));
352
353 let queue_group = config
354 .queue_group
355 .unwrap_or(DEFAULT_QUEUE_GROUP.to_string());
356 let id = crate::id_generator::next();
357 let started = OffsetDateTime::now_utc();
358 let subjects = Arc::new(Mutex::new(Vec::new()));
359 let info = Info {
360 kind: "io.nats.micro.v1.info_response".to_string(),
361 name: config.name.clone(),
362 id: id.clone(),
363 description: config.description.clone().unwrap_or_default(),
364 version: config.version.clone(),
365 metadata: config.metadata.clone().unwrap_or_default(),
366 endpoints: Vec::new(),
367 };
368
369 let (shutdown_tx, _) = tokio::sync::broadcast::channel(1);
370
371 let mut pings =
373 verb_subscription(client.clone(), Verb::Ping, config.name.clone(), id.clone()).await?;
374 let mut infos =
375 verb_subscription(client.clone(), Verb::Info, config.name.clone(), id.clone()).await?;
376 let mut stats =
377 verb_subscription(client.clone(), Verb::Stats, config.name.clone(), id.clone()).await?;
378
379 let handle = tokio::task::spawn({
381 let mut stats_callback = config.stats_handler;
382 let info = info.clone();
383 let endpoints_state = endpoints_state.clone();
384 let client = client.clone();
385 async move {
386 loop {
387 tokio::select! {
388 Some(ping) = pings.next() => {
389 let pong = serde_json::to_vec(&PingResponse{
390 kind: "io.nats.micro.v1.ping_response".to_string(),
391 name: info.name.clone(),
392 id: info.id.clone(),
393 version: info.version.clone(),
394 metadata: info.metadata.clone(),
395 })?;
396 client.publish(ping.reply.unwrap(), pong.into()).await?;
397 },
398 Some(info_request) = infos.next() => {
399 let info = info.clone();
400
401 let endpoints: Vec<endpoint::Info> = {
402 endpoints_state.lock().unwrap().endpoints.values().map(|value| {
403 endpoint::Info {
404 name: value.name.to_owned(),
405 subject: value.subject.to_owned(),
406 queue_group: value.queue_group.to_owned(),
407 metadata: value.metadata.to_owned()
408 }
409 }).collect()
410 };
411 let info = Info {
412 endpoints,
413 ..info
414 };
415 let info_json = serde_json::to_vec(&info).map(Bytes::from)?;
416 client.publish(info_request.reply.unwrap(), info_json.clone()).await?;
417 },
418 Some(stats_request) = stats.next() => {
419 if let Some(stats_callback) = stats_callback.as_mut() {
420 let mut endpoint_stats_locked = endpoints_state.lock().unwrap();
421 for (key, value) in &mut endpoint_stats_locked.endpoints {
422 let data = stats_callback.0(key.to_string(), value.clone().into());
423 value.data = Some(data);
424 }
425 }
426 let stats = serde_json::to_vec(&Stats {
427 kind: "io.nats.micro.v1.stats_response".to_string(),
428 name: info.name.clone(),
429 id: info.id.clone(),
430 version: info.version.clone(),
431 started,
432 endpoints: endpoints_state.lock().unwrap().endpoints.values().cloned().map(Into::into).collect(),
433 })?;
434 client.publish(stats_request.reply.unwrap(), stats.into()).await?;
435 },
436 else => break,
437 }
438 }
439 Ok(())
440 }
441 });
442 Ok(Service {
443 endpoints_state,
444 info,
445 client,
446 handle,
447 shutdown_tx,
448 subjects,
449 queue_group,
450 })
451 }
452 pub async fn stop(self) -> Result<(), Error> {
457 self.shutdown_tx.send(())?;
458 self.handle.abort();
459 Ok(())
460 }
461
462 pub async fn reset(&mut self) {
464 for value in self.endpoints_state.lock().unwrap().endpoints.values_mut() {
465 value.errors = 0;
466 value.processing_time = Duration::default();
467 value.requests = 0;
468 value.average_processing_time = Duration::default();
469 }
470 }
471
472 pub async fn stats(&self) -> HashMap<String, endpoint::Stats> {
474 self.endpoints_state
475 .lock()
476 .unwrap()
477 .endpoints
478 .iter()
479 .map(|(key, value)| (key.to_owned(), value.to_owned().into()))
480 .collect()
481 }
482
483 pub async fn info(&self) -> Info {
485 self.info.clone()
486 }
487
488 pub fn group<S: ToString>(&self, prefix: S) -> Group {
505 self.group_with_queue_group(prefix, self.queue_group.clone())
506 }
507
508 pub fn group_with_queue_group<S: ToString, Z: ToString>(
525 &self,
526 prefix: S,
527 queue_group: Z,
528 ) -> Group {
529 Group {
530 subjects: self.subjects.clone(),
531 prefix: prefix.to_string(),
532 stats: self.endpoints_state.clone(),
533 client: self.client.clone(),
534 shutdown_tx: self.shutdown_tx.clone(),
535 queue_group: queue_group.to_string(),
536 }
537 }
538
539 pub fn endpoint_builder(&self) -> EndpointBuilder {
559 EndpointBuilder::new(
560 self.client.clone(),
561 self.endpoints_state.clone(),
562 self.shutdown_tx.clone(),
563 self.subjects.clone(),
564 self.queue_group.clone(),
565 )
566 }
567
568 pub async fn endpoint<S: ToString>(&self, subject: S) -> Result<Endpoint, Error> {
584 EndpointBuilder::new(
585 self.client.clone(),
586 self.endpoints_state.clone(),
587 self.shutdown_tx.clone(),
588 self.subjects.clone(),
589 self.queue_group.clone(),
590 )
591 .add(subject)
592 .await
593 }
594}
595
596pub struct Group {
597 prefix: String,
598 stats: Arc<Mutex<Endpoints>>,
599 client: Client,
600 shutdown_tx: Sender<()>,
601 subjects: Arc<Mutex<Vec<String>>>,
602 queue_group: String,
603}
604
605impl Group {
606 pub fn group<S: ToString>(&self, prefix: S) -> Group {
623 self.group_with_queue_group(prefix, self.queue_group.clone())
624 }
625
626 pub fn group_with_queue_group<S: ToString, Z: ToString>(
643 &self,
644 prefix: S,
645 queue_group: Z,
646 ) -> Group {
647 Group {
648 prefix: format!("{}.{}", self.prefix, prefix.to_string()),
649 stats: self.stats.clone(),
650 client: self.client.clone(),
651 shutdown_tx: self.shutdown_tx.clone(),
652 subjects: self.subjects.clone(),
653 queue_group: queue_group.to_string(),
654 }
655 }
656
657 pub async fn endpoint<S: ToString>(&self, subject: S) -> Result<Endpoint, Error> {
674 let endpoint = self.endpoint_builder();
675 endpoint.add(subject.to_string()).await
676 }
677
678 pub fn endpoint_builder(&self) -> EndpointBuilder {
695 let mut endpoint = EndpointBuilder::new(
696 self.client.clone(),
697 self.stats.clone(),
698 self.shutdown_tx.clone(),
699 self.subjects.clone(),
700 self.queue_group.clone(),
701 );
702 endpoint.prefix = Some(self.prefix.clone());
703 endpoint
704 }
705}
706
707async fn verb_subscription(
708 client: Client,
709 verb: Verb,
710 name: String,
711 id: String,
712) -> Result<stream::Fuse<SelectAll<Subscriber>>, Error> {
713 let verb_all = client
714 .subscribe(format!("{SERVICE_API_PREFIX}.{verb}"))
715 .await?;
716 let verb_name = client
717 .subscribe(format!("{SERVICE_API_PREFIX}.{verb}.{name}"))
718 .await?;
719 let verb_id = client
720 .subscribe(format!("{SERVICE_API_PREFIX}.{verb}.{name}.{id}"))
721 .await?;
722 Ok(stream::select_all([verb_all, verb_id, verb_name]).fuse())
723}
724
725type ShutdownReceiverFuture = Pin<
726 Box<dyn Future<Output = Result<(), tokio::sync::broadcast::error::RecvError>> + Send + Sync>,
727>;
728
729#[derive(Debug)]
731pub struct Request {
732 issued: Instant,
733 client: Client,
734 pub message: Message,
735 endpoint: String,
736 stats: Arc<Mutex<Endpoints>>,
737}
738
739impl Request {
740 pub async fn respond(&self, response: Result<Bytes, error::Error>) -> Result<(), PublishError> {
759 self.respond_with_headers(response, HeaderMap::new()).await
760 }
761
762 pub async fn respond_with_headers(
791 &self,
792 response: Result<Bytes, error::Error>,
793 mut headers: HeaderMap,
794 ) -> Result<(), PublishError> {
795 let reply = match self.message.reply.clone() {
796 None => {
797 return Err(PublishError::with_source(
798 PublishErrorKind::InvalidSubject,
799 "Request is missing reply subject to respond to",
800 ))
801 }
802 Some(subject) => subject,
803 };
804 let result = match response {
805 Ok(payload) => {
806 if headers.is_empty() {
807 self.client.publish(reply, payload).await
808 } else {
809 self.client
810 .publish_with_headers(reply, headers, payload)
811 .await
812 }
813 }
814 Err(err) => {
815 self.stats
816 .lock()
817 .unwrap()
818 .endpoints
819 .entry(self.endpoint.clone())
820 .and_modify(|stats| {
821 stats.last_error = Some(err.clone());
822 stats.errors += 1;
823 })
824 .or_default();
825 headers.insert(NATS_SERVICE_ERROR, err.status.as_str());
826 headers.insert(NATS_SERVICE_ERROR_CODE, err.code.to_string().as_str());
827 self.client
828 .publish_with_headers(reply, headers, "".into())
829 .await
830 }
831 };
832 let elapsed = self.issued.elapsed();
833 let mut stats = self.stats.lock().unwrap();
834 let stats = stats.endpoints.get_mut(self.endpoint.as_str()).unwrap();
835 stats.requests += 1;
836 stats.processing_time += elapsed;
837 stats.average_processing_time = {
838 let avg_nanos = (stats.processing_time.as_nanos() / stats.requests as u128) as u64;
839 Duration::from_nanos(avg_nanos)
840 };
841 result
842 }
843}
844
845#[derive(Debug)]
846pub struct EndpointBuilder {
847 client: Client,
848 stats: Arc<Mutex<Endpoints>>,
849 shutdown_tx: Sender<()>,
850 name: Option<String>,
851 metadata: Option<HashMap<String, String>>,
852 subjects: Arc<Mutex<Vec<String>>>,
853 queue_group: String,
854 prefix: Option<String>,
855}
856
857impl EndpointBuilder {
858 fn new(
859 client: Client,
860 stats: Arc<Mutex<Endpoints>>,
861 shutdown_tx: Sender<()>,
862 subjects: Arc<Mutex<Vec<String>>>,
863 queue_group: String,
864 ) -> EndpointBuilder {
865 EndpointBuilder {
866 client,
867 stats,
868 subjects,
869 shutdown_tx,
870 name: None,
871 metadata: None,
872 queue_group,
873 prefix: None,
874 }
875 }
876
877 pub fn name<S: ToString>(mut self, name: S) -> EndpointBuilder {
879 self.name = Some(name.to_string());
880 self
881 }
882
883 pub fn metadata(mut self, metadata: HashMap<String, String>) -> EndpointBuilder {
885 self.metadata = Some(metadata);
886 self
887 }
888
889 pub fn queue_group<S: ToString>(mut self, queue_group: S) -> EndpointBuilder {
891 self.queue_group = queue_group.to_string();
892 self
893 }
894
895 pub async fn add<S: ToString>(self, subject: S) -> Result<Endpoint, Error> {
897 let mut subject = subject.to_string();
898 if let Some(prefix) = self.prefix {
899 subject = format!("{prefix}.{subject}");
900 }
901 let endpoint_name = self.name.clone().unwrap_or_else(|| subject.clone());
902 let name = self
903 .name
904 .clone()
905 .unwrap_or_else(|| subject.clone().replace('.', "-"));
906 let requests = self
907 .client
908 .queue_subscribe(subject.to_owned(), self.queue_group.to_string())
909 .await?;
910 debug!("created service for endpoint {subject}");
911
912 let shutdown_rx = self.shutdown_tx.subscribe();
913
914 let mut stats = self.stats.lock().unwrap();
915 stats
916 .endpoints
917 .entry(endpoint_name.clone())
918 .or_insert(endpoint::Inner {
919 name,
920 subject: subject.clone(),
921 metadata: self.metadata.unwrap_or_default(),
922 queue_group: self.queue_group.clone(),
923 ..Default::default()
924 });
925 self.subjects.lock().unwrap().push(subject.clone());
926 Ok(Endpoint {
927 requests,
928 stats: self.stats.clone(),
929 client: self.client.clone(),
930 endpoint: endpoint_name,
931 shutdown: Some(shutdown_rx),
932 shutdown_future: None,
933 })
934 }
935}
936
937pub struct StatsHandler(pub Box<dyn FnMut(String, endpoint::Stats) -> serde_json::Value + Send>);
938
939impl std::fmt::Debug for StatsHandler {
940 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
941 write!(f, "Stats handler")
942 }
943}
944
945#[cfg(test)]
946mod tests {
947 use super::*;
948
949 #[tokio::test]
950 async fn test_group_with_queue_group() {
951 let server = nats_server::run_basic_server();
952 let client = crate::connect(server.client_url()).await.unwrap();
953
954 let group = Group {
955 prefix: "test".to_string(),
956 stats: Arc::new(Mutex::new(Endpoints {
957 endpoints: HashMap::new(),
958 })),
959 client,
960 shutdown_tx: tokio::sync::broadcast::channel(1).0,
961 subjects: Arc::new(Mutex::new(vec![])),
962 queue_group: "default".to_string(),
963 };
964
965 let new_group = group.group_with_queue_group("v1", "custom_queue");
966
967 assert_eq!(new_group.prefix, "test.v1");
968 assert_eq!(new_group.queue_group, "custom_queue");
969 }
970
971 #[tokio::test]
972 async fn test_respond_with_headers_overrides_error_headers() {
973 let server = nats_server::run_basic_server();
974 let client = crate::connect(server.client_url()).await.unwrap();
975
976 let service = client
977 .service_builder()
978 .start("test-service", "1.0.0")
979 .await
980 .unwrap();
981
982 let subject = "test.subject";
983 let mut endpoint = service.endpoint(subject).await.unwrap();
984
985 let handler = async {
986 if let Some(request) = endpoint.next().await {
987 let mut resp_headers = HeaderMap::new();
988 resp_headers.insert("x-success", "false");
989 resp_headers.insert(NATS_SERVICE_ERROR, "user-supplied-value");
990 resp_headers.insert(NATS_SERVICE_ERROR_CODE, "999");
991
992 let err = error::Error {
993 status: "internal-error".to_string(),
994 code: 500,
995 };
996
997 request
998 .respond_with_headers(Err(err), resp_headers)
999 .await
1000 .expect("failed to send response");
1001 }
1002 };
1003
1004 let requester = crate::connect(server.client_url()).await.unwrap();
1005 let request_fut = async { requester.request(subject, "".into()).await.unwrap() };
1006
1007 let (_, resp) = tokio::join!(handler, request_fut);
1008
1009 let headers = resp.headers.expect("expected headers on reply");
1010 assert_eq!(headers.get("x-success").unwrap().as_str(), "false");
1011 assert_eq!(
1012 headers.get(NATS_SERVICE_ERROR).unwrap().as_str(),
1013 "internal-error"
1014 );
1015 assert_eq!(
1016 headers.get(NATS_SERVICE_ERROR_CODE).unwrap().as_str(),
1017 "500"
1018 );
1019 }
1020
1021 #[tokio::test]
1022 async fn test_respond_with_headers_preserves_headers_on_success() {
1023 let server = nats_server::run_basic_server();
1024 let client = crate::connect(server.client_url()).await.unwrap();
1025
1026 let service = client
1027 .service_builder()
1028 .start("test-service", "1.0.0")
1029 .await
1030 .unwrap();
1031
1032 let subject = "test.subject";
1033 let mut endpoint = service.endpoint(subject).await.unwrap();
1034
1035 let handler = async {
1036 if let Some(request) = endpoint.next().await {
1037 let mut resp_headers = HeaderMap::new();
1038 resp_headers.insert("x-success", "false");
1039 resp_headers.insert("x-request-id", "req-123");
1040 resp_headers.insert(NATS_SERVICE_ERROR, "user-supplied-value");
1041 resp_headers.insert(NATS_SERVICE_ERROR_CODE, "999");
1042
1043 request
1044 .respond_with_headers(Ok("ok".into()), resp_headers)
1045 .await
1046 .unwrap();
1047 }
1048 };
1049
1050 let requester = crate::connect(server.client_url()).await.unwrap();
1051 let request_fut = async { requester.request(subject, "".into()).await.unwrap() };
1052
1053 let (_, resp) = tokio::join!(handler, request_fut);
1054
1055 let headers = resp.headers.expect("expected headers on reply");
1056 assert_eq!(headers.get("x-success").unwrap().as_str(), "false");
1057 assert_eq!(headers.get("x-request-id").unwrap().as_str(), "req-123");
1058 assert_eq!(
1059 headers.get(NATS_SERVICE_ERROR).unwrap().as_str(),
1060 "user-supplied-value"
1061 );
1062 assert_eq!(
1063 headers.get(NATS_SERVICE_ERROR_CODE).unwrap().as_str(),
1064 "999"
1065 );
1066 }
1067}