Skip to main content

async_nats/service/
mod.rs

1// Copyright 2020-2023 The NATS Authors
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6// http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14pub mod error;
15
16use std::{
17    collections::HashMap,
18    fmt::Display,
19    pin::Pin,
20    sync::{Arc, Mutex},
21    time::{Duration, Instant},
22};
23
24use bytes::Bytes;
25pub mod endpoint;
26use futures_util::{
27    stream::{self, SelectAll},
28    Future, StreamExt,
29};
30use regex::Regex;
31use serde::{Deserialize, Serialize};
32use std::sync::LazyLock;
33use time::serde::rfc3339;
34use time::OffsetDateTime;
35use tokio::{sync::broadcast::Sender, task::JoinHandle};
36use tracing::debug;
37
38use crate::{
39    client::PublishErrorKind, Client, Error, HeaderMap, Message, PublishError, Subscriber,
40};
41
42use self::endpoint::Endpoint;
43
44const SERVICE_API_PREFIX: &str = "$SRV";
45const DEFAULT_QUEUE_GROUP: &str = "q";
46pub const NATS_SERVICE_ERROR: &str = "Nats-Service-Error";
47pub const NATS_SERVICE_ERROR_CODE: &str = "Nats-Service-Error-Code";
48
49// uses recommended semver validation expression from
50// https://semver.org/#is-there-a-suggested-regular-expression-regex-to-check-a-semver-string
51static SEMVER: LazyLock<Regex> = LazyLock::new(|| {
52    Regex::new(r"^(0|[1-9]\d*)\.(0|[1-9]\d*)\.(0|[1-9]\d*)(?:-((?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\+([0-9a-zA-Z-]+(?:\.[0-9a-zA-Z-]+)*))?$")
53        .unwrap()
54});
55// From ADR-33: Name can only have A-Z, a-z, 0-9, dash, underscore.
56static NAME: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"^[A-Za-z0-9\-_]+$").unwrap());
57
58/// Represents state for all endpoints.
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub(crate) struct Endpoints {
61    pub(crate) endpoints: HashMap<String, endpoint::Inner>,
62}
63
64/// Response for `PING` requests.
65#[derive(Serialize, Deserialize)]
66pub struct PingResponse {
67    /// Response type.
68    #[serde(rename = "type")]
69    pub kind: String,
70    /// Service name.
71    pub name: String,
72    /// Service id.
73    pub id: String,
74    /// Service version.
75    pub version: String,
76    /// Additional metadata
77    #[serde(default, deserialize_with = "endpoint::null_meta_as_default")]
78    pub metadata: HashMap<String, String>,
79}
80
81/// Response for `STATS` requests.
82#[derive(Serialize, Deserialize)]
83pub struct Stats {
84    /// Response type.
85    #[serde(rename = "type")]
86    pub kind: String,
87    /// Service name.
88    pub name: String,
89    /// Service id.
90    pub id: String,
91    // Service version.
92    pub version: String,
93    #[serde(with = "rfc3339")]
94    pub started: OffsetDateTime,
95    /// Statistics of all endpoints.
96    pub endpoints: Vec<endpoint::Stats>,
97}
98
99/// Information about service instance.
100/// Service name.
101#[derive(Serialize, Deserialize, Debug, Clone)]
102pub struct Info {
103    /// Response type.
104    #[serde(rename = "type")]
105    pub kind: String,
106    /// Service name.
107    pub name: String,
108    /// Service id.
109    pub id: String,
110    /// Service description.
111    pub description: String,
112    /// Service version.
113    pub version: String,
114    /// Additional metadata
115    #[serde(default, deserialize_with = "endpoint::null_meta_as_default")]
116    pub metadata: HashMap<String, String>,
117    /// Info about all service endpoints.
118    pub endpoints: Vec<endpoint::Info>,
119}
120
121/// Configuration of the [Service].
122#[derive(Serialize, Deserialize, Debug)]
123pub struct Config {
124    /// Really the kind of the service. Shared by all the services that have the same name.
125    /// This name can only have A-Z, a-z, 0-9, dash, underscore
126    pub name: String,
127    /// a human-readable description about the service
128    pub description: Option<String>,
129    /// A SemVer valid service version.
130    pub version: String,
131    /// Custom handler for providing the `EndpointStats.data` value.
132    #[serde(skip)]
133    pub stats_handler: Option<StatsHandler>,
134    /// Additional service metadata
135    pub metadata: Option<HashMap<String, String>>,
136    /// Custom queue group config
137    pub queue_group: Option<String>,
138}
139
140pub struct ServiceBuilder {
141    client: Client,
142    description: Option<String>,
143    stats_handler: Option<StatsHandler>,
144    metadata: Option<HashMap<String, String>>,
145    queue_group: Option<String>,
146}
147
148impl ServiceBuilder {
149    fn new(client: Client) -> Self {
150        Self {
151            client,
152            description: None,
153            stats_handler: None,
154            metadata: None,
155            queue_group: None,
156        }
157    }
158
159    /// Description for the service.
160    pub fn description<S: ToString>(mut self, description: S) -> Self {
161        self.description = Some(description.to_string());
162        self
163    }
164
165    /// Handler for custom service statistics.
166    pub fn stats_handler<F>(mut self, handler: F) -> Self
167    where
168        F: FnMut(String, endpoint::Stats) -> serde_json::Value + Send + Sync + 'static,
169    {
170        self.stats_handler = Some(StatsHandler(Box::new(handler)));
171        self
172    }
173
174    /// Additional service metadata.
175    pub fn metadata(mut self, metadata: HashMap<String, String>) -> Self {
176        self.metadata = Some(metadata);
177        self
178    }
179
180    /// Custom queue group. Default is `q`.
181    pub fn queue_group<S: ToString>(mut self, queue_group: S) -> Self {
182        self.queue_group = Some(queue_group.to_string());
183        self
184    }
185
186    /// Starts the service with configured options.
187    pub async fn start<N: ToString, V: ToString>(
188        self,
189        name: N,
190        version: V,
191    ) -> Result<Service, Error> {
192        Service::add(
193            self.client,
194            Config {
195                name: name.to_string(),
196                version: version.to_string(),
197                description: self.description,
198                stats_handler: self.stats_handler,
199                metadata: self.metadata,
200                queue_group: self.queue_group,
201            },
202        )
203        .await
204    }
205}
206
207/// Verbs that can be used to acquire information from the services.
208pub enum Verb {
209    Ping,
210    Stats,
211    Info,
212    Schema,
213}
214
215impl Display for Verb {
216    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
217        match self {
218            Verb::Ping => write!(f, "PING"),
219            Verb::Stats => write!(f, "STATS"),
220            Verb::Info => write!(f, "INFO"),
221            Verb::Schema => write!(f, "SCHEMA"),
222        }
223    }
224}
225
226pub trait ServiceExt {
227    type Output: Future<Output = Result<Service, crate::Error>>;
228
229    /// Adds a Service instance.
230    ///
231    /// # Examples
232    ///
233    /// ```no_run
234    /// # #[tokio::main]
235    /// # async fn main() -> Result<(), async_nats::Error> {
236    /// use async_nats::service::ServiceExt;
237    /// use futures_util::StreamExt;
238    /// let client = async_nats::connect("demo.nats.io").await?;
239    /// let mut service = client
240    ///     .add_service(async_nats::service::Config {
241    ///         name: "generator".to_string(),
242    ///         version: "1.0.0".to_string(),
243    ///         description: None,
244    ///         stats_handler: None,
245    ///         metadata: None,
246    ///         queue_group: None,
247    ///     })
248    ///     .await?;
249    ///
250    /// let mut endpoint = service.endpoint("get").await?;
251    ///
252    /// if let Some(request) = endpoint.next().await {
253    ///     request.respond(Ok("hello".into())).await?;
254    /// }
255    ///
256    /// # Ok(())
257    /// # }
258    /// ```
259    fn add_service(&self, config: Config) -> Self::Output;
260
261    /// Returns Service instance builder.
262    ///
263    /// # Examples
264    ///
265    /// ```no_run
266    /// # #[tokio::main]
267    /// # async fn main() -> Result<(), async_nats::Error> {
268    /// use async_nats::service::ServiceExt;
269    /// use futures_util::StreamExt;
270    /// let client = async_nats::connect("demo.nats.io").await?;
271    /// let mut service = client
272    ///     .service_builder()
273    ///     .description("some service")
274    ///     .stats_handler(|endpoint, stats| serde_json::json!({ "endpoint": endpoint }))
275    ///     .start("products", "1.0.0")
276    ///     .await?;
277    ///
278    /// let mut endpoint = service.endpoint("get").await?;
279    ///
280    /// if let Some(request) = endpoint.next().await {
281    ///     request.respond(Ok("hello".into())).await?;
282    /// }
283    /// # Ok(())
284    /// # }
285    /// ```
286    fn service_builder(&self) -> ServiceBuilder;
287}
288
289impl ServiceExt for Client {
290    type Output = Pin<Box<dyn Future<Output = Result<Service, crate::Error>> + Send>>;
291
292    fn add_service(&self, config: Config) -> Self::Output {
293        let client = self.clone();
294        Box::pin(async { Service::add(client, config).await })
295    }
296
297    fn service_builder(&self) -> ServiceBuilder {
298        ServiceBuilder::new(self.clone())
299    }
300}
301
302/// Service instance.
303///
304/// # Examples
305///
306/// ```no_run
307/// # #[tokio::main]
308/// # async fn main() -> Result<(), async_nats::Error> {
309/// use async_nats::service::ServiceExt;
310/// use futures_util::StreamExt;
311/// let client = async_nats::connect("demo.nats.io").await?;
312/// let mut service = client.service_builder().start("generator", "1.0.0").await?;
313/// let mut endpoint = service.endpoint("get").await?;
314///
315/// if let Some(request) = endpoint.next().await {
316///     request.respond(Ok("hello".into())).await?;
317/// }
318///
319/// # Ok(())
320/// # }
321/// ```
322#[derive(Debug)]
323pub struct Service {
324    endpoints_state: Arc<Mutex<Endpoints>>,
325    info: Info,
326    client: Client,
327    handle: JoinHandle<Result<(), Error>>,
328    shutdown_tx: Sender<()>,
329    subjects: Arc<Mutex<Vec<String>>>,
330    queue_group: String,
331}
332
333impl Service {
334    async fn add(client: Client, config: Config) -> Result<Service, Error> {
335        // validate service version semver string.
336        if !SEMVER.is_match(config.version.as_str()) {
337            return Err(Box::new(std::io::Error::new(
338                std::io::ErrorKind::InvalidInput,
339                "service version is not a valid semver string",
340            )));
341        }
342        // validate service name.
343        if !NAME.is_match(config.name.as_str()) {
344            return Err(Box::new(std::io::Error::new(
345                std::io::ErrorKind::InvalidInput,
346                "service name is not a valid string (only A-Z, a-z, 0-9, _, - are allowed)",
347            )));
348        }
349        let endpoints_state = Arc::new(Mutex::new(Endpoints {
350            endpoints: HashMap::new(),
351        }));
352
353        let queue_group = config
354            .queue_group
355            .unwrap_or(DEFAULT_QUEUE_GROUP.to_string());
356        let id = crate::id_generator::next();
357        let started = OffsetDateTime::now_utc();
358        let subjects = Arc::new(Mutex::new(Vec::new()));
359        let info = Info {
360            kind: "io.nats.micro.v1.info_response".to_string(),
361            name: config.name.clone(),
362            id: id.clone(),
363            description: config.description.clone().unwrap_or_default(),
364            version: config.version.clone(),
365            metadata: config.metadata.clone().unwrap_or_default(),
366            endpoints: Vec::new(),
367        };
368
369        let (shutdown_tx, _) = tokio::sync::broadcast::channel(1);
370
371        // create subscriptions for all verbs.
372        let mut pings =
373            verb_subscription(client.clone(), Verb::Ping, config.name.clone(), id.clone()).await?;
374        let mut infos =
375            verb_subscription(client.clone(), Verb::Info, config.name.clone(), id.clone()).await?;
376        let mut stats =
377            verb_subscription(client.clone(), Verb::Stats, config.name.clone(), id.clone()).await?;
378
379        // Start a task for handling verbs subscriptions.
380        let handle = tokio::task::spawn({
381            let mut stats_callback = config.stats_handler;
382            let info = info.clone();
383            let endpoints_state = endpoints_state.clone();
384            let client = client.clone();
385            async move {
386                loop {
387                    tokio::select! {
388                        Some(ping) = pings.next() => {
389                            let pong = serde_json::to_vec(&PingResponse{
390                                kind: "io.nats.micro.v1.ping_response".to_string(),
391                                name: info.name.clone(),
392                                id: info.id.clone(),
393                                version: info.version.clone(),
394                                metadata: info.metadata.clone(),
395                            })?;
396                            client.publish(ping.reply.unwrap(), pong.into()).await?;
397                        },
398                        Some(info_request) = infos.next() => {
399                            let info = info.clone();
400
401                            let endpoints: Vec<endpoint::Info> = {
402                                endpoints_state.lock().unwrap().endpoints.values().map(|value| {
403                                    endpoint::Info {
404                                        name: value.name.to_owned(),
405                                        subject: value.subject.to_owned(),
406                                        queue_group: value.queue_group.to_owned(),
407                                        metadata: value.metadata.to_owned()
408                                    }
409                                }).collect()
410                            };
411                            let info = Info {
412                                endpoints,
413                                ..info
414                            };
415                            let info_json = serde_json::to_vec(&info).map(Bytes::from)?;
416                            client.publish(info_request.reply.unwrap(), info_json.clone()).await?;
417                        },
418                        Some(stats_request) = stats.next() => {
419                            if let Some(stats_callback) = stats_callback.as_mut() {
420                                let mut endpoint_stats_locked = endpoints_state.lock().unwrap();
421                                for (key, value) in &mut endpoint_stats_locked.endpoints {
422                                    let data = stats_callback.0(key.to_string(), value.clone().into());
423                                    value.data = Some(data);
424                                }
425                            }
426                            let stats = serde_json::to_vec(&Stats {
427                                kind: "io.nats.micro.v1.stats_response".to_string(),
428                                name: info.name.clone(),
429                                id: info.id.clone(),
430                                version: info.version.clone(),
431                                started,
432                                endpoints: endpoints_state.lock().unwrap().endpoints.values().cloned().map(Into::into).collect(),
433                            })?;
434                            client.publish(stats_request.reply.unwrap(), stats.into()).await?;
435                        },
436                        else => break,
437                    }
438                }
439                Ok(())
440            }
441        });
442        Ok(Service {
443            endpoints_state,
444            info,
445            client,
446            handle,
447            shutdown_tx,
448            subjects,
449            queue_group,
450        })
451    }
452    /// Stops this instance of the [Service].
453    /// If there are more instances of [Services][Service] with the same name, the [Service] will
454    /// be scaled down by one instance. If it was the only running instance, it will effectively
455    /// remove the service entirely.
456    pub async fn stop(self) -> Result<(), Error> {
457        self.shutdown_tx.send(())?;
458        self.handle.abort();
459        Ok(())
460    }
461
462    /// Resets [Stats] of the [Service] instance.
463    pub async fn reset(&mut self) {
464        for value in self.endpoints_state.lock().unwrap().endpoints.values_mut() {
465            value.errors = 0;
466            value.processing_time = Duration::default();
467            value.requests = 0;
468            value.average_processing_time = Duration::default();
469        }
470    }
471
472    /// Returns [Stats] for this service instance.
473    pub async fn stats(&self) -> HashMap<String, endpoint::Stats> {
474        self.endpoints_state
475            .lock()
476            .unwrap()
477            .endpoints
478            .iter()
479            .map(|(key, value)| (key.to_owned(), value.to_owned().into()))
480            .collect()
481    }
482
483    /// Returns [Info] for this service instance.
484    pub async fn info(&self) -> Info {
485        self.info.clone()
486    }
487
488    /// Creates a group for endpoints under common prefix.
489    ///
490    /// # Examples
491    ///
492    /// ```no_run
493    /// # #[tokio::main]
494    /// # async fn main() -> Result<(), async_nats::Error> {
495    /// use async_nats::service::ServiceExt;
496    /// let client = async_nats::connect("demo.nats.io").await?;
497    /// let mut service = client.service_builder().start("service", "1.0.0").await?;
498    ///
499    /// let v1 = service.group("v1");
500    /// let products = v1.endpoint("products").await?;
501    /// # Ok(())
502    /// # }
503    /// ```
504    pub fn group<S: ToString>(&self, prefix: S) -> Group {
505        self.group_with_queue_group(prefix, self.queue_group.clone())
506    }
507
508    /// Creates a group for endpoints under common prefix with custom queue group.
509    ///
510    /// # Examples
511    ///
512    /// ```no_run
513    /// # #[tokio::main]
514    /// # async fn main() -> Result<(), async_nats::Error> {
515    /// use async_nats::service::ServiceExt;
516    /// let client = async_nats::connect("demo.nats.io").await?;
517    /// let mut service = client.service_builder().start("service", "1.0.0").await?;
518    ///
519    /// let v1 = service.group("v1");
520    /// let products = v1.endpoint("products").await?;
521    /// # Ok(())
522    /// # }
523    /// ```
524    pub fn group_with_queue_group<S: ToString, Z: ToString>(
525        &self,
526        prefix: S,
527        queue_group: Z,
528    ) -> Group {
529        Group {
530            subjects: self.subjects.clone(),
531            prefix: prefix.to_string(),
532            stats: self.endpoints_state.clone(),
533            client: self.client.clone(),
534            shutdown_tx: self.shutdown_tx.clone(),
535            queue_group: queue_group.to_string(),
536        }
537    }
538
539    /// Builder for customized [Endpoint] creation.
540    ///
541    /// # Examples
542    ///
543    /// ```no_run
544    /// # #[tokio::main]
545    /// # async fn main() -> Result<(), async_nats::Error> {
546    /// use async_nats::service::ServiceExt;
547    /// let client = async_nats::connect("demo.nats.io").await?;
548    /// let mut service = client.service_builder().start("service", "1.0.0").await?;
549    ///
550    /// let products = service
551    ///     .endpoint_builder()
552    ///     .name("api")
553    ///     .add("products")
554    ///     .await?;
555    /// # Ok(())
556    /// # }
557    /// ```
558    pub fn endpoint_builder(&self) -> EndpointBuilder {
559        EndpointBuilder::new(
560            self.client.clone(),
561            self.endpoints_state.clone(),
562            self.shutdown_tx.clone(),
563            self.subjects.clone(),
564            self.queue_group.clone(),
565        )
566    }
567
568    /// Adds a new endpoint to the [Service].
569    ///
570    /// # Examples
571    ///
572    /// ```no_run
573    /// # #[tokio::main]
574    /// # async fn main() -> Result<(), async_nats::Error> {
575    /// use async_nats::service::ServiceExt;
576    /// let client = async_nats::connect("demo.nats.io").await?;
577    /// let mut service = client.service_builder().start("service", "1.0.0").await?;
578    ///
579    /// let products = service.endpoint("products").await?;
580    /// # Ok(())
581    /// # }
582    /// ```
583    pub async fn endpoint<S: ToString>(&self, subject: S) -> Result<Endpoint, Error> {
584        EndpointBuilder::new(
585            self.client.clone(),
586            self.endpoints_state.clone(),
587            self.shutdown_tx.clone(),
588            self.subjects.clone(),
589            self.queue_group.clone(),
590        )
591        .add(subject)
592        .await
593    }
594}
595
596pub struct Group {
597    prefix: String,
598    stats: Arc<Mutex<Endpoints>>,
599    client: Client,
600    shutdown_tx: Sender<()>,
601    subjects: Arc<Mutex<Vec<String>>>,
602    queue_group: String,
603}
604
605impl Group {
606    /// Creates a group for [Endpoints][Endpoint] under common prefix.
607    ///
608    /// # Examples
609    ///
610    /// ```no_run
611    /// # #[tokio::main]
612    /// # async fn main() -> Result<(), async_nats::Error> {
613    /// use async_nats::service::ServiceExt;
614    /// let client = async_nats::connect("demo.nats.io").await?;
615    /// let mut service = client.service_builder().start("service", "1.0.0").await?;
616    ///
617    /// let v1 = service.group("v1");
618    /// let products = v1.endpoint("products").await?;
619    /// # Ok(())
620    /// # }
621    /// ```
622    pub fn group<S: ToString>(&self, prefix: S) -> Group {
623        self.group_with_queue_group(prefix, self.queue_group.clone())
624    }
625
626    /// Creates a group for [Endpoints][Endpoint] under common prefix with custom queue group.
627    ///
628    /// # Examples
629    ///
630    /// ```no_run
631    /// # #[tokio::main]
632    /// # async fn main() -> Result<(), async_nats::Error> {
633    /// use async_nats::service::ServiceExt;
634    /// let client = async_nats::connect("demo.nats.io").await?;
635    /// let mut service = client.service_builder().start("service", "1.0.0").await?;
636    ///
637    /// let v1 = service.group("v1");
638    /// let products = v1.endpoint("products").await?;
639    /// # Ok(())
640    /// # }
641    /// ```
642    pub fn group_with_queue_group<S: ToString, Z: ToString>(
643        &self,
644        prefix: S,
645        queue_group: Z,
646    ) -> Group {
647        Group {
648            prefix: format!("{}.{}", self.prefix, prefix.to_string()),
649            stats: self.stats.clone(),
650            client: self.client.clone(),
651            shutdown_tx: self.shutdown_tx.clone(),
652            subjects: self.subjects.clone(),
653            queue_group: queue_group.to_string(),
654        }
655    }
656
657    /// Adds a new endpoint to the [Service] under current [Group]
658    ///
659    /// # Examples
660    ///
661    /// ```no_run
662    /// # #[tokio::main]
663    /// # async fn main() -> Result<(), async_nats::Error> {
664    /// use async_nats::service::ServiceExt;
665    /// let client = async_nats::connect("demo.nats.io").await?;
666    /// let mut service = client.service_builder().start("service", "1.0.0").await?;
667    /// let v1 = service.group("v1");
668    ///
669    /// let products = v1.endpoint("products").await?;
670    /// # Ok(())
671    /// # }
672    /// ```
673    pub async fn endpoint<S: ToString>(&self, subject: S) -> Result<Endpoint, Error> {
674        let endpoint = self.endpoint_builder();
675        endpoint.add(subject.to_string()).await
676    }
677
678    /// Builder for customized [Endpoint] creation under current [Group]
679    ///
680    /// # Examples
681    ///
682    /// ```no_run
683    /// # #[tokio::main]
684    /// # async fn main() -> Result<(), async_nats::Error> {
685    /// use async_nats::service::ServiceExt;
686    /// let client = async_nats::connect("demo.nats.io").await?;
687    /// let mut service = client.service_builder().start("service", "1.0.0").await?;
688    /// let v1 = service.group("v1");
689    ///
690    /// let products = v1.endpoint_builder().name("api").add("products").await?;
691    /// # Ok(())
692    /// # }
693    /// ```
694    pub fn endpoint_builder(&self) -> EndpointBuilder {
695        let mut endpoint = EndpointBuilder::new(
696            self.client.clone(),
697            self.stats.clone(),
698            self.shutdown_tx.clone(),
699            self.subjects.clone(),
700            self.queue_group.clone(),
701        );
702        endpoint.prefix = Some(self.prefix.clone());
703        endpoint
704    }
705}
706
707async fn verb_subscription(
708    client: Client,
709    verb: Verb,
710    name: String,
711    id: String,
712) -> Result<stream::Fuse<SelectAll<Subscriber>>, Error> {
713    let verb_all = client
714        .subscribe(format!("{SERVICE_API_PREFIX}.{verb}"))
715        .await?;
716    let verb_name = client
717        .subscribe(format!("{SERVICE_API_PREFIX}.{verb}.{name}"))
718        .await?;
719    let verb_id = client
720        .subscribe(format!("{SERVICE_API_PREFIX}.{verb}.{name}.{id}"))
721        .await?;
722    Ok(stream::select_all([verb_all, verb_id, verb_name]).fuse())
723}
724
725type ShutdownReceiverFuture = Pin<
726    Box<dyn Future<Output = Result<(), tokio::sync::broadcast::error::RecvError>> + Send + Sync>,
727>;
728
729/// Request returned by [Service] [Stream][futures_util::Stream].
730#[derive(Debug)]
731pub struct Request {
732    issued: Instant,
733    client: Client,
734    pub message: Message,
735    endpoint: String,
736    stats: Arc<Mutex<Endpoints>>,
737}
738
739impl Request {
740    /// Sends response for the request.
741    ///
742    /// # Examples
743    ///
744    /// ```no_run
745    /// # #[tokio::main]
746    /// # async fn main() -> Result<(), async_nats::Error> {
747    /// use async_nats::service::ServiceExt;
748    /// use futures_util::StreamExt;
749    /// # let client = async_nats::connect("demo.nats.io").await?;
750    /// # let mut service = client
751    /// #    .service_builder().start("serviceA", "1.0.0.1").await?;
752    /// let mut endpoint = service.endpoint("endpoint").await?;
753    /// let request = endpoint.next().await.unwrap();
754    /// request.respond(Ok("hello".into())).await?;
755    /// # Ok(())
756    /// # }
757    /// ```
758    pub async fn respond(&self, response: Result<Bytes, error::Error>) -> Result<(), PublishError> {
759        self.respond_with_headers(response, HeaderMap::new()).await
760    }
761
762    /// Sends response for the request with headers.
763    ///
764    /// On error responses, [Nats-Service-Error][NATS_SERVICE_ERROR] and
765    /// [Nats-Service-Error-Code][NATS_SERVICE_ERROR_CODE] are always set from the provided
766    /// [`error::Error`]. If the provided [HeaderMap] already contains values for either
767    /// of those headers, they will be overridden. All other user-supplied headers
768    /// are preserved.
769    ///
770    /// # Examples
771    ///
772    /// ```no_run
773    /// # #[tokio::main]
774    /// # async fn main() -> Result<(), async_nats::Error> {
775    /// use async_nats::service::ServiceExt;
776    /// use futures_util::StreamExt;
777    /// # let client = async_nats::connect("demo.nats.io").await?;
778    /// # let mut service = client
779    /// #    .service_builder().start("serviceA", "1.0.0.1").await?;
780    /// let mut endpoint = service.endpoint("endpoint").await?;
781    /// let request = endpoint.next().await.unwrap();
782    /// let mut headers = async_nats::HeaderMap::new();
783    /// headers.insert("x-success", "true");
784    /// request
785    ///     .respond_with_headers(Ok("hello".into()), headers)
786    ///     .await?;
787    /// # Ok(())
788    /// # }
789    /// ```
790    pub async fn respond_with_headers(
791        &self,
792        response: Result<Bytes, error::Error>,
793        mut headers: HeaderMap,
794    ) -> Result<(), PublishError> {
795        let reply = match self.message.reply.clone() {
796            None => {
797                return Err(PublishError::with_source(
798                    PublishErrorKind::InvalidSubject,
799                    "Request is missing reply subject to respond to",
800                ))
801            }
802            Some(subject) => subject,
803        };
804        let result = match response {
805            Ok(payload) => {
806                if headers.is_empty() {
807                    self.client.publish(reply, payload).await
808                } else {
809                    self.client
810                        .publish_with_headers(reply, headers, payload)
811                        .await
812                }
813            }
814            Err(err) => {
815                self.stats
816                    .lock()
817                    .unwrap()
818                    .endpoints
819                    .entry(self.endpoint.clone())
820                    .and_modify(|stats| {
821                        stats.last_error = Some(err.clone());
822                        stats.errors += 1;
823                    })
824                    .or_default();
825                headers.insert(NATS_SERVICE_ERROR, err.status.as_str());
826                headers.insert(NATS_SERVICE_ERROR_CODE, err.code.to_string().as_str());
827                self.client
828                    .publish_with_headers(reply, headers, "".into())
829                    .await
830            }
831        };
832        let elapsed = self.issued.elapsed();
833        let mut stats = self.stats.lock().unwrap();
834        let stats = stats.endpoints.get_mut(self.endpoint.as_str()).unwrap();
835        stats.requests += 1;
836        stats.processing_time += elapsed;
837        stats.average_processing_time = {
838            let avg_nanos = (stats.processing_time.as_nanos() / stats.requests as u128) as u64;
839            Duration::from_nanos(avg_nanos)
840        };
841        result
842    }
843}
844
845#[derive(Debug)]
846pub struct EndpointBuilder {
847    client: Client,
848    stats: Arc<Mutex<Endpoints>>,
849    shutdown_tx: Sender<()>,
850    name: Option<String>,
851    metadata: Option<HashMap<String, String>>,
852    subjects: Arc<Mutex<Vec<String>>>,
853    queue_group: String,
854    prefix: Option<String>,
855}
856
857impl EndpointBuilder {
858    fn new(
859        client: Client,
860        stats: Arc<Mutex<Endpoints>>,
861        shutdown_tx: Sender<()>,
862        subjects: Arc<Mutex<Vec<String>>>,
863        queue_group: String,
864    ) -> EndpointBuilder {
865        EndpointBuilder {
866            client,
867            stats,
868            subjects,
869            shutdown_tx,
870            name: None,
871            metadata: None,
872            queue_group,
873            prefix: None,
874        }
875    }
876
877    /// Name of the [Endpoint]. By default, the subject of the endpoint is used.
878    pub fn name<S: ToString>(mut self, name: S) -> EndpointBuilder {
879        self.name = Some(name.to_string());
880        self
881    }
882
883    /// Metadata specific for the [Endpoint].
884    pub fn metadata(mut self, metadata: HashMap<String, String>) -> EndpointBuilder {
885        self.metadata = Some(metadata);
886        self
887    }
888
889    /// Custom queue group for the [Endpoint]. Otherwise, it will be derived from group or service.
890    pub fn queue_group<S: ToString>(mut self, queue_group: S) -> EndpointBuilder {
891        self.queue_group = queue_group.to_string();
892        self
893    }
894
895    /// Finalizes the builder and adds the [Endpoint].
896    pub async fn add<S: ToString>(self, subject: S) -> Result<Endpoint, Error> {
897        let mut subject = subject.to_string();
898        if let Some(prefix) = self.prefix {
899            subject = format!("{prefix}.{subject}");
900        }
901        let endpoint_name = self.name.clone().unwrap_or_else(|| subject.clone());
902        let name = self
903            .name
904            .clone()
905            .unwrap_or_else(|| subject.clone().replace('.', "-"));
906        let requests = self
907            .client
908            .queue_subscribe(subject.to_owned(), self.queue_group.to_string())
909            .await?;
910        debug!("created service for endpoint {subject}");
911
912        let shutdown_rx = self.shutdown_tx.subscribe();
913
914        let mut stats = self.stats.lock().unwrap();
915        stats
916            .endpoints
917            .entry(endpoint_name.clone())
918            .or_insert(endpoint::Inner {
919                name,
920                subject: subject.clone(),
921                metadata: self.metadata.unwrap_or_default(),
922                queue_group: self.queue_group.clone(),
923                ..Default::default()
924            });
925        self.subjects.lock().unwrap().push(subject.clone());
926        Ok(Endpoint {
927            requests,
928            stats: self.stats.clone(),
929            client: self.client.clone(),
930            endpoint: endpoint_name,
931            shutdown: Some(shutdown_rx),
932            shutdown_future: None,
933        })
934    }
935}
936
937pub struct StatsHandler(pub Box<dyn FnMut(String, endpoint::Stats) -> serde_json::Value + Send>);
938
939impl std::fmt::Debug for StatsHandler {
940    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
941        write!(f, "Stats handler")
942    }
943}
944
945#[cfg(test)]
946mod tests {
947    use super::*;
948
949    #[tokio::test]
950    async fn test_group_with_queue_group() {
951        let server = nats_server::run_basic_server();
952        let client = crate::connect(server.client_url()).await.unwrap();
953
954        let group = Group {
955            prefix: "test".to_string(),
956            stats: Arc::new(Mutex::new(Endpoints {
957                endpoints: HashMap::new(),
958            })),
959            client,
960            shutdown_tx: tokio::sync::broadcast::channel(1).0,
961            subjects: Arc::new(Mutex::new(vec![])),
962            queue_group: "default".to_string(),
963        };
964
965        let new_group = group.group_with_queue_group("v1", "custom_queue");
966
967        assert_eq!(new_group.prefix, "test.v1");
968        assert_eq!(new_group.queue_group, "custom_queue");
969    }
970
971    #[tokio::test]
972    async fn test_respond_with_headers_overrides_error_headers() {
973        let server = nats_server::run_basic_server();
974        let client = crate::connect(server.client_url()).await.unwrap();
975
976        let service = client
977            .service_builder()
978            .start("test-service", "1.0.0")
979            .await
980            .unwrap();
981
982        let subject = "test.subject";
983        let mut endpoint = service.endpoint(subject).await.unwrap();
984
985        let handler = async {
986            if let Some(request) = endpoint.next().await {
987                let mut resp_headers = HeaderMap::new();
988                resp_headers.insert("x-success", "false");
989                resp_headers.insert(NATS_SERVICE_ERROR, "user-supplied-value");
990                resp_headers.insert(NATS_SERVICE_ERROR_CODE, "999");
991
992                let err = error::Error {
993                    status: "internal-error".to_string(),
994                    code: 500,
995                };
996
997                request
998                    .respond_with_headers(Err(err), resp_headers)
999                    .await
1000                    .expect("failed to send response");
1001            }
1002        };
1003
1004        let requester = crate::connect(server.client_url()).await.unwrap();
1005        let request_fut = async { requester.request(subject, "".into()).await.unwrap() };
1006
1007        let (_, resp) = tokio::join!(handler, request_fut);
1008
1009        let headers = resp.headers.expect("expected headers on reply");
1010        assert_eq!(headers.get("x-success").unwrap().as_str(), "false");
1011        assert_eq!(
1012            headers.get(NATS_SERVICE_ERROR).unwrap().as_str(),
1013            "internal-error"
1014        );
1015        assert_eq!(
1016            headers.get(NATS_SERVICE_ERROR_CODE).unwrap().as_str(),
1017            "500"
1018        );
1019    }
1020
1021    #[tokio::test]
1022    async fn test_respond_with_headers_preserves_headers_on_success() {
1023        let server = nats_server::run_basic_server();
1024        let client = crate::connect(server.client_url()).await.unwrap();
1025
1026        let service = client
1027            .service_builder()
1028            .start("test-service", "1.0.0")
1029            .await
1030            .unwrap();
1031
1032        let subject = "test.subject";
1033        let mut endpoint = service.endpoint(subject).await.unwrap();
1034
1035        let handler = async {
1036            if let Some(request) = endpoint.next().await {
1037                let mut resp_headers = HeaderMap::new();
1038                resp_headers.insert("x-success", "false");
1039                resp_headers.insert("x-request-id", "req-123");
1040                resp_headers.insert(NATS_SERVICE_ERROR, "user-supplied-value");
1041                resp_headers.insert(NATS_SERVICE_ERROR_CODE, "999");
1042
1043                request
1044                    .respond_with_headers(Ok("ok".into()), resp_headers)
1045                    .await
1046                    .unwrap();
1047            }
1048        };
1049
1050        let requester = crate::connect(server.client_url()).await.unwrap();
1051        let request_fut = async { requester.request(subject, "".into()).await.unwrap() };
1052
1053        let (_, resp) = tokio::join!(handler, request_fut);
1054
1055        let headers = resp.headers.expect("expected headers on reply");
1056        assert_eq!(headers.get("x-success").unwrap().as_str(), "false");
1057        assert_eq!(headers.get("x-request-id").unwrap().as_str(), "req-123");
1058        assert_eq!(
1059            headers.get(NATS_SERVICE_ERROR).unwrap().as_str(),
1060            "user-supplied-value"
1061        );
1062        assert_eq!(
1063            headers.get(NATS_SERVICE_ERROR_CODE).unwrap().as_str(),
1064            "999"
1065        );
1066    }
1067}