use std::io;
use tokio::io::{AsyncReadExt, AsyncWriteExt, DuplexStream, duplex};
use wireframe::app::{Packet, WireframeApp};
use super::{DEFAULT_CAPACITY, TestSerializer};
pub(super) async fn drive_internal<F, Fut>(
server_fn: F,
frames: Vec<Vec<u8>>,
capacity: usize,
) -> io::Result<Vec<u8>>
where
F: FnOnce(DuplexStream) -> Fut,
Fut: std::future::Future<Output = ()> + Send,
{
let (mut client, server) = duplex(capacity);
let server_fut = async {
use futures::FutureExt as _;
let result = std::panic::AssertUnwindSafe(server_fn(server))
.catch_unwind()
.await;
match result {
Ok(_) => Ok(()),
Err(panic) => {
let panic_msg = wireframe::panic::format_panic(&panic);
Err(io::Error::new(
io::ErrorKind::Other,
format!("server task failed: {panic_msg}"),
))
}
}
};
let client_fut = async {
for frame in &frames {
client.write_all(frame).await?;
}
client.shutdown().await?;
let mut buf = Vec::new();
client.read_to_end(&mut buf).await?;
io::Result::Ok(buf)
};
let ((), buf) = tokio::try_join!(server_fut, client_fut)?;
Ok(buf)
}
macro_rules! forward_default {
(
$(#[$docs:meta])* $vis:vis fn $name:ident(
$app:ident : $app_ty:ty,
$arg:ident : $arg_ty:ty
) -> $ret:ty
=> $inner:ident($app_expr:ident, $arg_expr:expr)
) => {
$(#[$docs])*
$vis async fn $name<S, C, E>(
$app: $app_ty,
$arg: $arg_ty,
) -> $ret
where
S: TestSerializer,
C: Send + 'static,
E: Packet,
{
$inner($app_expr, $arg_expr, DEFAULT_CAPACITY).await
}
};
}
macro_rules! forward_with_capacity {
(
$(#[$docs:meta])* $vis:vis fn $name:ident(
$app:ident : $app_ty:ty,
$arg:ident : $arg_ty:ty,
capacity: usize
) -> $ret:ty
=> $inner:ident($app_expr:ident, $arg_expr:expr, capacity)
) => {
$(#[$docs])*
$vis async fn $name<S, C, E>(
$app: $app_ty,
$arg: $arg_ty,
capacity: usize,
) -> $ret
where
S: TestSerializer,
C: Send + 'static,
E: Packet,
{
$inner($app_expr, $arg_expr, capacity).await
}
};
}
pub async fn drive_with_frame<S, C, E>(
app: WireframeApp<S, C, E>,
frame: Vec<u8>,
) -> io::Result<Vec<u8>>
where
S: TestSerializer,
C: Send + 'static,
E: Packet,
{
drive_with_frame_with_capacity(app, frame, DEFAULT_CAPACITY).await
}
forward_with_capacity! {
pub fn drive_with_frame_with_capacity(app: WireframeApp<S, C, E>, frame: Vec<u8>, capacity: usize) -> io::Result<Vec<u8>>
=> drive_with_frames_with_capacity(app, vec![frame], capacity)
}
forward_default! {
pub fn drive_with_frames(app: WireframeApp<S, C, E>, frames: Vec<Vec<u8>>) -> io::Result<Vec<u8>>
=> drive_with_frames_with_capacity(app, frames)
}
pub async fn drive_with_frames_with_capacity<S, C, E>(
app: WireframeApp<S, C, E>,
frames: Vec<Vec<u8>>,
capacity: usize,
) -> io::Result<Vec<u8>>
where
S: TestSerializer,
C: Send + 'static,
E: Packet,
{
drive_internal(
|server| async move { app.handle_connection(server).await },
frames,
capacity,
)
.await
}
forward_default! {
pub fn drive_with_frame_mut(app: &mut WireframeApp<S, C, E>, frame: Vec<u8>) -> io::Result<Vec<u8>>
=> drive_with_frame_with_capacity_mut(app, frame)
}
forward_with_capacity! {
pub fn drive_with_frame_with_capacity_mut(app: &mut WireframeApp<S, C, E>, frame: Vec<u8>, capacity: usize) -> io::Result<Vec<u8>>
=> drive_with_frames_with_capacity_mut(app, vec![frame], capacity)
}
forward_default! {
pub fn drive_with_frames_mut(app: &mut WireframeApp<S, C, E>, frames: Vec<Vec<u8>>) -> io::Result<Vec<u8>>
=> drive_with_frames_with_capacity_mut(app, frames)
}
pub async fn drive_with_frames_with_capacity_mut<S, C, E>(
app: &mut WireframeApp<S, C, E>,
frames: Vec<Vec<u8>>,
capacity: usize,
) -> io::Result<Vec<u8>>
where
S: TestSerializer,
C: Send + 'static,
E: Packet,
{
drive_internal(
|server| async { app.handle_connection(server).await },
frames,
capacity,
)
.await
}