dynamo_runtime/pipeline/network/ingress/
push_handler.rs1use super::*;
5
6use crate::metrics::prometheus_names::work_handler;
7use crate::protocols::maybe_error::MaybeError;
8use prometheus::{Histogram, IntCounter, IntCounterVec, IntGauge};
9use serde::{Deserialize, Serialize};
10use std::sync::Arc;
11use std::time::Instant;
12use tracing::Instrument;
13use tracing::info_span;
14
15#[derive(Clone, Debug)]
17pub struct WorkHandlerMetrics {
18 pub request_counter: IntCounter,
19 pub request_duration: Histogram,
20 pub inflight_requests: IntGauge,
21 pub request_bytes: IntCounter,
22 pub response_bytes: IntCounter,
23 pub error_counter: IntCounterVec,
24}
25
26impl WorkHandlerMetrics {
27 pub fn new(
28 request_counter: IntCounter,
29 request_duration: Histogram,
30 inflight_requests: IntGauge,
31 request_bytes: IntCounter,
32 response_bytes: IntCounter,
33 error_counter: IntCounterVec,
34 ) -> Self {
35 Self {
36 request_counter,
37 request_duration,
38 inflight_requests,
39 request_bytes,
40 response_bytes,
41 error_counter,
42 }
43 }
44
45 pub fn from_endpoint(
47 endpoint: &crate::component::Endpoint,
48 metrics_labels: Option<&[(&str, &str)]>,
49 ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
50 let metrics_labels = metrics_labels.unwrap_or(&[]);
51 let metrics = endpoint.metrics();
52 let request_counter = metrics.create_intcounter(
53 work_handler::REQUESTS_TOTAL,
54 "Total number of requests processed by work handler",
55 metrics_labels,
56 )?;
57
58 let request_duration = metrics.create_histogram(
59 work_handler::REQUEST_DURATION_SECONDS,
60 "Time spent processing requests by work handler",
61 metrics_labels,
62 None,
63 )?;
64
65 let inflight_requests = metrics.create_intgauge(
66 work_handler::INFLIGHT_REQUESTS,
67 "Number of requests currently being processed by work handler",
68 metrics_labels,
69 )?;
70
71 let request_bytes = metrics.create_intcounter(
72 work_handler::REQUEST_BYTES_TOTAL,
73 "Total number of bytes received in requests by work handler",
74 metrics_labels,
75 )?;
76
77 let response_bytes = metrics.create_intcounter(
78 work_handler::RESPONSE_BYTES_TOTAL,
79 "Total number of bytes sent in responses by work handler",
80 metrics_labels,
81 )?;
82
83 let error_counter = metrics.create_intcountervec(
84 work_handler::ERRORS_TOTAL,
85 "Total number of errors in work handler processing",
86 &[work_handler::ERROR_TYPE_LABEL],
87 metrics_labels,
88 )?;
89
90 Ok(Self::new(
91 request_counter,
92 request_duration,
93 inflight_requests,
94 request_bytes,
95 response_bytes,
96 error_counter,
97 ))
98 }
99}
100
101struct RequestMetricsGuard {
103 inflight_requests: prometheus::IntGauge,
104 request_duration: prometheus::Histogram,
105 start_time: Instant,
106}
107impl Drop for RequestMetricsGuard {
108 fn drop(&mut self) {
109 self.inflight_requests.dec();
110 self.request_duration
111 .observe(self.start_time.elapsed().as_secs_f64());
112 }
113}
114
115#[async_trait]
116impl<T: Data, U: Data> PushWorkHandler for Ingress<SingleIn<T>, ManyOut<U>>
117where
118 T: Data + for<'de> Deserialize<'de> + std::fmt::Debug,
119 U: Data + Serialize + MaybeError + std::fmt::Debug,
120{
121 fn add_metrics(
122 &self,
123 endpoint: &crate::component::Endpoint,
124 metrics_labels: Option<&[(&str, &str)]>,
125 ) -> Result<()> {
126 use crate::pipeline::network::Ingress;
128 Ingress::add_metrics(self, endpoint, metrics_labels)
129 }
130
131 fn set_endpoint_health_check_notifier(&self, notifier: Arc<tokio::sync::Notify>) -> Result<()> {
132 use crate::pipeline::network::Ingress;
133 self.endpoint_health_check_notifier
134 .set(notifier)
135 .map_err(|_| anyhow::anyhow!("Endpoint health check notifier already set"))?;
136 Ok(())
137 }
138
139 async fn handle_payload(&self, payload: Bytes) -> Result<(), PipelineError> {
140 let start_time = std::time::Instant::now();
141
142 let _inflight_guard = self.metrics().map(|m| {
144 m.request_counter.inc();
145 m.inflight_requests.inc();
146 m.request_bytes.inc_by(payload.len() as u64);
147 RequestMetricsGuard {
148 inflight_requests: m.inflight_requests.clone(),
149 request_duration: m.request_duration.clone(),
150 start_time,
151 }
152 });
153
154 let msg = TwoPartCodec::default()
156 .decode_message(payload)?
157 .into_message_type();
158
159 let (control_msg, request) = match msg {
162 TwoPartMessageType::HeaderAndData(header, data) => {
163 tracing::trace!(
164 "received two part message with ctrl: {} bytes, data: {} bytes",
165 header.len(),
166 data.len()
167 );
168 let control_msg: RequestControlMessage = match serde_json::from_slice(&header) {
169 Ok(cm) => cm,
170 Err(err) => {
171 let json_str = String::from_utf8_lossy(&header);
172 if let Some(m) = self.metrics() {
173 m.error_counter
174 .with_label_values(&[work_handler::error_types::DESERIALIZATION])
175 .inc();
176 }
177 return Err(PipelineError::DeserializationError(format!(
178 "Failed deserializing to RequestControlMessage. err={err}, json_str={json_str}"
179 )));
180 }
181 };
182 let request: T = serde_json::from_slice(&data)?;
183 (control_msg, request)
184 }
185 _ => {
186 if let Some(m) = self.metrics() {
187 m.error_counter
188 .with_label_values(&[work_handler::error_types::INVALID_MESSAGE])
189 .inc();
190 }
191 return Err(PipelineError::Generic(String::from(
192 "Unexpected message from work queue; unable extract a TwoPartMessage with a header and data",
193 )));
194 }
195 };
196
197 tracing::trace!("received control message: {:?}", control_msg);
199 tracing::trace!("received request: {:?}", request);
200 let request: context::Context<T> = Context::with_id(request, control_msg.id);
201
202 tracing::trace!("creating tcp response stream");
205 let mut publisher = tcp::client::TcpClient::create_response_stream(
206 request.context(),
207 control_msg.connection_info,
208 )
209 .await
210 .map_err(|e| {
211 if let Some(m) = self.metrics() {
212 m.error_counter
213 .with_label_values(&[work_handler::error_types::RESPONSE_STREAM])
214 .inc();
215 }
216 PipelineError::Generic(format!("Failed to create response stream: {:?}", e,))
217 })?;
218
219 tracing::trace!("calling generate");
220 let stream = self
221 .segment
222 .get()
223 .expect("segment not set")
224 .generate(request)
225 .await
226 .map_err(|e| {
227 if let Some(m) = self.metrics() {
228 m.error_counter
229 .with_label_values(&[work_handler::error_types::GENERATE])
230 .inc();
231 }
232 PipelineError::GenerateError(e)
233 });
234
235 let mut stream = match stream {
238 Ok(stream) => {
239 tracing::trace!("Successfully generated response stream; sending prologue");
240 let _result = publisher.send_prologue(None).await;
241 stream
242 }
243 Err(e) => {
244 let error_string = e.to_string();
245
246 #[cfg(debug_assertions)]
247 {
248 tracing::debug!(
249 "Failed to generate response stream (with debug backtrace): {:?}",
250 e
251 );
252 }
253 #[cfg(not(debug_assertions))]
254 {
255 tracing::error!("Failed to generate response stream: {}", error_string);
256 }
257
258 let _result = publisher.send_prologue(Some(error_string)).await;
259 Err(e)?
260 }
261 };
262
263 let context = stream.context();
264
265 let mut send_complete_final = true;
267 while let Some(resp) = stream.next().await {
268 tracing::trace!("Sending response: {:?}", resp);
269 let resp_wrapper = NetworkStreamWrapper {
270 data: Some(resp),
271 complete_final: false,
272 };
273 let resp_bytes = serde_json::to_vec(&resp_wrapper)
274 .expect("fatal error: invalid response object - this should never happen");
275 if let Some(m) = self.metrics() {
276 m.response_bytes.inc_by(resp_bytes.len() as u64);
277 }
278 if (publisher.send(resp_bytes.into()).await).is_err() {
279 send_complete_final = false;
280 if context.is_stopped() {
281 tracing::warn!("Failed to publish response for stream {}", context.id());
290 } else {
291 tracing::error!("Failed to publish response for stream {}", context.id());
293 context.stop_generating();
294 }
295 if let Some(m) = self.metrics() {
298 m.error_counter
299 .with_label_values(&[work_handler::error_types::PUBLISH_RESPONSE])
300 .inc();
301 }
302 break;
303 }
304 }
305 if send_complete_final {
306 let resp_wrapper = NetworkStreamWrapper::<U> {
307 data: None,
308 complete_final: true,
309 };
310 let resp_bytes = serde_json::to_vec(&resp_wrapper)
311 .expect("fatal error: invalid response object - this should never happen");
312 if let Some(m) = self.metrics() {
313 m.response_bytes.inc_by(resp_bytes.len() as u64);
314 }
315 if (publisher.send(resp_bytes.into()).await).is_err() {
316 tracing::error!(
317 "Failed to publish complete final for stream {}",
318 context.id()
319 );
320 if let Some(m) = self.metrics() {
321 m.error_counter
322 .with_label_values(&[work_handler::error_types::PUBLISH_FINAL])
323 .inc();
324 }
325 }
326 if let Some(notifier) = self.endpoint_health_check_notifier.get() {
329 notifier.notify_one();
330 }
331 }
332
333 drop(_inflight_guard);
335
336 Ok(())
337 }
338}