use {
super::{choke::*, num_clients, num_concurrent_clients, TestResult},
color_eyre::eyre::{bail, Context},
std::{
borrow::Borrow,
io,
sync::mpsc::{channel, Sender},
thread,
},
};
pub fn drive_pair<T, Ld, Fl>(
leader: Ld,
leader_name: &str,
follower: Fl,
follower_name: &str,
) -> TestResult
where
T: Send,
Ld: FnOnce(Sender<T>) -> TestResult + Send,
Fl: FnOnce(T) -> TestResult,
{
thread::scope(|scope| {
let (sender, receiver) = channel();
let ltname = leader_name.to_lowercase();
let leading_thread = thread::Builder::new()
.name(ltname)
.spawn_scoped(scope, move || leader(sender))
.with_context(|| format!("{leader_name} thread launch failed"))?;
if let Ok(msg) = receiver.recv() {
let rslt = follower(msg);
exclude_deadconn(rslt)
.with_context(|| format!("{follower_name} exited early with error"))?;
}
let Ok(rslt) = leading_thread.join() else {
bail!("{leader_name} panicked");
};
exclude_deadconn(rslt).with_context(|| format!("{leader_name} exited early with error"))
})
}
#[rustfmt::skip] fn exclude_deadconn(r: TestResult) -> TestResult {
use io::ErrorKind::*;
let Err(e) = r else {
return r;
};
let Some(ioe) = e.root_cause().downcast_ref::<io::Error>() else {
return Err(e);
};
match ioe.kind() {
ConnectionRefused
| ConnectionReset
| ConnectionAborted
| NotConnected
| BrokenPipe
| WriteZero
| UnexpectedEof => Ok(()),
_ => Err(e),
}
}
pub fn drive_server_and_multiple_clients<T, B, Srv, Clt, Fclt>(
server: Srv,
first_client: Fclt,
client: Clt,
) -> TestResult<T>
where
T: Send + Borrow<B>,
B: Send + Sync + ?Sized,
Srv: FnOnce(Sender<T>, u32) -> TestResult + Send,
Fclt: FnOnce(&B) -> TestResult + Send,
Clt: Fn(&B) -> TestResult + Send + Sync,
{
let choke = Choke::new(num_concurrent_clients());
let num_clients = num_clients();
let mut t_return = None::<T>;
let client_wrapper = |msg: T| {
first_client(msg.borrow())?;
let ret = thread::scope(|scope| {
let mut client_threads = Vec::with_capacity(usize::try_from(num_clients).unwrap());
for n in 1..=num_clients {
let tname = format!("client {n}");
let choke_guard = choke.take();
let (bclient, bmsg) = (&client, msg.borrow());
let jhndl = thread::Builder::new()
.name(tname.clone())
.spawn_scoped(scope, move || {
let _cg = choke_guard;
bclient(bmsg)
})
.with_context(|| format!("{tname} thread launch failed"))?;
client_threads.push(jhndl);
}
for client in client_threads {
let Ok(rslt) = client.join() else {
bail!("client thread panicked");
};
rslt?; }
Ok(())
});
t_return = Some(msg);
ret
};
let server_wrapper = move |sender: Sender<T>| server(sender, num_clients);
drive_pair(server_wrapper, "server", client_wrapper, "client")?;
Ok(t_return.unwrap())
}