use super::{choke::*, TestResult, NUM_CLIENTS, NUM_CONCURRENT_CLIENTS};
use color_eyre::eyre::{bail, Context};
use std::{
borrow::Borrow,
io,
sync::mpsc::{channel, Sender},
thread,
};
pub fn drive_pair<T, Ld, Fl>(
leader: Ld,
leader_name: &str,
follower: Fl,
follower_name: &str,
) -> TestResult
where
T: Send,
Ld: FnOnce(Sender<T>) -> TestResult + Send,
Fl: FnOnce(T) -> TestResult,
{
thread::scope(|scope| {
let (sender, receiver) = channel();
let ltname = leader_name.to_lowercase();
let leading_thread = thread::Builder::new()
.name(ltname)
.spawn_scoped(scope, move || leader(sender))
.with_context(|| format!("{leader_name} thread launch failed"))?;
if let Ok(msg) = receiver.recv() {
let rslt = follower(msg);
exclude_deadconn(rslt)
.with_context(|| format!("{follower_name} exited early with error"))?;
}
let Ok(rslt) = leading_thread.join() else {
bail!("{leader_name} panicked");
};
exclude_deadconn(rslt).with_context(|| format!("{leader_name} exited early with error"))
})
}
#[rustfmt::skip] fn exclude_deadconn(r: TestResult) -> TestResult {
use io::ErrorKind::*;
let Err(e) = r else {
return r;
};
let Some(ioe) = e.root_cause().downcast_ref::<io::Error>() else {
return Err(e);
};
match ioe.kind() {
ConnectionRefused
| ConnectionReset
| ConnectionAborted
| NotConnected
| BrokenPipe
| WriteZero
| UnexpectedEof => Ok(()),
_ => Err(e),
}
}
pub fn drive_server_and_multiple_clients<T, B, Srv, Clt>(server: Srv, client: Clt) -> TestResult
where
T: Send + Borrow<B>,
B: Send + Sync + ?Sized,
Srv: FnOnce(Sender<T>, u32) -> TestResult + Send,
Clt: Fn(&B) -> TestResult + Send + Sync,
{
let choke = Choke::new(NUM_CONCURRENT_CLIENTS);
let client_wrapper = |msg: T| {
thread::scope(|scope| {
let mut client_threads = Vec::with_capacity(usize::try_from(NUM_CLIENTS).unwrap());
for n in 1..=NUM_CLIENTS {
let tname = format!("client {n}");
let choke_guard = choke.take();
let (bclient, bmsg) = (&client, msg.borrow());
let jhndl = thread::Builder::new()
.name(tname.clone())
.spawn_scoped(scope, move || {
let _cg = choke_guard;
bclient(bmsg)
})
.with_context(|| format!("{tname} thread launch failed"))?;
client_threads.push(jhndl);
}
for client in client_threads {
let Ok(rslt) = client.join() else {
bail!("client thread panicked");
};
rslt?; }
Ok(())
})
};
let server_wrapper = move |sender: Sender<T>| server(sender, NUM_CLIENTS);
drive_pair(server_wrapper, "server", client_wrapper, "client")
}