1use super::*;
5
6use crate::metrics::prometheus_names::work_handler;
7use crate::metrics::work_handler_perf::{
8 WORK_HANDLER_NETWORK_TRANSIT_SECONDS, WORK_HANDLER_TIME_TO_FIRST_RESPONSE_SECONDS,
9};
10use crate::protocols::maybe_error::MaybeError;
11use prometheus::{Histogram, IntCounter, IntCounterVec, IntGauge};
12use serde::{Deserialize, Serialize};
13use std::sync::Arc;
14use std::time::Instant;
15use tracing::Instrument;
16use tracing::info_span;
17
18#[derive(Clone, Debug)]
20pub struct WorkHandlerMetrics {
21 pub request_counter: IntCounter,
22 pub request_duration: Histogram,
23 pub inflight_requests: IntGauge,
24 pub request_bytes: IntCounter,
25 pub response_bytes: IntCounter,
26 pub error_counter: IntCounterVec,
27 pub cancellation_total: IntCounter,
28}
29
30impl WorkHandlerMetrics {
31 pub fn new(
32 request_counter: IntCounter,
33 request_duration: Histogram,
34 inflight_requests: IntGauge,
35 request_bytes: IntCounter,
36 response_bytes: IntCounter,
37 error_counter: IntCounterVec,
38 cancellation_total: IntCounter,
39 ) -> Self {
40 Self {
41 request_counter,
42 request_duration,
43 inflight_requests,
44 request_bytes,
45 response_bytes,
46 error_counter,
47 cancellation_total,
48 }
49 }
50
51 pub fn from_endpoint(
53 endpoint: &crate::component::Endpoint,
54 metrics_labels: Option<&[(&str, &str)]>,
55 ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
56 let metrics_labels = metrics_labels.unwrap_or(&[]);
57 let metrics = endpoint.metrics();
58 let request_counter = metrics.create_intcounter(
59 work_handler::REQUESTS_TOTAL,
60 "Total number of requests processed by work handler",
61 metrics_labels,
62 )?;
63
64 let request_duration_buckets = vec![
68 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 20.0, 30.0, 60.0, 120.0,
69 300.0, 600.0,
70 ];
71 let request_duration = metrics.create_histogram(
72 work_handler::REQUEST_DURATION_SECONDS,
73 "Time spent processing requests by work handler",
74 metrics_labels,
75 Some(request_duration_buckets),
76 )?;
77
78 let inflight_requests = metrics.create_intgauge(
79 work_handler::INFLIGHT_REQUESTS,
80 "Number of requests currently being processed by work handler",
81 metrics_labels,
82 )?;
83
84 let request_bytes = metrics.create_intcounter(
85 work_handler::REQUEST_BYTES_TOTAL,
86 "Total number of bytes received in requests by work handler",
87 metrics_labels,
88 )?;
89
90 let response_bytes = metrics.create_intcounter(
91 work_handler::RESPONSE_BYTES_TOTAL,
92 "Total number of bytes sent in responses by work handler",
93 metrics_labels,
94 )?;
95
96 let error_counter = metrics.create_intcountervec(
97 work_handler::ERRORS_TOTAL,
98 "Total number of errors in work handler processing",
99 &[work_handler::ERROR_TYPE_LABEL],
100 metrics_labels,
101 )?;
102
103 let cancellation_total = metrics.create_intcounter(
104 work_handler::CANCELLATION_TOTAL,
105 "Total number of requests cancelled by work handler",
106 metrics_labels,
107 )?;
108
109 Ok(Self::new(
110 request_counter,
111 request_duration,
112 inflight_requests,
113 request_bytes,
114 response_bytes,
115 error_counter,
116 cancellation_total,
117 ))
118 }
119}
120
121struct RequestMetricsGuard {
124 inflight_requests: prometheus::IntGauge,
125 request_duration: prometheus::Histogram,
126 start_time: Instant,
127 request_id: Option<String>,
128}
129
130impl Drop for RequestMetricsGuard {
131 fn drop(&mut self) {
132 self.inflight_requests.dec();
133 self.request_duration
134 .observe(self.start_time.elapsed().as_secs_f64());
135 if let Some(request_id) = &self.request_id {
136 tracing::info!(request_id = %request_id, "request completed");
137 }
138 }
139}
140
141impl<Req: PipelineIO + Sync, Resp: PipelineIO> Ingress<Req, Resp> {
142 async fn pump_response_stream<U>(&self, mut stream: ManyOut<U>, publisher: &StreamSender)
149 where
150 U: Data + Serialize + MaybeError + std::fmt::Debug,
151 {
152 let context = stream.context();
153
154 let mut send_complete_final = true;
156 let mut saw_error_response = false;
157 while let Some(resp) = stream.next().await {
158 tracing::trace!("Sending response: {:?}", resp);
159 let is_error = resp.err().is_some();
160 if is_error {
161 saw_error_response = true;
162 }
163 let resp_wrapper = NetworkStreamWrapper {
164 data: Some(resp),
165 complete_final: false,
166 };
167 let resp_bytes = serde_json::to_vec(&resp_wrapper)
168 .expect("fatal error: invalid response object - this should never happen");
169 if let Some(m) = self.metrics() {
170 m.response_bytes.inc_by(resp_bytes.len() as u64);
171 }
172 if (publisher.send(resp_bytes.into()).await).is_err() {
173 send_complete_final = false;
174 if context.is_stopped() {
175 tracing::warn!("Failed to publish response for stream {}", context.id());
184 } else {
185 tracing::error!("Failed to publish response for stream {}", context.id());
187 context.stop_generating();
188 }
189 if let Some(m) = self.metrics() {
192 m.error_counter
193 .with_label_values(&[work_handler::error_types::PUBLISH_RESPONSE])
194 .inc();
195 }
196 break;
197 } else if !is_error {
198 if let Some(notifier) = self.endpoint_health_check_notifier.get() {
201 notifier.notify_one();
202 }
203 }
204 }
205 if send_complete_final {
206 let resp_wrapper = NetworkStreamWrapper::<U> {
207 data: None,
208 complete_final: true,
209 };
210 let resp_bytes = serde_json::to_vec(&resp_wrapper)
211 .expect("fatal error: invalid response object - this should never happen");
212 if let Some(m) = self.metrics() {
213 m.response_bytes.inc_by(resp_bytes.len() as u64);
214 }
215 if (publisher.send(resp_bytes.into()).await).is_err() {
216 tracing::error!(
217 "Failed to publish complete final for stream {}",
218 context.id()
219 );
220 if let Some(m) = self.metrics() {
221 m.error_counter
222 .with_label_values(&[work_handler::error_types::PUBLISH_FINAL])
223 .inc();
224 }
225 }
226 if let (false, Some(notifier)) = (
228 saw_error_response,
229 self.endpoint_health_check_notifier.get(),
230 ) {
231 notifier.notify_one();
232 }
233 }
234 }
235}
236struct ParsedRequest<Req> {
241 request: Req,
242 response_connection_info: ConnectionInfo,
243 frontend_send_ts_ns: Option<u64>,
244}
245
246#[async_trait]
256trait IngressDispatch: Send + Sync {
257 type Request: PipelineIO;
258
259 async fn parse_and_build_request(
260 &self,
261 payload: Bytes,
262 ) -> Result<ParsedRequest<Self::Request>, PipelineError>;
263}
264
265#[async_trait]
266impl<T, U> IngressDispatch for Ingress<SingleIn<T>, ManyOut<U>>
267where
268 T: Data + for<'de> Deserialize<'de> + std::fmt::Debug,
269 U: Data + Serialize + MaybeError + std::fmt::Debug,
270{
271 type Request = SingleIn<T>;
272
273 async fn parse_and_build_request(
274 &self,
275 payload: Bytes,
276 ) -> Result<ParsedRequest<SingleIn<T>>, PipelineError> {
277 let msg = TwoPartCodec::default()
279 .decode_message(payload)?
280 .into_message_type();
281
282 let (control_msg, request_t) = match msg {
284 TwoPartMessageType::HeaderAndData(header, data) => {
285 tracing::trace!(
286 "received two part message with ctrl: {} bytes, data: {} bytes",
287 header.len(),
288 data.len()
289 );
290 let control_msg: RequestControlMessage = match serde_json::from_slice(&header) {
291 Ok(cm) => cm,
292 Err(err) => {
293 let json_str = String::from_utf8_lossy(&header);
294 if let Some(m) = self.metrics() {
295 m.error_counter
296 .with_label_values(&[work_handler::error_types::DESERIALIZATION])
297 .inc();
298 }
299 return Err(PipelineError::DeserializationError(format!(
300 "Failed deserializing to RequestControlMessage. err={err}, json_str={json_str}, header_len={}",
301 header.len(),
302 )));
303 }
304 };
305 let request_t: T = serde_json::from_slice(&data)?;
306 (control_msg, request_t)
307 }
308 _ => {
309 if let Some(m) = self.metrics() {
310 m.error_counter
311 .with_label_values(&[work_handler::error_types::INVALID_MESSAGE])
312 .inc();
313 }
314 return Err(PipelineError::Generic(String::from(
315 "Unexpected message from work queue; unable extract a TwoPartMessage with a header and data",
316 )));
317 }
318 };
319
320 tracing::trace!(
322 request_id = %control_msg.id,
323 metadata_entries = control_msg.metadata.len(),
324 "received control message"
325 );
326 tracing::trace!("received request: {:?}", request_t);
327
328 let request: context::Context<T> =
329 Context::with_id_and_metadata(request_t, control_msg.id, control_msg.metadata);
330
331 Ok(ParsedRequest {
332 request,
333 response_connection_info: control_msg.connection_info,
334 frontend_send_ts_ns: control_msg.frontend_send_ts_ns,
335 })
336 }
337}
338
339impl<Req: PipelineIO + Sync, U> Ingress<Req, ManyOut<U>>
340where
341 U: Data + Serialize + MaybeError + std::fmt::Debug,
342{
343 async fn handle_payload_shared(
351 &self,
352 payload: Bytes,
353 request_id: Option<String>,
354 ) -> Result<(), PipelineError>
355 where
356 Self: IngressDispatch<Request = Req>,
357 {
358 let t2_wallclock_ns = std::time::SystemTime::now()
359 .duration_since(std::time::UNIX_EPOCH)
360 .unwrap_or_default()
361 .as_nanos() as u64;
362 let start_time = std::time::Instant::now();
363
364 let _inflight_guard = self.metrics().map(|m| {
366 m.request_counter.inc();
367 m.inflight_requests.inc();
368 m.request_bytes.inc_by(payload.len() as u64);
369 if let Some(rid) = &request_id {
370 tracing::info!(request_id = %rid, "request received");
371 }
372 RequestMetricsGuard {
373 inflight_requests: m.inflight_requests.clone(),
374 request_duration: m.request_duration.clone(),
375 start_time,
376 request_id: request_id.clone(),
377 }
378 });
379
380 let ParsedRequest {
381 request,
382 response_connection_info,
383 frontend_send_ts_ns,
384 } = self.parse_and_build_request(payload).await?;
385
386 if let Some(t1_ns) = frontend_send_ts_ns {
388 let transit_ns = t2_wallclock_ns.saturating_sub(t1_ns);
389 WORK_HANDLER_NETWORK_TRANSIT_SECONDS.observe(transit_ns as f64 / 1_000_000_000.0);
390 }
391
392 tracing::trace!("creating tcp response stream");
395 let mut publisher = tcp::client::TcpClient::create_response_stream(
396 request.context(),
397 response_connection_info,
398 self.metrics().map(|m| m.cancellation_total.clone()),
399 )
400 .await
401 .map_err(|e| {
402 if let Some(m) = self.metrics() {
403 m.error_counter
404 .with_label_values(&[work_handler::error_types::RESPONSE_STREAM])
405 .inc();
406 }
407 PipelineError::Generic(format!("Failed to create response stream: {:?}", e,))
408 })?;
409
410 tracing::trace!("calling generate");
411 let stream = self
412 .segment
413 .get()
414 .expect("segment not set")
415 .generate(request)
416 .await
417 .map_err(|e| {
418 if let Some(m) = self.metrics() {
419 m.error_counter
420 .with_label_values(&[work_handler::error_types::GENERATE])
421 .inc();
422 }
423 PipelineError::GenerateError(e)
424 });
425
426 let stream = match stream {
429 Ok(stream) => {
430 tracing::trace!("Successfully generated response stream; sending prologue");
431 let _result = publisher.send_prologue(None).await;
432 WORK_HANDLER_TIME_TO_FIRST_RESPONSE_SECONDS
433 .observe(start_time.elapsed().as_secs_f64());
434 stream
435 }
436 Err(e) => {
437 let error_string = e.to_string();
438
439 #[cfg(debug_assertions)]
440 {
441 tracing::debug!(
442 "Failed to generate response stream (with debug backtrace): {:?}",
443 e
444 );
445 }
446 #[cfg(not(debug_assertions))]
447 {
448 tracing::error!("Failed to generate response stream: {error_string}");
449 }
450
451 let _result = publisher.send_prologue(Some(error_string)).await;
452 Err(e)?
453 }
454 };
455
456 self.pump_response_stream(stream, &publisher).await;
457
458 drop(_inflight_guard);
461
462 Ok(())
463 }
464}
465
466#[async_trait]
467impl<T, U> PushWorkHandler for Ingress<SingleIn<T>, ManyOut<U>>
468where
469 T: Data + for<'de> Deserialize<'de> + std::fmt::Debug,
470 U: Data + Serialize + MaybeError + std::fmt::Debug,
471{
472 fn add_metrics(
473 &self,
474 endpoint: &crate::component::Endpoint,
475 metrics_labels: Option<&[(&str, &str)]>,
476 ) -> Result<()> {
477 use crate::pipeline::network::Ingress;
479 Ingress::add_metrics(self, endpoint, metrics_labels)
480 }
481
482 fn set_endpoint_health_check_notifier(&self, notifier: Arc<tokio::sync::Notify>) -> Result<()> {
483 use crate::pipeline::network::Ingress;
484 self.endpoint_health_check_notifier
485 .set(notifier)
486 .map_err(|_| anyhow::anyhow!("Endpoint health check notifier already set"))?;
487 Ok(())
488 }
489
490 async fn handle_payload(
491 &self,
492 payload: Bytes,
493 request_id: Option<String>,
494 ) -> Result<(), PipelineError> {
495 self.handle_payload_shared(payload, request_id).await
496 }
497}