use std::{
future::Future,
net::SocketAddr,
pin::Pin,
sync::{
Arc,
atomic::{AtomicUsize, Ordering},
},
};
use tokio::net::{TcpListener, TcpStream};
use crate::{
client::{ClientError, WireframeClient, WireframeClientBuilder},
serializer::{BincodeSerializer, MessageCompatibilitySerializer, Serializer},
};
pub type CountingHookClosure<T> =
Arc<dyn Fn(T) -> Pin<Box<dyn Future<Output = T> + Send>> + Send + Sync>;
pub async fn spawn_listener() -> (SocketAddr, tokio::task::JoinHandle<TcpStream>) {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("bind listener");
let addr = listener.local_addr().expect("listener addr");
let accept = tokio::spawn(async move {
let (stream, _) = listener.accept().await.expect("accept client");
stream
});
(addr, accept)
}
pub async fn assert_builder_option<F, A>(configure_builder: F, assert_option: A)
where
F: FnOnce(WireframeClientBuilder) -> WireframeClientBuilder,
A: FnOnce(&WireframeClient<BincodeSerializer, crate::rewind_stream::RewindStream<TcpStream>>),
{
let (addr, accept) = spawn_listener().await;
let client = configure_builder(WireframeClient::builder())
.connect(addr)
.await
.expect("connect client");
assert_option(&client);
let _server_stream = accept.await.expect("join accept task");
}
pub async fn test_with_client<F, C>(
configure_builder: F,
) -> WireframeClient<BincodeSerializer, crate::rewind_stream::RewindStream<TcpStream>, C>
where
F: FnOnce(WireframeClientBuilder) -> WireframeClientBuilder<BincodeSerializer, (), C>,
C: Send + 'static,
{
let (addr, accept) = spawn_listener().await;
let client = configure_builder(WireframeClient::builder())
.connect(addr)
.await
.expect("connect client");
let _server = accept.await.expect("join accept task");
client
}
pub fn counting_hook<T>() -> (Arc<AtomicUsize>, CountingHookClosure<T>)
where
T: Send + 'static,
{
let counter = Arc::new(AtomicUsize::new(0));
let count = counter.clone();
let increment = move |value: T| {
let count = count.clone();
Box::pin(async move {
count.fetch_add(1, Ordering::SeqCst);
value
}) as Pin<Box<dyn Future<Output = T> + Send>>
};
(counter, Arc::new(increment))
}
pub async fn test_error_hook_on_disconnect<F, C>(configure_builder: F) -> Arc<AtomicUsize>
where
F: FnOnce(
WireframeClientBuilder,
Arc<AtomicUsize>,
) -> WireframeClientBuilder<BincodeSerializer, (), C>,
C: Send + 'static,
{
let error_count = Arc::new(AtomicUsize::new(0));
let (addr, accept) = spawn_listener().await;
let mut client = configure_builder(WireframeClient::builder(), error_count.clone())
.connect(addr)
.await
.expect("connect client");
let server = accept.await.expect("join accept task");
drop(server);
let result: Result<Vec<u8>, ClientError> = client.receive().await;
assert!(result.is_err(), "receive should fail after disconnect");
error_count
}
pub struct FailingSerializer;
impl MessageCompatibilitySerializer for FailingSerializer {}
impl Serializer for FailingSerializer {
fn serialize<M>(&self, _value: &M) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>>
where
M: crate::message::EncodeWith<Self>,
{
Err(Box::new(std::io::Error::other(
"forced serialization failure",
)))
}
fn deserialize<M>(
&self,
_bytes: &[u8],
) -> Result<(M, usize), Box<dyn std::error::Error + Send + Sync>>
where
M: crate::message::DecodeWith<Self>,
{
Err(Box::new(std::io::Error::other(
"forced deserialization failure",
)))
}
}
macro_rules! socket_option_test {
($name:ident, $configure:expr, $assert:expr $(,)?) => {
#[tokio::test]
async fn $name() {
$crate::client::tests::helpers::assert_builder_option($configure, $assert).await;
}
};
}
pub(crate) use socket_option_test;