use std::net::SocketAddr;
use tokio::{sync::oneshot, task::JoinHandle};
use wireframe::{
app::Packet,
client::{WireframeClient, WireframeClientBuilder},
codec::FrameCodec,
rewind_stream::RewindStream,
serializer::BincodeSerializer,
server::{AppFactory, WireframeServer},
};
use crate::{TestError, TestResult, integration_helpers::unused_listener};
struct Running {
client: Option<WireframeClient<BincodeSerializer, RewindStream<tokio::net::TcpStream>, ()>>,
shutdown_tx: Option<oneshot::Sender<()>>,
handle: Option<JoinHandle<Result<(), wireframe::server::ServerError>>>,
}
pub struct WireframePair {
addr: SocketAddr,
running: Option<Running>,
}
impl std::fmt::Debug for WireframePair {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WireframePair")
.field("addr", &self.addr)
.field("running", &self.running.as_ref().map(|_| ".."))
.finish()
}
}
impl WireframePair {
pub fn client_mut(
&mut self,
) -> TestResult<&mut WireframeClient<BincodeSerializer, RewindStream<tokio::net::TcpStream>, ()>>
{
self.running
.as_mut()
.and_then(|r| r.client.as_mut())
.ok_or_else(|| TestError::Msg("client_mut called after shutdown".into()))
}
#[must_use]
pub const fn local_addr(&self) -> SocketAddr { self.addr }
pub async fn shutdown(&mut self) -> TestResult<()> {
let Some(running) = self.running.as_mut() else {
return Ok(());
};
if let Some(client) = running.client.take() {
client.close().await;
}
if let Some(shutdown_tx) = running.shutdown_tx.take() {
let _ = shutdown_tx.send(());
}
let Some(handle) = running.handle.as_mut() else {
return Ok(());
};
let result = match handle.await {
Err(join_err) => Err(TestError::Msg(format!(
"server task join error: {join_err}"
))),
Ok(Err(server_err)) => Err(TestError::Msg(format!("server error: {server_err}"))),
Ok(Ok(())) => Ok(()),
};
self.running = None;
result
}
}
impl Drop for WireframePair {
fn drop(&mut self) {
if let Some(running) = self.running.take() {
if let Some(shutdown_tx) = running.shutdown_tx {
let _ = shutdown_tx.send(());
}
if let Some(handle) = running.handle {
spawn_bounded_shutdown(handle, std::time::Duration::from_millis(100));
}
}
}
}
fn spawn_bounded_shutdown(
mut handle: JoinHandle<Result<(), wireframe::server::ServerError>>,
timeout: std::time::Duration,
) {
if let Ok(runtime) = tokio::runtime::Handle::try_current() {
runtime.spawn(async move {
tokio::select! {
_ = &mut handle => {
}
_ = tokio::time::sleep(timeout) => {
handle.abort();
}
}
});
} else {
handle.abort();
}
}
struct PendingServer(
Option<(
oneshot::Sender<()>,
JoinHandle<Result<(), wireframe::server::ServerError>>,
)>,
);
impl PendingServer {
fn take(
&mut self,
) -> Option<(
oneshot::Sender<()>,
JoinHandle<Result<(), wireframe::server::ServerError>>,
)> {
self.0.take()
}
}
impl Drop for PendingServer {
fn drop(&mut self) {
if let Some((shutdown_tx, handle)) = self.0.take() {
let _ = shutdown_tx.send(());
spawn_bounded_shutdown(handle, std::time::Duration::from_millis(100));
}
}
}
pub async fn spawn_wireframe_pair<F, E, Codec, B>(
app_factory: F,
configure_client: B,
) -> TestResult<WireframePair>
where
F: AppFactory<BincodeSerializer, (), E, Codec>,
E: Packet,
Codec: FrameCodec,
B: FnOnce(
WireframeClientBuilder<BincodeSerializer, (), ()>,
) -> WireframeClientBuilder<BincodeSerializer, (), ()>,
{
let listener = unused_listener()?;
let server = WireframeServer::new(app_factory)
.workers(1)
.bind_existing_listener(listener)?;
let addr = server
.local_addr()
.ok_or("server did not report a bound address")?;
let (shutdown_tx, shutdown_rx) = oneshot::channel();
let (ready_tx, ready_rx) = oneshot::channel();
let handle = tokio::spawn(async move {
server
.ready_signal(ready_tx)
.run_with_shutdown(async {
let _ = shutdown_rx.await;
})
.await
});
let mut pending = PendingServer(Some((shutdown_tx, handle)));
if ready_rx.await.is_err() {
let (_, handle) = pending
.take()
.ok_or_else(|| TestError::Msg("pending server already taken".into()))?;
return match handle.await {
Err(join_err) => Err(TestError::Msg(format!(
"server task failed to start: {join_err}"
))),
Ok(Err(server_err)) => Err(TestError::Msg(format!(
"server failed to start: {server_err}"
))),
Ok(Ok(())) => Err(TestError::Msg(
"server exited before signalling ready".into(),
)),
};
}
let builder = configure_client(WireframeClientBuilder::new());
let client = builder.connect(addr).await?;
let (shutdown_tx, handle) = pending
.take()
.ok_or_else(|| TestError::Msg("pending server already taken".into()))?;
Ok(WireframePair {
addr,
running: Some(Running {
client: Some(client),
shutdown_tx: Some(shutdown_tx),
handle: Some(handle),
}),
})
}
pub async fn spawn_wireframe_pair_default<F, E, Codec>(app_factory: F) -> TestResult<WireframePair>
where
F: AppFactory<BincodeSerializer, (), E, Codec>,
E: Packet,
Codec: FrameCodec,
{
spawn_wireframe_pair(app_factory, |builder| builder).await
}