Skip to main content

dynamo_runtime/pipeline/network/ingress/
push_handler.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use super::*;
5
6use crate::metrics::prometheus_names::work_handler;
7use crate::metrics::work_handler_perf::{
8    WORK_HANDLER_NETWORK_TRANSIT_SECONDS, WORK_HANDLER_TIME_TO_FIRST_RESPONSE_SECONDS,
9};
10use crate::protocols::maybe_error::MaybeError;
11use prometheus::{Histogram, IntCounter, IntCounterVec, IntGauge};
12use serde::{Deserialize, Serialize};
13use std::sync::Arc;
14use std::time::Instant;
15use tracing::Instrument;
16use tracing::info_span;
17
18/// Metrics configuration for profiling work handlers
19#[derive(Clone, Debug)]
20pub struct WorkHandlerMetrics {
21    pub request_counter: IntCounter,
22    pub request_duration: Histogram,
23    pub inflight_requests: IntGauge,
24    pub request_bytes: IntCounter,
25    pub response_bytes: IntCounter,
26    pub error_counter: IntCounterVec,
27    pub cancellation_total: IntCounter,
28}
29
30impl WorkHandlerMetrics {
31    pub fn new(
32        request_counter: IntCounter,
33        request_duration: Histogram,
34        inflight_requests: IntGauge,
35        request_bytes: IntCounter,
36        response_bytes: IntCounter,
37        error_counter: IntCounterVec,
38        cancellation_total: IntCounter,
39    ) -> Self {
40        Self {
41            request_counter,
42            request_duration,
43            inflight_requests,
44            request_bytes,
45            response_bytes,
46            error_counter,
47            cancellation_total,
48        }
49    }
50
51    /// Create WorkHandlerMetrics from an endpoint using its built-in labeling
52    pub fn from_endpoint(
53        endpoint: &crate::component::Endpoint,
54        metrics_labels: Option<&[(&str, &str)]>,
55    ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
56        let metrics_labels = metrics_labels.unwrap_or(&[]);
57        let metrics = endpoint.metrics();
58        let request_counter = metrics.create_intcounter(
59            work_handler::REQUESTS_TOTAL,
60            "Total number of requests processed by work handler",
61            metrics_labels,
62        )?;
63
64        // Custom buckets for inference workloads: retain sub-second resolution for
65        // fast operations, extend well beyond the default 10s ceiling to capture
66        // long-running generation requests that can last minutes.
67        let request_duration_buckets = vec![
68            0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 20.0, 30.0, 60.0, 120.0,
69            300.0, 600.0,
70        ];
71        let request_duration = metrics.create_histogram(
72            work_handler::REQUEST_DURATION_SECONDS,
73            "Time spent processing requests by work handler",
74            metrics_labels,
75            Some(request_duration_buckets),
76        )?;
77
78        let inflight_requests = metrics.create_intgauge(
79            work_handler::INFLIGHT_REQUESTS,
80            "Number of requests currently being processed by work handler",
81            metrics_labels,
82        )?;
83
84        let request_bytes = metrics.create_intcounter(
85            work_handler::REQUEST_BYTES_TOTAL,
86            "Total number of bytes received in requests by work handler",
87            metrics_labels,
88        )?;
89
90        let response_bytes = metrics.create_intcounter(
91            work_handler::RESPONSE_BYTES_TOTAL,
92            "Total number of bytes sent in responses by work handler",
93            metrics_labels,
94        )?;
95
96        let error_counter = metrics.create_intcountervec(
97            work_handler::ERRORS_TOTAL,
98            "Total number of errors in work handler processing",
99            &[work_handler::ERROR_TYPE_LABEL],
100            metrics_labels,
101        )?;
102
103        let cancellation_total = metrics.create_intcounter(
104            work_handler::CANCELLATION_TOTAL,
105            "Total number of requests cancelled by work handler",
106            metrics_labels,
107        )?;
108
109        Ok(Self::new(
110            request_counter,
111            request_duration,
112            inflight_requests,
113            request_bytes,
114            response_bytes,
115            error_counter,
116            cancellation_total,
117        ))
118    }
119}
120
121// RAII guard to ensure inflight gauge is decremented, request duration is observed,
122// and lifecycle logs are emitted on all code paths.
123struct RequestMetricsGuard {
124    inflight_requests: prometheus::IntGauge,
125    request_duration: prometheus::Histogram,
126    start_time: Instant,
127    request_id: Option<String>,
128}
129
130impl Drop for RequestMetricsGuard {
131    fn drop(&mut self) {
132        self.inflight_requests.dec();
133        self.request_duration
134            .observe(self.start_time.elapsed().as_secs_f64());
135        if let Some(request_id) = &self.request_id {
136            tracing::info!(request_id = %request_id, "request completed");
137        }
138    }
139}
140
141#[async_trait]
142impl<T: Data, U: Data> PushWorkHandler for Ingress<SingleIn<T>, ManyOut<U>>
143where
144    T: Data + for<'de> Deserialize<'de> + std::fmt::Debug,
145    U: Data + Serialize + MaybeError + std::fmt::Debug,
146{
147    fn add_metrics(
148        &self,
149        endpoint: &crate::component::Endpoint,
150        metrics_labels: Option<&[(&str, &str)]>,
151    ) -> Result<()> {
152        // Call the Ingress-specific add_metrics implementation
153        use crate::pipeline::network::Ingress;
154        Ingress::add_metrics(self, endpoint, metrics_labels)
155    }
156
157    fn set_endpoint_health_check_notifier(&self, notifier: Arc<tokio::sync::Notify>) -> Result<()> {
158        use crate::pipeline::network::Ingress;
159        self.endpoint_health_check_notifier
160            .set(notifier)
161            .map_err(|_| anyhow::anyhow!("Endpoint health check notifier already set"))?;
162        Ok(())
163    }
164
165    async fn handle_payload(
166        &self,
167        payload: Bytes,
168        request_id: Option<String>,
169    ) -> Result<(), PipelineError> {
170        let t2_wallclock_ns = std::time::SystemTime::now()
171            .duration_since(std::time::UNIX_EPOCH)
172            .unwrap_or_default()
173            .as_nanos() as u64;
174        let start_time = std::time::Instant::now();
175
176        // Increment inflight and ensure it's decremented on all exits via RAII guard
177        let _inflight_guard = self.metrics().map(|m| {
178            m.request_counter.inc();
179            m.inflight_requests.inc();
180            m.request_bytes.inc_by(payload.len() as u64);
181            if let Some(rid) = &request_id {
182                tracing::info!(request_id = %rid, "request received");
183            }
184            RequestMetricsGuard {
185                inflight_requests: m.inflight_requests.clone(),
186                request_duration: m.request_duration.clone(),
187                start_time,
188                request_id: request_id.clone(),
189            }
190        });
191
192        // decode the control message and the request
193        let msg = TwoPartCodec::default()
194            .decode_message(payload)?
195            .into_message_type();
196
197        // we must have a header and a body
198        // it will be held by this closure as a Some(permit)
199        let (control_msg, request) = match msg {
200            TwoPartMessageType::HeaderAndData(header, data) => {
201                tracing::trace!(
202                    "received two part message with ctrl: {} bytes, data: {} bytes",
203                    header.len(),
204                    data.len()
205                );
206                let control_msg: RequestControlMessage = match serde_json::from_slice(&header) {
207                    Ok(cm) => cm,
208                    Err(err) => {
209                        let json_str = String::from_utf8_lossy(&header);
210                        if let Some(m) = self.metrics() {
211                            m.error_counter
212                                .with_label_values(&[work_handler::error_types::DESERIALIZATION])
213                                .inc();
214                        }
215                        return Err(PipelineError::DeserializationError(format!(
216                            "Failed deserializing to RequestControlMessage. err={err}, json_str={json_str}"
217                        )));
218                    }
219                };
220                let request: T = serde_json::from_slice(&data)?;
221                (control_msg, request)
222            }
223            _ => {
224                if let Some(m) = self.metrics() {
225                    m.error_counter
226                        .with_label_values(&[work_handler::error_types::INVALID_MESSAGE])
227                        .inc();
228                }
229                return Err(PipelineError::Generic(String::from(
230                    "Unexpected message from work queue; unable extract a TwoPartMessage with a header and data",
231                )));
232            }
233        };
234
235        // Compute network transit time (T2 - T1) using cross-process wall-clock timestamps
236        if let Some(t1_ns) = control_msg.frontend_send_ts_ns {
237            let transit_ns = t2_wallclock_ns.saturating_sub(t1_ns);
238            WORK_HANDLER_NETWORK_TRANSIT_SECONDS.observe(transit_ns as f64 / 1_000_000_000.0);
239        }
240
241        // extend request with context
242        tracing::trace!("received control message: {:?}", control_msg);
243        tracing::trace!("received request: {:?}", request);
244        let request: context::Context<T> = Context::with_id(request, control_msg.id);
245
246        // todo - eventually have a handler class which will returned an abstracted object, but for now,
247        // we only support tcp here, so we can just unwrap the connection info
248        tracing::trace!("creating tcp response stream");
249        let mut publisher = tcp::client::TcpClient::create_response_stream(
250            request.context(),
251            control_msg.connection_info,
252            self.metrics().map(|m| m.cancellation_total.clone()),
253        )
254        .await
255        .map_err(|e| {
256            if let Some(m) = self.metrics() {
257                m.error_counter
258                    .with_label_values(&[work_handler::error_types::RESPONSE_STREAM])
259                    .inc();
260            }
261            PipelineError::Generic(format!("Failed to create response stream: {:?}", e,))
262        })?;
263
264        tracing::trace!("calling generate");
265        let stream = self
266            .segment
267            .get()
268            .expect("segment not set")
269            .generate(request)
270            .await
271            .map_err(|e| {
272                if let Some(m) = self.metrics() {
273                    m.error_counter
274                        .with_label_values(&[work_handler::error_types::GENERATE])
275                        .inc();
276                }
277                PipelineError::GenerateError(e)
278            });
279
280        // the prolouge is sent to the client to indicate that the stream is ready to receive data
281        // or if the generate call failed, the error is sent to the client
282        let mut stream = match stream {
283            Ok(stream) => {
284                tracing::trace!("Successfully generated response stream; sending prologue");
285                let _result = publisher.send_prologue(None).await;
286                WORK_HANDLER_TIME_TO_FIRST_RESPONSE_SECONDS
287                    .observe(start_time.elapsed().as_secs_f64());
288                stream
289            }
290            Err(e) => {
291                let error_string = e.to_string();
292
293                #[cfg(debug_assertions)]
294                {
295                    tracing::debug!(
296                        "Failed to generate response stream (with debug backtrace): {:?}",
297                        e
298                    );
299                }
300                #[cfg(not(debug_assertions))]
301                {
302                    tracing::error!("Failed to generate response stream: {error_string}");
303                }
304
305                let _result = publisher.send_prologue(Some(error_string)).await;
306                Err(e)?
307            }
308        };
309
310        let context = stream.context();
311
312        // TODO: Detect end-of-stream using Server-Sent Events (SSE)
313        let mut send_complete_final = true;
314        while let Some(resp) = stream.next().await {
315            tracing::trace!("Sending response: {:?}", resp);
316            let resp_wrapper = NetworkStreamWrapper {
317                data: Some(resp),
318                complete_final: false,
319            };
320            let resp_bytes = serde_json::to_vec(&resp_wrapper)
321                .expect("fatal error: invalid response object - this should never happen");
322            if let Some(m) = self.metrics() {
323                m.response_bytes.inc_by(resp_bytes.len() as u64);
324            }
325            if (publisher.send(resp_bytes.into()).await).is_err() {
326                send_complete_final = false;
327                if context.is_stopped() {
328                    // Say there are 2 threads accessing `context`, the sequence can be either:
329                    // 1. context.stop_generating (other) -> publisher.send failure (this)
330                    //    -> context.is_stopped (this)
331                    // 2. publisher.send failure (this) -> context.stop_generating (other)
332                    //    -> context.is_stopped (this)
333                    // Case 1 can happen when client closed the connection after receiving the
334                    // complete response from frontend. Hence, send failure can be expected in this
335                    // case.
336                    tracing::warn!("Failed to publish response for stream {}", context.id());
337                } else {
338                    // Otherwise, this is an error.
339                    tracing::error!("Failed to publish response for stream {}", context.id());
340                    context.stop_generating();
341                }
342                // Account errors in all cases, including cancellation. Therefore this metric can be
343                // inflated.
344                if let Some(m) = self.metrics() {
345                    m.error_counter
346                        .with_label_values(&[work_handler::error_types::PUBLISH_RESPONSE])
347                        .inc();
348                }
349                break;
350            }
351        }
352        if send_complete_final {
353            let resp_wrapper = NetworkStreamWrapper::<U> {
354                data: None,
355                complete_final: true,
356            };
357            let resp_bytes = serde_json::to_vec(&resp_wrapper)
358                .expect("fatal error: invalid response object - this should never happen");
359            if let Some(m) = self.metrics() {
360                m.response_bytes.inc_by(resp_bytes.len() as u64);
361            }
362            if (publisher.send(resp_bytes.into()).await).is_err() {
363                tracing::error!(
364                    "Failed to publish complete final for stream {}",
365                    context.id()
366                );
367                if let Some(m) = self.metrics() {
368                    m.error_counter
369                        .with_label_values(&[work_handler::error_types::PUBLISH_FINAL])
370                        .inc();
371                }
372            }
373            // Notify the health check manager that the stream has finished.
374            // This resets the timer, delaying the next canary health check.
375            if let Some(notifier) = self.endpoint_health_check_notifier.get() {
376                notifier.notify_one();
377            }
378        }
379
380        // Ensure the metrics guard is not dropped until the end of the function.
381        // Drop fires "request completed" log via RAII.
382        drop(_inflight_guard);
383
384        Ok(())
385    }
386}