use super::*;
use crate::metrics::prometheus_names::work_handler;
use crate::metrics::work_handler_perf::{
WORK_HANDLER_NETWORK_TRANSIT_SECONDS, WORK_HANDLER_TIME_TO_FIRST_RESPONSE_SECONDS,
};
use crate::protocols::maybe_error::MaybeError;
use prometheus::{Histogram, IntCounter, IntCounterVec, IntGauge};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::Instant;
use tracing::Instrument;
use tracing::info_span;
#[derive(Clone, Debug)]
pub struct WorkHandlerMetrics {
pub request_counter: IntCounter,
pub request_duration: Histogram,
pub inflight_requests: IntGauge,
pub request_bytes: IntCounter,
pub response_bytes: IntCounter,
pub error_counter: IntCounterVec,
pub cancellation_total: IntCounter,
}
impl WorkHandlerMetrics {
pub fn new(
request_counter: IntCounter,
request_duration: Histogram,
inflight_requests: IntGauge,
request_bytes: IntCounter,
response_bytes: IntCounter,
error_counter: IntCounterVec,
cancellation_total: IntCounter,
) -> Self {
Self {
request_counter,
request_duration,
inflight_requests,
request_bytes,
response_bytes,
error_counter,
cancellation_total,
}
}
pub fn from_endpoint(
endpoint: &crate::component::Endpoint,
metrics_labels: Option<&[(&str, &str)]>,
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
let metrics_labels = metrics_labels.unwrap_or(&[]);
let metrics = endpoint.metrics();
let request_counter = metrics.create_intcounter(
work_handler::REQUESTS_TOTAL,
"Total number of requests processed by work handler",
metrics_labels,
)?;
let request_duration_buckets = vec![
0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 20.0, 30.0, 60.0, 120.0,
300.0, 600.0,
];
let request_duration = metrics.create_histogram(
work_handler::REQUEST_DURATION_SECONDS,
"Time spent processing requests by work handler",
metrics_labels,
Some(request_duration_buckets),
)?;
let inflight_requests = metrics.create_intgauge(
work_handler::INFLIGHT_REQUESTS,
"Number of requests currently being processed by work handler",
metrics_labels,
)?;
let request_bytes = metrics.create_intcounter(
work_handler::REQUEST_BYTES_TOTAL,
"Total number of bytes received in requests by work handler",
metrics_labels,
)?;
let response_bytes = metrics.create_intcounter(
work_handler::RESPONSE_BYTES_TOTAL,
"Total number of bytes sent in responses by work handler",
metrics_labels,
)?;
let error_counter = metrics.create_intcountervec(
work_handler::ERRORS_TOTAL,
"Total number of errors in work handler processing",
&[work_handler::ERROR_TYPE_LABEL],
metrics_labels,
)?;
let cancellation_total = metrics.create_intcounter(
work_handler::CANCELLATION_TOTAL,
"Total number of requests cancelled by work handler",
metrics_labels,
)?;
Ok(Self::new(
request_counter,
request_duration,
inflight_requests,
request_bytes,
response_bytes,
error_counter,
cancellation_total,
))
}
}
struct RequestMetricsGuard {
inflight_requests: prometheus::IntGauge,
request_duration: prometheus::Histogram,
start_time: Instant,
request_id: Option<String>,
}
impl Drop for RequestMetricsGuard {
fn drop(&mut self) {
self.inflight_requests.dec();
self.request_duration
.observe(self.start_time.elapsed().as_secs_f64());
if let Some(request_id) = &self.request_id {
tracing::info!(request_id = %request_id, "request completed");
}
}
}
impl<Req: PipelineIO + Sync, Resp: PipelineIO> Ingress<Req, Resp> {
async fn pump_response_stream<U>(&self, mut stream: ManyOut<U>, publisher: &StreamSender)
where
U: Data + Serialize + MaybeError + std::fmt::Debug,
{
let context = stream.context();
let mut send_complete_final = true;
let mut saw_error_response = false;
while let Some(resp) = stream.next().await {
tracing::trace!("Sending response: {:?}", resp);
let is_error = resp.err().is_some();
if is_error {
saw_error_response = true;
}
let resp_wrapper = NetworkStreamWrapper {
data: Some(resp),
complete_final: false,
};
let resp_bytes = serde_json::to_vec(&resp_wrapper)
.expect("fatal error: invalid response object - this should never happen");
if let Some(m) = self.metrics() {
m.response_bytes.inc_by(resp_bytes.len() as u64);
}
if (publisher.send(resp_bytes.into()).await).is_err() {
send_complete_final = false;
if context.is_stopped() {
tracing::warn!("Failed to publish response for stream {}", context.id());
} else {
tracing::error!("Failed to publish response for stream {}", context.id());
context.stop_generating();
}
if let Some(m) = self.metrics() {
m.error_counter
.with_label_values(&[work_handler::error_types::PUBLISH_RESPONSE])
.inc();
}
break;
} else if !is_error {
if let Some(notifier) = self.endpoint_health_check_notifier.get() {
notifier.notify_one();
}
}
}
if send_complete_final {
let resp_wrapper = NetworkStreamWrapper::<U> {
data: None,
complete_final: true,
};
let resp_bytes = serde_json::to_vec(&resp_wrapper)
.expect("fatal error: invalid response object - this should never happen");
if let Some(m) = self.metrics() {
m.response_bytes.inc_by(resp_bytes.len() as u64);
}
if (publisher.send(resp_bytes.into()).await).is_err() {
tracing::error!(
"Failed to publish complete final for stream {}",
context.id()
);
if let Some(m) = self.metrics() {
m.error_counter
.with_label_values(&[work_handler::error_types::PUBLISH_FINAL])
.inc();
}
}
if let (false, Some(notifier)) = (
saw_error_response,
self.endpoint_health_check_notifier.get(),
) {
notifier.notify_one();
}
}
}
}
struct ParsedRequest<Req> {
request: Req,
response_connection_info: ConnectionInfo,
frontend_send_ts_ns: Option<u64>,
}
#[async_trait]
trait IngressDispatch: Send + Sync {
type Request: PipelineIO;
async fn parse_and_build_request(
&self,
payload: Bytes,
) -> Result<ParsedRequest<Self::Request>, PipelineError>;
}
#[async_trait]
impl<T, U> IngressDispatch for Ingress<SingleIn<T>, ManyOut<U>>
where
T: Data + for<'de> Deserialize<'de> + std::fmt::Debug,
U: Data + Serialize + MaybeError + std::fmt::Debug,
{
type Request = SingleIn<T>;
async fn parse_and_build_request(
&self,
payload: Bytes,
) -> Result<ParsedRequest<SingleIn<T>>, PipelineError> {
let msg = TwoPartCodec::default()
.decode_message(payload)?
.into_message_type();
let (control_msg, request_t) = match msg {
TwoPartMessageType::HeaderAndData(header, data) => {
tracing::trace!(
"received two part message with ctrl: {} bytes, data: {} bytes",
header.len(),
data.len()
);
let control_msg: RequestControlMessage = match serde_json::from_slice(&header) {
Ok(cm) => cm,
Err(err) => {
let json_str = String::from_utf8_lossy(&header);
if let Some(m) = self.metrics() {
m.error_counter
.with_label_values(&[work_handler::error_types::DESERIALIZATION])
.inc();
}
return Err(PipelineError::DeserializationError(format!(
"Failed deserializing to RequestControlMessage. err={err}, json_str={json_str}, header_len={}",
header.len(),
)));
}
};
let request_t: T = serde_json::from_slice(&data)?;
(control_msg, request_t)
}
_ => {
if let Some(m) = self.metrics() {
m.error_counter
.with_label_values(&[work_handler::error_types::INVALID_MESSAGE])
.inc();
}
return Err(PipelineError::Generic(String::from(
"Unexpected message from work queue; unable extract a TwoPartMessage with a header and data",
)));
}
};
tracing::trace!(
request_id = %control_msg.id,
metadata_entries = control_msg.metadata.len(),
"received control message"
);
tracing::trace!("received request: {:?}", request_t);
let request: context::Context<T> =
Context::with_id_and_metadata(request_t, control_msg.id, control_msg.metadata);
Ok(ParsedRequest {
request,
response_connection_info: control_msg.connection_info,
frontend_send_ts_ns: control_msg.frontend_send_ts_ns,
})
}
}
impl<Req: PipelineIO + Sync, U> Ingress<Req, ManyOut<U>>
where
U: Data + Serialize + MaybeError + std::fmt::Debug,
{
async fn handle_payload_shared(
&self,
payload: Bytes,
request_id: Option<String>,
) -> Result<(), PipelineError>
where
Self: IngressDispatch<Request = Req>,
{
let t2_wallclock_ns = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos() as u64;
let start_time = std::time::Instant::now();
let _inflight_guard = self.metrics().map(|m| {
m.request_counter.inc();
m.inflight_requests.inc();
m.request_bytes.inc_by(payload.len() as u64);
if let Some(rid) = &request_id {
tracing::info!(request_id = %rid, "request received");
}
RequestMetricsGuard {
inflight_requests: m.inflight_requests.clone(),
request_duration: m.request_duration.clone(),
start_time,
request_id: request_id.clone(),
}
});
let ParsedRequest {
request,
response_connection_info,
frontend_send_ts_ns,
} = self.parse_and_build_request(payload).await?;
if let Some(t1_ns) = frontend_send_ts_ns {
let transit_ns = t2_wallclock_ns.saturating_sub(t1_ns);
WORK_HANDLER_NETWORK_TRANSIT_SECONDS.observe(transit_ns as f64 / 1_000_000_000.0);
}
tracing::trace!("creating tcp response stream");
let mut publisher = tcp::client::TcpClient::create_response_stream(
request.context(),
response_connection_info,
self.metrics().map(|m| m.cancellation_total.clone()),
)
.await
.map_err(|e| {
if let Some(m) = self.metrics() {
m.error_counter
.with_label_values(&[work_handler::error_types::RESPONSE_STREAM])
.inc();
}
PipelineError::Generic(format!("Failed to create response stream: {:?}", e,))
})?;
tracing::trace!("calling generate");
let stream = self
.segment
.get()
.expect("segment not set")
.generate(request)
.await
.map_err(|e| {
if let Some(m) = self.metrics() {
m.error_counter
.with_label_values(&[work_handler::error_types::GENERATE])
.inc();
}
PipelineError::GenerateError(e)
});
let stream = match stream {
Ok(stream) => {
tracing::trace!("Successfully generated response stream; sending prologue");
let _result = publisher.send_prologue(None).await;
WORK_HANDLER_TIME_TO_FIRST_RESPONSE_SECONDS
.observe(start_time.elapsed().as_secs_f64());
stream
}
Err(e) => {
let error_string = e.to_string();
#[cfg(debug_assertions)]
{
tracing::debug!(
"Failed to generate response stream (with debug backtrace): {:?}",
e
);
}
#[cfg(not(debug_assertions))]
{
tracing::error!("Failed to generate response stream: {error_string}");
}
let _result = publisher.send_prologue(Some(error_string)).await;
Err(e)?
}
};
self.pump_response_stream(stream, &publisher).await;
drop(_inflight_guard);
Ok(())
}
}
#[async_trait]
impl<T, U> PushWorkHandler for Ingress<SingleIn<T>, ManyOut<U>>
where
T: Data + for<'de> Deserialize<'de> + std::fmt::Debug,
U: Data + Serialize + MaybeError + std::fmt::Debug,
{
fn add_metrics(
&self,
endpoint: &crate::component::Endpoint,
metrics_labels: Option<&[(&str, &str)]>,
) -> Result<()> {
use crate::pipeline::network::Ingress;
Ingress::add_metrics(self, endpoint, metrics_labels)
}
fn set_endpoint_health_check_notifier(&self, notifier: Arc<tokio::sync::Notify>) -> Result<()> {
use crate::pipeline::network::Ingress;
self.endpoint_health_check_notifier
.set(notifier)
.map_err(|_| anyhow::anyhow!("Endpoint health check notifier already set"))?;
Ok(())
}
async fn handle_payload(
&self,
payload: Bytes,
request_id: Option<String>,
) -> Result<(), PipelineError> {
self.handle_payload_shared(payload, request_id).await
}
}