use std::net::SocketAddr;
use bytes::Bytes;
use futures::{SinkExt, StreamExt};
use log::warn;
use rstest::fixture;
use tokio::{net::TcpListener, task::JoinHandle};
use tokio_util::codec::{Framed, LengthDelimitedCodec};
use wireframe::{
app::{Envelope, Packet},
client::{ClientError, WireframeClient},
correlation::CorrelatableFrame,
rewind_stream::RewindStream,
serializer::BincodeSerializer,
};
pub use wireframe_testing::TestResult;
use wireframe_testing::{ServerMode, process_frame};
#[derive(Debug, Default)]
pub struct ClientMessagingWorld {
addr: Option<SocketAddr>,
server: Option<JoinHandle<()>>,
client: Option<WireframeClient<BincodeSerializer, RewindStream<tokio::net::TcpStream>>>,
envelope: Option<Envelope>,
sent_correlation_ids: Vec<u64>,
pub response: Option<Envelope>,
last_error: Option<ClientError>,
expected_message_id: Option<u32>,
expected_payload: Option<String>,
}
#[rustfmt::skip]
#[fixture]
pub fn client_messaging_world() -> ClientMessagingWorld {
ClientMessagingWorld::default()
}
impl ClientMessagingWorld {
pub async fn start_echo_server(&mut self) -> TestResult {
self.start_server_with_mode(ServerMode::Echo).await
}
pub async fn start_mismatch_server(&mut self) -> TestResult {
self.start_server_with_mode(ServerMode::Mismatch).await
}
async fn start_server_with_mode(&mut self, mode: ServerMode) -> TestResult {
let listener = TcpListener::bind("127.0.0.1:0").await?;
let addr = listener.local_addr()?;
let handle = tokio::spawn(async move {
let Ok((stream, _)) = listener.accept().await else {
warn!("client messaging server failed to accept connection");
return;
};
let mut framed = Framed::new(stream, LengthDelimitedCodec::new());
run_frame_loop(&mut framed, mode).await;
});
self.addr = Some(addr);
self.server = Some(handle);
Ok(())
}
pub async fn connect_client(&mut self) -> TestResult {
let addr = self.addr.ok_or("server address missing")?;
let client = WireframeClient::builder().connect(addr).await?;
self.client = Some(client);
Ok(())
}
pub fn set_envelope_without_correlation(&mut self) {
self.envelope = Some(Envelope::new(1, None, vec![1, 2, 3]));
}
pub fn set_envelope_with_correlation(&mut self, correlation_id: u64) {
self.envelope = Some(Envelope::new(1, Some(correlation_id), vec![1, 2, 3]));
}
pub fn set_envelope_with_payload(&mut self, message_id: u32, payload: &str) {
self.envelope = Some(Envelope::new(message_id, None, payload.as_bytes().to_vec()));
self.expected_message_id = Some(message_id);
self.expected_payload = Some(payload.to_string());
}
pub async fn send_envelope(&mut self) -> TestResult {
let client = self.client.as_mut().ok_or("client not connected")?;
let envelope = self.envelope.take().ok_or("envelope not configured")?;
let correlation_id = client.send_envelope(envelope).await?;
self.sent_correlation_ids.push(correlation_id);
Ok(())
}
pub async fn call_correlated(&mut self) -> TestResult {
let client = self.client.as_mut().ok_or("client not connected")?;
let envelope = self.envelope.take().ok_or("envelope not configured")?;
match client.call_correlated(envelope).await {
Ok(response) => {
self.response = Some(response);
self.last_error = None;
}
Err(err) => {
self.last_error = Some(err);
self.response = None;
}
}
Ok(())
}
#[expect(
clippy::cast_possible_truncation,
reason = "test helper with small count values"
)]
pub async fn send_multiple_envelopes(&mut self, count: usize) -> TestResult {
let client = self.client.as_mut().ok_or("client not connected")?;
self.sent_correlation_ids.clear();
for i in 0..count {
let envelope = Envelope::new(i as u32, None, vec![i as u8]);
let correlation_id = client.send_envelope(envelope).await?;
self.sent_correlation_ids.push(correlation_id);
let _: Envelope = client.receive_envelope().await?;
}
Ok(())
}
fn get_first_correlation_id(&self) -> TestResult<u64> {
self.sent_correlation_ids
.first()
.copied()
.ok_or_else(|| "no correlation ID captured".into())
}
pub fn verify_auto_generated_correlation(&self) -> TestResult {
let id = self.get_first_correlation_id()?;
if id == 0 {
return Err("correlation ID should be non-zero".into());
}
Ok(())
}
pub fn verify_correlation_id(&self, expected: u64) -> TestResult {
let id = self.get_first_correlation_id()?;
if id != expected {
return Err(format!("expected correlation ID {expected}, got {id}").into());
}
Ok(())
}
pub fn verify_response_correlation_matches(&self) -> TestResult {
let response = self.response.as_ref().ok_or("no response captured")?;
if response.correlation_id().is_none() {
return Err("response should have correlation ID".into());
}
Ok(())
}
pub fn verify_no_mismatch_error(&self) -> TestResult {
if self.last_error.is_some() {
return Err("unexpected error occurred".into());
}
Ok(())
}
pub fn verify_mismatch_error(&self) -> TestResult {
match &self.last_error {
Some(ClientError::CorrelationMismatch { .. }) => Ok(()),
Some(err) => Err(format!("expected CorrelationMismatch, got {err:?}").into()),
None => Err("expected CorrelationMismatch error, but none occurred".into()),
}
}
pub fn verify_unique_correlation_ids(&self) -> TestResult {
let mut sorted = self.sent_correlation_ids.clone();
sorted.sort_unstable();
sorted.dedup();
if sorted.len() != self.sent_correlation_ids.len() {
return Err("correlation IDs are not unique".into());
}
Ok(())
}
pub fn verify_response_matches_expected(&self) -> TestResult {
let response = self.response.as_ref().ok_or("no response captured")?;
let expected_id = self
.expected_message_id
.ok_or("expected message ID not set")?;
let expected_payload = self
.expected_payload
.as_ref()
.ok_or("expected payload not set")?;
if response.id() != expected_id {
return Err(format!("expected message ID {expected_id}, got {}", response.id()).into());
}
if response.payload_bytes() != expected_payload.as_bytes() {
return Err(format!(
"expected payload {:?}, got {:?}",
expected_payload.as_bytes(),
response.payload_bytes()
)
.into());
}
Ok(())
}
pub fn abort_server(&mut self) {
if let Some(handle) = self.server.take() {
handle.abort();
}
}
}
async fn run_frame_loop<T>(framed: &mut Framed<T, LengthDelimitedCodec>, mode: ServerMode)
where
T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
{
while let Some(result) = framed.next().await {
let Ok(bytes) = result else {
warn!("client messaging server failed to decode frame");
break;
};
let Some(response_bytes) = process_frame(mode, &bytes) else {
warn!("client messaging server failed to process frame");
break;
};
if framed.send(Bytes::from(response_bytes)).await.is_err() {
break;
}
}
}