dynamo_runtime/pipeline/network/ingress/
push_handler.rs1use super::*;
5
6use crate::metrics::prometheus_names::work_handler;
7use crate::metrics::work_handler_perf::{
8 WORK_HANDLER_NETWORK_TRANSIT_SECONDS, WORK_HANDLER_TIME_TO_FIRST_RESPONSE_SECONDS,
9};
10use crate::protocols::maybe_error::MaybeError;
11use prometheus::{Histogram, IntCounter, IntCounterVec, IntGauge};
12use serde::{Deserialize, Serialize};
13use std::sync::Arc;
14use std::time::Instant;
15use tracing::Instrument;
16use tracing::info_span;
17
18#[derive(Clone, Debug)]
20pub struct WorkHandlerMetrics {
21 pub request_counter: IntCounter,
22 pub request_duration: Histogram,
23 pub inflight_requests: IntGauge,
24 pub request_bytes: IntCounter,
25 pub response_bytes: IntCounter,
26 pub error_counter: IntCounterVec,
27 pub cancellation_total: IntCounter,
28}
29
30impl WorkHandlerMetrics {
31 pub fn new(
32 request_counter: IntCounter,
33 request_duration: Histogram,
34 inflight_requests: IntGauge,
35 request_bytes: IntCounter,
36 response_bytes: IntCounter,
37 error_counter: IntCounterVec,
38 cancellation_total: IntCounter,
39 ) -> Self {
40 Self {
41 request_counter,
42 request_duration,
43 inflight_requests,
44 request_bytes,
45 response_bytes,
46 error_counter,
47 cancellation_total,
48 }
49 }
50
51 pub fn from_endpoint(
53 endpoint: &crate::component::Endpoint,
54 metrics_labels: Option<&[(&str, &str)]>,
55 ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
56 let metrics_labels = metrics_labels.unwrap_or(&[]);
57 let metrics = endpoint.metrics();
58 let request_counter = metrics.create_intcounter(
59 work_handler::REQUESTS_TOTAL,
60 "Total number of requests processed by work handler",
61 metrics_labels,
62 )?;
63
64 let request_duration_buckets = vec![
68 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 20.0, 30.0, 60.0, 120.0,
69 300.0, 600.0,
70 ];
71 let request_duration = metrics.create_histogram(
72 work_handler::REQUEST_DURATION_SECONDS,
73 "Time spent processing requests by work handler",
74 metrics_labels,
75 Some(request_duration_buckets),
76 )?;
77
78 let inflight_requests = metrics.create_intgauge(
79 work_handler::INFLIGHT_REQUESTS,
80 "Number of requests currently being processed by work handler",
81 metrics_labels,
82 )?;
83
84 let request_bytes = metrics.create_intcounter(
85 work_handler::REQUEST_BYTES_TOTAL,
86 "Total number of bytes received in requests by work handler",
87 metrics_labels,
88 )?;
89
90 let response_bytes = metrics.create_intcounter(
91 work_handler::RESPONSE_BYTES_TOTAL,
92 "Total number of bytes sent in responses by work handler",
93 metrics_labels,
94 )?;
95
96 let error_counter = metrics.create_intcountervec(
97 work_handler::ERRORS_TOTAL,
98 "Total number of errors in work handler processing",
99 &[work_handler::ERROR_TYPE_LABEL],
100 metrics_labels,
101 )?;
102
103 let cancellation_total = metrics.create_intcounter(
104 work_handler::CANCELLATION_TOTAL,
105 "Total number of requests cancelled by work handler",
106 metrics_labels,
107 )?;
108
109 Ok(Self::new(
110 request_counter,
111 request_duration,
112 inflight_requests,
113 request_bytes,
114 response_bytes,
115 error_counter,
116 cancellation_total,
117 ))
118 }
119}
120
121struct RequestMetricsGuard {
124 inflight_requests: prometheus::IntGauge,
125 request_duration: prometheus::Histogram,
126 start_time: Instant,
127 request_id: Option<String>,
128}
129
130impl Drop for RequestMetricsGuard {
131 fn drop(&mut self) {
132 self.inflight_requests.dec();
133 self.request_duration
134 .observe(self.start_time.elapsed().as_secs_f64());
135 if let Some(request_id) = &self.request_id {
136 tracing::info!(request_id = %request_id, "request completed");
137 }
138 }
139}
140
141#[async_trait]
142impl<T: Data, U: Data> PushWorkHandler for Ingress<SingleIn<T>, ManyOut<U>>
143where
144 T: Data + for<'de> Deserialize<'de> + std::fmt::Debug,
145 U: Data + Serialize + MaybeError + std::fmt::Debug,
146{
147 fn add_metrics(
148 &self,
149 endpoint: &crate::component::Endpoint,
150 metrics_labels: Option<&[(&str, &str)]>,
151 ) -> Result<()> {
152 use crate::pipeline::network::Ingress;
154 Ingress::add_metrics(self, endpoint, metrics_labels)
155 }
156
157 fn set_endpoint_health_check_notifier(&self, notifier: Arc<tokio::sync::Notify>) -> Result<()> {
158 use crate::pipeline::network::Ingress;
159 self.endpoint_health_check_notifier
160 .set(notifier)
161 .map_err(|_| anyhow::anyhow!("Endpoint health check notifier already set"))?;
162 Ok(())
163 }
164
165 async fn handle_payload(
166 &self,
167 payload: Bytes,
168 request_id: Option<String>,
169 ) -> Result<(), PipelineError> {
170 let t2_wallclock_ns = std::time::SystemTime::now()
171 .duration_since(std::time::UNIX_EPOCH)
172 .unwrap_or_default()
173 .as_nanos() as u64;
174 let start_time = std::time::Instant::now();
175
176 let _inflight_guard = self.metrics().map(|m| {
178 m.request_counter.inc();
179 m.inflight_requests.inc();
180 m.request_bytes.inc_by(payload.len() as u64);
181 if let Some(rid) = &request_id {
182 tracing::info!(request_id = %rid, "request received");
183 }
184 RequestMetricsGuard {
185 inflight_requests: m.inflight_requests.clone(),
186 request_duration: m.request_duration.clone(),
187 start_time,
188 request_id: request_id.clone(),
189 }
190 });
191
192 let msg = TwoPartCodec::default()
194 .decode_message(payload)?
195 .into_message_type();
196
197 let (control_msg, request) = match msg {
200 TwoPartMessageType::HeaderAndData(header, data) => {
201 tracing::trace!(
202 "received two part message with ctrl: {} bytes, data: {} bytes",
203 header.len(),
204 data.len()
205 );
206 let control_msg: RequestControlMessage = match serde_json::from_slice(&header) {
207 Ok(cm) => cm,
208 Err(err) => {
209 let json_str = String::from_utf8_lossy(&header);
210 if let Some(m) = self.metrics() {
211 m.error_counter
212 .with_label_values(&[work_handler::error_types::DESERIALIZATION])
213 .inc();
214 }
215 return Err(PipelineError::DeserializationError(format!(
216 "Failed deserializing to RequestControlMessage. err={err}, json_str={json_str}"
217 )));
218 }
219 };
220 let request: T = serde_json::from_slice(&data)?;
221 (control_msg, request)
222 }
223 _ => {
224 if let Some(m) = self.metrics() {
225 m.error_counter
226 .with_label_values(&[work_handler::error_types::INVALID_MESSAGE])
227 .inc();
228 }
229 return Err(PipelineError::Generic(String::from(
230 "Unexpected message from work queue; unable extract a TwoPartMessage with a header and data",
231 )));
232 }
233 };
234
235 if let Some(t1_ns) = control_msg.frontend_send_ts_ns {
237 let transit_ns = t2_wallclock_ns.saturating_sub(t1_ns);
238 WORK_HANDLER_NETWORK_TRANSIT_SECONDS.observe(transit_ns as f64 / 1_000_000_000.0);
239 }
240
241 tracing::trace!("received control message: {:?}", control_msg);
243 tracing::trace!("received request: {:?}", request);
244 let request: context::Context<T> = Context::with_id(request, control_msg.id);
245
246 tracing::trace!("creating tcp response stream");
249 let mut publisher = tcp::client::TcpClient::create_response_stream(
250 request.context(),
251 control_msg.connection_info,
252 self.metrics().map(|m| m.cancellation_total.clone()),
253 )
254 .await
255 .map_err(|e| {
256 if let Some(m) = self.metrics() {
257 m.error_counter
258 .with_label_values(&[work_handler::error_types::RESPONSE_STREAM])
259 .inc();
260 }
261 PipelineError::Generic(format!("Failed to create response stream: {:?}", e,))
262 })?;
263
264 tracing::trace!("calling generate");
265 let stream = self
266 .segment
267 .get()
268 .expect("segment not set")
269 .generate(request)
270 .await
271 .map_err(|e| {
272 if let Some(m) = self.metrics() {
273 m.error_counter
274 .with_label_values(&[work_handler::error_types::GENERATE])
275 .inc();
276 }
277 PipelineError::GenerateError(e)
278 });
279
280 let mut stream = match stream {
283 Ok(stream) => {
284 tracing::trace!("Successfully generated response stream; sending prologue");
285 let _result = publisher.send_prologue(None).await;
286 WORK_HANDLER_TIME_TO_FIRST_RESPONSE_SECONDS
287 .observe(start_time.elapsed().as_secs_f64());
288 stream
289 }
290 Err(e) => {
291 let error_string = e.to_string();
292
293 #[cfg(debug_assertions)]
294 {
295 tracing::debug!(
296 "Failed to generate response stream (with debug backtrace): {:?}",
297 e
298 );
299 }
300 #[cfg(not(debug_assertions))]
301 {
302 tracing::error!("Failed to generate response stream: {error_string}");
303 }
304
305 let _result = publisher.send_prologue(Some(error_string)).await;
306 Err(e)?
307 }
308 };
309
310 let context = stream.context();
311
312 let mut send_complete_final = true;
314 while let Some(resp) = stream.next().await {
315 tracing::trace!("Sending response: {:?}", resp);
316 let resp_wrapper = NetworkStreamWrapper {
317 data: Some(resp),
318 complete_final: false,
319 };
320 let resp_bytes = serde_json::to_vec(&resp_wrapper)
321 .expect("fatal error: invalid response object - this should never happen");
322 if let Some(m) = self.metrics() {
323 m.response_bytes.inc_by(resp_bytes.len() as u64);
324 }
325 if (publisher.send(resp_bytes.into()).await).is_err() {
326 send_complete_final = false;
327 if context.is_stopped() {
328 tracing::warn!("Failed to publish response for stream {}", context.id());
337 } else {
338 tracing::error!("Failed to publish response for stream {}", context.id());
340 context.stop_generating();
341 }
342 if let Some(m) = self.metrics() {
345 m.error_counter
346 .with_label_values(&[work_handler::error_types::PUBLISH_RESPONSE])
347 .inc();
348 }
349 break;
350 }
351 }
352 if send_complete_final {
353 let resp_wrapper = NetworkStreamWrapper::<U> {
354 data: None,
355 complete_final: true,
356 };
357 let resp_bytes = serde_json::to_vec(&resp_wrapper)
358 .expect("fatal error: invalid response object - this should never happen");
359 if let Some(m) = self.metrics() {
360 m.response_bytes.inc_by(resp_bytes.len() as u64);
361 }
362 if (publisher.send(resp_bytes.into()).await).is_err() {
363 tracing::error!(
364 "Failed to publish complete final for stream {}",
365 context.id()
366 );
367 if let Some(m) = self.metrics() {
368 m.error_counter
369 .with_label_values(&[work_handler::error_types::PUBLISH_FINAL])
370 .inc();
371 }
372 }
373 if let Some(notifier) = self.endpoint_health_check_notifier.get() {
376 notifier.notify_one();
377 }
378 }
379
380 drop(_inflight_guard);
383
384 Ok(())
385 }
386}