Skip to main content

dynamo_runtime/pipeline/network/ingress/
push_endpoint.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::sync::atomic::{AtomicU64, Ordering};
5
6use super::*;
7use crate::SystemHealth;
8use crate::config::HealthStatus;
9use crate::logging::make_handle_payload_span;
10use crate::protocols::LeaseId;
11use anyhow::Result;
12use async_nats::service::endpoint::Endpoint;
13use derive_builder::Builder;
14use parking_lot::Mutex;
15use std::collections::HashMap;
16use tokio::sync::Notify;
17use tokio_util::sync::CancellationToken;
18use tracing::Instrument;
19
20#[derive(Builder)]
21pub struct PushEndpoint {
22    pub service_handler: Arc<dyn PushWorkHandler>,
23    pub cancellation_token: CancellationToken,
24    #[builder(default = "true")]
25    pub graceful_shutdown: bool,
26}
27
28/// version of crate
29pub const VERSION: &str = env!("CARGO_PKG_VERSION");
30
31impl PushEndpoint {
32    pub fn builder() -> PushEndpointBuilder {
33        PushEndpointBuilder::default()
34    }
35
36    pub async fn start(
37        self,
38        endpoint: Endpoint,
39        namespace: String,
40        component_name: String,
41        endpoint_name: String,
42        instance_id: u64,
43        system_health: Arc<Mutex<SystemHealth>>,
44    ) -> Result<()> {
45        let mut endpoint = endpoint;
46
47        let inflight = Arc::new(AtomicU64::new(0));
48        let notify = Arc::new(Notify::new());
49        let component_name_local: Arc<String> = Arc::from(component_name);
50        let endpoint_name_local: Arc<String> = Arc::from(endpoint_name);
51        let namespace_local: Arc<String> = Arc::from(namespace);
52
53        system_health
54            .lock()
55            .set_endpoint_health_status(endpoint_name_local.as_str(), HealthStatus::Ready);
56
57        loop {
58            let req = tokio::select! {
59                biased;
60
61                // await on service request
62                req = endpoint.next() => {
63                    req
64                }
65
66                // process shutdown
67                _ = self.cancellation_token.cancelled() => {
68                    tracing::info!("PushEndpoint received cancellation signal, shutting down service");
69                    if let Err(e) = endpoint.stop().await {
70                        tracing::warn!("Failed to stop NATS service: {:?}", e);
71                    }
72                    break;
73                }
74            };
75
76            if let Some(req) = req {
77                let response = "".to_string();
78                if let Err(e) = req.respond(Ok(response.into())).await {
79                    tracing::warn!(
80                        "Failed to respond to request; this may indicate the request has shutdown: {:?}",
81                        e
82                    );
83                }
84
85                let ingress = self.service_handler.clone();
86                let endpoint_name: Arc<String> = Arc::clone(&endpoint_name_local);
87                let component_name: Arc<String> = Arc::clone(&component_name_local);
88                let namespace: Arc<String> = Arc::clone(&namespace_local);
89
90                // increment the inflight counter
91                inflight.fetch_add(1, Ordering::SeqCst);
92                let inflight_clone = inflight.clone();
93                let notify_clone = notify.clone();
94
95                // Handle headers here for tracing
96                let span = if let Some(headers) = req.message.headers.as_ref() {
97                    make_handle_payload_span(
98                        headers,
99                        component_name.as_ref(),
100                        endpoint_name.as_ref(),
101                        namespace.as_ref(),
102                        instance_id,
103                    )
104                } else {
105                    tracing::info_span!("handle_payload")
106                };
107
108                tokio::spawn(async move {
109                    tracing::trace!(instance_id, "handling new request");
110                    let result = ingress
111                        .handle_payload(req.message.payload)
112                        .instrument(span)
113                        .await;
114                    match result {
115                        Ok(_) => {
116                            tracing::trace!(instance_id, "request handled successfully");
117                        }
118                        Err(e) => {
119                            tracing::warn!("Failed to handle request: {}", e.to_string());
120                        }
121                    }
122
123                    // decrease the inflight counter
124                    inflight_clone.fetch_sub(1, Ordering::SeqCst);
125                    notify_clone.notify_one();
126                });
127            } else {
128                break;
129            }
130        }
131
132        system_health
133            .lock()
134            .set_endpoint_health_status(endpoint_name_local.as_str(), HealthStatus::NotReady);
135
136        // await for all inflight requests to complete if graceful shutdown
137        if self.graceful_shutdown {
138            let inflight_count = inflight.load(Ordering::SeqCst);
139            if inflight_count > 0 {
140                tracing::info!(
141                    endpoint_name = endpoint_name_local.as_str(),
142                    inflight_count = inflight_count,
143                    "Waiting for inflight NATS requests to complete"
144                );
145                while inflight.load(Ordering::SeqCst) > 0 {
146                    notify.notified().await;
147                }
148                tracing::info!(
149                    endpoint_name = endpoint_name_local.as_str(),
150                    "All inflight NATS requests completed"
151                );
152            }
153        } else {
154            tracing::info!(
155                endpoint_name = endpoint_name_local.as_str(),
156                "Skipping graceful shutdown, not waiting for inflight requests"
157            );
158        }
159
160        Ok(())
161    }
162}