Skip to main content

dynamo_runtime/pipeline/network/ingress/
push_handler.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use super::*;
5
6use crate::metrics::prometheus_names::work_handler;
7use crate::protocols::maybe_error::MaybeError;
8use prometheus::{Histogram, IntCounter, IntCounterVec, IntGauge};
9use serde::{Deserialize, Serialize};
10use std::sync::Arc;
11use std::time::Instant;
12use tracing::Instrument;
13use tracing::info_span;
14
15/// Metrics configuration for profiling work handlers
16#[derive(Clone, Debug)]
17pub struct WorkHandlerMetrics {
18    pub request_counter: IntCounter,
19    pub request_duration: Histogram,
20    pub inflight_requests: IntGauge,
21    pub request_bytes: IntCounter,
22    pub response_bytes: IntCounter,
23    pub error_counter: IntCounterVec,
24}
25
26impl WorkHandlerMetrics {
27    pub fn new(
28        request_counter: IntCounter,
29        request_duration: Histogram,
30        inflight_requests: IntGauge,
31        request_bytes: IntCounter,
32        response_bytes: IntCounter,
33        error_counter: IntCounterVec,
34    ) -> Self {
35        Self {
36            request_counter,
37            request_duration,
38            inflight_requests,
39            request_bytes,
40            response_bytes,
41            error_counter,
42        }
43    }
44
45    /// Create WorkHandlerMetrics from an endpoint using its built-in labeling
46    pub fn from_endpoint(
47        endpoint: &crate::component::Endpoint,
48        metrics_labels: Option<&[(&str, &str)]>,
49    ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
50        let metrics_labels = metrics_labels.unwrap_or(&[]);
51        let metrics = endpoint.metrics();
52        let request_counter = metrics.create_intcounter(
53            work_handler::REQUESTS_TOTAL,
54            "Total number of requests processed by work handler",
55            metrics_labels,
56        )?;
57
58        let request_duration = metrics.create_histogram(
59            work_handler::REQUEST_DURATION_SECONDS,
60            "Time spent processing requests by work handler",
61            metrics_labels,
62            None,
63        )?;
64
65        let inflight_requests = metrics.create_intgauge(
66            work_handler::INFLIGHT_REQUESTS,
67            "Number of requests currently being processed by work handler",
68            metrics_labels,
69        )?;
70
71        let request_bytes = metrics.create_intcounter(
72            work_handler::REQUEST_BYTES_TOTAL,
73            "Total number of bytes received in requests by work handler",
74            metrics_labels,
75        )?;
76
77        let response_bytes = metrics.create_intcounter(
78            work_handler::RESPONSE_BYTES_TOTAL,
79            "Total number of bytes sent in responses by work handler",
80            metrics_labels,
81        )?;
82
83        let error_counter = metrics.create_intcountervec(
84            work_handler::ERRORS_TOTAL,
85            "Total number of errors in work handler processing",
86            &[work_handler::ERROR_TYPE_LABEL],
87            metrics_labels,
88        )?;
89
90        Ok(Self::new(
91            request_counter,
92            request_duration,
93            inflight_requests,
94            request_bytes,
95            response_bytes,
96            error_counter,
97        ))
98    }
99}
100
101// RAII guard to ensure inflight gauge is decremented and request duration is observed on all code paths.
102struct RequestMetricsGuard {
103    inflight_requests: prometheus::IntGauge,
104    request_duration: prometheus::Histogram,
105    start_time: Instant,
106}
107impl Drop for RequestMetricsGuard {
108    fn drop(&mut self) {
109        self.inflight_requests.dec();
110        self.request_duration
111            .observe(self.start_time.elapsed().as_secs_f64());
112    }
113}
114
115#[async_trait]
116impl<T: Data, U: Data> PushWorkHandler for Ingress<SingleIn<T>, ManyOut<U>>
117where
118    T: Data + for<'de> Deserialize<'de> + std::fmt::Debug,
119    U: Data + Serialize + MaybeError + std::fmt::Debug,
120{
121    fn add_metrics(
122        &self,
123        endpoint: &crate::component::Endpoint,
124        metrics_labels: Option<&[(&str, &str)]>,
125    ) -> Result<()> {
126        // Call the Ingress-specific add_metrics implementation
127        use crate::pipeline::network::Ingress;
128        Ingress::add_metrics(self, endpoint, metrics_labels)
129    }
130
131    fn set_endpoint_health_check_notifier(&self, notifier: Arc<tokio::sync::Notify>) -> Result<()> {
132        use crate::pipeline::network::Ingress;
133        self.endpoint_health_check_notifier
134            .set(notifier)
135            .map_err(|_| anyhow::anyhow!("Endpoint health check notifier already set"))?;
136        Ok(())
137    }
138
139    async fn handle_payload(&self, payload: Bytes) -> Result<(), PipelineError> {
140        let start_time = std::time::Instant::now();
141
142        // Increment inflight and ensure it's decremented on all exits via RAII guard
143        let _inflight_guard = self.metrics().map(|m| {
144            m.request_counter.inc();
145            m.inflight_requests.inc();
146            m.request_bytes.inc_by(payload.len() as u64);
147            RequestMetricsGuard {
148                inflight_requests: m.inflight_requests.clone(),
149                request_duration: m.request_duration.clone(),
150                start_time,
151            }
152        });
153
154        // decode the control message and the request
155        let msg = TwoPartCodec::default()
156            .decode_message(payload)?
157            .into_message_type();
158
159        // we must have a header and a body
160        // it will be held by this closure as a Some(permit)
161        let (control_msg, request) = match msg {
162            TwoPartMessageType::HeaderAndData(header, data) => {
163                tracing::trace!(
164                    "received two part message with ctrl: {} bytes, data: {} bytes",
165                    header.len(),
166                    data.len()
167                );
168                let control_msg: RequestControlMessage = match serde_json::from_slice(&header) {
169                    Ok(cm) => cm,
170                    Err(err) => {
171                        let json_str = String::from_utf8_lossy(&header);
172                        if let Some(m) = self.metrics() {
173                            m.error_counter
174                                .with_label_values(&[work_handler::error_types::DESERIALIZATION])
175                                .inc();
176                        }
177                        return Err(PipelineError::DeserializationError(format!(
178                            "Failed deserializing to RequestControlMessage. err={err}, json_str={json_str}"
179                        )));
180                    }
181                };
182                let request: T = serde_json::from_slice(&data)?;
183                (control_msg, request)
184            }
185            _ => {
186                if let Some(m) = self.metrics() {
187                    m.error_counter
188                        .with_label_values(&[work_handler::error_types::INVALID_MESSAGE])
189                        .inc();
190                }
191                return Err(PipelineError::Generic(String::from(
192                    "Unexpected message from work queue; unable extract a TwoPartMessage with a header and data",
193                )));
194            }
195        };
196
197        // extend request with context
198        tracing::trace!("received control message: {:?}", control_msg);
199        tracing::trace!("received request: {:?}", request);
200        let request: context::Context<T> = Context::with_id(request, control_msg.id);
201
202        // todo - eventually have a handler class which will returned an abstracted object, but for now,
203        // we only support tcp here, so we can just unwrap the connection info
204        tracing::trace!("creating tcp response stream");
205        let mut publisher = tcp::client::TcpClient::create_response_stream(
206            request.context(),
207            control_msg.connection_info,
208        )
209        .await
210        .map_err(|e| {
211            if let Some(m) = self.metrics() {
212                m.error_counter
213                    .with_label_values(&[work_handler::error_types::RESPONSE_STREAM])
214                    .inc();
215            }
216            PipelineError::Generic(format!("Failed to create response stream: {:?}", e,))
217        })?;
218
219        tracing::trace!("calling generate");
220        let stream = self
221            .segment
222            .get()
223            .expect("segment not set")
224            .generate(request)
225            .await
226            .map_err(|e| {
227                if let Some(m) = self.metrics() {
228                    m.error_counter
229                        .with_label_values(&[work_handler::error_types::GENERATE])
230                        .inc();
231                }
232                PipelineError::GenerateError(e)
233            });
234
235        // the prolouge is sent to the client to indicate that the stream is ready to receive data
236        // or if the generate call failed, the error is sent to the client
237        let mut stream = match stream {
238            Ok(stream) => {
239                tracing::trace!("Successfully generated response stream; sending prologue");
240                let _result = publisher.send_prologue(None).await;
241                stream
242            }
243            Err(e) => {
244                let error_string = e.to_string();
245
246                #[cfg(debug_assertions)]
247                {
248                    tracing::debug!(
249                        "Failed to generate response stream (with debug backtrace): {:?}",
250                        e
251                    );
252                }
253                #[cfg(not(debug_assertions))]
254                {
255                    tracing::error!("Failed to generate response stream: {}", error_string);
256                }
257
258                let _result = publisher.send_prologue(Some(error_string)).await;
259                Err(e)?
260            }
261        };
262
263        let context = stream.context();
264
265        // TODO: Detect end-of-stream using Server-Sent Events (SSE)
266        let mut send_complete_final = true;
267        while let Some(resp) = stream.next().await {
268            tracing::trace!("Sending response: {:?}", resp);
269            let resp_wrapper = NetworkStreamWrapper {
270                data: Some(resp),
271                complete_final: false,
272            };
273            let resp_bytes = serde_json::to_vec(&resp_wrapper)
274                .expect("fatal error: invalid response object - this should never happen");
275            if let Some(m) = self.metrics() {
276                m.response_bytes.inc_by(resp_bytes.len() as u64);
277            }
278            if (publisher.send(resp_bytes.into()).await).is_err() {
279                send_complete_final = false;
280                if context.is_stopped() {
281                    // Say there are 2 threads accessing `context`, the sequence can be either:
282                    // 1. context.stop_generating (other) -> publisher.send failure (this)
283                    //    -> context.is_stopped (this)
284                    // 2. publisher.send failure (this) -> context.stop_generating (other)
285                    //    -> context.is_stopped (this)
286                    // Case 1 can happen when client closed the connection after receiving the
287                    // complete response from frontend. Hence, send failure can be expected in this
288                    // case.
289                    tracing::warn!("Failed to publish response for stream {}", context.id());
290                } else {
291                    // Otherwise, this is an error.
292                    tracing::error!("Failed to publish response for stream {}", context.id());
293                    context.stop_generating();
294                }
295                // Account errors in all cases, including cancellation. Therefore this metric can be
296                // inflated.
297                if let Some(m) = self.metrics() {
298                    m.error_counter
299                        .with_label_values(&[work_handler::error_types::PUBLISH_RESPONSE])
300                        .inc();
301                }
302                break;
303            }
304        }
305        if send_complete_final {
306            let resp_wrapper = NetworkStreamWrapper::<U> {
307                data: None,
308                complete_final: true,
309            };
310            let resp_bytes = serde_json::to_vec(&resp_wrapper)
311                .expect("fatal error: invalid response object - this should never happen");
312            if let Some(m) = self.metrics() {
313                m.response_bytes.inc_by(resp_bytes.len() as u64);
314            }
315            if (publisher.send(resp_bytes.into()).await).is_err() {
316                tracing::error!(
317                    "Failed to publish complete final for stream {}",
318                    context.id()
319                );
320                if let Some(m) = self.metrics() {
321                    m.error_counter
322                        .with_label_values(&[work_handler::error_types::PUBLISH_FINAL])
323                        .inc();
324                }
325            }
326            // Notify the health check manager that the stream has finished.
327            // This resets the timer, delaying the next canary health check.
328            if let Some(notifier) = self.endpoint_health_check_notifier.get() {
329                notifier.notify_one();
330            }
331        }
332
333        // Ensure the metrics guard is not dropped until the end of the function.
334        drop(_inflight_guard);
335
336        Ok(())
337    }
338}