use std::sync::Arc;
use std::time::Instant;
use super::unified_client::RequestPlaneClient;
use super::*;
use crate::dynamo_nvtx_range;
use crate::engine::{AsyncEngine, AsyncEngineContextProvider, Data};
use crate::error::{DynamoError, ErrorType};
use crate::logging::inject_trace_headers_into_map;
use crate::metrics::frontend_perf::STAGE_DURATION_SECONDS;
use crate::metrics::request_plane::{
REQUEST_PLANE_INFLIGHT, REQUEST_PLANE_QUEUE_SECONDS, REQUEST_PLANE_ROUNDTRIP_TTFT_SECONDS,
REQUEST_PLANE_SEND_SECONDS,
};
use crate::pipeline::network::ConnectionInfo;
use crate::pipeline::network::NetworkStreamWrapper;
use crate::pipeline::network::PendingConnections;
use crate::pipeline::network::StreamOptions;
use crate::pipeline::network::TwoPartCodec;
use crate::pipeline::network::codec::TwoPartMessage;
use crate::pipeline::network::tcp;
use crate::pipeline::{ManyOut, PipelineError, ResponseStream, SingleIn};
use crate::protocols::maybe_error::MaybeError;
use anyhow::{Error, Result};
use futures::stream::Stream;
use serde::Deserialize;
use serde::Serialize;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio_stream::{StreamExt, StreamNotifyClose, wrappers::ReceiverStream};
use tracing::Instrument;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
enum RequestType {
SingleIn,
ManyIn,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
enum ResponseType {
SingleOut,
ManyOut,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct RequestControlMessage {
id: String,
request_type: RequestType,
response_type: ResponseType,
connection_info: ConnectionInfo,
#[serde(default, skip_serializing_if = "Option::is_none")]
frontend_send_ts_ns: Option<u64>,
}
struct InflightGuard {
armed: bool,
}
impl InflightGuard {
fn new() -> Self {
Self { armed: true }
}
fn disarm(mut self) {
self.armed = false;
}
}
impl Drop for InflightGuard {
fn drop(&mut self) {
if self.armed {
REQUEST_PLANE_INFLIGHT.dec();
}
}
}
struct InflightDecStream<S> {
inner: S,
}
impl<S, T> Stream for InflightDecStream<S>
where
S: Stream<Item = T> + Unpin,
{
type Item = T;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.inner).poll_next(cx)
}
}
impl<S> Drop for InflightDecStream<S> {
fn drop(&mut self) {
REQUEST_PLANE_INFLIGHT.dec();
}
}
pub struct AddressedRequest<T> {
request: T,
address: String,
}
impl<T> AddressedRequest<T> {
pub fn new(request: T, address: String) -> Self {
Self { request, address }
}
pub(crate) fn into_parts(self) -> (T, String) {
(self.request, self.address)
}
}
pub struct AddressedPushRouter {
req_client: Arc<dyn RequestPlaneClient>,
resp_transport: Arc<tcp::server::TcpStreamServer>,
}
impl AddressedPushRouter {
pub fn new(
req_client: Arc<dyn RequestPlaneClient>,
resp_transport: Arc<tcp::server::TcpStreamServer>,
) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
req_client,
resp_transport,
}))
}
}
#[async_trait::async_trait]
impl<T, U> AsyncEngine<SingleIn<AddressedRequest<T>>, ManyOut<U>, Error> for AddressedPushRouter
where
T: Data + Serialize,
U: Data + for<'de> Deserialize<'de> + MaybeError,
{
async fn generate(&self, request: SingleIn<AddressedRequest<T>>) -> Result<ManyOut<U>, Error> {
let queue_start = Instant::now();
REQUEST_PLANE_INFLIGHT.inc();
let inflight_guard = InflightGuard::new();
let request_id = request.context().id().to_string();
let (addressed_request, context) = request.transfer(());
let (request, address) = addressed_request.into_parts();
let engine_ctx = context.context();
let engine_ctx_ = engine_ctx.clone();
let options = StreamOptions::builder()
.context(engine_ctx.clone())
.enable_request_stream(false)
.enable_response_stream(true)
.build()
.unwrap();
let pending_connections: PendingConnections = self.resp_transport.register(options).await;
let pending_response_stream = match pending_connections.into_parts() {
(None, Some(recv_stream)) => recv_stream,
_ => {
panic!("Invalid data plane registration for a SingleIn/ManyOut transport");
}
};
let (connection_info, response_stream_provider) = pending_response_stream.into_parts();
let control_message = RequestControlMessage {
id: engine_ctx.id().to_string(),
request_type: RequestType::SingleIn,
response_type: ResponseType::ManyOut,
connection_info,
frontend_send_ts_ns: None,
};
let ctrl = serde_json::to_vec(&control_message)?;
let data = serde_json::to_vec(&request)?;
tracing::trace!(
request_id,
"packaging two-part message; ctrl: {} bytes, data: {} bytes",
ctrl.len(),
data.len()
);
let msg = TwoPartMessage::from_parts(ctrl.into(), data.into());
let codec = TwoPartCodec::default();
let buffer = {
let _nvtx = dynamo_nvtx_range!("codec.encode");
codec.encode_message(msg)?
};
REQUEST_PLANE_QUEUE_SECONDS.observe(queue_start.elapsed().as_secs_f64());
let tx_start = Instant::now();
tracing::trace!(
request_id,
transport = self.req_client.transport_name(),
address = %address,
"Sending request via request plane client"
);
let mut headers = std::collections::HashMap::new();
inject_trace_headers_into_map(&mut headers);
headers.insert("request-id".to_string(), request_id.clone());
let send_ts_ns = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos() as u64;
headers.insert("x-frontend-send-ts-ns".to_string(), send_ts_ns.to_string());
let _nvtx_send = dynamo_nvtx_range!("transport.tcp.send");
let _response = self
.req_client
.send_request(address, buffer, headers)
.await?;
drop(_nvtx_send);
REQUEST_PLANE_SEND_SECONDS.observe(tx_start.elapsed().as_secs_f64());
let _nvtx_wait = dynamo_nvtx_range!("transport.tcp.wait_backend");
tracing::trace!(request_id, "awaiting transport handshake");
let response_stream = response_stream_provider
.await
.map_err(|_| PipelineError::DetachedStreamReceiver)?
.map_err(PipelineError::ConnectionFailed)?;
drop(_nvtx_wait);
let mut is_complete_final = false;
let mut first_response = true;
let stream = tokio_stream::StreamNotifyClose::new(
tokio_stream::wrappers::ReceiverStream::new(response_stream.rx),
)
.filter_map(move |res| {
if let Some(res_bytes) = res {
if first_response {
first_response = false;
let roundtrip_ttft = tx_start.elapsed().as_secs_f64();
REQUEST_PLANE_ROUNDTRIP_TTFT_SECONDS.observe(roundtrip_ttft);
STAGE_DURATION_SECONDS
.with_label_values(&["transport_roundtrip"])
.observe(queue_start.elapsed().as_secs_f64());
}
if is_complete_final {
let err = DynamoError::msg(
"Response received after generation ended - this should never happen",
);
return Some(U::from_err(err));
}
match serde_json::from_slice::<NetworkStreamWrapper<U>>(&res_bytes) {
Ok(item) => {
is_complete_final = item.complete_final;
if let Some(data) = item.data {
Some(data)
} else if is_complete_final {
None
} else {
let err = DynamoError::msg(
"Empty response received - this should never happen",
);
Some(U::from_err(err))
}
}
Err(err) => {
let json_str = String::from_utf8_lossy(&res_bytes);
tracing::warn!(%err, %json_str, "Failed deserializing JSON to response");
Some(U::from_err(DynamoError::msg(err.to_string())))
}
}
} else if is_complete_final {
None
} else if engine_ctx_.is_stopped() {
tracing::debug!("Request cancelled and then trying to read a response");
None
} else {
let err = DynamoError::builder()
.error_type(ErrorType::Disconnected)
.message("Stream ended before generation completed")
.build();
tracing::debug!("{err}");
Some(U::from_err(err))
}
});
inflight_guard.disarm();
let stream = InflightDecStream { inner: stream };
Ok(ResponseStream::new(Box::pin(stream), engine_ctx))
}
}