use std::sync::Arc;
use super::unified_client::RequestPlaneClient;
use super::*;
use crate::engine::{AsyncEngine, AsyncEngineContextProvider, Data};
use crate::error::{DynamoError, ErrorType};
use crate::logging::inject_trace_headers_into_map;
use crate::pipeline::network::ConnectionInfo;
use crate::pipeline::network::NetworkStreamWrapper;
use crate::pipeline::network::PendingConnections;
use crate::pipeline::network::StreamOptions;
use crate::pipeline::network::TwoPartCodec;
use crate::pipeline::network::codec::TwoPartMessage;
use crate::pipeline::network::tcp;
use crate::pipeline::{ManyOut, PipelineError, ResponseStream, SingleIn};
use crate::protocols::maybe_error::MaybeError;
use anyhow::{Error, Result};
use serde::Deserialize;
use serde::Serialize;
use tokio_stream::{StreamExt, StreamNotifyClose, wrappers::ReceiverStream};
use tracing::Instrument;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
enum RequestType {
SingleIn,
ManyIn,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
enum ResponseType {
SingleOut,
ManyOut,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct RequestControlMessage {
id: String,
request_type: RequestType,
response_type: ResponseType,
connection_info: ConnectionInfo,
}
pub struct AddressedRequest<T> {
request: T,
address: String,
}
impl<T> AddressedRequest<T> {
pub fn new(request: T, address: String) -> Self {
Self { request, address }
}
pub(crate) fn into_parts(self) -> (T, String) {
(self.request, self.address)
}
}
pub struct AddressedPushRouter {
req_client: Arc<dyn RequestPlaneClient>,
resp_transport: Arc<tcp::server::TcpStreamServer>,
}
impl AddressedPushRouter {
pub fn new(
req_client: Arc<dyn RequestPlaneClient>,
resp_transport: Arc<tcp::server::TcpStreamServer>,
) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
req_client,
resp_transport,
}))
}
}
#[async_trait::async_trait]
impl<T, U> AsyncEngine<SingleIn<AddressedRequest<T>>, ManyOut<U>, Error> for AddressedPushRouter
where
T: Data + Serialize,
U: Data + for<'de> Deserialize<'de> + MaybeError,
{
async fn generate(&self, request: SingleIn<AddressedRequest<T>>) -> Result<ManyOut<U>, Error> {
let request_id = request.context().id().to_string();
let (addressed_request, context) = request.transfer(());
let (request, address) = addressed_request.into_parts();
let engine_ctx = context.context();
let engine_ctx_ = engine_ctx.clone();
let options = StreamOptions::builder()
.context(engine_ctx.clone())
.enable_request_stream(false)
.enable_response_stream(true)
.build()
.unwrap();
let pending_connections: PendingConnections = self.resp_transport.register(options).await;
let pending_response_stream = match pending_connections.into_parts() {
(None, Some(recv_stream)) => recv_stream,
_ => {
panic!("Invalid data plane registration for a SingleIn/ManyOut transport");
}
};
let (connection_info, response_stream_provider) = pending_response_stream.into_parts();
let control_message = RequestControlMessage {
id: engine_ctx.id().to_string(),
request_type: RequestType::SingleIn,
response_type: ResponseType::ManyOut,
connection_info,
};
let ctrl = serde_json::to_vec(&control_message)?;
let data = serde_json::to_vec(&request)?;
tracing::trace!(
request_id,
"packaging two-part message; ctrl: {} bytes, data: {} bytes",
ctrl.len(),
data.len()
);
let msg = TwoPartMessage::from_parts(ctrl.into(), data.into());
let codec = TwoPartCodec::default();
let buffer = codec.encode_message(msg)?;
tracing::trace!(
request_id,
transport = self.req_client.transport_name(),
address = %address,
"Sending request via request plane client"
);
let mut headers = std::collections::HashMap::new();
inject_trace_headers_into_map(&mut headers);
let _response = self
.req_client
.send_request(address, buffer, headers)
.await?;
tracing::trace!(request_id, "awaiting transport handshake");
let response_stream = response_stream_provider
.await
.map_err(|_| PipelineError::DetachedStreamReceiver)?
.map_err(PipelineError::ConnectionFailed)?;
let mut is_complete_final = false;
let stream = tokio_stream::StreamNotifyClose::new(
tokio_stream::wrappers::ReceiverStream::new(response_stream.rx),
)
.filter_map(move |res| {
if let Some(res_bytes) = res {
if is_complete_final {
let err = DynamoError::msg(
"Response received after generation ended - this should never happen",
);
return Some(U::from_err(err));
}
match serde_json::from_slice::<NetworkStreamWrapper<U>>(&res_bytes) {
Ok(item) => {
is_complete_final = item.complete_final;
if let Some(data) = item.data {
Some(data)
} else if is_complete_final {
None
} else {
let err = DynamoError::msg(
"Empty response received - this should never happen",
);
Some(U::from_err(err))
}
}
Err(err) => {
let json_str = String::from_utf8_lossy(&res_bytes);
tracing::warn!(%err, %json_str, "Failed deserializing JSON to response");
Some(U::from_err(DynamoError::msg(err.to_string())))
}
}
} else if is_complete_final {
None
} else if engine_ctx_.is_stopped() {
tracing::debug!("Request cancelled and then trying to read a response");
None
} else {
let err = DynamoError::builder()
.error_type(ErrorType::Disconnected)
.message("Stream ended before generation completed")
.build();
tracing::debug!("{}", err);
Some(U::from_err(err))
}
});
Ok(ResponseStream::new(Box::pin(stream), engine_ctx))
}
}