use super::*;
use crate::metrics::prometheus_names::work_handler;
use crate::metrics::work_handler_perf::{
WORK_HANDLER_NETWORK_TRANSIT_SECONDS, WORK_HANDLER_TIME_TO_FIRST_RESPONSE_SECONDS,
};
use crate::protocols::maybe_error::MaybeError;
use prometheus::{Histogram, IntCounter, IntCounterVec, IntGauge};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::Instant;
use tracing::Instrument;
use tracing::info_span;
#[derive(Clone, Debug)]
pub struct WorkHandlerMetrics {
pub request_counter: IntCounter,
pub request_duration: Histogram,
pub inflight_requests: IntGauge,
pub request_bytes: IntCounter,
pub response_bytes: IntCounter,
pub error_counter: IntCounterVec,
pub cancellation_total: IntCounter,
}
impl WorkHandlerMetrics {
pub fn new(
request_counter: IntCounter,
request_duration: Histogram,
inflight_requests: IntGauge,
request_bytes: IntCounter,
response_bytes: IntCounter,
error_counter: IntCounterVec,
cancellation_total: IntCounter,
) -> Self {
Self {
request_counter,
request_duration,
inflight_requests,
request_bytes,
response_bytes,
error_counter,
cancellation_total,
}
}
pub fn from_endpoint(
endpoint: &crate::component::Endpoint,
metrics_labels: Option<&[(&str, &str)]>,
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
let metrics_labels = metrics_labels.unwrap_or(&[]);
let metrics = endpoint.metrics();
let request_counter = metrics.create_intcounter(
work_handler::REQUESTS_TOTAL,
"Total number of requests processed by work handler",
metrics_labels,
)?;
let request_duration_buckets = vec![
0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 20.0, 30.0, 60.0, 120.0,
300.0, 600.0,
];
let request_duration = metrics.create_histogram(
work_handler::REQUEST_DURATION_SECONDS,
"Time spent processing requests by work handler",
metrics_labels,
Some(request_duration_buckets),
)?;
let inflight_requests = metrics.create_intgauge(
work_handler::INFLIGHT_REQUESTS,
"Number of requests currently being processed by work handler",
metrics_labels,
)?;
let request_bytes = metrics.create_intcounter(
work_handler::REQUEST_BYTES_TOTAL,
"Total number of bytes received in requests by work handler",
metrics_labels,
)?;
let response_bytes = metrics.create_intcounter(
work_handler::RESPONSE_BYTES_TOTAL,
"Total number of bytes sent in responses by work handler",
metrics_labels,
)?;
let error_counter = metrics.create_intcountervec(
work_handler::ERRORS_TOTAL,
"Total number of errors in work handler processing",
&[work_handler::ERROR_TYPE_LABEL],
metrics_labels,
)?;
let cancellation_total = metrics.create_intcounter(
work_handler::CANCELLATION_TOTAL,
"Total number of requests cancelled by work handler",
metrics_labels,
)?;
Ok(Self::new(
request_counter,
request_duration,
inflight_requests,
request_bytes,
response_bytes,
error_counter,
cancellation_total,
))
}
}
struct RequestMetricsGuard {
inflight_requests: prometheus::IntGauge,
request_duration: prometheus::Histogram,
start_time: Instant,
request_id: Option<String>,
}
impl Drop for RequestMetricsGuard {
fn drop(&mut self) {
self.inflight_requests.dec();
self.request_duration
.observe(self.start_time.elapsed().as_secs_f64());
if let Some(request_id) = &self.request_id {
tracing::info!(request_id = %request_id, "request completed");
}
}
}
#[async_trait]
impl<T: Data, U: Data> PushWorkHandler for Ingress<SingleIn<T>, ManyOut<U>>
where
T: Data + for<'de> Deserialize<'de> + std::fmt::Debug,
U: Data + Serialize + MaybeError + std::fmt::Debug,
{
fn add_metrics(
&self,
endpoint: &crate::component::Endpoint,
metrics_labels: Option<&[(&str, &str)]>,
) -> Result<()> {
use crate::pipeline::network::Ingress;
Ingress::add_metrics(self, endpoint, metrics_labels)
}
fn set_endpoint_health_check_notifier(&self, notifier: Arc<tokio::sync::Notify>) -> Result<()> {
use crate::pipeline::network::Ingress;
self.endpoint_health_check_notifier
.set(notifier)
.map_err(|_| anyhow::anyhow!("Endpoint health check notifier already set"))?;
Ok(())
}
async fn handle_payload(
&self,
payload: Bytes,
request_id: Option<String>,
) -> Result<(), PipelineError> {
let t2_wallclock_ns = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos() as u64;
let start_time = std::time::Instant::now();
let _inflight_guard = self.metrics().map(|m| {
m.request_counter.inc();
m.inflight_requests.inc();
m.request_bytes.inc_by(payload.len() as u64);
if let Some(rid) = &request_id {
tracing::info!(request_id = %rid, "request received");
}
RequestMetricsGuard {
inflight_requests: m.inflight_requests.clone(),
request_duration: m.request_duration.clone(),
start_time,
request_id: request_id.clone(),
}
});
let msg = TwoPartCodec::default()
.decode_message(payload)?
.into_message_type();
let (control_msg, request) = match msg {
TwoPartMessageType::HeaderAndData(header, data) => {
tracing::trace!(
"received two part message with ctrl: {} bytes, data: {} bytes",
header.len(),
data.len()
);
let control_msg: RequestControlMessage = match serde_json::from_slice(&header) {
Ok(cm) => cm,
Err(err) => {
let json_str = String::from_utf8_lossy(&header);
if let Some(m) = self.metrics() {
m.error_counter
.with_label_values(&[work_handler::error_types::DESERIALIZATION])
.inc();
}
return Err(PipelineError::DeserializationError(format!(
"Failed deserializing to RequestControlMessage. err={err}, json_str={json_str}"
)));
}
};
let request: T = serde_json::from_slice(&data)?;
(control_msg, request)
}
_ => {
if let Some(m) = self.metrics() {
m.error_counter
.with_label_values(&[work_handler::error_types::INVALID_MESSAGE])
.inc();
}
return Err(PipelineError::Generic(String::from(
"Unexpected message from work queue; unable extract a TwoPartMessage with a header and data",
)));
}
};
if let Some(t1_ns) = control_msg.frontend_send_ts_ns {
let transit_ns = t2_wallclock_ns.saturating_sub(t1_ns);
WORK_HANDLER_NETWORK_TRANSIT_SECONDS.observe(transit_ns as f64 / 1_000_000_000.0);
}
tracing::trace!("received control message: {:?}", control_msg);
tracing::trace!("received request: {:?}", request);
let request: context::Context<T> = Context::with_id(request, control_msg.id);
tracing::trace!("creating tcp response stream");
let mut publisher = tcp::client::TcpClient::create_response_stream(
request.context(),
control_msg.connection_info,
self.metrics().map(|m| m.cancellation_total.clone()),
)
.await
.map_err(|e| {
if let Some(m) = self.metrics() {
m.error_counter
.with_label_values(&[work_handler::error_types::RESPONSE_STREAM])
.inc();
}
PipelineError::Generic(format!("Failed to create response stream: {:?}", e,))
})?;
tracing::trace!("calling generate");
let stream = self
.segment
.get()
.expect("segment not set")
.generate(request)
.await
.map_err(|e| {
if let Some(m) = self.metrics() {
m.error_counter
.with_label_values(&[work_handler::error_types::GENERATE])
.inc();
}
PipelineError::GenerateError(e)
});
let mut stream = match stream {
Ok(stream) => {
tracing::trace!("Successfully generated response stream; sending prologue");
let _result = publisher.send_prologue(None).await;
WORK_HANDLER_TIME_TO_FIRST_RESPONSE_SECONDS
.observe(start_time.elapsed().as_secs_f64());
stream
}
Err(e) => {
let error_string = e.to_string();
#[cfg(debug_assertions)]
{
tracing::debug!(
"Failed to generate response stream (with debug backtrace): {:?}",
e
);
}
#[cfg(not(debug_assertions))]
{
tracing::error!("Failed to generate response stream: {error_string}");
}
let _result = publisher.send_prologue(Some(error_string)).await;
Err(e)?
}
};
let context = stream.context();
let mut send_complete_final = true;
while let Some(resp) = stream.next().await {
tracing::trace!("Sending response: {:?}", resp);
let resp_wrapper = NetworkStreamWrapper {
data: Some(resp),
complete_final: false,
};
let resp_bytes = serde_json::to_vec(&resp_wrapper)
.expect("fatal error: invalid response object - this should never happen");
if let Some(m) = self.metrics() {
m.response_bytes.inc_by(resp_bytes.len() as u64);
}
if (publisher.send(resp_bytes.into()).await).is_err() {
send_complete_final = false;
if context.is_stopped() {
tracing::warn!("Failed to publish response for stream {}", context.id());
} else {
tracing::error!("Failed to publish response for stream {}", context.id());
context.stop_generating();
}
if let Some(m) = self.metrics() {
m.error_counter
.with_label_values(&[work_handler::error_types::PUBLISH_RESPONSE])
.inc();
}
break;
}
}
if send_complete_final {
let resp_wrapper = NetworkStreamWrapper::<U> {
data: None,
complete_final: true,
};
let resp_bytes = serde_json::to_vec(&resp_wrapper)
.expect("fatal error: invalid response object - this should never happen");
if let Some(m) = self.metrics() {
m.response_bytes.inc_by(resp_bytes.len() as u64);
}
if (publisher.send(resp_bytes.into()).await).is_err() {
tracing::error!(
"Failed to publish complete final for stream {}",
context.id()
);
if let Some(m) = self.metrics() {
m.error_counter
.with_label_values(&[work_handler::error_types::PUBLISH_FINAL])
.inc();
}
}
if let Some(notifier) = self.endpoint_health_check_notifier.get() {
notifier.notify_one();
}
}
drop(_inflight_guard);
Ok(())
}
}