use std::{
cell::{Cell, RefCell},
net::SocketAddr,
};
use bytes::Bytes;
use futures::{SinkExt, StreamExt};
use log::warn;
use rstest::fixture;
use tokio::{net::TcpListener, task::JoinHandle};
use tokio_util::codec::{Framed, LengthDelimitedCodec};
use wireframe::{
WireframeError,
app::Envelope,
client::{ClientCodecConfig, ClientError, ClientProtocolError, WireframeClient},
message::Message,
rewind_stream::RewindStream,
serializer::BincodeSerializer,
};
#[path = "../../examples/support/echo_login_contract.rs"]
mod echo_login_contract;
use echo_login_contract::{LOGIN_ROUTE_ID, LoginAck, LoginRequest};
pub use wireframe_testing::TestResult;
pub const SIMULATED_TLS_RECORD: [u8; 7] = [0x16, 0x03, 0x03, 0x00, 0x02, 0x02, 0x28];
#[derive(Debug)]
pub struct ClientRuntimeWorld {
runtime: Option<tokio::runtime::Runtime>,
runtime_error: Option<String>,
addr: Cell<Option<SocketAddr>>,
server: RefCell<Option<JoinHandle<()>>>,
client:
RefCell<Option<WireframeClient<BincodeSerializer, RewindStream<tokio::net::TcpStream>>>>,
payload: RefCell<Option<ClientPayload>>,
response: RefCell<Option<ClientPayload>>,
login_ack: RefCell<Option<LoginAck>>,
last_error: RefCell<Option<ClientError>>,
}
impl ClientRuntimeWorld {
pub fn new() -> Self {
match tokio::runtime::Runtime::new() {
Ok(runtime) => Self {
runtime: Some(runtime),
runtime_error: None,
addr: Cell::new(None),
server: RefCell::new(None),
client: RefCell::new(None),
payload: RefCell::new(None),
response: RefCell::new(None),
login_ack: RefCell::new(None),
last_error: RefCell::new(None),
},
Err(err) => Self {
runtime: None,
runtime_error: Some(format!("failed to create runtime: {err}")),
addr: Cell::new(None),
server: RefCell::new(None),
client: RefCell::new(None),
payload: RefCell::new(None),
response: RefCell::new(None),
login_ack: RefCell::new(None),
last_error: RefCell::new(None),
},
}
}
fn runtime(&self) -> TestResult<&tokio::runtime::Runtime> {
self.runtime.as_ref().ok_or_else(|| {
self.runtime_error
.clone()
.unwrap_or_else(|| "runtime unavailable".to_string())
.into()
})
}
fn block_on<F, T>(&self, future: F) -> TestResult<T>
where
F: std::future::Future<Output = T>,
{
if tokio::runtime::Handle::try_current().is_ok() {
return Err("nested Tokio runtime detected in client runtime fixture".into());
}
let runtime = self.runtime()?;
Ok(runtime.block_on(future))
}
}
#[derive(bincode::Encode, bincode::BorrowDecode, Debug, PartialEq, Eq, Clone)]
struct ClientPayload {
data: Vec<u8>,
}
#[rustfmt::skip]
#[fixture]
pub fn client_runtime_world() -> ClientRuntimeWorld {
ClientRuntimeWorld::new()
}
impl ClientRuntimeWorld {
pub fn start_server(&self, max_frame_length: usize) -> TestResult {
let listener = self.block_on(async { TcpListener::bind("127.0.0.1:0").await })??;
let addr = listener.local_addr()?;
let handle = self.runtime()?.spawn(async move {
let Ok((stream, _)) = listener.accept().await else {
warn!("client runtime server failed to accept connection");
return;
};
let codec = LengthDelimitedCodec::builder()
.max_frame_length(max_frame_length)
.new_codec();
let mut framed = Framed::new(stream, codec);
let Some(result) = framed.next().await else {
warn!("client runtime server closed before receiving a frame");
return;
};
let Ok(frame) = result else {
warn!("client runtime server failed to decode frame");
return;
};
if let Err(err) = framed.send(frame.freeze()).await {
warn!("client runtime server failed to send response: {err:?}");
}
});
self.addr.set(Some(addr));
*self.server.borrow_mut() = Some(handle);
Ok(())
}
pub fn start_malformed_response_server(&self) -> TestResult {
let listener = self.block_on(async { TcpListener::bind("127.0.0.1:0").await })??;
let addr = listener.local_addr()?;
let handle = self.runtime()?.spawn(async move {
let Ok((stream, _)) = listener.accept().await else {
warn!("client runtime malformed server failed to accept connection");
return;
};
let mut framed = Framed::new(stream, LengthDelimitedCodec::new());
let Some(result) = framed.next().await else {
warn!("client runtime malformed server closed before receiving a frame");
return;
};
let Ok(_frame) = result else {
warn!("client runtime malformed server failed to decode request frame");
return;
};
if let Err(err) = framed
.send(Bytes::from_static(&[0xff, 0xff, 0xff, 0xff]))
.await
{
warn!("client runtime malformed server failed to send invalid frame: {err:?}");
}
});
self.addr.set(Some(addr));
*self.server.borrow_mut() = Some(handle);
Ok(())
}
pub fn start_tls_mismatch_server(&self) -> TestResult {
let listener = self.block_on(async { TcpListener::bind("127.0.0.1:0").await })??;
let addr = listener.local_addr()?;
let handle = self.runtime()?.spawn(async move {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let Ok((mut stream, _)) = listener.accept().await else {
warn!("client runtime TLS-mismatch server failed to accept connection");
return;
};
let mut request_buf = [0_u8; 64];
match stream.read(&mut request_buf).await {
Ok(_) => {}
Err(err) => {
warn!("client runtime TLS-mismatch server failed to read request: {err:?}");
return;
}
}
if let Err(err) = stream.write_all(&SIMULATED_TLS_RECORD).await {
warn!("client runtime TLS-mismatch server failed to write response: {err:?}");
return;
}
if let Err(err) = stream.shutdown().await {
warn!("client runtime TLS-mismatch server failed to shutdown: {err:?}");
}
});
self.addr.set(Some(addr));
*self.server.borrow_mut() = Some(handle);
Ok(())
}
pub fn connect_client(&self, max_frame_length: usize) -> TestResult {
let addr = self.addr.get().ok_or("server address missing")?;
let codec_config = ClientCodecConfig::default().max_frame_length(max_frame_length);
let client = self.block_on(async {
WireframeClient::builder()
.codec_config(codec_config)
.connect(addr)
.await
})??;
*self.client.borrow_mut() = Some(client);
Ok(())
}
pub fn send_payload(&self, size: usize) -> TestResult {
let (payload, result) = self.send_payload_inner(size)?;
let response = result?;
*self.payload.borrow_mut() = Some(payload);
*self.response.borrow_mut() = Some(response);
*self.login_ack.borrow_mut() = None;
*self.last_error.borrow_mut() = None;
Ok(())
}
pub fn send_payload_expect_error(&self, size: usize) -> TestResult {
let (_payload, result) = self.send_payload_inner(size)?;
match result {
Ok(_) => return Err("expected client error for oversized payload".into()),
Err(err) => {
*self.last_error.borrow_mut() = Some(err);
*self.login_ack.borrow_mut() = None;
}
}
Ok(())
}
fn send_payload_inner(
&self,
size: usize,
) -> TestResult<(ClientPayload, Result<ClientPayload, ClientError>)> {
let payload = ClientPayload {
data: vec![7_u8; size],
};
let mut client = self
.client
.borrow_mut()
.take()
.ok_or("client not connected")?;
let result = self.block_on(async { client.call(&payload).await })?;
*self.client.borrow_mut() = Some(client);
Ok((payload, result))
}
pub fn verify_echo(&self) -> TestResult {
let payload_ref = self.payload.borrow();
let response_ref = self.response.borrow();
let payload = payload_ref.as_ref().ok_or("payload missing")?;
let response = response_ref.as_ref().ok_or("response missing")?;
if payload != response {
return Err("response did not match payload".into());
}
self.await_server()?;
Ok(())
}
pub fn send_login_request(&self, username: String) -> TestResult {
*self.login_ack.borrow_mut() = None;
let login = LoginRequest { username };
let mut client = self
.client
.borrow_mut()
.take()
.ok_or("client not connected")?;
let request = Envelope::new(LOGIN_ROUTE_ID, None, login.to_bytes()?);
let response: Envelope =
self.block_on(async { client.call_correlated(request).await })??;
let (ack, _) = LoginAck::from_bytes(response.payload_bytes())?;
*self.client.borrow_mut() = Some(client);
*self.login_ack.borrow_mut() = Some(ack);
*self.last_error.borrow_mut() = None;
Ok(())
}
pub fn verify_login_acknowledgement(&self, expected: &str) -> TestResult {
let ack_ref = self.login_ack.borrow();
let ack = ack_ref.as_ref().ok_or("login acknowledgement missing")?;
if ack.username != expected {
return Err(format!(
"expected login acknowledgement for '{expected}', got '{}'",
ack.username
)
.into());
}
self.await_server()?;
Ok(())
}
pub fn verify_transport_wireframe_error(&self) -> TestResult {
let error_ref = self.last_error.borrow();
let err = error_ref
.as_ref()
.ok_or("expected client error was not captured")?;
if !matches!(err, ClientError::Wireframe(WireframeError::Io(_))) {
return Err(format!("expected transport WireframeError::Io, got {err:?}").into());
}
self.await_server()?;
Ok(())
}
pub fn verify_decode_wireframe_error(&self) -> TestResult {
let error_ref = self.last_error.borrow();
let err = error_ref
.as_ref()
.ok_or("expected client error was not captured")?;
if !matches!(
err,
ClientError::Wireframe(WireframeError::Protocol(ClientProtocolError::Deserialize(
_
)))
) {
return Err(format!(
"expected decode WireframeError::Protocol(ClientProtocolError::Deserialize(_)), \
got {err:?}"
)
.into());
}
self.await_server()?;
Ok(())
}
fn await_server(&self) -> TestResult {
if let Some(handle) = self.server.borrow_mut().take() {
self.block_on(async {
handle
.await
.map_err(|err| format!("server task failed: {err}"))
})??;
}
Ok(())
}
}