use std::{collections::HashMap, sync::Arc};
use bytes::{Bytes, BytesMut};
use futures::{SinkExt, StreamExt};
use log::{debug, warn};
use tokio::{
io::{self, AsyncRead, AsyncWrite, AsyncWriteExt},
time::{Duration, timeout},
};
use tokio_util::codec::{Encoder, Framed, LengthDelimitedCodec};
use super::{
builder::WireframeApp,
codec_driver::FramePipeline,
combined_codec::{CombinedCodec, ConnectionCodec},
envelope::{Envelope, Packet},
error::SendError,
frame_handling,
};
use crate::{
codec::{FrameCodec, LengthDelimitedFrameCodec, MAX_FRAME_LENGTH, clamp_frame_length},
frame::FrameMetadata,
message::{DecodeWith, DeserializeContext, EncodeWith},
message_assembler::MessageAssemblyState,
middleware::HandlerService,
serializer::Serializer,
};
fn purge_expired(
pipeline: &mut FramePipeline,
message_assembly: &mut Option<MessageAssemblyState>,
) {
pipeline.purge_expired();
frame_handling::purge_expired_assemblies(message_assembly);
}
const MAX_DESER_FAILURES: u32 = 10;
struct FrameHandlingContext<'a, E, W, F>
where
E: Packet,
W: AsyncRead + AsyncWrite + Unpin,
F: FrameCodec,
{
framed: &'a mut Framed<W, ConnectionCodec<F>>,
deser_failures: &'a mut u32,
routes: &'a HashMap<u32, HandlerService<E>>,
pipeline: &'a mut FramePipeline,
message_assembly: &'a mut Option<MessageAssemblyState>,
}
impl<S, C, E, F> WireframeApp<S, C, E, F>
where
S: Serializer + Send + Sync,
C: Send + 'static,
E: Packet,
F: FrameCodec,
{
pub async fn send_response<W, M>(
&self,
stream: &mut W,
msg: &M,
) -> std::result::Result<(), SendError>
where
W: AsyncWrite + Unpin,
M: EncodeWith<S>,
{
let bytes = self
.serializer
.serialize(msg)
.map_err(SendError::Serialize)?;
let payload_len = bytes.len();
let frame = self.codec.wrap_payload(Bytes::from(bytes));
let mut encoder = self.codec.encoder();
let mut encoded_buf = BytesMut::with_capacity(payload_len);
encoder
.encode(frame, &mut encoded_buf)
.map_err(SendError::Io)?;
stream
.write_all(&encoded_buf)
.await
.map_err(SendError::Io)?;
stream.flush().await.map_err(SendError::Io)
}
pub async fn send_response_framed_with_codec<W, M, Cc>(
&self,
framed: &mut Framed<W, Cc>,
msg: &M,
) -> std::result::Result<(), SendError>
where
W: AsyncRead + AsyncWrite + Unpin,
M: EncodeWith<S>,
Cc: Encoder<F::Frame, Error = io::Error>,
{
let bytes = self
.serializer
.serialize(msg)
.map_err(SendError::Serialize)?;
let frame = self.codec.wrap_payload(Bytes::from(bytes));
framed.send(frame).await.map_err(SendError::Io)
}
}
impl<S, C, E> WireframeApp<S, C, E, LengthDelimitedFrameCodec>
where
S: Serializer + Send + Sync,
C: Send + 'static,
E: Packet,
{
#[must_use]
pub fn length_codec(&self) -> LengthDelimitedCodec {
LengthDelimitedCodec::builder()
.max_frame_length(self.codec.max_frame_length())
.new_codec()
}
pub async fn send_response_framed<W, M>(
&self,
framed: &mut Framed<W, LengthDelimitedCodec>,
msg: &M,
) -> std::result::Result<(), SendError>
where
W: AsyncRead + AsyncWrite + Unpin,
M: EncodeWith<S>,
{
let bytes = self
.serializer
.serialize(msg)
.map_err(SendError::Serialize)?;
framed.send(bytes.into()).await.map_err(SendError::Io)
}
}
impl<S, C, E, F> WireframeApp<S, C, E, F>
where
S: Serializer + FrameMetadata<Frame = Envelope> + Send + Sync,
C: Send + 'static,
E: Packet,
F: FrameCodec,
Envelope: DecodeWith<S> + EncodeWith<S>,
{
fn parse_envelope(
&self,
payload: &[u8],
) -> std::result::Result<(Envelope, usize), Box<dyn std::error::Error + Send + Sync>> {
match self.serializer.parse(payload) {
Ok((parsed_envelope, metadata_bytes_consumed)) => {
if !self.serializer.should_deserialize_after_parse() {
return Ok((parsed_envelope, metadata_bytes_consumed));
}
let context = DeserializeContext {
frame_metadata: payload.get(..metadata_bytes_consumed),
message_id: Some(parsed_envelope.id),
correlation_id: parsed_envelope.correlation_id,
metadata_bytes_consumed: Some(metadata_bytes_consumed),
};
self.serializer
.deserialize_with_context::<Envelope>(payload, &context)
}
Err(_) => self.serializer.deserialize::<Envelope>(payload),
}
}
pub async fn handle_connection_result<W>(&self, stream: W) -> io::Result<()>
where
W: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
let state = if let Some(setup) = &self.on_connect {
Some((setup)().await)
} else {
None
};
let routes = self
.routes
.get_or_init(|| async { Arc::new(self.build_chains().await) })
.await
.clone();
if let Err(e) = self.process_stream(stream, &routes).await {
warn!(
"connection terminated with error: correlation_id={:?}, error={e:?}",
None::<u64>
);
return Err(e);
}
if let (Some(teardown), Some(state)) = (&self.on_disconnect, state) {
teardown(state).await;
}
Ok(())
}
pub async fn handle_connection<W>(&self, stream: W)
where
W: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
if let Err(e) = self.handle_connection_result(stream).await {
warn!(
"connection handling completed with error: correlation_id={:?}, error={e:?}",
None::<u64>
);
}
}
async fn build_chains(&self) -> HashMap<u32, HandlerService<E>> {
let mut routes = HashMap::new();
for (&id, handler) in &self.handlers {
let mut service = HandlerService::new(id, handler.clone());
for mw in self.middleware.iter().rev() {
service = mw.transform(service).await;
}
routes.insert(id, service);
}
routes
}
async fn process_stream<W>(
&self,
stream: W,
routes: &Arc<HashMap<u32, HandlerService<E>>>,
) -> io::Result<()>
where
W: AsyncRead + AsyncWrite + Unpin,
{
let codec = self.codec.clone();
let combined = CombinedCodec::new(codec.decoder(), codec.encoder());
let mut framed = Framed::new(stream, combined);
let requested_frame_length = codec.max_frame_length();
let max_frame_length = clamp_frame_length(requested_frame_length);
if requested_frame_length > MAX_FRAME_LENGTH {
warn!(
"codec max frame length exceeds guardrail; clamping to {MAX_FRAME_LENGTH} bytes \
(requested={requested_frame_length})"
);
}
framed.read_buffer_mut().reserve(max_frame_length);
let effective_budgets =
frame_handling::resolve_effective_budgets(self.memory_budgets, requested_frame_length);
let mut deser_failures = 0u32;
let mut message_assembly = self.message_assembler.as_ref().map(|_| {
frame_handling::new_message_assembly_state(
self.fragmentation,
requested_frame_length,
Some(effective_budgets),
)
});
let mut pipeline = FramePipeline::new(self.fragmentation);
let timeout_dur = Duration::from_millis(self.read_timeout_ms);
loop {
let pressure = frame_handling::evaluate_memory_pressure(
message_assembly.as_ref(),
Some(effective_budgets),
);
frame_handling::apply_memory_pressure(pressure, || {
purge_expired(&mut pipeline, &mut message_assembly);
})
.await?;
match timeout(timeout_dur, framed.next()).await {
Ok(Some(Ok(frame))) => {
self.handle_frame(
&frame,
FrameHandlingContext {
framed: &mut framed,
deser_failures: &mut deser_failures,
routes,
message_assembly: &mut message_assembly,
pipeline: &mut pipeline,
},
&codec,
)
.await?;
}
Ok(Some(Err(e))) => return Err(e),
Ok(None) => break,
Err(_) => {
debug!("read timeout elapsed; continuing to wait for next frame");
purge_expired(&mut pipeline, &mut message_assembly);
}
}
}
Ok(())
}
async fn handle_frame<W>(
&self,
frame: &F::Frame,
ctx: FrameHandlingContext<'_, E, W, F>,
codec: &F,
) -> io::Result<()>
where
W: AsyncRead + AsyncWrite + Unpin,
{
let FrameHandlingContext {
framed,
deser_failures,
routes,
message_assembly,
pipeline,
} = ctx;
crate::metrics::inc_frames(crate::metrics::Direction::Inbound);
let Some(env) = frame_handling::decode_envelope::<F>(
self.parse_envelope(F::frame_payload(frame)),
frame,
deser_failures,
MAX_DESER_FAILURES,
)?
else {
return Ok(());
};
let Some(env) = frame_handling::reassemble_if_needed(
pipeline,
deser_failures,
env,
MAX_DESER_FAILURES,
)?
else {
return Ok(());
};
let Some(env) = frame_handling::assemble_if_needed(
frame_handling::AssemblyRuntime::new(self.message_assembler.as_ref(), message_assembly),
deser_failures,
env,
MAX_DESER_FAILURES,
)?
else {
return Ok(());
};
*deser_failures = 0;
if let Some(service) = routes.get(&env.id) {
frame_handling::forward_response(
env,
service,
frame_handling::ResponseContext::<S, W, F> {
serializer: &self.serializer,
framed,
pipeline,
codec,
},
)
.await?;
} else {
warn!(
"no handler for message id: id={}, correlation_id={:?}",
env.id, env.correlation_id
);
}
Ok(())
}
}
#[cfg(test)]
mod tests;