use std::{io, num::NonZeroUsize};
use bytes::{Bytes, BytesMut};
use tokio::io::{
AsyncRead,
AsyncReadExt,
AsyncWrite,
AsyncWriteExt,
DuplexStream,
ReadHalf,
WriteHalf,
duplex,
split,
};
use tokio_util::codec::{Decoder, Encoder};
use crate::{
app::{Envelope, Packet, WireframeApp},
codec::FrameCodec,
frame::FrameMetadata,
serializer::{MessageCompatibilitySerializer, Serializer},
};
pub(crate) const DEFAULT_CAPACITY: usize = 4096;
pub trait TestSerializer:
Serializer
+ MessageCompatibilitySerializer
+ FrameMetadata<Frame = Envelope>
+ Send
+ Sync
+ 'static
{
}
impl<T> TestSerializer for T where
T: Serializer
+ MessageCompatibilitySerializer
+ FrameMetadata<Frame = Envelope>
+ Send
+ Sync
+ 'static
{
}
async fn panic_guarded_server<F, Fut>(server_fn: F, server: DuplexStream) -> io::Result<()>
where
F: FnOnce(DuplexStream) -> Fut,
Fut: std::future::Future<Output = ()>,
{
use futures::FutureExt as _;
let result = std::panic::AssertUnwindSafe(server_fn(server))
.catch_unwind()
.await;
match result {
Ok(()) => Ok(()),
Err(panic) => {
let panic_msg = crate::panic::format_panic(&panic);
Err(io::Error::other(format!("server task failed: {panic_msg}")))
}
}
}
pub(crate) async fn drive_with_strategies<F, Fut, WFn, WFut, RFn, RFut>(
server_fn: F,
capacity: usize,
write_strategy: WFn,
read_strategy: RFn,
) -> io::Result<Vec<u8>>
where
F: FnOnce(DuplexStream) -> Fut,
Fut: std::future::Future<Output = ()>,
WFn: FnOnce(WriteHalf<DuplexStream>) -> WFut,
WFut: std::future::Future<Output = io::Result<WriteHalf<DuplexStream>>>,
RFn: FnOnce(ReadHalf<DuplexStream>) -> RFut,
RFut: std::future::Future<Output = io::Result<Vec<u8>>>,
{
let (client, server) = duplex(capacity);
let (reader, writer) = split(client);
let server_fut = panic_guarded_server(server_fn, server);
let writer_fut = async {
let mut writer = write_strategy(writer).await?;
writer.shutdown().await?;
io::Result::Ok(())
};
let reader_fut = read_strategy(reader);
let ((), (), out) = tokio::try_join!(server_fut, writer_fut, reader_fut)?;
Ok(out)
}
pub(crate) async fn write_frames(
mut writer: impl AsyncWrite + Unpin,
frames: &[Vec<u8>],
) -> io::Result<()> {
for frame in frames {
writer.write_all(frame).await?;
}
Ok(())
}
pub(crate) async fn write_chunked(
mut writer: impl AsyncWrite + Unpin,
bytes: &[u8],
chunk_size: NonZeroUsize,
) -> io::Result<()> {
let total = bytes.len();
let step = chunk_size.get();
let mut offset = 0;
while offset < total {
let end = (offset + step).min(total);
let chunk = bytes
.get(offset..end)
.ok_or_else(|| io::Error::other("chunk slice out of bounds"))?;
writer.write_all(chunk).await?;
offset = end;
}
Ok(())
}
pub(crate) async fn read_all(mut reader: impl AsyncRead + Unpin) -> io::Result<Vec<u8>> {
let mut buf = Vec::new();
reader.read_to_end(&mut buf).await?;
Ok(buf)
}
async fn write_frames_strategy(
mut writer: WriteHalf<DuplexStream>,
frames: Vec<Vec<u8>>,
) -> io::Result<WriteHalf<DuplexStream>> {
write_frames(&mut writer, &frames).await?;
Ok(writer)
}
async fn write_chunked_strategy(
mut writer: WriteHalf<DuplexStream>,
bytes: Vec<u8>,
chunk_size: NonZeroUsize,
) -> io::Result<WriteHalf<DuplexStream>> {
write_chunked(&mut writer, &bytes, chunk_size).await?;
Ok(writer)
}
pub(crate) async fn drive_internal<F, Fut>(
server_fn: F,
frames: Vec<Vec<u8>>,
capacity: usize,
) -> io::Result<Vec<u8>>
where
F: FnOnce(DuplexStream) -> Fut,
Fut: std::future::Future<Output = ()> + Send,
{
drive_with_strategies(
server_fn,
capacity,
|writer| write_frames_strategy(writer, frames),
read_all,
)
.await
}
pub(crate) async fn drive_chunked_internal<F, Fut>(
server_fn: F,
wire_bytes: Vec<u8>,
chunk_size: NonZeroUsize,
capacity: usize,
) -> io::Result<Vec<u8>>
where
F: FnOnce(DuplexStream) -> Fut,
Fut: std::future::Future<Output = ()> + Send,
{
drive_with_strategies(
server_fn,
capacity,
|writer| write_chunked_strategy(writer, wire_bytes, chunk_size),
read_all,
)
.await
}
pub(crate) fn encode_payloads_with_codec<F: FrameCodec>(
codec: &F,
payloads: Vec<Vec<u8>>,
) -> io::Result<Vec<Vec<u8>>> {
let mut encoder = codec.encoder();
payloads
.into_iter()
.map(|payload| {
let frame = codec.wrap_payload(Bytes::from(payload));
let mut buf = BytesMut::new();
encoder.encode(frame, &mut buf).map_err(|error| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("codec encode failed: {error}"),
)
})?;
Ok(buf.to_vec())
})
.collect()
}
pub(crate) fn decode_frames_with_codec<F: FrameCodec>(
codec: &F,
bytes: &[u8],
) -> io::Result<Vec<F::Frame>> {
let mut decoder = codec.decoder();
let mut buf = BytesMut::from(bytes);
let mut frames = Vec::new();
while let Some(frame) = decoder.decode(&mut buf)? {
frames.push(frame);
}
while let Some(frame) = decoder.decode_eof(&mut buf)? {
frames.push(frame);
}
if !buf.is_empty() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"trailing {} byte(s) after decoding - possible truncated frame",
buf.len()
),
));
}
Ok(frames)
}
pub(crate) async fn drive_codec_roundtrip<F, H, Fut>(
handler: H,
codec: &F,
payloads: Vec<Vec<u8>>,
drive: impl AsyncFnOnce(H, Vec<u8>) -> io::Result<Vec<u8>>,
) -> io::Result<Vec<F::Frame>>
where
F: FrameCodec,
H: FnOnce(DuplexStream) -> Fut,
Fut: std::future::Future<Output = ()> + Send,
{
let encoded = encode_payloads_with_codec(codec, payloads)?;
let wire_bytes: Vec<u8> = encoded.into_iter().flatten().collect();
let raw = drive(handler, wire_bytes).await?;
decode_frames_with_codec(codec, &raw)
}
pub(crate) fn extract_payloads<F: FrameCodec>(frames: &[F::Frame]) -> Vec<Vec<u8>> {
frames
.iter()
.map(|frame| F::frame_payload(frame).to_vec())
.collect()
}
pub(crate) async fn run_owned_app<S, C, E, F>(app: WireframeApp<S, C, E, F>, server: DuplexStream)
where
S: TestSerializer,
C: Send + 'static,
E: Packet,
F: FrameCodec,
{
app.handle_connection(server).await;
}