use std::{
net::SocketAddr,
sync::{
Arc,
atomic::{AtomicBool, Ordering},
},
};
use futures::StreamExt;
use rstest::fixture;
use tokio::{io::AsyncWriteExt, net::TcpListener, task::JoinHandle};
use tokio_util::codec::{Framed, LengthDelimitedCodec};
use super::streaming_infra::TestServer;
use crate::client::WireframeClient;
pub(super) const DEFAULT_MAX_FRAME: usize = 1024;
#[rustfmt::skip]
#[fixture]
pub(super) fn protocol_header() -> Vec<u8> {
vec![0xCA, 0xFE, 0xBA, 0xBE]
}
pub(super) struct ReceivingServer {
pub addr: SocketAddr,
handle: Option<JoinHandle<Vec<Vec<u8>>>>,
}
impl ReceivingServer {
pub(super) async fn collect_frames(
&mut self,
) -> Result<Vec<Vec<u8>>, Box<dyn std::error::Error + Send + Sync>> {
let handle = self
.handle
.take()
.ok_or("server handle already collected")?;
handle
.await
.map_err(|e| format!("server task panicked: {e}").into())
}
}
impl Drop for ReceivingServer {
fn drop(&mut self) {
if let Some(handle) = self.handle.take() {
handle.abort();
}
}
}
pub(super) async fn spawn_receiving_server()
-> Result<ReceivingServer, Box<dyn std::error::Error + Send + Sync>> {
let listener = TcpListener::bind("127.0.0.1:0").await?;
let addr = listener.local_addr()?;
let handle = tokio::spawn(async move {
let mut collected = Vec::new();
let Ok((tcp, _)) = listener.accept().await else {
return collected;
};
let mut transport = Framed::new(tcp, LengthDelimitedCodec::new());
while let Some(Ok(bytes)) = transport.next().await {
collected.push(bytes.to_vec());
}
collected
});
Ok(ReceivingServer {
addr,
handle: Some(handle),
})
}
pub(super) async fn spawn_dropping_server()
-> Result<(TestServer, Arc<tokio::sync::Notify>), Box<dyn std::error::Error + Send + Sync>> {
let listener = TcpListener::bind("127.0.0.1:0").await?;
let addr = listener.local_addr()?;
let shutdown_done = Arc::new(tokio::sync::Notify::new());
let notify = shutdown_done.clone();
let handle = tokio::spawn(async move {
let Ok((mut tcp, _)) = listener.accept().await else {
return;
};
let _ = tcp.shutdown().await;
notify.notify_one();
});
Ok((TestServer::from_handle(addr, handle), shutdown_done))
}
pub(super) async fn create_send_client(
addr: SocketAddr,
) -> Result<
WireframeClient<
crate::serializer::BincodeSerializer,
crate::rewind_stream::RewindStream<tokio::net::TcpStream>,
>,
Box<dyn std::error::Error + Send + Sync>,
> {
create_send_client_with_max_frame(addr, DEFAULT_MAX_FRAME).await
}
pub(super) async fn create_send_client_with_max_frame(
addr: SocketAddr,
max_frame_length: usize,
) -> Result<
WireframeClient<
crate::serializer::BincodeSerializer,
crate::rewind_stream::RewindStream<tokio::net::TcpStream>,
>,
Box<dyn std::error::Error + Send + Sync>,
> {
Ok(WireframeClient::builder()
.max_frame_length(max_frame_length)
.connect(addr)
.await?)
}
pub(super) async fn create_send_client_with_error_hook(
addr: SocketAddr,
) -> Result<
(
WireframeClient<
crate::serializer::BincodeSerializer,
crate::rewind_stream::RewindStream<tokio::net::TcpStream>,
>,
Arc<AtomicBool>,
),
Box<dyn std::error::Error + Send + Sync>,
> {
let hook_invoked = Arc::new(AtomicBool::new(false));
let flag = hook_invoked.clone();
let client = WireframeClient::builder()
.on_error(move |_err| {
let flag = flag.clone();
async move {
flag.store(true, Ordering::SeqCst);
}
})
.connect(addr)
.await?;
Ok((client, hook_invoked))
}
#[expect(
clippy::integer_division_remainder_used,
reason = "modulo generates a deterministic test byte pattern"
)]
#[expect(
clippy::cast_possible_truncation,
reason = "value is modulo 256, guaranteed to fit in u8"
)]
pub(super) fn test_body(n: usize) -> Vec<u8> { (0..n).map(|i| (i % 256) as u8).collect() }
type ByteResult = Result<bytes::Bytes, std::io::Error>;
type BlockingStream = tokio_stream::wrappers::ReceiverStream<ByteResult>;
pub(super) type BlockingReader = (
tokio_util::io::StreamReader<BlockingStream, bytes::Bytes>,
tokio::sync::mpsc::Sender<ByteResult>,
);
pub(super) fn blocking_reader() -> BlockingReader {
let (tx, rx) = tokio::sync::mpsc::channel::<ByteResult>(1);
let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
let reader = tokio_util::io::StreamReader::new(stream);
(reader, tx)
}