#![expect(
clippy::excessive_nesting,
reason = "spawned server tasks naturally nest several levels deep"
)]
use std::{
net::SocketAddr,
sync::{
Arc,
Mutex,
atomic::{AtomicUsize, Ordering},
},
};
use bytes::Bytes;
use futures::{SinkExt, StreamExt};
use rstest::fixture;
use tokio::{net::TcpListener, task::JoinHandle};
use tokio_util::codec::{Framed, LengthDelimitedCodec};
use wireframe::{
app::Envelope,
client::WireframeClient,
rewind_stream::RewindStream,
serializer::BincodeSerializer,
};
pub use wireframe_testing::TestResult;
use wireframe_testing::{ServerMode, process_frame};
const MARKER_BYTE: u8 = 0xff;
type TestClient = WireframeClient<BincodeSerializer, RewindStream<tokio::net::TcpStream>>;
#[derive(Debug, Default)]
pub struct ClientRequestHooksWorld {
addr: Option<SocketAddr>,
server: Option<JoinHandle<()>>,
capturing_server: Option<JoinHandle<Vec<Vec<u8>>>>,
client: Option<TestClient>,
before_send_count: Arc<AtomicUsize>,
after_receive_count: Arc<AtomicUsize>,
marker_log: Arc<Mutex<Vec<u8>>>,
received_payload: Option<Vec<u8>>,
}
impl Drop for ClientRequestHooksWorld {
fn drop(&mut self) {
if let Some(handle) = self.server.take() {
handle.abort();
}
if let Some(handle) = self.capturing_server.take() {
handle.abort();
}
}
}
#[rustfmt::skip]
#[fixture]
pub fn client_request_hooks_world() -> ClientRequestHooksWorld {
ClientRequestHooksWorld::default()
}
impl ClientRequestHooksWorld {
pub async fn start_echo_server(&mut self) -> TestResult {
let listener = TcpListener::bind("127.0.0.1:0").await?;
let addr = listener.local_addr()?;
let handle = tokio::spawn(async move {
let Ok((stream, _)) = listener.accept().await else {
return;
};
let mut framed = Framed::new(stream, LengthDelimitedCodec::new());
while let Some(Ok(bytes)) = framed.next().await {
let Some(response_bytes) = process_frame(ServerMode::Echo, &bytes) else {
break;
};
if framed.send(Bytes::from(response_bytes)).await.is_err() {
break;
}
}
});
self.addr = Some(addr);
self.server = Some(handle);
Ok(())
}
pub async fn start_capturing_server(&mut self) -> TestResult {
let listener = TcpListener::bind("127.0.0.1:0").await?;
let addr = listener.local_addr()?;
let handle = tokio::spawn(async move {
let Ok((stream, _)) = listener.accept().await else {
return Vec::new();
};
let mut framed = Framed::new(stream, LengthDelimitedCodec::new());
let mut captured = Vec::new();
while let Some(Ok(bytes)) = framed.next().await {
captured.push(bytes.to_vec());
if framed.send(bytes.freeze()).await.is_err() {
break;
}
}
captured
});
self.addr = Some(addr);
self.capturing_server = Some(handle);
Ok(())
}
async fn connect_with_hooks<F>(&mut self, configure: F) -> TestResult
where
F: FnOnce(
wireframe::client::WireframeClientBuilder,
) -> wireframe::client::WireframeClientBuilder,
{
let addr = self.addr.ok_or("server address missing")?;
let client = configure(WireframeClient::builder()).connect(addr).await?;
self.client = Some(client);
Ok(())
}
pub async fn connect_with_before_send_counter(&mut self) -> TestResult {
let count = Arc::clone(&self.before_send_count);
self.connect_with_hooks(|b| {
b.before_send(move |_bytes: &mut Vec<u8>| {
count.fetch_add(1, Ordering::SeqCst);
})
})
.await
}
pub async fn connect_with_after_receive_counter(&mut self) -> TestResult {
let count = Arc::clone(&self.after_receive_count);
self.connect_with_hooks(|b| {
b.after_receive(move |_bytes: &mut bytes::BytesMut| {
count.fetch_add(1, Ordering::SeqCst);
})
})
.await
}
pub async fn connect_with_marker_hooks(&mut self) -> TestResult {
let log_a = Arc::clone(&self.marker_log);
let log_b = Arc::clone(&self.marker_log);
self.connect_with_hooks(|b| {
b.before_send(move |_bytes: &mut Vec<u8>| {
if let Ok(mut guard) = log_a.lock() {
guard.push(b'A');
}
})
.before_send(move |_bytes: &mut Vec<u8>| {
if let Ok(mut guard) = log_b.lock() {
guard.push(b'B');
}
})
})
.await
}
pub async fn connect_with_both_counters(&mut self) -> TestResult {
let sc = Arc::clone(&self.before_send_count);
let rc = Arc::clone(&self.after_receive_count);
self.connect_with_hooks(|b| {
b.before_send(move |_bytes: &mut Vec<u8>| {
sc.fetch_add(1, Ordering::SeqCst);
})
.after_receive(move |_bytes: &mut bytes::BytesMut| {
rc.fetch_add(1, Ordering::SeqCst);
})
})
.await
}
pub async fn connect_with_before_send_marker(&mut self) -> TestResult {
self.connect_with_hooks(|b| {
b.before_send(move |bytes: &mut Vec<u8>| {
bytes.push(MARKER_BYTE);
})
})
.await
}
pub async fn connect_with_after_receive_replacement(&mut self) -> TestResult {
let replacement = Envelope::new(42, Some(1), vec![99, 98, 97]);
let replacement_bytes = bincode::encode_to_vec(&replacement, bincode::config::standard())?;
self.connect_with_hooks(move |b| {
b.after_receive(move |bytes: &mut bytes::BytesMut| {
bytes.clear();
bytes.extend_from_slice(&replacement_bytes);
})
})
.await
}
pub async fn send_envelope(&mut self) -> TestResult {
let client = self.client.as_mut().ok_or("client missing")?;
let envelope = Envelope::new(1, None, vec![1, 2, 3]);
client.send_envelope(envelope).await?;
Ok(())
}
pub async fn send_and_receive_envelope(&mut self) -> TestResult {
let client = self.client.as_mut().ok_or("client missing")?;
let envelope = Envelope::new(1, None, vec![1, 2, 3]);
client.send_envelope(envelope).await?;
let response: Envelope = client.receive_envelope().await?;
self.received_payload = Some(response.payload_bytes().to_vec());
Ok(())
}
pub async fn perform_correlated_call(&mut self) -> TestResult {
let client = self.client.as_mut().ok_or("client missing")?;
let request = Envelope::new(1, None, vec![10, 20]);
let _response: Envelope = client.call_correlated(request).await?;
Ok(())
}
pub async fn collect_captured_frames(
&mut self,
) -> Result<Vec<Vec<u8>>, Box<dyn std::error::Error + Send + Sync>> {
self.client.take();
let handle = self
.capturing_server
.take()
.ok_or("capturing server missing")?;
let frames = handle.await?;
Ok(frames)
}
#[must_use]
pub fn before_send_count(&self) -> usize { self.before_send_count.load(Ordering::SeqCst) }
#[must_use]
pub fn after_receive_count(&self) -> usize { self.after_receive_count.load(Ordering::SeqCst) }
pub fn marker_log(&self) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
let guard = self.marker_log.lock().map_err(|e| e.to_string())?;
Ok(guard.clone())
}
#[must_use]
pub fn received_payload(&self) -> Option<&[u8]> { self.received_payload.as_deref() }
#[must_use]
pub fn marker_byte() -> u8 { MARKER_BYTE }
}