use std::{io, num::NonZeroUsize};
use tokio::io::{AsyncReadExt, AsyncWriteExt, DuplexStream, duplex};
use wireframe::{
app::{Packet, WireframeApp},
codec::FrameCodec,
};
use super::{
DEFAULT_CAPACITY,
TestSerializer,
codec_ext::{decode_frames_with_codec, encode_payloads_with_codec, extract_payloads},
};
#[derive(Debug, Clone, Copy)]
pub(super) struct ChunkConfig {
pub chunk_size: NonZeroUsize,
pub capacity: usize,
}
impl ChunkConfig {
pub fn new(chunk_size: NonZeroUsize) -> Self {
Self {
chunk_size,
capacity: DEFAULT_CAPACITY,
}
}
pub fn with_capacity(chunk_size: NonZeroUsize, capacity: usize) -> Self {
Self {
chunk_size,
capacity,
}
}
}
pub(super) async fn drive_chunked_internal<F, Fut>(
server_fn: F,
wire_bytes: Vec<u8>,
chunk_size: NonZeroUsize,
capacity: usize,
) -> io::Result<Vec<u8>>
where
F: FnOnce(DuplexStream) -> Fut,
Fut: std::future::Future<Output = ()> + Send,
{
let (mut client, server) = duplex(capacity);
let server_fut = async {
use futures::FutureExt as _;
let result = std::panic::AssertUnwindSafe(server_fn(server))
.catch_unwind()
.await;
match result {
Ok(()) => Ok(()),
Err(panic) => {
let panic_msg = wireframe::panic::format_panic(&panic);
Err(io::Error::new(
io::ErrorKind::Other,
format!("server task failed: {panic_msg}"),
))
}
}
};
let client_fut = async {
let total = wire_bytes.len();
let step = chunk_size.get();
let mut offset = 0;
while offset < total {
let end = (offset + step).min(total);
let chunk = wire_bytes
.get(offset..end)
.ok_or_else(|| io::Error::other("chunk slice out of bounds"))?;
client.write_all(chunk).await?;
offset = end;
}
client.shutdown().await?;
let mut buf = Vec::new();
client.read_to_end(&mut buf).await?;
io::Result::Ok(buf)
};
let ((), buf) = tokio::try_join!(server_fut, client_fut)?;
Ok(buf)
}
async fn drive_partial_frames_internal<F, H, Fut>(
handler: H,
codec: &F,
payloads: Vec<Vec<u8>>,
config: ChunkConfig,
) -> io::Result<Vec<F::Frame>>
where
F: FrameCodec,
H: FnOnce(DuplexStream) -> Fut,
Fut: std::future::Future<Output = ()> + Send,
{
let encoded = encode_payloads_with_codec(codec, payloads)?;
let wire_bytes: Vec<u8> = encoded.into_iter().flatten().collect();
let raw =
drive_chunked_internal(handler, wire_bytes, config.chunk_size, config.capacity).await?;
decode_frames_with_codec(codec, raw)
}
pub async fn drive_with_partial_frames<S, C, E, F>(
app: WireframeApp<S, C, E, F>,
codec: &F,
payloads: Vec<Vec<u8>>,
chunk_size: NonZeroUsize,
) -> io::Result<Vec<Vec<u8>>>
where
S: TestSerializer,
C: Send + 'static,
E: Packet,
F: FrameCodec,
{
drive_with_partial_frames_with_capacity(app, codec, payloads, chunk_size, DEFAULT_CAPACITY)
.await
}
pub async fn drive_with_partial_frames_with_capacity<S, C, E, F>(
app: WireframeApp<S, C, E, F>,
codec: &F,
payloads: Vec<Vec<u8>>,
chunk_size: NonZeroUsize,
capacity: usize,
) -> io::Result<Vec<Vec<u8>>>
where
S: TestSerializer,
C: Send + 'static,
E: Packet,
F: FrameCodec,
{
let frames = drive_partial_frames_internal(
|server| async move { app.handle_connection(server).await },
codec,
payloads,
ChunkConfig::with_capacity(chunk_size, capacity),
)
.await?;
Ok(extract_payloads::<F>(&frames))
}
pub async fn drive_with_partial_frames_mut<S, C, E, F>(
app: &mut WireframeApp<S, C, E, F>,
codec: &F,
payloads: Vec<Vec<u8>>,
chunk_size: NonZeroUsize,
) -> io::Result<Vec<Vec<u8>>>
where
S: TestSerializer,
C: Send + 'static,
E: Packet,
F: FrameCodec,
{
let frames = drive_partial_frames_internal(
|server| async move { app.handle_connection(server).await },
codec,
payloads,
ChunkConfig::new(chunk_size),
)
.await?;
Ok(extract_payloads::<F>(&frames))
}
pub async fn drive_with_partial_codec_frames<S, C, E, F>(
app: WireframeApp<S, C, E, F>,
codec: &F,
payloads: Vec<Vec<u8>>,
chunk_size: NonZeroUsize,
) -> io::Result<Vec<F::Frame>>
where
S: TestSerializer,
C: Send + 'static,
E: Packet,
F: FrameCodec,
{
drive_partial_frames_internal(
|server| async move { app.handle_connection(server).await },
codec,
payloads,
ChunkConfig::new(chunk_size),
)
.await
}