dynamo_runtime/pipeline/network/ingress/
push_handler.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use super::*;
5use crate::metrics::prometheus_names::work_handler;
6use crate::protocols::maybe_error::MaybeError;
7use prometheus::{Histogram, IntCounter, IntCounterVec, IntGauge};
8use serde::{Deserialize, Serialize};
9use std::sync::Arc;
10use std::time::Instant;
11use tracing::Instrument;
12use tracing::info_span;
13
14/// Metrics configuration for profiling work handlers
15#[derive(Clone, Debug)]
16pub struct WorkHandlerMetrics {
17    pub request_counter: IntCounter,
18    pub request_duration: Histogram,
19    pub inflight_requests: IntGauge,
20    pub request_bytes: IntCounter,
21    pub response_bytes: IntCounter,
22    pub error_counter: IntCounterVec,
23}
24
25impl WorkHandlerMetrics {
26    pub fn new(
27        request_counter: IntCounter,
28        request_duration: Histogram,
29        inflight_requests: IntGauge,
30        request_bytes: IntCounter,
31        response_bytes: IntCounter,
32        error_counter: IntCounterVec,
33    ) -> Self {
34        Self {
35            request_counter,
36            request_duration,
37            inflight_requests,
38            request_bytes,
39            response_bytes,
40            error_counter,
41        }
42    }
43
44    /// Create WorkHandlerMetrics from an endpoint using its built-in labeling
45    pub fn from_endpoint(
46        endpoint: &crate::component::Endpoint,
47        metrics_labels: Option<&[(&str, &str)]>,
48    ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
49        let metrics_labels = metrics_labels.unwrap_or(&[]);
50        let request_counter = endpoint.create_intcounter(
51            work_handler::REQUESTS_TOTAL,
52            "Total number of requests processed by work handler",
53            metrics_labels,
54        )?;
55
56        let request_duration = endpoint.create_histogram(
57            work_handler::REQUEST_DURATION_SECONDS,
58            "Time spent processing requests by work handler",
59            metrics_labels,
60            None,
61        )?;
62
63        let inflight_requests = endpoint.create_intgauge(
64            work_handler::INFLIGHT_REQUESTS,
65            "Number of requests currently being processed by work handler",
66            metrics_labels,
67        )?;
68
69        let request_bytes = endpoint.create_intcounter(
70            work_handler::REQUEST_BYTES_TOTAL,
71            "Total number of bytes received in requests by work handler",
72            metrics_labels,
73        )?;
74
75        let response_bytes = endpoint.create_intcounter(
76            work_handler::RESPONSE_BYTES_TOTAL,
77            "Total number of bytes sent in responses by work handler",
78            metrics_labels,
79        )?;
80
81        let error_counter = endpoint.create_intcountervec(
82            work_handler::ERRORS_TOTAL,
83            "Total number of errors in work handler processing",
84            &[work_handler::ERROR_TYPE_LABEL],
85            metrics_labels,
86        )?;
87
88        Ok(Self::new(
89            request_counter,
90            request_duration,
91            inflight_requests,
92            request_bytes,
93            response_bytes,
94            error_counter,
95        ))
96    }
97}
98
99// RAII guard to ensure inflight gauge is decremented and request duration is observed on all code paths.
100struct RequestMetricsGuard {
101    inflight_requests: prometheus::IntGauge,
102    request_duration: prometheus::Histogram,
103    start_time: Instant,
104}
105impl Drop for RequestMetricsGuard {
106    fn drop(&mut self) {
107        self.inflight_requests.dec();
108        self.request_duration
109            .observe(self.start_time.elapsed().as_secs_f64());
110    }
111}
112
113#[async_trait]
114impl<T: Data, U: Data> PushWorkHandler for Ingress<SingleIn<T>, ManyOut<U>>
115where
116    T: Data + for<'de> Deserialize<'de> + std::fmt::Debug,
117    U: Data + Serialize + MaybeError + std::fmt::Debug,
118{
119    fn add_metrics(
120        &self,
121        endpoint: &crate::component::Endpoint,
122        metrics_labels: Option<&[(&str, &str)]>,
123    ) -> Result<()> {
124        // Call the Ingress-specific add_metrics implementation
125        use crate::pipeline::network::Ingress;
126        Ingress::add_metrics(self, endpoint, metrics_labels)
127    }
128
129    fn set_endpoint_health_check_notifier(&self, notifier: Arc<tokio::sync::Notify>) -> Result<()> {
130        use crate::pipeline::network::Ingress;
131        self.endpoint_health_check_notifier
132            .set(notifier)
133            .map_err(|_| anyhow::anyhow!("Endpoint health check notifier already set"))?;
134        Ok(())
135    }
136
137    async fn handle_payload(&self, payload: Bytes) -> Result<(), PipelineError> {
138        let start_time = std::time::Instant::now();
139
140        // Increment inflight and ensure it's decremented on all exits via RAII guard
141        let _inflight_guard = self.metrics().map(|m| {
142            m.request_counter.inc();
143            m.inflight_requests.inc();
144            m.request_bytes.inc_by(payload.len() as u64);
145            RequestMetricsGuard {
146                inflight_requests: m.inflight_requests.clone(),
147                request_duration: m.request_duration.clone(),
148                start_time,
149            }
150        });
151
152        // decode the control message and the request
153        let msg = TwoPartCodec::default()
154            .decode_message(payload)?
155            .into_message_type();
156
157        // we must have a header and a body
158        // it will be held by this closure as a Some(permit)
159        let (control_msg, request) = match msg {
160            TwoPartMessageType::HeaderAndData(header, data) => {
161                tracing::trace!(
162                    "received two part message with ctrl: {} bytes, data: {} bytes",
163                    header.len(),
164                    data.len()
165                );
166                let control_msg: RequestControlMessage = match serde_json::from_slice(&header) {
167                    Ok(cm) => cm,
168                    Err(err) => {
169                        let json_str = String::from_utf8_lossy(&header);
170                        if let Some(m) = self.metrics() {
171                            m.error_counter
172                                .with_label_values(&[work_handler::error_types::DESERIALIZATION])
173                                .inc();
174                        }
175                        return Err(PipelineError::DeserializationError(format!(
176                            "Failed deserializing to RequestControlMessage. err={err}, json_str={json_str}"
177                        )));
178                    }
179                };
180                let request: T = serde_json::from_slice(&data)?;
181                (control_msg, request)
182            }
183            _ => {
184                if let Some(m) = self.metrics() {
185                    m.error_counter
186                        .with_label_values(&[work_handler::error_types::INVALID_MESSAGE])
187                        .inc();
188                }
189                return Err(PipelineError::Generic(String::from(
190                    "Unexpected message from work queue; unable extract a TwoPartMessage with a header and data",
191                )));
192            }
193        };
194
195        // extend request with context
196        tracing::trace!("received control message: {:?}", control_msg);
197        tracing::trace!("received request: {:?}", request);
198        let request: context::Context<T> = Context::with_id(request, control_msg.id);
199
200        // todo - eventually have a handler class which will returned an abstracted object, but for now,
201        // we only support tcp here, so we can just unwrap the connection info
202        tracing::trace!("creating tcp response stream");
203        let mut publisher = tcp::client::TcpClient::create_response_stream(
204            request.context(),
205            control_msg.connection_info,
206        )
207        .await
208        .map_err(|e| {
209            if let Some(m) = self.metrics() {
210                m.error_counter
211                    .with_label_values(&[work_handler::error_types::RESPONSE_STREAM])
212                    .inc();
213            }
214            PipelineError::Generic(format!("Failed to create response stream: {:?}", e,))
215        })?;
216
217        tracing::trace!("calling generate");
218        let stream = self
219            .segment
220            .get()
221            .expect("segment not set")
222            .generate(request)
223            .await
224            .map_err(|e| {
225                if let Some(m) = self.metrics() {
226                    m.error_counter
227                        .with_label_values(&[work_handler::error_types::GENERATE])
228                        .inc();
229                }
230                PipelineError::GenerateError(e)
231            });
232
233        // the prolouge is sent to the client to indicate that the stream is ready to receive data
234        // or if the generate call failed, the error is sent to the client
235        let mut stream = match stream {
236            Ok(stream) => {
237                tracing::trace!("Successfully generated response stream; sending prologue");
238                let _result = publisher.send_prologue(None).await;
239                stream
240            }
241            Err(e) => {
242                let error_string = e.to_string();
243
244                #[cfg(debug_assertions)]
245                {
246                    tracing::debug!(
247                        "Failed to generate response stream (with debug backtrace): {:?}",
248                        e
249                    );
250                }
251                #[cfg(not(debug_assertions))]
252                {
253                    tracing::error!("Failed to generate response stream: {}", error_string);
254                }
255
256                let _result = publisher.send_prologue(Some(error_string)).await;
257                Err(e)?
258            }
259        };
260
261        let context = stream.context();
262
263        // TODO: Detect end-of-stream using Server-Sent Events (SSE)
264        let mut send_complete_final = true;
265        while let Some(resp) = stream.next().await {
266            tracing::trace!("Sending response: {:?}", resp);
267            if let Some(err) = resp.err()
268                && format!("{:?}", err) == STREAM_ERR_MSG
269            {
270                tracing::warn!(STREAM_ERR_MSG);
271                send_complete_final = false;
272                break;
273            }
274            let resp_wrapper = NetworkStreamWrapper {
275                data: Some(resp),
276                complete_final: false,
277            };
278            let resp_bytes = serde_json::to_vec(&resp_wrapper)
279                .expect("fatal error: invalid response object - this should never happen");
280            if let Some(m) = self.metrics() {
281                m.response_bytes.inc_by(resp_bytes.len() as u64);
282            }
283            if (publisher.send(resp_bytes.into()).await).is_err() {
284                tracing::error!("Failed to publish response for stream {}", context.id());
285                context.stop_generating();
286                send_complete_final = false;
287                if let Some(m) = self.metrics() {
288                    m.error_counter
289                        .with_label_values(&[work_handler::error_types::PUBLISH_RESPONSE])
290                        .inc();
291                }
292                break;
293            }
294        }
295        if send_complete_final {
296            let resp_wrapper = NetworkStreamWrapper::<U> {
297                data: None,
298                complete_final: true,
299            };
300            let resp_bytes = serde_json::to_vec(&resp_wrapper)
301                .expect("fatal error: invalid response object - this should never happen");
302            if let Some(m) = self.metrics() {
303                m.response_bytes.inc_by(resp_bytes.len() as u64);
304            }
305            if (publisher.send(resp_bytes.into()).await).is_err() {
306                tracing::error!(
307                    "Failed to publish complete final for stream {}",
308                    context.id()
309                );
310                if let Some(m) = self.metrics() {
311                    m.error_counter
312                        .with_label_values(&[work_handler::error_types::PUBLISH_FINAL])
313                        .inc();
314                }
315            }
316            // Notify the health check manager that the stream has finished.
317            // This resets the timer, delaying the next canary health check.
318            if let Some(notifier) = self.endpoint_health_check_notifier.get() {
319                notifier.notify_one();
320            }
321        }
322
323        // Ensure the metrics guard is not dropped until the end of the function.
324        drop(_inflight_guard);
325
326        Ok(())
327    }
328}