dynamo_runtime/pipeline/network/ingress/
push_handler.rs1use super::*;
5use crate::metrics::prometheus_names::work_handler;
6use crate::protocols::maybe_error::MaybeError;
7use prometheus::{Histogram, IntCounter, IntCounterVec, IntGauge};
8use serde::{Deserialize, Serialize};
9use std::sync::Arc;
10use std::time::Instant;
11use tracing::Instrument;
12use tracing::info_span;
13
14#[derive(Clone, Debug)]
16pub struct WorkHandlerMetrics {
17 pub request_counter: IntCounter,
18 pub request_duration: Histogram,
19 pub inflight_requests: IntGauge,
20 pub request_bytes: IntCounter,
21 pub response_bytes: IntCounter,
22 pub error_counter: IntCounterVec,
23}
24
25impl WorkHandlerMetrics {
26 pub fn new(
27 request_counter: IntCounter,
28 request_duration: Histogram,
29 inflight_requests: IntGauge,
30 request_bytes: IntCounter,
31 response_bytes: IntCounter,
32 error_counter: IntCounterVec,
33 ) -> Self {
34 Self {
35 request_counter,
36 request_duration,
37 inflight_requests,
38 request_bytes,
39 response_bytes,
40 error_counter,
41 }
42 }
43
44 pub fn from_endpoint(
46 endpoint: &crate::component::Endpoint,
47 metrics_labels: Option<&[(&str, &str)]>,
48 ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
49 let metrics_labels = metrics_labels.unwrap_or(&[]);
50 let metrics = endpoint.metrics();
51 let request_counter = metrics.create_intcounter(
52 work_handler::REQUESTS_TOTAL,
53 "Total number of requests processed by work handler",
54 metrics_labels,
55 )?;
56
57 let request_duration = metrics.create_histogram(
58 work_handler::REQUEST_DURATION_SECONDS,
59 "Time spent processing requests by work handler",
60 metrics_labels,
61 None,
62 )?;
63
64 let inflight_requests = metrics.create_intgauge(
65 work_handler::INFLIGHT_REQUESTS,
66 "Number of requests currently being processed by work handler",
67 metrics_labels,
68 )?;
69
70 let request_bytes = metrics.create_intcounter(
71 work_handler::REQUEST_BYTES_TOTAL,
72 "Total number of bytes received in requests by work handler",
73 metrics_labels,
74 )?;
75
76 let response_bytes = metrics.create_intcounter(
77 work_handler::RESPONSE_BYTES_TOTAL,
78 "Total number of bytes sent in responses by work handler",
79 metrics_labels,
80 )?;
81
82 let error_counter = metrics.create_intcountervec(
83 work_handler::ERRORS_TOTAL,
84 "Total number of errors in work handler processing",
85 &[work_handler::ERROR_TYPE_LABEL],
86 metrics_labels,
87 )?;
88
89 Ok(Self::new(
90 request_counter,
91 request_duration,
92 inflight_requests,
93 request_bytes,
94 response_bytes,
95 error_counter,
96 ))
97 }
98}
99
100struct RequestMetricsGuard {
102 inflight_requests: prometheus::IntGauge,
103 request_duration: prometheus::Histogram,
104 start_time: Instant,
105}
106impl Drop for RequestMetricsGuard {
107 fn drop(&mut self) {
108 self.inflight_requests.dec();
109 self.request_duration
110 .observe(self.start_time.elapsed().as_secs_f64());
111 }
112}
113
114#[async_trait]
115impl<T: Data, U: Data> PushWorkHandler for Ingress<SingleIn<T>, ManyOut<U>>
116where
117 T: Data + for<'de> Deserialize<'de> + std::fmt::Debug,
118 U: Data + Serialize + MaybeError + std::fmt::Debug,
119{
120 fn add_metrics(
121 &self,
122 endpoint: &crate::component::Endpoint,
123 metrics_labels: Option<&[(&str, &str)]>,
124 ) -> Result<()> {
125 use crate::pipeline::network::Ingress;
127 Ingress::add_metrics(self, endpoint, metrics_labels)
128 }
129
130 fn set_endpoint_health_check_notifier(&self, notifier: Arc<tokio::sync::Notify>) -> Result<()> {
131 use crate::pipeline::network::Ingress;
132 self.endpoint_health_check_notifier
133 .set(notifier)
134 .map_err(|_| anyhow::anyhow!("Endpoint health check notifier already set"))?;
135 Ok(())
136 }
137
138 async fn handle_payload(&self, payload: Bytes) -> Result<(), PipelineError> {
139 let start_time = std::time::Instant::now();
140
141 let _inflight_guard = self.metrics().map(|m| {
143 m.request_counter.inc();
144 m.inflight_requests.inc();
145 m.request_bytes.inc_by(payload.len() as u64);
146 RequestMetricsGuard {
147 inflight_requests: m.inflight_requests.clone(),
148 request_duration: m.request_duration.clone(),
149 start_time,
150 }
151 });
152
153 let msg = TwoPartCodec::default()
155 .decode_message(payload)?
156 .into_message_type();
157
158 let (control_msg, request) = match msg {
161 TwoPartMessageType::HeaderAndData(header, data) => {
162 tracing::trace!(
163 "received two part message with ctrl: {} bytes, data: {} bytes",
164 header.len(),
165 data.len()
166 );
167 let control_msg: RequestControlMessage = match serde_json::from_slice(&header) {
168 Ok(cm) => cm,
169 Err(err) => {
170 let json_str = String::from_utf8_lossy(&header);
171 if let Some(m) = self.metrics() {
172 m.error_counter
173 .with_label_values(&[work_handler::error_types::DESERIALIZATION])
174 .inc();
175 }
176 return Err(PipelineError::DeserializationError(format!(
177 "Failed deserializing to RequestControlMessage. err={err}, json_str={json_str}"
178 )));
179 }
180 };
181 let request: T = serde_json::from_slice(&data)?;
182 (control_msg, request)
183 }
184 _ => {
185 if let Some(m) = self.metrics() {
186 m.error_counter
187 .with_label_values(&[work_handler::error_types::INVALID_MESSAGE])
188 .inc();
189 }
190 return Err(PipelineError::Generic(String::from(
191 "Unexpected message from work queue; unable extract a TwoPartMessage with a header and data",
192 )));
193 }
194 };
195
196 tracing::trace!("received control message: {:?}", control_msg);
198 tracing::trace!("received request: {:?}", request);
199 let request: context::Context<T> = Context::with_id(request, control_msg.id);
200
201 tracing::trace!("creating tcp response stream");
204 let mut publisher = tcp::client::TcpClient::create_response_stream(
205 request.context(),
206 control_msg.connection_info,
207 )
208 .await
209 .map_err(|e| {
210 if let Some(m) = self.metrics() {
211 m.error_counter
212 .with_label_values(&[work_handler::error_types::RESPONSE_STREAM])
213 .inc();
214 }
215 PipelineError::Generic(format!("Failed to create response stream: {:?}", e,))
216 })?;
217
218 tracing::trace!("calling generate");
219 let stream = self
220 .segment
221 .get()
222 .expect("segment not set")
223 .generate(request)
224 .await
225 .map_err(|e| {
226 if let Some(m) = self.metrics() {
227 m.error_counter
228 .with_label_values(&[work_handler::error_types::GENERATE])
229 .inc();
230 }
231 PipelineError::GenerateError(e)
232 });
233
234 let mut stream = match stream {
237 Ok(stream) => {
238 tracing::trace!("Successfully generated response stream; sending prologue");
239 let _result = publisher.send_prologue(None).await;
240 stream
241 }
242 Err(e) => {
243 let error_string = e.to_string();
244
245 #[cfg(debug_assertions)]
246 {
247 tracing::debug!(
248 "Failed to generate response stream (with debug backtrace): {:?}",
249 e
250 );
251 }
252 #[cfg(not(debug_assertions))]
253 {
254 tracing::error!("Failed to generate response stream: {}", error_string);
255 }
256
257 let _result = publisher.send_prologue(Some(error_string)).await;
258 Err(e)?
259 }
260 };
261
262 let context = stream.context();
263
264 let mut send_complete_final = true;
266 while let Some(resp) = stream.next().await {
267 tracing::trace!("Sending response: {:?}", resp);
268 if let Some(err) = resp.err()
269 && format!("{:?}", err) == STREAM_ERR_MSG
270 {
271 tracing::warn!(STREAM_ERR_MSG);
272 send_complete_final = false;
273 break;
274 }
275 let resp_wrapper = NetworkStreamWrapper {
276 data: Some(resp),
277 complete_final: false,
278 };
279 let resp_bytes = serde_json::to_vec(&resp_wrapper)
280 .expect("fatal error: invalid response object - this should never happen");
281 if let Some(m) = self.metrics() {
282 m.response_bytes.inc_by(resp_bytes.len() as u64);
283 }
284 if (publisher.send(resp_bytes.into()).await).is_err() {
285 send_complete_final = false;
286 if context.is_stopped() {
287 tracing::warn!("Failed to publish response for stream {}", context.id());
296 } else {
297 tracing::error!("Failed to publish response for stream {}", context.id());
299 context.stop_generating();
300 }
301 if let Some(m) = self.metrics() {
304 m.error_counter
305 .with_label_values(&[work_handler::error_types::PUBLISH_RESPONSE])
306 .inc();
307 }
308 break;
309 }
310 }
311 if send_complete_final {
312 let resp_wrapper = NetworkStreamWrapper::<U> {
313 data: None,
314 complete_final: true,
315 };
316 let resp_bytes = serde_json::to_vec(&resp_wrapper)
317 .expect("fatal error: invalid response object - this should never happen");
318 if let Some(m) = self.metrics() {
319 m.response_bytes.inc_by(resp_bytes.len() as u64);
320 }
321 if (publisher.send(resp_bytes.into()).await).is_err() {
322 tracing::error!(
323 "Failed to publish complete final for stream {}",
324 context.id()
325 );
326 if let Some(m) = self.metrics() {
327 m.error_counter
328 .with_label_values(&[work_handler::error_types::PUBLISH_FINAL])
329 .inc();
330 }
331 }
332 if let Some(notifier) = self.endpoint_health_check_notifier.get() {
335 notifier.notify_one();
336 }
337 }
338
339 drop(_inflight_guard);
341
342 Ok(())
343 }
344}