Skip to main content

dynamo_runtime/pipeline/network/ingress/
push_handler.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use super::*;
5
6use crate::metrics::prometheus_names::work_handler;
7use crate::metrics::work_handler_perf::{
8    WORK_HANDLER_NETWORK_TRANSIT_SECONDS, WORK_HANDLER_TIME_TO_FIRST_RESPONSE_SECONDS,
9};
10use crate::protocols::maybe_error::MaybeError;
11use prometheus::{Histogram, IntCounter, IntCounterVec, IntGauge};
12use serde::{Deserialize, Serialize};
13use std::sync::Arc;
14use std::time::Instant;
15use tracing::Instrument;
16use tracing::info_span;
17
18/// Metrics configuration for profiling work handlers
19#[derive(Clone, Debug)]
20pub struct WorkHandlerMetrics {
21    pub request_counter: IntCounter,
22    pub request_duration: Histogram,
23    pub inflight_requests: IntGauge,
24    pub request_bytes: IntCounter,
25    pub response_bytes: IntCounter,
26    pub error_counter: IntCounterVec,
27    pub cancellation_total: IntCounter,
28}
29
30impl WorkHandlerMetrics {
31    pub fn new(
32        request_counter: IntCounter,
33        request_duration: Histogram,
34        inflight_requests: IntGauge,
35        request_bytes: IntCounter,
36        response_bytes: IntCounter,
37        error_counter: IntCounterVec,
38        cancellation_total: IntCounter,
39    ) -> Self {
40        Self {
41            request_counter,
42            request_duration,
43            inflight_requests,
44            request_bytes,
45            response_bytes,
46            error_counter,
47            cancellation_total,
48        }
49    }
50
51    /// Create WorkHandlerMetrics from an endpoint using its built-in labeling
52    pub fn from_endpoint(
53        endpoint: &crate::component::Endpoint,
54        metrics_labels: Option<&[(&str, &str)]>,
55    ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
56        let metrics_labels = metrics_labels.unwrap_or(&[]);
57        let metrics = endpoint.metrics();
58        let request_counter = metrics.create_intcounter(
59            work_handler::REQUESTS_TOTAL,
60            "Total number of requests processed by work handler",
61            metrics_labels,
62        )?;
63
64        // Custom buckets for inference workloads: retain sub-second resolution for
65        // fast operations, extend well beyond the default 10s ceiling to capture
66        // long-running generation requests that can last minutes.
67        let request_duration_buckets = vec![
68            0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 20.0, 30.0, 60.0, 120.0,
69            300.0, 600.0,
70        ];
71        let request_duration = metrics.create_histogram(
72            work_handler::REQUEST_DURATION_SECONDS,
73            "Time spent processing requests by work handler",
74            metrics_labels,
75            Some(request_duration_buckets),
76        )?;
77
78        let inflight_requests = metrics.create_intgauge(
79            work_handler::INFLIGHT_REQUESTS,
80            "Number of requests currently being processed by work handler",
81            metrics_labels,
82        )?;
83
84        let request_bytes = metrics.create_intcounter(
85            work_handler::REQUEST_BYTES_TOTAL,
86            "Total number of bytes received in requests by work handler",
87            metrics_labels,
88        )?;
89
90        let response_bytes = metrics.create_intcounter(
91            work_handler::RESPONSE_BYTES_TOTAL,
92            "Total number of bytes sent in responses by work handler",
93            metrics_labels,
94        )?;
95
96        let error_counter = metrics.create_intcountervec(
97            work_handler::ERRORS_TOTAL,
98            "Total number of errors in work handler processing",
99            &[work_handler::ERROR_TYPE_LABEL],
100            metrics_labels,
101        )?;
102
103        let cancellation_total = metrics.create_intcounter(
104            work_handler::CANCELLATION_TOTAL,
105            "Total number of requests cancelled by work handler",
106            metrics_labels,
107        )?;
108
109        Ok(Self::new(
110            request_counter,
111            request_duration,
112            inflight_requests,
113            request_bytes,
114            response_bytes,
115            error_counter,
116            cancellation_total,
117        ))
118    }
119}
120
121// RAII guard to ensure inflight gauge is decremented, request duration is observed,
122// and lifecycle logs are emitted on all code paths.
123struct RequestMetricsGuard {
124    inflight_requests: prometheus::IntGauge,
125    request_duration: prometheus::Histogram,
126    start_time: Instant,
127    request_id: Option<String>,
128}
129
130impl Drop for RequestMetricsGuard {
131    fn drop(&mut self) {
132        self.inflight_requests.dec();
133        self.request_duration
134            .observe(self.start_time.elapsed().as_secs_f64());
135        if let Some(request_id) = &self.request_id {
136            tracing::info!(request_id = %request_id, "request completed");
137        }
138    }
139}
140
141impl<Req: PipelineIO + Sync, Resp: PipelineIO> Ingress<Req, Resp> {
142    /// Pump every chunk from the engine's response stream out to the
143    /// upstream-side `StreamSender`, plus the terminal complete-final
144    /// frame. Captures the per-frame metrics, the publish-failure error
145    /// classification (client-side disconnect vs. real failure), and the
146    /// health-check notifier policy (notify only on non-error chunks and
147    /// at clean stream end).
148    async fn pump_response_stream<U>(&self, mut stream: ManyOut<U>, publisher: &StreamSender)
149    where
150        U: Data + Serialize + MaybeError + std::fmt::Debug,
151    {
152        let context = stream.context();
153
154        // TODO: Detect end-of-stream using Server-Sent Events (SSE)
155        let mut send_complete_final = true;
156        let mut saw_error_response = false;
157        while let Some(resp) = stream.next().await {
158            tracing::trace!("Sending response: {:?}", resp);
159            let is_error = resp.err().is_some();
160            if is_error {
161                saw_error_response = true;
162            }
163            let resp_wrapper = NetworkStreamWrapper {
164                data: Some(resp),
165                complete_final: false,
166            };
167            let resp_bytes = serde_json::to_vec(&resp_wrapper)
168                .expect("fatal error: invalid response object - this should never happen");
169            if let Some(m) = self.metrics() {
170                m.response_bytes.inc_by(resp_bytes.len() as u64);
171            }
172            if (publisher.send(resp_bytes.into()).await).is_err() {
173                send_complete_final = false;
174                if context.is_stopped() {
175                    // Say there are 2 threads accessing `context`, the sequence can be either:
176                    // 1. context.stop_generating (other) -> publisher.send failure (this)
177                    //    -> context.is_stopped (this)
178                    // 2. publisher.send failure (this) -> context.stop_generating (other)
179                    //    -> context.is_stopped (this)
180                    // Case 1 can happen when client closed the connection after receiving the
181                    // complete response from frontend. Hence, send failure can be expected in this
182                    // case.
183                    tracing::warn!("Failed to publish response for stream {}", context.id());
184                } else {
185                    // Otherwise, this is an error.
186                    tracing::error!("Failed to publish response for stream {}", context.id());
187                    context.stop_generating();
188                }
189                // Account errors in all cases, including cancellation. Therefore this metric can be
190                // inflated.
191                if let Some(m) = self.metrics() {
192                    m.error_counter
193                        .with_label_values(&[work_handler::error_types::PUBLISH_RESPONSE])
194                        .inc();
195                }
196                break;
197            } else if !is_error {
198                // Only notify on non-error chunks — error responses don't prove
199                // the engine is healthy and should not reset the canary timer.
200                if let Some(notifier) = self.endpoint_health_check_notifier.get() {
201                    notifier.notify_one();
202                }
203            }
204        }
205        if send_complete_final {
206            let resp_wrapper = NetworkStreamWrapper::<U> {
207                data: None,
208                complete_final: true,
209            };
210            let resp_bytes = serde_json::to_vec(&resp_wrapper)
211                .expect("fatal error: invalid response object - this should never happen");
212            if let Some(m) = self.metrics() {
213                m.response_bytes.inc_by(resp_bytes.len() as u64);
214            }
215            if (publisher.send(resp_bytes.into()).await).is_err() {
216                tracing::error!(
217                    "Failed to publish complete final for stream {}",
218                    context.id()
219                );
220                if let Some(m) = self.metrics() {
221                    m.error_counter
222                        .with_label_values(&[work_handler::error_types::PUBLISH_FINAL])
223                        .inc();
224                }
225            }
226            // Only notify on stream completion if no error responses were seen
227            if let (false, Some(notifier)) = (
228                saw_error_response,
229                self.endpoint_health_check_notifier.get(),
230            ) {
231                notifier.notify_one();
232            }
233        }
234    }
235}
236/// The output of [`IngressDispatch::parse_and_build_request`]: the typed
237/// request the engine consumes, plus the bits of the on-wire control
238/// message the shared handler needs after parsing (the response-stream
239/// connection info and the frontend send timestamp).
240struct ParsedRequest<Req> {
241    request: Req,
242    response_connection_info: ConnectionInfo,
243    frontend_send_ts_ns: Option<u64>,
244}
245
246/// Per-shape strategy for turning a raw payload into a typed engine
247/// request. Captures the wire-shape-specific parsing of the request
248/// envelope; everything else — metrics-guard, response stream open,
249/// `segment.generate`, prologue, pump — lives in
250/// [`Ingress::handle_payload_shared`] below.
251///
252/// Currently has a single impl (the unary `HeaderAndData` shape). The
253/// abstraction exists to keep an additional impl for the bidirectional
254/// `HeaderOnly` + dial-in shape addition cheap when that lands.
255#[async_trait]
256trait IngressDispatch: Send + Sync {
257    type Request: PipelineIO;
258
259    async fn parse_and_build_request(
260        &self,
261        payload: Bytes,
262    ) -> Result<ParsedRequest<Self::Request>, PipelineError>;
263}
264
265#[async_trait]
266impl<T, U> IngressDispatch for Ingress<SingleIn<T>, ManyOut<U>>
267where
268    T: Data + for<'de> Deserialize<'de> + std::fmt::Debug,
269    U: Data + Serialize + MaybeError + std::fmt::Debug,
270{
271    type Request = SingleIn<T>;
272
273    async fn parse_and_build_request(
274        &self,
275        payload: Bytes,
276    ) -> Result<ParsedRequest<SingleIn<T>>, PipelineError> {
277        // decode the control message and the request
278        let msg = TwoPartCodec::default()
279            .decode_message(payload)?
280            .into_message_type();
281
282        // we must have a header and a body
283        let (control_msg, request_t) = match msg {
284            TwoPartMessageType::HeaderAndData(header, data) => {
285                tracing::trace!(
286                    "received two part message with ctrl: {} bytes, data: {} bytes",
287                    header.len(),
288                    data.len()
289                );
290                let control_msg: RequestControlMessage = match serde_json::from_slice(&header) {
291                    Ok(cm) => cm,
292                    Err(err) => {
293                        let json_str = String::from_utf8_lossy(&header);
294                        if let Some(m) = self.metrics() {
295                            m.error_counter
296                                .with_label_values(&[work_handler::error_types::DESERIALIZATION])
297                                .inc();
298                        }
299                        return Err(PipelineError::DeserializationError(format!(
300                            "Failed deserializing to RequestControlMessage. err={err}, json_str={json_str}, header_len={}",
301                            header.len(),
302                        )));
303                    }
304                };
305                let request_t: T = serde_json::from_slice(&data)?;
306                (control_msg, request_t)
307            }
308            _ => {
309                if let Some(m) = self.metrics() {
310                    m.error_counter
311                        .with_label_values(&[work_handler::error_types::INVALID_MESSAGE])
312                        .inc();
313                }
314                return Err(PipelineError::Generic(String::from(
315                    "Unexpected message from work queue; unable extract a TwoPartMessage with a header and data",
316                )));
317            }
318        };
319
320        // extend request with context
321        tracing::trace!(
322            request_id = %control_msg.id,
323            metadata_entries = control_msg.metadata.len(),
324            "received control message"
325        );
326        tracing::trace!("received request: {:?}", request_t);
327
328        let request: context::Context<T> =
329            Context::with_id_and_metadata(request_t, control_msg.id, control_msg.metadata);
330
331        Ok(ParsedRequest {
332            request,
333            response_connection_info: control_msg.connection_info,
334            frontend_send_ts_ns: control_msg.frontend_send_ts_ns,
335        })
336    }
337}
338
339impl<Req: PipelineIO + Sync, U> Ingress<Req, ManyOut<U>>
340where
341    U: Data + Serialize + MaybeError + std::fmt::Debug,
342{
343    /// Shared body of `PushWorkHandler::handle_payload` for every
344    /// `Ingress<Req, ManyOut<U>>` shape that has an [`IngressDispatch`]
345    /// impl. Sets up the inflight metrics guard, calls
346    /// `parse_and_build_request` for the wire-shape-specific request
347    /// building, opens the response stream uniformly, dispatches via
348    /// the engine, sends the prologue, and pumps the response through
349    /// [`Self::pump_response_stream`].
350    async fn handle_payload_shared(
351        &self,
352        payload: Bytes,
353        request_id: Option<String>,
354    ) -> Result<(), PipelineError>
355    where
356        Self: IngressDispatch<Request = Req>,
357    {
358        let t2_wallclock_ns = std::time::SystemTime::now()
359            .duration_since(std::time::UNIX_EPOCH)
360            .unwrap_or_default()
361            .as_nanos() as u64;
362        let start_time = std::time::Instant::now();
363
364        // Increment inflight and ensure it's decremented on all exits via RAII guard
365        let _inflight_guard = self.metrics().map(|m| {
366            m.request_counter.inc();
367            m.inflight_requests.inc();
368            m.request_bytes.inc_by(payload.len() as u64);
369            if let Some(rid) = &request_id {
370                tracing::info!(request_id = %rid, "request received");
371            }
372            RequestMetricsGuard {
373                inflight_requests: m.inflight_requests.clone(),
374                request_duration: m.request_duration.clone(),
375                start_time,
376                request_id: request_id.clone(),
377            }
378        });
379
380        let ParsedRequest {
381            request,
382            response_connection_info,
383            frontend_send_ts_ns,
384        } = self.parse_and_build_request(payload).await?;
385
386        // Compute network transit time (T2 - T1) using cross-process wall-clock timestamps
387        if let Some(t1_ns) = frontend_send_ts_ns {
388            let transit_ns = t2_wallclock_ns.saturating_sub(t1_ns);
389            WORK_HANDLER_NETWORK_TRANSIT_SECONDS.observe(transit_ns as f64 / 1_000_000_000.0);
390        }
391
392        // todo - eventually have a handler class which will returned an abstracted object, but for now,
393        // we only support tcp here, so we can just unwrap the connection info
394        tracing::trace!("creating tcp response stream");
395        let mut publisher = tcp::client::TcpClient::create_response_stream(
396            request.context(),
397            response_connection_info,
398            self.metrics().map(|m| m.cancellation_total.clone()),
399        )
400        .await
401        .map_err(|e| {
402            if let Some(m) = self.metrics() {
403                m.error_counter
404                    .with_label_values(&[work_handler::error_types::RESPONSE_STREAM])
405                    .inc();
406            }
407            PipelineError::Generic(format!("Failed to create response stream: {:?}", e,))
408        })?;
409
410        tracing::trace!("calling generate");
411        let stream = self
412            .segment
413            .get()
414            .expect("segment not set")
415            .generate(request)
416            .await
417            .map_err(|e| {
418                if let Some(m) = self.metrics() {
419                    m.error_counter
420                        .with_label_values(&[work_handler::error_types::GENERATE])
421                        .inc();
422                }
423                PipelineError::GenerateError(e)
424            });
425
426        // the prolouge is sent to the client to indicate that the stream is ready to receive data
427        // or if the generate call failed, the error is sent to the client
428        let stream = match stream {
429            Ok(stream) => {
430                tracing::trace!("Successfully generated response stream; sending prologue");
431                let _result = publisher.send_prologue(None).await;
432                WORK_HANDLER_TIME_TO_FIRST_RESPONSE_SECONDS
433                    .observe(start_time.elapsed().as_secs_f64());
434                stream
435            }
436            Err(e) => {
437                let error_string = e.to_string();
438
439                #[cfg(debug_assertions)]
440                {
441                    tracing::debug!(
442                        "Failed to generate response stream (with debug backtrace): {:?}",
443                        e
444                    );
445                }
446                #[cfg(not(debug_assertions))]
447                {
448                    tracing::error!("Failed to generate response stream: {error_string}");
449                }
450
451                let _result = publisher.send_prologue(Some(error_string)).await;
452                Err(e)?
453            }
454        };
455
456        self.pump_response_stream(stream, &publisher).await;
457
458        // Ensure the metrics guard is not dropped until the end of the function.
459        // Drop fires "request completed" log via RAII.
460        drop(_inflight_guard);
461
462        Ok(())
463    }
464}
465
466#[async_trait]
467impl<T, U> PushWorkHandler for Ingress<SingleIn<T>, ManyOut<U>>
468where
469    T: Data + for<'de> Deserialize<'de> + std::fmt::Debug,
470    U: Data + Serialize + MaybeError + std::fmt::Debug,
471{
472    fn add_metrics(
473        &self,
474        endpoint: &crate::component::Endpoint,
475        metrics_labels: Option<&[(&str, &str)]>,
476    ) -> Result<()> {
477        // Call the Ingress-specific add_metrics implementation
478        use crate::pipeline::network::Ingress;
479        Ingress::add_metrics(self, endpoint, metrics_labels)
480    }
481
482    fn set_endpoint_health_check_notifier(&self, notifier: Arc<tokio::sync::Notify>) -> Result<()> {
483        use crate::pipeline::network::Ingress;
484        self.endpoint_health_check_notifier
485            .set(notifier)
486            .map_err(|_| anyhow::anyhow!("Endpoint health check notifier already set"))?;
487        Ok(())
488    }
489
490    async fn handle_payload(
491        &self,
492        payload: Bytes,
493        request_id: Option<String>,
494    ) -> Result<(), PipelineError> {
495        self.handle_payload_shared(payload, request_id).await
496    }
497}