Skip to main content

dynamo_runtime/pipeline/network/ingress/
push_handler.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use super::*;
5use crate::metrics::prometheus_names::work_handler;
6use crate::protocols::maybe_error::MaybeError;
7use prometheus::{Histogram, IntCounter, IntCounterVec, IntGauge};
8use serde::{Deserialize, Serialize};
9use std::sync::Arc;
10use std::time::Instant;
11use tracing::Instrument;
12use tracing::info_span;
13
14/// Metrics configuration for profiling work handlers
15#[derive(Clone, Debug)]
16pub struct WorkHandlerMetrics {
17    pub request_counter: IntCounter,
18    pub request_duration: Histogram,
19    pub inflight_requests: IntGauge,
20    pub request_bytes: IntCounter,
21    pub response_bytes: IntCounter,
22    pub error_counter: IntCounterVec,
23}
24
25impl WorkHandlerMetrics {
26    pub fn new(
27        request_counter: IntCounter,
28        request_duration: Histogram,
29        inflight_requests: IntGauge,
30        request_bytes: IntCounter,
31        response_bytes: IntCounter,
32        error_counter: IntCounterVec,
33    ) -> Self {
34        Self {
35            request_counter,
36            request_duration,
37            inflight_requests,
38            request_bytes,
39            response_bytes,
40            error_counter,
41        }
42    }
43
44    /// Create WorkHandlerMetrics from an endpoint using its built-in labeling
45    pub fn from_endpoint(
46        endpoint: &crate::component::Endpoint,
47        metrics_labels: Option<&[(&str, &str)]>,
48    ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
49        let metrics_labels = metrics_labels.unwrap_or(&[]);
50        let metrics = endpoint.metrics();
51        let request_counter = metrics.create_intcounter(
52            work_handler::REQUESTS_TOTAL,
53            "Total number of requests processed by work handler",
54            metrics_labels,
55        )?;
56
57        let request_duration = metrics.create_histogram(
58            work_handler::REQUEST_DURATION_SECONDS,
59            "Time spent processing requests by work handler",
60            metrics_labels,
61            None,
62        )?;
63
64        let inflight_requests = metrics.create_intgauge(
65            work_handler::INFLIGHT_REQUESTS,
66            "Number of requests currently being processed by work handler",
67            metrics_labels,
68        )?;
69
70        let request_bytes = metrics.create_intcounter(
71            work_handler::REQUEST_BYTES_TOTAL,
72            "Total number of bytes received in requests by work handler",
73            metrics_labels,
74        )?;
75
76        let response_bytes = metrics.create_intcounter(
77            work_handler::RESPONSE_BYTES_TOTAL,
78            "Total number of bytes sent in responses by work handler",
79            metrics_labels,
80        )?;
81
82        let error_counter = metrics.create_intcountervec(
83            work_handler::ERRORS_TOTAL,
84            "Total number of errors in work handler processing",
85            &[work_handler::ERROR_TYPE_LABEL],
86            metrics_labels,
87        )?;
88
89        Ok(Self::new(
90            request_counter,
91            request_duration,
92            inflight_requests,
93            request_bytes,
94            response_bytes,
95            error_counter,
96        ))
97    }
98}
99
100// RAII guard to ensure inflight gauge is decremented and request duration is observed on all code paths.
101struct RequestMetricsGuard {
102    inflight_requests: prometheus::IntGauge,
103    request_duration: prometheus::Histogram,
104    start_time: Instant,
105}
106impl Drop for RequestMetricsGuard {
107    fn drop(&mut self) {
108        self.inflight_requests.dec();
109        self.request_duration
110            .observe(self.start_time.elapsed().as_secs_f64());
111    }
112}
113
114#[async_trait]
115impl<T: Data, U: Data> PushWorkHandler for Ingress<SingleIn<T>, ManyOut<U>>
116where
117    T: Data + for<'de> Deserialize<'de> + std::fmt::Debug,
118    U: Data + Serialize + MaybeError + std::fmt::Debug,
119{
120    fn add_metrics(
121        &self,
122        endpoint: &crate::component::Endpoint,
123        metrics_labels: Option<&[(&str, &str)]>,
124    ) -> Result<()> {
125        // Call the Ingress-specific add_metrics implementation
126        use crate::pipeline::network::Ingress;
127        Ingress::add_metrics(self, endpoint, metrics_labels)
128    }
129
130    fn set_endpoint_health_check_notifier(&self, notifier: Arc<tokio::sync::Notify>) -> Result<()> {
131        use crate::pipeline::network::Ingress;
132        self.endpoint_health_check_notifier
133            .set(notifier)
134            .map_err(|_| anyhow::anyhow!("Endpoint health check notifier already set"))?;
135        Ok(())
136    }
137
138    async fn handle_payload(&self, payload: Bytes) -> Result<(), PipelineError> {
139        let start_time = std::time::Instant::now();
140
141        // Increment inflight and ensure it's decremented on all exits via RAII guard
142        let _inflight_guard = self.metrics().map(|m| {
143            m.request_counter.inc();
144            m.inflight_requests.inc();
145            m.request_bytes.inc_by(payload.len() as u64);
146            RequestMetricsGuard {
147                inflight_requests: m.inflight_requests.clone(),
148                request_duration: m.request_duration.clone(),
149                start_time,
150            }
151        });
152
153        // decode the control message and the request
154        let msg = TwoPartCodec::default()
155            .decode_message(payload)?
156            .into_message_type();
157
158        // we must have a header and a body
159        // it will be held by this closure as a Some(permit)
160        let (control_msg, request) = match msg {
161            TwoPartMessageType::HeaderAndData(header, data) => {
162                tracing::trace!(
163                    "received two part message with ctrl: {} bytes, data: {} bytes",
164                    header.len(),
165                    data.len()
166                );
167                let control_msg: RequestControlMessage = match serde_json::from_slice(&header) {
168                    Ok(cm) => cm,
169                    Err(err) => {
170                        let json_str = String::from_utf8_lossy(&header);
171                        if let Some(m) = self.metrics() {
172                            m.error_counter
173                                .with_label_values(&[work_handler::error_types::DESERIALIZATION])
174                                .inc();
175                        }
176                        return Err(PipelineError::DeserializationError(format!(
177                            "Failed deserializing to RequestControlMessage. err={err}, json_str={json_str}"
178                        )));
179                    }
180                };
181                let request: T = serde_json::from_slice(&data)?;
182                (control_msg, request)
183            }
184            _ => {
185                if let Some(m) = self.metrics() {
186                    m.error_counter
187                        .with_label_values(&[work_handler::error_types::INVALID_MESSAGE])
188                        .inc();
189                }
190                return Err(PipelineError::Generic(String::from(
191                    "Unexpected message from work queue; unable extract a TwoPartMessage with a header and data",
192                )));
193            }
194        };
195
196        // extend request with context
197        tracing::trace!("received control message: {:?}", control_msg);
198        tracing::trace!("received request: {:?}", request);
199        let request: context::Context<T> = Context::with_id(request, control_msg.id);
200
201        // todo - eventually have a handler class which will returned an abstracted object, but for now,
202        // we only support tcp here, so we can just unwrap the connection info
203        tracing::trace!("creating tcp response stream");
204        let mut publisher = tcp::client::TcpClient::create_response_stream(
205            request.context(),
206            control_msg.connection_info,
207        )
208        .await
209        .map_err(|e| {
210            if let Some(m) = self.metrics() {
211                m.error_counter
212                    .with_label_values(&[work_handler::error_types::RESPONSE_STREAM])
213                    .inc();
214            }
215            PipelineError::Generic(format!("Failed to create response stream: {:?}", e,))
216        })?;
217
218        tracing::trace!("calling generate");
219        let stream = self
220            .segment
221            .get()
222            .expect("segment not set")
223            .generate(request)
224            .await
225            .map_err(|e| {
226                if let Some(m) = self.metrics() {
227                    m.error_counter
228                        .with_label_values(&[work_handler::error_types::GENERATE])
229                        .inc();
230                }
231                PipelineError::GenerateError(e)
232            });
233
234        // the prolouge is sent to the client to indicate that the stream is ready to receive data
235        // or if the generate call failed, the error is sent to the client
236        let mut stream = match stream {
237            Ok(stream) => {
238                tracing::trace!("Successfully generated response stream; sending prologue");
239                let _result = publisher.send_prologue(None).await;
240                stream
241            }
242            Err(e) => {
243                let error_string = e.to_string();
244
245                #[cfg(debug_assertions)]
246                {
247                    tracing::debug!(
248                        "Failed to generate response stream (with debug backtrace): {:?}",
249                        e
250                    );
251                }
252                #[cfg(not(debug_assertions))]
253                {
254                    tracing::error!("Failed to generate response stream: {}", error_string);
255                }
256
257                let _result = publisher.send_prologue(Some(error_string)).await;
258                Err(e)?
259            }
260        };
261
262        let context = stream.context();
263
264        // TODO: Detect end-of-stream using Server-Sent Events (SSE)
265        let mut send_complete_final = true;
266        while let Some(resp) = stream.next().await {
267            tracing::trace!("Sending response: {:?}", resp);
268            if let Some(err) = resp.err()
269                && format!("{:?}", err) == STREAM_ERR_MSG
270            {
271                tracing::warn!(STREAM_ERR_MSG);
272                send_complete_final = false;
273                break;
274            }
275            let resp_wrapper = NetworkStreamWrapper {
276                data: Some(resp),
277                complete_final: false,
278            };
279            let resp_bytes = serde_json::to_vec(&resp_wrapper)
280                .expect("fatal error: invalid response object - this should never happen");
281            if let Some(m) = self.metrics() {
282                m.response_bytes.inc_by(resp_bytes.len() as u64);
283            }
284            if (publisher.send(resp_bytes.into()).await).is_err() {
285                send_complete_final = false;
286                if context.is_stopped() {
287                    // Say there are 2 threads accessing `context`, the sequence can be either:
288                    // 1. context.stop_generating (other) -> publisher.send failure (this)
289                    //    -> context.is_stopped (this)
290                    // 2. publisher.send failure (this) -> context.stop_generating (other)
291                    //    -> context.is_stopped (this)
292                    // Case 1 can happen when client closed the connection after receiving the
293                    // complete response from frontend. Hence, send failure can be expected in this
294                    // case.
295                    tracing::warn!("Failed to publish response for stream {}", context.id());
296                } else {
297                    // Otherwise, this is an error.
298                    tracing::error!("Failed to publish response for stream {}", context.id());
299                    context.stop_generating();
300                }
301                // Account errors in all cases, including cancellation. Therefore this metric can be
302                // inflated.
303                if let Some(m) = self.metrics() {
304                    m.error_counter
305                        .with_label_values(&[work_handler::error_types::PUBLISH_RESPONSE])
306                        .inc();
307                }
308                break;
309            }
310        }
311        if send_complete_final {
312            let resp_wrapper = NetworkStreamWrapper::<U> {
313                data: None,
314                complete_final: true,
315            };
316            let resp_bytes = serde_json::to_vec(&resp_wrapper)
317                .expect("fatal error: invalid response object - this should never happen");
318            if let Some(m) = self.metrics() {
319                m.response_bytes.inc_by(resp_bytes.len() as u64);
320            }
321            if (publisher.send(resp_bytes.into()).await).is_err() {
322                tracing::error!(
323                    "Failed to publish complete final for stream {}",
324                    context.id()
325                );
326                if let Some(m) = self.metrics() {
327                    m.error_counter
328                        .with_label_values(&[work_handler::error_types::PUBLISH_FINAL])
329                        .inc();
330                }
331            }
332            // Notify the health check manager that the stream has finished.
333            // This resets the timer, delaying the next canary health check.
334            if let Some(notifier) = self.endpoint_health_check_notifier.get() {
335                notifier.notify_one();
336            }
337        }
338
339        // Ensure the metrics guard is not dropped until the end of the function.
340        drop(_inflight_guard);
341
342        Ok(())
343    }
344}