dynamo_runtime/pipeline/network/ingress/
push_handler.rs1use super::*;
5use crate::metrics::prometheus_names::work_handler;
6use crate::protocols::maybe_error::MaybeError;
7use prometheus::{Histogram, IntCounter, IntCounterVec, IntGauge};
8use serde::{Deserialize, Serialize};
9use std::sync::Arc;
10use std::time::Instant;
11use tracing::Instrument;
12use tracing::info_span;
13
14#[derive(Clone, Debug)]
16pub struct WorkHandlerMetrics {
17 pub request_counter: IntCounter,
18 pub request_duration: Histogram,
19 pub inflight_requests: IntGauge,
20 pub request_bytes: IntCounter,
21 pub response_bytes: IntCounter,
22 pub error_counter: IntCounterVec,
23}
24
25impl WorkHandlerMetrics {
26 pub fn new(
27 request_counter: IntCounter,
28 request_duration: Histogram,
29 inflight_requests: IntGauge,
30 request_bytes: IntCounter,
31 response_bytes: IntCounter,
32 error_counter: IntCounterVec,
33 ) -> Self {
34 Self {
35 request_counter,
36 request_duration,
37 inflight_requests,
38 request_bytes,
39 response_bytes,
40 error_counter,
41 }
42 }
43
44 pub fn from_endpoint(
46 endpoint: &crate::component::Endpoint,
47 metrics_labels: Option<&[(&str, &str)]>,
48 ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
49 let metrics_labels = metrics_labels.unwrap_or(&[]);
50 let request_counter = endpoint.create_intcounter(
51 work_handler::REQUESTS_TOTAL,
52 "Total number of requests processed by work handler",
53 metrics_labels,
54 )?;
55
56 let request_duration = endpoint.create_histogram(
57 work_handler::REQUEST_DURATION_SECONDS,
58 "Time spent processing requests by work handler",
59 metrics_labels,
60 None,
61 )?;
62
63 let inflight_requests = endpoint.create_intgauge(
64 work_handler::INFLIGHT_REQUESTS,
65 "Number of requests currently being processed by work handler",
66 metrics_labels,
67 )?;
68
69 let request_bytes = endpoint.create_intcounter(
70 work_handler::REQUEST_BYTES_TOTAL,
71 "Total number of bytes received in requests by work handler",
72 metrics_labels,
73 )?;
74
75 let response_bytes = endpoint.create_intcounter(
76 work_handler::RESPONSE_BYTES_TOTAL,
77 "Total number of bytes sent in responses by work handler",
78 metrics_labels,
79 )?;
80
81 let error_counter = endpoint.create_intcountervec(
82 work_handler::ERRORS_TOTAL,
83 "Total number of errors in work handler processing",
84 &[work_handler::ERROR_TYPE_LABEL],
85 metrics_labels,
86 )?;
87
88 Ok(Self::new(
89 request_counter,
90 request_duration,
91 inflight_requests,
92 request_bytes,
93 response_bytes,
94 error_counter,
95 ))
96 }
97}
98
99struct RequestMetricsGuard {
101 inflight_requests: prometheus::IntGauge,
102 request_duration: prometheus::Histogram,
103 start_time: Instant,
104}
105impl Drop for RequestMetricsGuard {
106 fn drop(&mut self) {
107 self.inflight_requests.dec();
108 self.request_duration
109 .observe(self.start_time.elapsed().as_secs_f64());
110 }
111}
112
113#[async_trait]
114impl<T: Data, U: Data> PushWorkHandler for Ingress<SingleIn<T>, ManyOut<U>>
115where
116 T: Data + for<'de> Deserialize<'de> + std::fmt::Debug,
117 U: Data + Serialize + MaybeError + std::fmt::Debug,
118{
119 fn add_metrics(
120 &self,
121 endpoint: &crate::component::Endpoint,
122 metrics_labels: Option<&[(&str, &str)]>,
123 ) -> Result<()> {
124 use crate::pipeline::network::Ingress;
126 Ingress::add_metrics(self, endpoint, metrics_labels)
127 }
128
129 fn set_endpoint_health_check_notifier(&self, notifier: Arc<tokio::sync::Notify>) -> Result<()> {
130 use crate::pipeline::network::Ingress;
131 self.endpoint_health_check_notifier
132 .set(notifier)
133 .map_err(|_| anyhow::anyhow!("Endpoint health check notifier already set"))?;
134 Ok(())
135 }
136
137 async fn handle_payload(&self, payload: Bytes) -> Result<(), PipelineError> {
138 let start_time = std::time::Instant::now();
139
140 let _inflight_guard = self.metrics().map(|m| {
142 m.request_counter.inc();
143 m.inflight_requests.inc();
144 m.request_bytes.inc_by(payload.len() as u64);
145 RequestMetricsGuard {
146 inflight_requests: m.inflight_requests.clone(),
147 request_duration: m.request_duration.clone(),
148 start_time,
149 }
150 });
151
152 let msg = TwoPartCodec::default()
154 .decode_message(payload)?
155 .into_message_type();
156
157 let (control_msg, request) = match msg {
160 TwoPartMessageType::HeaderAndData(header, data) => {
161 tracing::trace!(
162 "received two part message with ctrl: {} bytes, data: {} bytes",
163 header.len(),
164 data.len()
165 );
166 let control_msg: RequestControlMessage = match serde_json::from_slice(&header) {
167 Ok(cm) => cm,
168 Err(err) => {
169 let json_str = String::from_utf8_lossy(&header);
170 if let Some(m) = self.metrics() {
171 m.error_counter
172 .with_label_values(&[work_handler::error_types::DESERIALIZATION])
173 .inc();
174 }
175 return Err(PipelineError::DeserializationError(format!(
176 "Failed deserializing to RequestControlMessage. err={err}, json_str={json_str}"
177 )));
178 }
179 };
180 let request: T = serde_json::from_slice(&data)?;
181 (control_msg, request)
182 }
183 _ => {
184 if let Some(m) = self.metrics() {
185 m.error_counter
186 .with_label_values(&[work_handler::error_types::INVALID_MESSAGE])
187 .inc();
188 }
189 return Err(PipelineError::Generic(String::from(
190 "Unexpected message from work queue; unable extract a TwoPartMessage with a header and data",
191 )));
192 }
193 };
194
195 tracing::trace!("received control message: {:?}", control_msg);
197 tracing::trace!("received request: {:?}", request);
198 let request: context::Context<T> = Context::with_id(request, control_msg.id);
199
200 tracing::trace!("creating tcp response stream");
203 let mut publisher = tcp::client::TcpClient::create_response_stream(
204 request.context(),
205 control_msg.connection_info,
206 )
207 .await
208 .map_err(|e| {
209 if let Some(m) = self.metrics() {
210 m.error_counter
211 .with_label_values(&[work_handler::error_types::RESPONSE_STREAM])
212 .inc();
213 }
214 PipelineError::Generic(format!("Failed to create response stream: {:?}", e,))
215 })?;
216
217 tracing::trace!("calling generate");
218 let stream = self
219 .segment
220 .get()
221 .expect("segment not set")
222 .generate(request)
223 .await
224 .map_err(|e| {
225 if let Some(m) = self.metrics() {
226 m.error_counter
227 .with_label_values(&[work_handler::error_types::GENERATE])
228 .inc();
229 }
230 PipelineError::GenerateError(e)
231 });
232
233 let mut stream = match stream {
236 Ok(stream) => {
237 tracing::trace!("Successfully generated response stream; sending prologue");
238 let _result = publisher.send_prologue(None).await;
239 stream
240 }
241 Err(e) => {
242 let error_string = e.to_string();
243
244 #[cfg(debug_assertions)]
245 {
246 tracing::debug!(
247 "Failed to generate response stream (with debug backtrace): {:?}",
248 e
249 );
250 }
251 #[cfg(not(debug_assertions))]
252 {
253 tracing::error!("Failed to generate response stream: {}", error_string);
254 }
255
256 let _result = publisher.send_prologue(Some(error_string)).await;
257 Err(e)?
258 }
259 };
260
261 let context = stream.context();
262
263 let mut send_complete_final = true;
265 while let Some(resp) = stream.next().await {
266 tracing::trace!("Sending response: {:?}", resp);
267 if let Some(err) = resp.err()
268 && format!("{:?}", err) == STREAM_ERR_MSG
269 {
270 tracing::warn!(STREAM_ERR_MSG);
271 send_complete_final = false;
272 break;
273 }
274 let resp_wrapper = NetworkStreamWrapper {
275 data: Some(resp),
276 complete_final: false,
277 };
278 let resp_bytes = serde_json::to_vec(&resp_wrapper)
279 .expect("fatal error: invalid response object - this should never happen");
280 if let Some(m) = self.metrics() {
281 m.response_bytes.inc_by(resp_bytes.len() as u64);
282 }
283 if (publisher.send(resp_bytes.into()).await).is_err() {
284 tracing::error!("Failed to publish response for stream {}", context.id());
285 context.stop_generating();
286 send_complete_final = false;
287 if let Some(m) = self.metrics() {
288 m.error_counter
289 .with_label_values(&[work_handler::error_types::PUBLISH_RESPONSE])
290 .inc();
291 }
292 break;
293 }
294 }
295 if send_complete_final {
296 let resp_wrapper = NetworkStreamWrapper::<U> {
297 data: None,
298 complete_final: true,
299 };
300 let resp_bytes = serde_json::to_vec(&resp_wrapper)
301 .expect("fatal error: invalid response object - this should never happen");
302 if let Some(m) = self.metrics() {
303 m.response_bytes.inc_by(resp_bytes.len() as u64);
304 }
305 if (publisher.send(resp_bytes.into()).await).is_err() {
306 tracing::error!(
307 "Failed to publish complete final for stream {}",
308 context.id()
309 );
310 if let Some(m) = self.metrics() {
311 m.error_counter
312 .with_label_values(&[work_handler::error_types::PUBLISH_FINAL])
313 .inc();
314 }
315 }
316 if let Some(notifier) = self.endpoint_health_check_notifier.get() {
319 notifier.notify_one();
320 }
321 }
322
323 drop(_inflight_guard);
325
326 Ok(())
327 }
328}