1pub mod error;
15
16use std::{
17    collections::HashMap,
18    fmt::Display,
19    pin::Pin,
20    sync::{Arc, Mutex},
21    time::{Duration, Instant},
22};
23
24use bytes::Bytes;
25pub mod endpoint;
26use futures::{
27    stream::{self, SelectAll},
28    Future, StreamExt,
29};
30use once_cell::sync::Lazy;
31use regex::Regex;
32use serde::{Deserialize, Serialize};
33use time::serde::rfc3339;
34use time::OffsetDateTime;
35use tokio::{sync::broadcast::Sender, task::JoinHandle};
36use tracing::debug;
37
38use crate::{Client, Error, HeaderMap, Message, PublishError, Subscriber};
39
40use self::endpoint::Endpoint;
41
42const SERVICE_API_PREFIX: &str = "$SRV";
43const DEFAULT_QUEUE_GROUP: &str = "q";
44pub const NATS_SERVICE_ERROR: &str = "Nats-Service-Error";
45pub const NATS_SERVICE_ERROR_CODE: &str = "Nats-Service-Error-Code";
46
47static SEMVER: Lazy<Regex> = Lazy::new(|| {
50    Regex::new(r"^(0|[1-9]\d*)\.(0|[1-9]\d*)\.(0|[1-9]\d*)(?:-((?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\+([0-9a-zA-Z-]+(?:\.[0-9a-zA-Z-]+)*))?$")
51        .unwrap()
52});
53static NAME: Lazy<Regex> = Lazy::new(|| Regex::new(r"^[A-Za-z0-9\-_]+$").unwrap());
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
58pub(crate) struct Endpoints {
59    pub(crate) endpoints: HashMap<String, endpoint::Inner>,
60}
61
62#[derive(Serialize, Deserialize)]
64pub struct PingResponse {
65    #[serde(rename = "type")]
67    pub kind: String,
68    pub name: String,
70    pub id: String,
72    pub version: String,
74    #[serde(default, deserialize_with = "endpoint::null_meta_as_default")]
76    pub metadata: HashMap<String, String>,
77}
78
79#[derive(Serialize, Deserialize)]
81pub struct Stats {
82    #[serde(rename = "type")]
84    pub kind: String,
85    pub name: String,
87    pub id: String,
89    pub version: String,
91    #[serde(with = "rfc3339")]
92    pub started: OffsetDateTime,
93    pub endpoints: Vec<endpoint::Stats>,
95}
96
97#[derive(Serialize, Deserialize, Debug, Clone)]
100pub struct Info {
101    #[serde(rename = "type")]
103    pub kind: String,
104    pub name: String,
106    pub id: String,
108    pub description: String,
110    pub version: String,
112    #[serde(default, deserialize_with = "endpoint::null_meta_as_default")]
114    pub metadata: HashMap<String, String>,
115    pub endpoints: Vec<endpoint::Info>,
117}
118
119#[derive(Serialize, Deserialize, Debug)]
121pub struct Config {
122    pub name: String,
125    pub description: Option<String>,
127    pub version: String,
129    #[serde(skip)]
131    pub stats_handler: Option<StatsHandler>,
132    pub metadata: Option<HashMap<String, String>>,
134    pub queue_group: Option<String>,
136}
137
138pub struct ServiceBuilder {
139    client: Client,
140    description: Option<String>,
141    stats_handler: Option<StatsHandler>,
142    metadata: Option<HashMap<String, String>>,
143    queue_group: Option<String>,
144}
145
146impl ServiceBuilder {
147    fn new(client: Client) -> Self {
148        Self {
149            client,
150            description: None,
151            stats_handler: None,
152            metadata: None,
153            queue_group: None,
154        }
155    }
156
157    pub fn description<S: ToString>(mut self, description: S) -> Self {
159        self.description = Some(description.to_string());
160        self
161    }
162
163    pub fn stats_handler<F>(mut self, handler: F) -> Self
165    where
166        F: FnMut(String, endpoint::Stats) -> serde_json::Value + Send + Sync + 'static,
167    {
168        self.stats_handler = Some(StatsHandler(Box::new(handler)));
169        self
170    }
171
172    pub fn metadata(mut self, metadata: HashMap<String, String>) -> Self {
174        self.metadata = Some(metadata);
175        self
176    }
177
178    pub fn queue_group<S: ToString>(mut self, queue_group: S) -> Self {
180        self.queue_group = Some(queue_group.to_string());
181        self
182    }
183
184    pub async fn start<S: ToString>(self, name: S, version: S) -> Result<Service, Error> {
186        Service::add(
187            self.client,
188            Config {
189                name: name.to_string(),
190                version: version.to_string(),
191                description: self.description,
192                stats_handler: self.stats_handler,
193                metadata: self.metadata,
194                queue_group: self.queue_group,
195            },
196        )
197        .await
198    }
199}
200
201pub enum Verb {
203    Ping,
204    Stats,
205    Info,
206    Schema,
207}
208
209impl Display for Verb {
210    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
211        match self {
212            Verb::Ping => write!(f, "PING"),
213            Verb::Stats => write!(f, "STATS"),
214            Verb::Info => write!(f, "INFO"),
215            Verb::Schema => write!(f, "SCHEMA"),
216        }
217    }
218}
219
220pub trait ServiceExt {
221    type Output: Future<Output = Result<Service, crate::Error>>;
222
223    fn add_service(&self, config: Config) -> Self::Output;
254
255    fn service_builder(&self) -> ServiceBuilder;
281}
282
283impl ServiceExt for crate::Client {
284    type Output = Pin<Box<dyn Future<Output = Result<Service, crate::Error>> + Send>>;
285
286    fn add_service(&self, config: Config) -> Self::Output {
287        let client = self.clone();
288        Box::pin(async { Service::add(client, config).await })
289    }
290
291    fn service_builder(&self) -> ServiceBuilder {
292        ServiceBuilder::new(self.clone())
293    }
294}
295
296#[derive(Debug)]
317pub struct Service {
318    endpoints_state: Arc<Mutex<Endpoints>>,
319    info: Info,
320    client: Client,
321    handle: JoinHandle<Result<(), Error>>,
322    shutdown_tx: tokio::sync::broadcast::Sender<()>,
323    subjects: Arc<Mutex<Vec<String>>>,
324    queue_group: String,
325}
326
327impl Service {
328    async fn add(client: Client, config: Config) -> Result<Service, Error> {
329        if !SEMVER.is_match(config.version.as_str()) {
331            return Err(Box::new(std::io::Error::new(
332                std::io::ErrorKind::InvalidInput,
333                "service version is not a valid semver string",
334            )));
335        }
336        if !NAME.is_match(config.name.as_str()) {
338            return Err(Box::new(std::io::Error::new(
339                std::io::ErrorKind::InvalidInput,
340                "service name is not a valid string (only A-Z, a-z, 0-9, _, - are allowed)",
341            )));
342        }
343        let endpoints_state = Arc::new(Mutex::new(Endpoints {
344            endpoints: HashMap::new(),
345        }));
346
347        let queue_group = config
348            .queue_group
349            .unwrap_or(DEFAULT_QUEUE_GROUP.to_string());
350        let id = nuid::next().to_string();
351        let started = time::OffsetDateTime::now_utc();
352        let subjects = Arc::new(Mutex::new(Vec::new()));
353        let info = Info {
354            kind: "io.nats.micro.v1.info_response".to_string(),
355            name: config.name.clone(),
356            id: id.clone(),
357            description: config.description.clone().unwrap_or_default(),
358            version: config.version.clone(),
359            metadata: config.metadata.clone().unwrap_or_default(),
360            endpoints: Vec::new(),
361        };
362
363        let (shutdown_tx, _) = tokio::sync::broadcast::channel(1);
364
365        let mut pings =
367            verb_subscription(client.clone(), Verb::Ping, config.name.clone(), id.clone()).await?;
368        let mut infos =
369            verb_subscription(client.clone(), Verb::Info, config.name.clone(), id.clone()).await?;
370        let mut stats =
371            verb_subscription(client.clone(), Verb::Stats, config.name.clone(), id.clone()).await?;
372
373        let handle = tokio::task::spawn({
375            let mut stats_callback = config.stats_handler;
376            let info = info.clone();
377            let endpoints_state = endpoints_state.clone();
378            let client = client.clone();
379            async move {
380                loop {
381                    tokio::select! {
382                        Some(ping) = pings.next() => {
383                            let pong = serde_json::to_vec(&PingResponse{
384                                kind: "io.nats.micro.v1.ping_response".to_string(),
385                                name: info.name.clone(),
386                                id: info.id.clone(),
387                                version: info.version.clone(),
388                                metadata: info.metadata.clone(),
389                            })?;
390                            client.publish(ping.reply.unwrap(), pong.into()).await?;
391                        },
392                        Some(info_request) = infos.next() => {
393                            let info = info.clone();
394
395                            let endpoints: Vec<endpoint::Info> = {
396                                endpoints_state.lock().unwrap().endpoints.values().map(|value| {
397                                    endpoint::Info {
398                                        name: value.name.to_owned(),
399                                        subject: value.subject.to_owned(),
400                                        queue_group: value.queue_group.to_owned(),
401                                        metadata: value.metadata.to_owned()
402                                    }
403                                }).collect()
404                            };
405                            let info = Info {
406                                endpoints,
407                                ..info
408                            };
409                            let info_json = serde_json::to_vec(&info).map(Bytes::from)?;
410                            client.publish(info_request.reply.unwrap(), info_json.clone()).await?;
411                        },
412                        Some(stats_request) = stats.next() => {
413                            if let Some(stats_callback) = stats_callback.as_mut() {
414                                let mut endpoint_stats_locked = endpoints_state.lock().unwrap();
415                                for (key, value) in &mut endpoint_stats_locked.endpoints {
416                                    let data = stats_callback.0(key.to_string(), value.clone().into());
417                                    value.data = Some(data);
418                                }
419                            }
420                            let stats = serde_json::to_vec(&Stats {
421                                kind: "io.nats.micro.v1.stats_response".to_string(),
422                                name: info.name.clone(),
423                                id: info.id.clone(),
424                                version: info.version.clone(),
425                                started,
426                                endpoints: endpoints_state.lock().unwrap().endpoints.values().cloned().map(Into::into).collect(),
427                            })?;
428                            client.publish(stats_request.reply.unwrap(), stats.into()).await?;
429                        },
430                        else => break,
431                    }
432                }
433                Ok(())
434            }
435        });
436        Ok(Service {
437            endpoints_state,
438            info,
439            client,
440            handle,
441            shutdown_tx,
442            subjects,
443            queue_group,
444        })
445    }
446    pub async fn stop(self) -> Result<(), Error> {
451        self.shutdown_tx.send(())?;
452        self.handle.abort();
453        Ok(())
454    }
455
456    pub async fn reset(&mut self) {
458        for value in self.endpoints_state.lock().unwrap().endpoints.values_mut() {
459            value.errors = 0;
460            value.processing_time = Duration::default();
461            value.requests = 0;
462            value.average_processing_time = Duration::default();
463        }
464    }
465
466    pub async fn stats(&self) -> HashMap<String, endpoint::Stats> {
468        self.endpoints_state
469            .lock()
470            .unwrap()
471            .endpoints
472            .iter()
473            .map(|(key, value)| (key.to_owned(), value.to_owned().into()))
474            .collect()
475    }
476
477    pub async fn info(&self) -> Info {
479        self.info.clone()
480    }
481
482    pub fn group<S: ToString>(&self, prefix: S) -> Group {
499        self.group_with_queue_group(prefix, self.queue_group.clone())
500    }
501
502    pub fn group_with_queue_group<S: ToString, Z: ToString>(
519        &self,
520        prefix: S,
521        queue_group: Z,
522    ) -> Group {
523        Group {
524            subjects: self.subjects.clone(),
525            prefix: prefix.to_string(),
526            stats: self.endpoints_state.clone(),
527            client: self.client.clone(),
528            shutdown_tx: self.shutdown_tx.clone(),
529            queue_group: queue_group.to_string(),
530        }
531    }
532
533    pub fn endpoint_builder(&self) -> EndpointBuilder {
553        EndpointBuilder::new(
554            self.client.clone(),
555            self.endpoints_state.clone(),
556            self.shutdown_tx.clone(),
557            self.subjects.clone(),
558            self.queue_group.clone(),
559        )
560    }
561
562    pub async fn endpoint<S: ToString>(&self, subject: S) -> Result<Endpoint, Error> {
578        EndpointBuilder::new(
579            self.client.clone(),
580            self.endpoints_state.clone(),
581            self.shutdown_tx.clone(),
582            self.subjects.clone(),
583            self.queue_group.clone(),
584        )
585        .add(subject)
586        .await
587    }
588}
589
590pub struct Group {
591    prefix: String,
592    stats: Arc<Mutex<Endpoints>>,
593    client: Client,
594    shutdown_tx: tokio::sync::broadcast::Sender<()>,
595    subjects: Arc<Mutex<Vec<String>>>,
596    queue_group: String,
597}
598
599impl Group {
600    pub fn group<S: ToString>(&self, prefix: S) -> Group {
617        self.group_with_queue_group(prefix, self.queue_group.clone())
618    }
619
620    pub fn group_with_queue_group<S: ToString, Z: ToString>(
637        &self,
638        prefix: S,
639        queue_group: Z,
640    ) -> Group {
641        Group {
642            prefix: format!("{}.{}", self.prefix, prefix.to_string()),
643            stats: self.stats.clone(),
644            client: self.client.clone(),
645            shutdown_tx: self.shutdown_tx.clone(),
646            subjects: self.subjects.clone(),
647            queue_group: queue_group.to_string(),
648        }
649    }
650
651    pub async fn endpoint<S: ToString>(&self, subject: S) -> Result<Endpoint, Error> {
668        let mut endpoint = EndpointBuilder::new(
669            self.client.clone(),
670            self.stats.clone(),
671            self.shutdown_tx.clone(),
672            self.subjects.clone(),
673            self.queue_group.clone(),
674        );
675        endpoint.prefix = Some(self.prefix.clone());
676        endpoint.add(subject.to_string()).await
677    }
678
679    pub fn endpoint_builder(&self) -> EndpointBuilder {
696        let mut endpoint = EndpointBuilder::new(
697            self.client.clone(),
698            self.stats.clone(),
699            self.shutdown_tx.clone(),
700            self.subjects.clone(),
701            self.queue_group.clone(),
702        );
703        endpoint.prefix = Some(self.prefix.clone());
704        endpoint
705    }
706}
707
708async fn verb_subscription(
709    client: Client,
710    verb: Verb,
711    name: String,
712    id: String,
713) -> Result<futures::stream::Fuse<SelectAll<Subscriber>>, Error> {
714    let verb_all = client
715        .subscribe(format!("{SERVICE_API_PREFIX}.{verb}"))
716        .await?;
717    let verb_name = client
718        .subscribe(format!("{SERVICE_API_PREFIX}.{verb}.{name}"))
719        .await?;
720    let verb_id = client
721        .subscribe(format!("{SERVICE_API_PREFIX}.{verb}.{name}.{id}"))
722        .await?;
723    Ok(stream::select_all([verb_all, verb_id, verb_name]).fuse())
724}
725
726type ShutdownReceiverFuture = Pin<
727    Box<dyn Future<Output = Result<(), tokio::sync::broadcast::error::RecvError>> + Send + Sync>,
728>;
729
730#[derive(Debug)]
732pub struct Request {
733    issued: Instant,
734    client: Client,
735    pub message: Message,
736    endpoint: String,
737    stats: Arc<Mutex<Endpoints>>,
738}
739
740impl Request {
741    pub async fn respond(&self, response: Result<Bytes, error::Error>) -> Result<(), PublishError> {
760        let reply = self.message.reply.clone().unwrap();
761        let result = match response {
762            Ok(payload) => self.client.publish(reply, payload).await,
763            Err(err) => {
764                self.stats
765                    .lock()
766                    .unwrap()
767                    .endpoints
768                    .entry(self.endpoint.clone())
769                    .and_modify(|stats| {
770                        stats.last_error = Some(err.clone());
771                        stats.errors += 1;
772                    })
773                    .or_default();
774                let mut headers = HeaderMap::new();
775                headers.insert(NATS_SERVICE_ERROR, err.status.as_str());
776                headers.insert(NATS_SERVICE_ERROR_CODE, err.code.to_string().as_str());
777                self.client
778                    .publish_with_headers(reply, headers, "".into())
779                    .await
780            }
781        };
782        let elapsed = self.issued.elapsed();
783        let mut stats = self.stats.lock().unwrap();
784        let stats = stats.endpoints.get_mut(self.endpoint.as_str()).unwrap();
785        stats.requests += 1;
786        stats.processing_time += elapsed;
787        stats.average_processing_time = {
788            let avg_nanos = (stats.processing_time.as_nanos() / stats.requests as u128) as u64;
789            Duration::from_nanos(avg_nanos)
790        };
791        result
792    }
793}
794
795#[derive(Debug)]
796pub struct EndpointBuilder {
797    client: Client,
798    stats: Arc<Mutex<Endpoints>>,
799    shutdown_tx: Sender<()>,
800    name: Option<String>,
801    metadata: Option<HashMap<String, String>>,
802    subjects: Arc<Mutex<Vec<String>>>,
803    queue_group: String,
804    prefix: Option<String>,
805}
806
807impl EndpointBuilder {
808    fn new(
809        client: Client,
810        stats: Arc<Mutex<Endpoints>>,
811        shutdown_tx: Sender<()>,
812        subjects: Arc<Mutex<Vec<String>>>,
813        queue_group: String,
814    ) -> EndpointBuilder {
815        EndpointBuilder {
816            client,
817            stats,
818            subjects,
819            shutdown_tx,
820            name: None,
821            metadata: None,
822            queue_group,
823            prefix: None,
824        }
825    }
826
827    pub fn name<S: ToString>(mut self, name: S) -> EndpointBuilder {
829        self.name = Some(name.to_string());
830        self
831    }
832
833    pub fn metadata(mut self, metadata: HashMap<String, String>) -> EndpointBuilder {
835        self.metadata = Some(metadata);
836        self
837    }
838
839    pub fn queue_group<S: ToString>(mut self, queue_group: S) -> EndpointBuilder {
841        self.queue_group = queue_group.to_string();
842        self
843    }
844
845    pub async fn add<S: ToString>(self, subject: S) -> Result<Endpoint, Error> {
847        let mut subject = subject.to_string();
848        if let Some(prefix) = self.prefix {
849            subject = format!("{}.{}", prefix, subject);
850        }
851        let endpoint_name = self.name.clone().unwrap_or_else(|| subject.clone());
852        let name = self
853            .name
854            .clone()
855            .unwrap_or_else(|| subject.clone().replace('.', "-"));
856        let requests = self
857            .client
858            .queue_subscribe(subject.to_owned(), self.queue_group.to_string())
859            .await?;
860        debug!("created service for endpoint {subject}");
861
862        let shutdown_rx = self.shutdown_tx.subscribe();
863
864        let mut stats = self.stats.lock().unwrap();
865        stats
866            .endpoints
867            .entry(endpoint_name.clone())
868            .or_insert(endpoint::Inner {
869                name,
870                subject: subject.clone(),
871                metadata: self.metadata.unwrap_or_default(),
872                queue_group: self.queue_group.clone(),
873                ..Default::default()
874            });
875        self.subjects.lock().unwrap().push(subject.clone());
876        Ok(Endpoint {
877            requests,
878            stats: self.stats.clone(),
879            client: self.client.clone(),
880            endpoint: endpoint_name,
881            shutdown: Some(shutdown_rx),
882            shutdown_future: None,
883        })
884    }
885}
886
887pub struct StatsHandler(pub Box<dyn FnMut(String, endpoint::Stats) -> serde_json::Value + Send>);
888
889impl std::fmt::Debug for StatsHandler {
890    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
891        write!(f, "Stats handler")
892    }
893}
894
895#[cfg(test)]
896mod tests {
897    use super::*;
898
899    #[tokio::test]
900    async fn test_group_with_queue_group() {
901        let server = nats_server::run_basic_server();
902        let client = crate::connect(server.client_url()).await.unwrap();
903
904        let group = Group {
905            prefix: "test".to_string(),
906            stats: Arc::new(Mutex::new(Endpoints {
907                endpoints: HashMap::new(),
908            })),
909            client,
910            shutdown_tx: tokio::sync::broadcast::channel(1).0,
911            subjects: Arc::new(Mutex::new(vec![])),
912            queue_group: "default".to_string(),
913        };
914
915        let new_group = group.group_with_queue_group("v1", "custom_queue");
916
917        assert_eq!(new_group.prefix, "test.v1");
918        assert_eq!(new_group.queue_group, "custom_queue");
919    }
920}