pub mod builder;
mod context;
#[cfg(feature = "affinity")]
pub mod cores;
#[allow(clippy::module_inception)]
mod hive;
mod husk;
mod inner;
pub mod mock;
mod outcome;
mod sentinel;
mod util;
#[cfg(feature = "local-batch")]
mod weighted;
pub use self::builder::{BeeBuilder, ChannelBuilder, FullBuilder, OpenBuilder, TaskQueuesBuilder};
pub use self::builder::{
channel as channel_builder, open as open_builder, workstealing as workstealing_builder,
};
#[cfg(feature = "affinity")]
pub use self::cores::{Core, Cores};
pub use self::hive::{DefaultHive, Hive, Poisoned};
pub use self::husk::Husk;
pub use self::inner::{
Builder, ChannelTaskQueues, TaskInput, TaskQueues, WorkstealingTaskQueues, set_config::*,
};
pub use self::outcome::{Outcome, OutcomeBatch, OutcomeIteratorExt, OutcomeStore};
#[cfg(feature = "local-batch")]
pub use self::weighted::{Weighted, WeightedIteratorExt};
use self::context::HiveLocalContext;
use self::inner::{Config, Shared, Task, WorkerQueues};
use self::outcome::{DerefOutcomes, OutcomeQueue, OwnedOutcomes};
use self::sentinel::Sentinel;
use crate::bee::Worker;
use crate::channel::{Receiver, Sender, channel};
use std::io::Error as SpawnError;
pub type OutcomeSender<W> = Sender<Outcome<W>>;
pub type OutcomeReceiver<W> = Receiver<Outcome<W>>;
#[inline]
pub fn outcome_channel<W: Worker>() -> (OutcomeSender<W>, OutcomeReceiver<W>) {
channel()
}
pub mod prelude {
pub use super::{
Builder, Hive, Husk, Outcome, OutcomeBatch, OutcomeIteratorExt, OutcomeStore, Poisoned,
TaskQueuesBuilder, channel_builder, open_builder, outcome_channel, workstealing_builder,
};
#[cfg(feature = "local-batch")]
pub use super::{Weighted, WeightedIteratorExt};
}
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod tests {
use super::inner::TaskQueues;
use super::{
Builder, ChannelTaskQueues, Hive, Outcome, OutcomeIteratorExt, OutcomeStore,
TaskQueuesBuilder, WorkstealingTaskQueues, channel_builder, workstealing_builder,
};
use crate::barrier::IndexedBarrier;
use crate::bee::stock::{Caller, OnceCaller, RefCaller, Thunk, ThunkWorker};
use crate::bee::{
ApplyError, ApplyRefError, Context, DefaultQueen, QueenMut, RefWorker, RefWorkerResult,
TaskId, Worker, WorkerResult,
};
use crate::channel::{Message, ReceiverExt};
use crate::hive::outcome::DerefOutcomes;
use rstest::*;
use std::fmt::Debug;
use std::io::{self, BufRead, BufReader, Write};
use std::process::{Child, ChildStdin, ChildStdout, Command, ExitStatus, Stdio};
use std::sync::{
Arc, Barrier,
atomic::{AtomicUsize, Ordering},
mpsc,
};
use std::thread;
use std::time::Duration;
const TEST_TASKS: usize = 4;
const ONE_SEC: Duration = Duration::from_secs(1);
const SHORT_TASK: Duration = Duration::from_secs(2);
const LONG_TASK: Duration = Duration::from_secs(5);
type TWrk<I> = ThunkWorker<I>;
pub fn thunk_hive<I, T, B>(num_threads: usize, builder: B) -> Hive<DefaultQueen<TWrk<I>>, T>
where
I: Send + Sync + Debug + 'static,
T: TaskQueues<TWrk<I>>,
B: TaskQueuesBuilder<TaskQueues<TWrk<I>> = T>,
{
builder
.num_threads(num_threads)
.with_queen_default()
.build()
}
pub fn void_thunk_hive<T, B>(num_threads: usize, builder: B) -> Hive<DefaultQueen<TWrk<()>>, T>
where
T: TaskQueues<TWrk<()>>,
B: TaskQueuesBuilder<TaskQueues<TWrk<()>> = T>,
{
thunk_hive(num_threads, builder)
}
#[rstest]
fn test_works<B, F>(#[values(channel_builder, workstealing_builder)] builder_factory: F)
where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let hive = thunk_hive(TEST_TASKS, builder_factory(true));
let (tx, rx) = mpsc::channel();
assert_eq!(hive.max_workers(), TEST_TASKS);
assert_eq!(hive.alive_workers(), TEST_TASKS);
assert!(!hive.has_dead_workers());
for _ in 0..TEST_TASKS {
let tx = tx.clone();
hive.apply_store(Thunk::from(move || {
tx.send(1).unwrap();
}));
}
assert_eq!(rx.iter().take(TEST_TASKS).sum::<usize>(), TEST_TASKS);
}
#[rstest]
fn test_grow_from_zero<B, F>(
#[values(channel_builder, workstealing_builder)] builder_factory: F,
) where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let hive = thunk_hive::<u8, _, _>(0, builder_factory(true));
let (tx, rx) = super::outcome_channel();
let _ = hive.apply_send(Thunk::from(|| 0), &tx);
thread::sleep(ONE_SEC);
assert_eq!(hive.num_tasks().0, 1);
assert!(matches!(rx.try_recv_msg(), Message::ChannelEmpty));
assert!(matches!(hive.grow(0), Ok(0)));
thread::sleep(ONE_SEC);
assert_eq!(hive.num_tasks().0, 1);
assert!(matches!(hive.grow(1), Ok(1)));
thread::sleep(ONE_SEC);
assert_eq!(hive.num_tasks().0, 0);
assert!(matches!(
rx.try_recv_msg(),
Message::Received(Outcome::Success { value: 0, .. })
));
}
#[rstest]
fn test_grow_from_nonzero<B, F>(
#[values(channel_builder, workstealing_builder)] builder_factory: F,
) where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let hive = void_thunk_hive(TEST_TASKS, builder_factory(false));
for _ in 0..TEST_TASKS {
hive.apply_store(Thunk::from(|| thread::sleep(LONG_TASK)));
}
thread::sleep(ONE_SEC);
assert_eq!(hive.num_tasks().1, TEST_TASKS as u64);
let new_threads = 4;
let total_threads = new_threads + TEST_TASKS;
hive.grow(new_threads).expect("error spawning threads");
for _ in 0..new_threads {
hive.apply_store(Thunk::from(|| thread::sleep(LONG_TASK)));
}
thread::sleep(ONE_SEC);
assert_eq!(hive.num_tasks().1, total_threads as u64);
let husk = hive.try_into_husk(false).unwrap();
assert_eq!(husk.iter_successes().count(), total_threads);
}
#[rstest]
fn test_use_all_cores<B, F>(#[values(channel_builder, workstealing_builder)] builder_factory: F)
where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let hive = void_thunk_hive(0, builder_factory(false));
let num_cores = num_cpus::get();
for _ in 0..num_cores {
hive.apply_store(Thunk::from(|| thread::sleep(LONG_TASK)));
}
thread::sleep(ONE_SEC);
assert_eq!(hive.num_tasks().0, num_cores as u64);
assert_eq!(hive.use_all_cores().unwrap(), num_cores);
assert_eq!(hive.max_workers(), num_cores);
thread::sleep(ONE_SEC);
let husk = hive.try_into_husk(false).unwrap();
assert_eq!(husk.iter_successes().count(), num_cores);
}
#[rstest]
fn test_suspend<B, F>(#[values(channel_builder, workstealing_builder)] builder_factory: F)
where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let hive = void_thunk_hive(TEST_TASKS, builder_factory(false));
let total_tasks = 2 * TEST_TASKS;
for _ in 0..total_tasks {
hive.apply_store(Thunk::from(|| thread::sleep(SHORT_TASK)));
}
thread::sleep(ONE_SEC);
assert_eq!(hive.num_tasks(), (TEST_TASKS as u64, TEST_TASKS as u64));
hive.suspend();
thread::sleep(SHORT_TASK);
assert_eq!(hive.num_tasks(), (TEST_TASKS as u64, 0));
assert_eq!(hive.num_successes(), TEST_TASKS);
hive.resume();
thread::sleep(ONE_SEC);
assert_eq!(hive.num_tasks(), (0, TEST_TASKS as u64));
thread::sleep(SHORT_TASK);
assert_eq!(hive.num_tasks(), (0, 0));
assert_eq!(hive.num_successes(), total_tasks);
}
#[derive(Debug, Default)]
struct MyRefWorker;
impl RefWorker for MyRefWorker {
type Input = u8;
type Output = u8;
type Error = ();
fn apply_ref(
&mut self,
input: &Self::Input,
ctx: &Context<Self::Input>,
) -> RefWorkerResult<Self> {
for _ in 0..3 {
thread::sleep(Duration::from_secs(1));
if ctx.is_cancelled() {
return Err(ApplyRefError::Cancelled);
}
}
Ok(*input)
}
}
#[rstest]
fn test_suspend_resume_send_with_cancelled_tasks<B, F>(
#[values(channel_builder, workstealing_builder)] builder_factory: F,
) where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let hive: Hive<_, _> = builder_factory(false)
.num_threads(TEST_TASKS)
.with_worker_default::<MyRefWorker>()
.build();
let _ = hive.swarm_store(0..TEST_TASKS as u8);
thread::sleep(Duration::from_millis(500));
assert_eq!(hive.num_tasks(), (0, TEST_TASKS as u64));
hive.suspend();
thread::sleep(Duration::from_secs(2));
assert_eq!(hive.num_tasks(), (0, 0));
assert_eq!(hive.num_unprocessed(), TEST_TASKS);
hive.resume();
let (tx, rx) = super::outcome_channel();
let new_task_ids = hive.swarm_unprocessed_send(tx);
assert_eq!(new_task_ids.len(), TEST_TASKS);
thread::sleep(Duration::from_millis(500));
assert_eq!(hive.num_tasks(), (0, TEST_TASKS as u64));
hive.join();
let mut outputs = rx
.into_iter()
.select_ordered_outputs(new_task_ids)
.collect::<Vec<_>>();
outputs.sort();
assert_eq!(outputs, (0..TEST_TASKS as u8).collect::<Vec<_>>());
}
#[rstest]
fn test_suspend_resume_store_with_cancelled_tasks<B, F>(
#[values(channel_builder, workstealing_builder)] builder_factory: F,
) where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let hive: Hive<_, _> = builder_factory(false)
.num_threads(TEST_TASKS)
.with_worker_default::<MyRefWorker>()
.build();
hive.swarm_store(0..TEST_TASKS as u8);
hive.suspend();
thread::sleep(Duration::from_secs(2));
hive.resume();
hive.swarm_unprocessed_store();
thread::sleep(Duration::from_secs(1));
assert_eq!(hive.num_tasks().1, TEST_TASKS as u64);
thread::sleep(Duration::from_secs(3));
assert_eq!(hive.num_successes(), TEST_TASKS);
}
#[rstest]
fn test_num_tasks_active<B, F>(
#[values(channel_builder, workstealing_builder)] builder_factory: F,
) where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let hive = void_thunk_hive(TEST_TASKS, builder_factory(false));
for _ in 0..2 * TEST_TASKS {
hive.apply_store(Thunk::from(|| {
loop {
thread::sleep(LONG_TASK)
}
}));
}
thread::sleep(ONE_SEC);
assert_eq!(hive.num_tasks().1, TEST_TASKS as u64);
let num_threads = hive.max_workers();
assert_eq!(num_threads, TEST_TASKS);
}
#[rstest]
fn test_all_threads<B, F>(#[values(channel_builder, workstealing_builder)] builder_factory: F)
where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let hive: Hive<DefaultQueen<TWrk<()>>, _> = builder_factory(false)
.with_queen_default()
.with_thread_per_core()
.build();
let num_threads = num_cpus::get();
for _ in 0..num_threads {
hive.apply_store(Thunk::from(|| {
loop {
thread::sleep(LONG_TASK)
}
}));
}
thread::sleep(ONE_SEC);
assert_eq!(hive.num_tasks().1, num_threads as u64);
let max_workers = hive.max_workers();
assert_eq!(num_threads, max_workers);
}
#[rstest]
fn test_panic<B, F>(#[values(channel_builder, workstealing_builder)] builder_factory: F)
where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let hive = thunk_hive(TEST_TASKS, builder_factory(true));
let (tx, _) = super::outcome_channel();
for _ in 0..TEST_TASKS {
hive.apply_send(Thunk::from(|| panic!("intentional panic")), &tx);
}
hive.join();
assert_eq!(hive.num_panics(), TEST_TASKS);
let husk = hive.try_into_husk(false).unwrap();
assert_eq!(husk.num_panics(), TEST_TASKS);
}
#[rstest]
fn test_catch_panic<B, F>(#[values(channel_builder, workstealing_builder)] builder_factory: F)
where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let hive: Hive<_, _> = builder_factory(false)
.with_worker(RefCaller::from(|_: &u8| -> Result<u8, String> {
panic!("intentional panic")
}))
.num_threads(TEST_TASKS)
.build();
let (tx, rx) = super::outcome_channel();
for i in 0..TEST_TASKS {
hive.apply_send(i as u8, &tx);
}
hive.join();
assert_eq!(hive.num_panics(), 0);
for outcome in rx.into_iter().take(TEST_TASKS) {
assert!(matches!(outcome, Outcome::Panic { .. }));
}
}
#[rstest]
fn test_should_not_panic_on_drop_if_subtasks_panic_after_drop<B, F>(
#[values(channel_builder, workstealing_builder)] builder_factory: F,
) where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let hive = void_thunk_hive(TEST_TASKS, builder_factory(false));
let waiter = Arc::new(Barrier::new(TEST_TASKS + 1));
let waiter_count = Arc::new(AtomicUsize::new(0));
for _ in 0..TEST_TASKS {
let waiter = waiter.clone();
let waiter_count = waiter_count.clone();
hive.apply_store(Thunk::from(move || {
waiter_count.fetch_add(1, Ordering::SeqCst);
waiter.wait();
panic!("intentional panic");
}));
}
thread::sleep(Duration::from_secs(1));
assert_eq!(waiter_count.load(Ordering::SeqCst), TEST_TASKS);
drop(hive);
waiter.wait();
}
#[rstest]
fn test_massive_task_creation<B, F>(
#[values(channel_builder, workstealing_builder)] builder_factory: F,
) where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let test_tasks = 4_200_000;
let hive = thunk_hive(TEST_TASKS, builder_factory(true));
let b0 = IndexedBarrier::new(TEST_TASKS);
let b1 = IndexedBarrier::new(TEST_TASKS);
let (tx, rx) = mpsc::channel();
for _ in 0..test_tasks {
let tx = tx.clone();
let (b0, b1) = (b0.clone(), b1.clone());
hive.apply_store(Thunk::from(move || {
b0.wait();
b1.wait();
assert!(tx.send(1).is_ok());
}));
}
b0.wait();
assert_eq!(hive.num_tasks().1, TEST_TASKS as u64);
b1.wait();
assert_eq!(rx.iter().take(test_tasks).sum::<usize>(), test_tasks);
hive.join();
let atomic_num_tasks_active = hive.num_tasks().1;
assert!(
atomic_num_tasks_active == 0,
"atomic_num_tasks_active: {}",
atomic_num_tasks_active
);
}
#[rstest]
fn test_name<B, F>(#[values(channel_builder, workstealing_builder)] builder_factory: F)
where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let name = "test";
let hive: Hive<DefaultQueen<TWrk<()>>, B::TaskQueues<_>> = builder_factory(false)
.with_queen_default()
.thread_name(name.to_owned())
.num_threads(2)
.build();
let (tx, rx) = mpsc::channel();
for _ in 0..2 {
let tx = tx.clone();
hive.apply_store(Thunk::from(move || {
let name = thread::current().name().unwrap().to_owned();
tx.send(name).unwrap();
}));
}
hive.grow(3).expect("error spawning threads");
let tx_clone = tx.clone();
hive.apply_store(Thunk::from(move || {
let name = thread::current().name().unwrap().to_owned();
tx_clone.send(name).unwrap();
}));
for thread_name in rx.iter().take(3) {
assert_eq!(name, thread_name);
}
}
#[rstest]
fn test_stack_size<B, F>(#[values(channel_builder, workstealing_builder)] builder_factory: F)
where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let stack_size = 4_000_000;
let hive: Hive<DefaultQueen<TWrk<usize>>, B::TaskQueues<_>> = builder_factory(false)
.with_queen_default()
.num_threads(1)
.thread_stack_size(stack_size)
.build();
let actual_stack_size = hive
.apply(Thunk::from(|| {
stacker::remaining_stack().unwrap()
}))
.unwrap() as f64;
assert!(actual_stack_size > (stack_size as f64 * 0.99));
assert!(actual_stack_size < (stack_size as f64 * 1.01));
}
#[rstest]
fn test_debug<B, F>(#[values(channel_builder, workstealing_builder)] builder_factory: F)
where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let hive = void_thunk_hive(4, builder_factory(true));
let debug = format!("{:?}", hive);
assert_eq!(
debug,
"Hive(Some(Shared { name: None, num_threads: 4, num_tasks_queued: 0, num_tasks_active: 0 }))"
);
let hive: Hive<DefaultQueen<TWrk<usize>>, B::TaskQueues<_>> = builder_factory(false)
.with_queen_default()
.thread_name("hello")
.num_threads(4)
.build();
let debug = format!("{:?}", hive);
assert_eq!(
debug,
"Hive(Some(Shared { name: \"hello\", num_threads: 4, num_tasks_queued: 0, num_tasks_active: 0 }))"
);
let hive = thunk_hive(4, builder_factory(true));
hive.apply_store(Thunk::from(|| thread::sleep(LONG_TASK)));
thread::sleep(ONE_SEC);
let debug = format!("{:?}", hive);
assert_eq!(
debug,
"Hive(Some(Shared { name: None, num_threads: 4, num_tasks_queued: 0, num_tasks_active: 1 }))"
);
}
#[rstest]
fn test_repeated_join<B, F>(#[values(channel_builder, workstealing_builder)] builder_factory: F)
where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let hive: Hive<DefaultQueen<TWrk<()>>, B::TaskQueues<_>> = builder_factory(false)
.with_queen_default()
.thread_name("repeated join test")
.num_threads(8)
.build();
let test_count = Arc::new(AtomicUsize::new(0));
for _ in 0..42 {
let test_count = test_count.clone();
hive.apply_store(Thunk::from(move || {
thread::sleep(SHORT_TASK);
test_count.fetch_add(1, Ordering::Release);
}));
}
hive.join();
assert_eq!(42, test_count.load(Ordering::Acquire));
for _ in 0..42 {
let test_count = test_count.clone();
hive.apply_store(Thunk::from(move || {
thread::sleep(SHORT_TASK);
test_count.fetch_add(1, Ordering::Relaxed);
}));
}
hive.join();
assert_eq!(84, test_count.load(Ordering::Relaxed));
}
#[rstest]
fn test_multi_join<B, F>(#[values(channel_builder, workstealing_builder)] builder_factory: F)
where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let hive0: Hive<DefaultQueen<TWrk<()>>, B::TaskQueues<_>> = builder_factory(false)
.with_queen_default()
.thread_name("multi join pool0")
.num_threads(4)
.build();
let hive1: Hive<DefaultQueen<TWrk<()>>, B::TaskQueues<_>> = builder_factory(false)
.with_queen_default()
.thread_name("multi join pool1")
.num_threads(4)
.build();
let (tx, rx) = crate::channel::channel();
for i in 0..8 {
let hive1_clone = hive1.clone();
let hive0_clone = hive0.clone();
let tx = tx.clone();
hive0.apply_store(Thunk::from(move || {
hive1_clone.apply_store(Thunk::from(move || {
hive0_clone.join();
thread::sleep(Duration::from_millis(10));
tx.send(i).expect("send failed from hive1_clone to main");
}));
}));
}
drop(tx);
let before_any_send = rx.try_recv_msg();
assert!(matches!(before_any_send, Message::ChannelEmpty));
hive0.join();
hive1.join();
assert_eq!(rx.into_iter().sum::<u32>(), (0..8).sum());
}
#[rstest]
fn test_empty_hive<B, F>(#[values(channel_builder, workstealing_builder)] builder_factory: F)
where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let hive = void_thunk_hive(4, builder_factory(true));
hive.join();
}
#[rstest]
fn test_no_fun_or_joy<B, F>(#[values(channel_builder, workstealing_builder)] builder_factory: F)
where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
fn sleepy_function() {
thread::sleep(LONG_TASK);
}
let hive: Hive<DefaultQueen<TWrk<()>>, B::TaskQueues<_>> = builder_factory(false)
.with_queen_default()
.thread_name("no fun or joy")
.num_threads(8)
.build();
hive.apply_store(Thunk::from(sleepy_function));
let p_t = hive.clone();
thread::spawn(move || {
(0..23)
.inspect(|_| {
p_t.apply_store(Thunk::from(sleepy_function));
})
.count();
});
hive.join();
}
#[rstest]
fn test_map<B, F>(#[values(channel_builder, workstealing_builder)] builder_factory: F)
where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let hive = thunk_hive::<u8, _, _>(2, builder_factory(false));
let outputs: Vec<_> = hive
.map((0..10u8).map(|i| {
Thunk::from(move || {
thread::sleep(Duration::from_millis((10 - i as u64) * 100));
i
})
}))
.map(Outcome::unwrap)
.collect();
assert_eq!(outputs, (0..10).collect::<Vec<_>>())
}
#[rstest]
fn test_map_unordered<B, F>(#[values(channel_builder, workstealing_builder)] builder_factory: F)
where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let hive = thunk_hive::<u8, _, _>(8, builder_factory(false));
let mut outputs: Vec<_> = hive
.map_unordered((0..8u8).map(|i| {
Thunk::from(move || {
thread::sleep(Duration::from_millis((8 - i as u64) * 100));
i
})
}))
.map(Outcome::unwrap)
.collect();
outputs.sort();
assert_eq!(outputs, (0..8).collect::<Vec<_>>())
}
#[rstest]
fn test_map_send<B, F>(#[values(channel_builder, workstealing_builder)] builder_factory: F)
where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let hive = thunk_hive::<u8, _, _>(8, builder_factory(false));
let (tx, rx) = super::outcome_channel();
let mut task_ids = hive.map_send(
(0..8u8).map(|i| {
Thunk::from(move || {
thread::sleep(Duration::from_millis((8 - i as u64) * 100));
i
})
}),
tx,
);
let (mut outcome_task_ids, mut values): (Vec<TaskId>, Vec<u8>) = rx
.iter()
.map(|outcome| match outcome {
Outcome::Success { value, task_id } => (task_id, value),
_ => panic!("unexpected error"),
})
.unzip();
task_ids.sort();
outcome_task_ids.sort();
assert_eq!(task_ids, outcome_task_ids);
values.sort();
assert_eq!(values, (0..8).collect::<Vec<_>>());
}
#[rstest]
fn test_map_store<B, F>(#[values(channel_builder, workstealing_builder)] builder_factory: F)
where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let mut hive = thunk_hive::<u8, _, _>(8, builder_factory(false));
let mut task_ids = hive.map_store((0..8u8).map(|i| {
Thunk::from(move || {
thread::sleep(Duration::from_millis((8 - i as u64) * 100));
i
})
}));
hive.join();
for i in task_ids.iter() {
assert!(hive.outcomes_deref().get(i).unwrap().is_success());
}
let (mut outcome_task_ids, values): (Vec<TaskId>, Vec<u8>) = task_ids
.clone()
.into_iter()
.map(|i| (i, hive.remove_success(i).unwrap()))
.collect();
assert_eq!(values, (0..8).collect::<Vec<_>>());
task_ids.sort();
outcome_task_ids.sort();
assert_eq!(task_ids, outcome_task_ids);
}
#[rstest]
fn test_swarm<B, F>(#[values(channel_builder, workstealing_builder)] builder_factory: F)
where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let hive = thunk_hive::<u8, _, _>(2, builder_factory(false));
let outputs: Vec<_> = hive
.swarm((0..10u8).map(|i| {
Thunk::from(move || {
thread::sleep(Duration::from_millis((10 - i as u64) * 100));
i
})
}))
.map(Outcome::unwrap)
.collect();
assert_eq!(outputs, (0..10).collect::<Vec<_>>())
}
#[rstest]
fn test_swarm_unordered<B, F>(
#[values(channel_builder, workstealing_builder)] builder_factory: F,
) where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let hive = thunk_hive::<u8, _, _>(8, builder_factory(false));
let mut outputs: Vec<_> = hive
.swarm_unordered((0..8u8).map(|i| {
Thunk::from(move || {
thread::sleep(Duration::from_millis((8 - i as u64) * 100));
i
})
}))
.map(Outcome::unwrap)
.collect();
outputs.sort();
assert_eq!(outputs, (0..8).collect::<Vec<_>>())
}
#[rstest]
fn test_swarm_send<B, F>(#[values(channel_builder, workstealing_builder)] builder_factory: F)
where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let hive = thunk_hive::<u8, _, _>(8, builder_factory(false));
#[cfg(feature = "local-batch")]
assert_eq!(hive.worker_batch_limit(), 0);
let (tx, rx) = super::outcome_channel();
let mut task_ids = hive.swarm_send(
(0..8u8).map(|i| {
Thunk::from(move || {
thread::sleep(Duration::from_millis((8 - i as u64) * 200));
i
})
}),
tx,
);
let (mut outcome_task_ids, values): (Vec<TaskId>, Vec<u8>) = rx
.iter()
.map(|outcome| match outcome {
Outcome::Success { value, task_id } => (task_id, value),
_ => panic!("unexpected error"),
})
.unzip();
assert_eq!(values, (0..8).rev().collect::<Vec<_>>());
task_ids.sort();
outcome_task_ids.sort();
assert_eq!(task_ids, outcome_task_ids);
}
#[rstest]
fn test_swarm_store<B, F>(#[values(channel_builder, workstealing_builder)] builder_factory: F)
where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let mut hive = thunk_hive::<u8, _, _>(8, builder_factory(false));
let mut task_ids = hive.swarm_store((0..8u8).map(|i| {
Thunk::from(move || {
thread::sleep(Duration::from_millis((8 - i as u64) * 100));
i
})
}));
hive.join();
for i in task_ids.iter() {
assert!(hive.outcomes_deref().get(i).unwrap().is_success());
}
let (mut outcome_task_ids, values): (Vec<TaskId>, Vec<u8>) = task_ids
.clone()
.into_iter()
.map(|i| (i, hive.remove_success(i).unwrap()))
.collect();
assert_eq!(values, (0..8).collect::<Vec<_>>());
task_ids.sort();
outcome_task_ids.sort();
assert_eq!(task_ids, outcome_task_ids);
}
#[rstest]
fn test_scan<B, F>(#[values(channel_builder, workstealing_builder)] builder_factory: F)
where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let hive = builder_factory(false)
.with_worker(Caller::from(|i: usize| i * i))
.num_threads(4)
.build();
let (outputs, state) = hive.scan(0..10usize, 0, |acc, i| {
*acc += i;
*acc
});
let mut outputs = outputs.unwrap();
outputs.sort();
assert_eq!(outputs.len(), 10);
assert_eq!(state, 45);
assert_eq!(
outputs,
(0..10)
.scan(0, |acc, i| {
*acc += i;
Some(*acc)
})
.map(|i| i * i)
.collect::<Vec<_>>()
);
}
#[rstest]
fn test_scan_send<B, F>(#[values(channel_builder, workstealing_builder)] builder_factory: F)
where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let hive = builder_factory(false)
.with_worker(Caller::from(|i: i32| i * i))
.num_threads(4)
.build();
let (tx, rx) = super::outcome_channel();
let (mut task_ids, state) = hive.scan_send(0..10, tx, 0, |acc, i| {
*acc += i;
*acc
});
assert_eq!(task_ids.len(), 10);
assert_eq!(state, 45);
let (mut outcome_task_ids, mut values): (Vec<TaskId>, Vec<i32>) = rx
.iter()
.map(|outcome| match outcome {
Outcome::Success { value, task_id } => (task_id, value),
_ => panic!("unexpected error"),
})
.unzip();
values.sort();
assert_eq!(
values,
(0..10)
.scan(0, |acc, i| {
*acc += i;
Some(*acc)
})
.map(|i| i * i)
.collect::<Vec<_>>()
);
task_ids.sort();
outcome_task_ids.sort();
assert_eq!(task_ids, outcome_task_ids);
}
#[rstest]
fn test_try_scan<B, F>(#[values(channel_builder, workstealing_builder)] builder_factory: F)
where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let hive = builder_factory(false)
.with_worker(Caller::from(|i: i32| i * i))
.num_threads(4)
.build();
let (outcomes, error, state) = hive.try_scan(0..10, 0, |acc, i| {
*acc += i;
Ok::<_, String>(*acc)
});
let task_ids: Vec<_> = outcomes.success_task_ids();
assert_eq!(task_ids.len(), 10);
assert_eq!(error.len(), 0);
assert_eq!(state, 45);
let mut values: Vec<_> = outcomes
.into_iter()
.select_unordered(task_ids)
.into_outputs()
.collect();
values.sort();
assert_eq!(
values,
(0..10)
.scan(0, |acc, i| {
*acc += i;
Some(*acc)
})
.map(|i| i * i)
.collect::<Vec<_>>()
);
}
#[rstest]
#[should_panic]
fn test_try_scan_fail<B, F>(#[values(channel_builder, workstealing_builder)] builder_factory: F)
where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let hive = builder_factory(false)
.with_worker(Caller::from(|i: i32| i * i))
.num_threads(4)
.build();
let (outcomes, error, state) = hive.try_scan(0..10, 0, |_, _| Err::<i32, _>("fail"));
let task_ids: Vec<_> = outcomes.success_task_ids();
assert_eq!(task_ids.len(), 10);
assert_eq!(error.len(), 0);
assert_eq!(state, 45);
let _ = outcomes
.into_iter()
.select_unordered(task_ids)
.into_outputs()
.collect::<Vec<_>>();
}
#[rstest]
fn test_try_scan_send<B, F>(#[values(channel_builder, workstealing_builder)] builder_factory: F)
where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let hive = builder_factory(false)
.with_worker(Caller::from(|i: i32| i * i))
.num_threads(4)
.build();
let (tx, rx) = super::outcome_channel();
let (results, state) = hive.try_scan_send(0..10, tx, 0, |acc, i| {
*acc += i;
Ok::<_, String>(*acc)
});
let mut task_ids: Vec<_> = results.into_iter().map(Result::unwrap).collect();
assert_eq!(task_ids.len(), 10);
assert_eq!(state, 45);
let (mut outcome_task_ids, mut values): (Vec<TaskId>, Vec<i32>) = rx
.iter()
.map(|outcome| match outcome {
Outcome::Success { value, task_id } => (task_id, value),
_ => panic!("unexpected error"),
})
.unzip();
values.sort();
assert_eq!(
values,
(0..10)
.scan(0, |acc, i| {
*acc += i;
Some(*acc)
})
.map(|i| i * i)
.collect::<Vec<_>>()
);
task_ids.sort();
outcome_task_ids.sort();
assert_eq!(task_ids, outcome_task_ids);
}
#[rstest]
#[should_panic]
fn test_try_scan_send_fail<B, F>(
#[values(channel_builder, workstealing_builder)] builder_factory: F,
) where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let hive = builder_factory(false)
.with_worker(OnceCaller::from(|i: i32| Ok::<_, String>(i * i)))
.num_threads(4)
.build();
let (tx, _) = super::outcome_channel();
let _ = hive
.try_scan_send(0..10, &tx, 0, |_, _| Err::<i32, _>("fail"))
.0
.into_iter()
.map(Result::unwrap)
.collect::<Vec<_>>();
}
#[rstest]
fn test_scan_store<B, F>(#[values(channel_builder, workstealing_builder)] builder_factory: F)
where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let mut hive = builder_factory(false)
.with_worker(Caller::from(|i: i32| i * i))
.num_threads(4)
.build();
let (mut task_ids, state) = hive.scan_store(0..10, 0, |acc, i| {
*acc += i;
*acc
});
assert_eq!(task_ids.len(), 10);
assert_eq!(state, 45);
hive.join();
for i in task_ids.iter() {
assert!(hive.outcomes_deref().get(i).unwrap().is_success());
}
let (mut outcome_task_ids, values): (Vec<TaskId>, Vec<i32>) = task_ids
.clone()
.into_iter()
.map(|i| (i, hive.remove_success(i).unwrap()))
.unzip();
assert_eq!(
values,
(0..10)
.scan(0, |acc, i| {
*acc += i;
Some(*acc)
})
.map(|i| i * i)
.collect::<Vec<_>>()
);
task_ids.sort();
outcome_task_ids.sort();
assert_eq!(task_ids, outcome_task_ids);
}
#[rstest]
fn test_try_scan_store<B, F>(
#[values(channel_builder, workstealing_builder)] builder_factory: F,
) where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let mut hive = builder_factory(false)
.with_worker(Caller::from(|i: i32| i * i))
.num_threads(4)
.build();
let (results, state) = hive.try_scan_store(0..10, 0, |acc, i| {
*acc += i;
Ok::<i32, String>(*acc)
});
let mut task_ids: Vec<_> = results.into_iter().map(Result::unwrap).collect();
assert_eq!(task_ids.len(), 10);
assert_eq!(state, 45);
hive.join();
for i in task_ids.iter() {
assert!(hive.outcomes_deref().get(i).unwrap().is_success());
}
let (mut outcome_task_ids, values): (Vec<TaskId>, Vec<i32>) = task_ids
.clone()
.into_iter()
.map(|i| (i, hive.remove_success(i).unwrap()))
.unzip();
assert_eq!(
values,
(0..10)
.scan(0, |acc, i| {
*acc += i;
Some(*acc)
})
.map(|i| i * i)
.collect::<Vec<_>>()
);
task_ids.sort();
outcome_task_ids.sort();
assert_eq!(task_ids, outcome_task_ids);
}
#[rstest]
#[should_panic]
fn test_try_scan_store_fail<B, F>(
#[values(channel_builder, workstealing_builder)] builder_factory: F,
) where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let hive = builder_factory(false)
.with_worker(OnceCaller::from(|i: i32| Ok::<i32, String>(i * i)))
.num_threads(4)
.build();
let _ = hive
.try_scan_store(0..10, 0, |_, _| Err::<i32, _>("fail"))
.0
.into_iter()
.map(Result::unwrap)
.collect::<Vec<_>>();
}
const NUM_FIRST_TASKS: usize = 4;
#[derive(Debug, Default)]
struct SendWorker;
impl Worker for SendWorker {
type Input = usize;
type Output = usize;
type Error = ();
fn apply(&mut self, input: Self::Input, ctx: &Context<Self::Input>) -> WorkerResult<Self> {
if input < NUM_FIRST_TASKS {
ctx.submit(input + NUM_FIRST_TASKS)
.map_err(|input| ApplyError::Retryable { input, error: () })?;
}
Ok(input)
}
}
#[rstest]
fn test_send_from_task<B, F>(
#[values(channel_builder, workstealing_builder)] builder_factory: F,
) where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let hive = builder_factory(false)
.num_threads(2)
.with_worker_default::<SendWorker>()
.build();
let (tx, rx) = super::outcome_channel();
let task_ids = hive.map_send(0..NUM_FIRST_TASKS, tx);
hive.join();
assert_eq!(task_ids.len(), NUM_FIRST_TASKS);
let outputs: Vec<_> = rx.select_ordered_outputs(task_ids).collect();
assert_eq!(outputs.len(), NUM_FIRST_TASKS * 2);
assert_eq!(outputs, (0..NUM_FIRST_TASKS * 2).collect::<Vec<_>>());
}
#[rstest]
fn test_close<B, F>(#[values(channel_builder, workstealing_builder)] builder_factory: F)
where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let hive1 = thunk_hive::<u8, _, _>(8, builder_factory(false));
let _ = hive1.map_store((0..8u8).map(|i| Thunk::from(move || i)));
hive1.join();
let hive2 = hive1.clone();
assert!(!hive1.close(false));
assert!(hive2.close(false));
}
#[rstest]
fn test_into_outcomes<B, F>(#[values(channel_builder, workstealing_builder)] builder_factory: F)
where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let hive = thunk_hive::<u8, _, _>(8, builder_factory(false));
let task_ids = hive.map_store((0..8u8).map(|i| Thunk::from(move || i)));
hive.join();
let outcomes = hive.try_into_outcomes(false).unwrap();
for i in task_ids.iter() {
assert!(outcomes.get(i).unwrap().is_success());
assert!(matches!(outcomes.get(i), Some(Outcome::Success { .. })));
}
}
#[rstest]
fn test_husk<B, F>(#[values(channel_builder, workstealing_builder)] builder_factory: F)
where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let hive1 = thunk_hive::<u8, _, _>(8, builder_factory(false));
let task_ids = hive1.map_store((0..8u8).map(|i| Thunk::from(move || i)));
hive1.join();
let mut husk1 = hive1.try_into_husk(false).unwrap();
for i in task_ids.iter() {
assert!(husk1.outcomes_deref().get(i).unwrap().is_success());
assert!(matches!(husk1.get(*i), Some(Outcome::Success { .. })));
}
let builder = husk1.as_builder();
let hive2 = builder
.num_threads(4)
.with_worker_default::<ThunkWorker<u8>>()
.with_channel_queues()
.build();
hive2.map_store((0..8u8).map(|i| {
Thunk::from(move || {
thread::sleep(Duration::from_millis((8 - i as u64) * 100));
i
})
}));
hive2.join();
let mut husk2 = hive2.try_into_husk(false).unwrap();
let mut outputs1 = husk1
.remove_all()
.into_iter()
.map(Outcome::unwrap)
.collect::<Vec<_>>();
outputs1.sort();
let mut outputs2 = husk2
.remove_all()
.into_iter()
.map(Outcome::unwrap)
.collect::<Vec<_>>();
outputs2.sort();
assert_eq!(outputs1, outputs2);
let hive3 = husk1.into_hive::<ChannelTaskQueues<_>>();
hive3.map_store((0..8u8).map(|i| {
Thunk::from(move || {
thread::sleep(Duration::from_millis((8 - i as u64) * 100));
i
})
}));
hive3.join();
let husk3 = hive3.try_into_husk(false).unwrap();
let (_, outcomes3) = husk3.into_parts();
let mut outputs3 = outcomes3
.into_iter()
.map(Outcome::unwrap)
.collect::<Vec<_>>();
outputs3.sort();
assert_eq!(outputs1, outputs3);
}
#[rstest]
fn test_clone<B, F>(#[values(channel_builder, workstealing_builder)] builder_factory: F)
where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let hive: Hive<DefaultQueen<TWrk<()>>, B::TaskQueues<_>> = builder_factory(false)
.with_worker_default()
.thread_name("clone example")
.num_threads(2)
.build();
for _ in 0..6 {
hive.apply_store(Thunk::from(|| {
thread::sleep(SHORT_TASK);
}));
}
let t0 = {
let hive = hive.clone();
thread::spawn(move || {
hive.join();
let (tx, rx) = mpsc::channel();
for i in 0..42 {
let tx = tx.clone();
hive.apply_store(Thunk::from(move || {
tx.send(i).expect("channel will be waiting");
}));
}
drop(tx);
rx.iter().sum::<i32>()
})
};
let t1 = {
let pool = hive.clone();
thread::spawn(move || {
pool.join();
let (tx, rx) = mpsc::channel();
for i in 1..12 {
let tx = tx.clone();
pool.apply_store(Thunk::from(move || {
tx.send(i).expect("channel will be waiting");
}));
}
drop(tx);
rx.iter().product::<i32>()
})
};
assert_eq!(
861,
t0.join()
.expect("thread 0 will return after calculating additions",)
);
assert_eq!(
39916800,
t1.join()
.expect("thread 1 will return after calculating multiplications",)
);
}
#[rstest]
fn test_clone_into_husk_fails<B, F>(
#[values(channel_builder, workstealing_builder)] builder_factory: F,
) where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let hive1: Hive<DefaultQueen<TWrk<()>>, B::TaskQueues<_>> = builder_factory(false)
.with_worker_default()
.num_threads(2)
.build();
let hive2 = hive1.clone();
assert!(hive1.try_into_husk(false).is_none());
assert!(hive2.try_into_husk(false).is_some());
}
#[rstest]
fn test_channel_hive_send() {
fn assert_send<T: Send>() {}
assert_send::<Hive<DefaultQueen<TWrk<()>>, ChannelTaskQueues<_>>>();
}
#[rstest]
fn test_workstealing_hive_send() {
fn assert_send<T: Send>() {}
assert_send::<Hive<DefaultQueen<TWrk<()>>, WorkstealingTaskQueues<_>>>();
}
#[rstest]
fn test_cloned_eq<B, F>(#[values(channel_builder, workstealing_builder)] builder_factory: F)
where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let a = thunk_hive::<(), _, _>(2, builder_factory(true));
assert_eq!(a, a.clone());
}
#[rstest]
fn test_join_wavesurfer() {
let n_waves = 4;
let n_workers = 4;
let (tx, rx) = mpsc::channel();
let builder = channel_builder(false)
.num_threads(n_workers)
.thread_name("join wavesurfer");
let waiter_hive = builder
.clone()
.with_worker_default::<ThunkWorker<()>>()
.build();
let clock_hive = builder.with_worker_default::<ThunkWorker<()>>().build();
let barrier = Arc::new(Barrier::new(3));
let wave_counter = Arc::new(AtomicUsize::new(0));
let clock_thread = {
let barrier = barrier.clone();
let wave_counter = wave_counter.clone();
thread::spawn(move || {
barrier.wait();
for wave_num in 0..n_waves {
let _ = wave_counter.swap(wave_num, Ordering::SeqCst);
thread::sleep(ONE_SEC);
}
})
};
{
let barrier = barrier.clone();
clock_hive.apply_store(Thunk::from(move || {
barrier.wait();
thread::sleep(Duration::from_millis(100));
}));
}
for worker in 0..(3 * n_workers) {
let tx = tx.clone();
let clock_hive = clock_hive.clone();
let wave_counter = wave_counter.clone();
waiter_hive.apply_store(Thunk::from(move || {
let wave_before = wave_counter.load(Ordering::SeqCst);
clock_hive.join();
clock_hive.apply_store(Thunk::from(|| thread::sleep(ONE_SEC)));
let wave_after = wave_counter.load(Ordering::SeqCst);
tx.send((wave_before, wave_after, worker)).unwrap();
}));
}
barrier.wait();
clock_hive.join();
drop(tx);
let mut hist = vec![0; n_waves];
let mut data = vec![];
for (before, after, worker) in rx.iter() {
let mut dur = after - before;
if dur >= n_waves - 1 {
dur = n_waves - 1;
}
hist[dur] += 1;
data.push((before, after, worker));
}
println!("Histogram of wave duration:");
for (i, n) in hist.iter().enumerate() {
println!(
"\t{}: {} {}",
i,
n,
&*(0..*n).fold("".to_owned(), |s, _| s + "*")
);
}
for (wave_before, wave_after, worker) in data.iter() {
if *worker < n_workers {
assert_eq!(wave_before, wave_after);
} else {
assert!(wave_before < wave_after);
}
}
clock_thread.join().unwrap();
}
#[rstest]
fn doctest_lib_2<B, F>(#[values(channel_builder, workstealing_builder)] builder_factory: F)
where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let hive: Hive<DefaultQueen<TWrk<i32>>, B::TaskQueues<_>> = builder_factory(false)
.with_worker_default()
.num_threads(4)
.thread_name("thunk_hive")
.build();
let (tx, rx) = crate::hive::outcome_channel();
let task_ids = hive.swarm_send((0..10).map(|i: i32| Thunk::from(move || i * i)), &tx);
let outputs: Vec<_> = rx.select_unordered_outputs(task_ids).collect();
assert_eq!(285, outputs.into_iter().sum());
let outputs2: Vec<_> = hive
.swarm((0..10).map(|i: i32| Thunk::from(move || i * -i)))
.into_outputs()
.collect();
assert_eq!(-285, outputs2.into_iter().sum());
}
#[rstest]
fn doctest_lib_3<B, F>(#[values(channel_builder, workstealing_builder)] builder_factory: F)
where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
#[derive(Debug)]
struct CatWorker {
stdin: ChildStdin,
stdout: BufReader<ChildStdout>,
}
impl CatWorker {
fn new(stdin: ChildStdin, stdout: ChildStdout) -> Self {
Self {
stdin,
stdout: BufReader::new(stdout),
}
}
fn write_char(&mut self, c: u8) -> io::Result<String> {
self.stdin.write_all(&[c])?;
self.stdin.write_all(b"\n")?;
self.stdin.flush()?;
let mut s = String::new();
self.stdout.read_line(&mut s)?;
s.pop(); Ok(s)
}
}
impl Worker for CatWorker {
type Input = u8;
type Output = String;
type Error = io::Error;
fn apply(&mut self, input: Self::Input, _: &Context<u8>) -> WorkerResult<Self> {
self.write_char(input).map_err(|error| ApplyError::Fatal {
input: Some(input),
error,
})
}
}
#[derive(Default)]
struct CatQueen {
children: Vec<Child>,
}
impl CatQueen {
fn wait_for_all(&mut self) -> Vec<io::Result<ExitStatus>> {
self.children
.drain(..)
.map(|mut child| child.wait())
.collect()
}
}
impl QueenMut for CatQueen {
type Kind = CatWorker;
fn create(&mut self) -> Self::Kind {
let mut child = Command::new("cat")
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::inherit())
.spawn()
.unwrap();
let stdin = child.stdin.take().unwrap();
let stdout = child.stdout.take().unwrap();
self.children.push(child);
CatWorker::new(stdin, stdout)
}
}
impl Drop for CatQueen {
fn drop(&mut self) {
self.wait_for_all()
.into_iter()
.for_each(|result| match result {
Ok(status) if status.success() => (),
Ok(status) => eprintln!("Child process failed: {}", status),
Err(e) => eprintln!("Error waiting for child process: {}", e),
})
}
}
let hive = builder_factory(false)
.with_queen_mut_default::<CatQueen>()
.num_threads(4)
.build();
let inputs: Vec<u8> = (0..8).map(|i| 97 + i).collect();
let output = hive
.swarm(inputs)
.into_outputs()
.fold(String::new(), |mut a, b| {
a.push_str(&b);
a
})
.into_bytes();
assert_eq!(output, b"abcdefgh");
let mut queen = hive
.try_into_husk(false)
.unwrap()
.into_parts()
.0
.into_inner();
let (wait_ok, wait_err): (Vec<_>, Vec<_>) =
queen.wait_for_all().into_iter().partition(Result::is_ok);
if !wait_err.is_empty() {
panic!(
"Error(s) occurred while waiting for child processes: {:?}",
wait_err
);
}
let exec_err_codes: Vec<_> = wait_ok
.into_iter()
.map(Result::unwrap)
.filter(|status| !status.success())
.filter_map(|status| status.code())
.collect();
if !exec_err_codes.is_empty() {
panic!(
"Child process(es) failed with exit codes: {:?}",
exec_err_codes
);
}
}
}
#[cfg(all(test, feature = "affinity"))]
mod affinity_tests {
use crate::bee::stock::{Thunk, ThunkWorker};
use crate::channel::{Message, ReceiverExt};
use crate::hive::{Builder, Outcome, TaskQueuesBuilder, channel_builder, workstealing_builder};
use rstest::*;
use std::thread;
use std::time::Duration;
#[rstest]
fn test_affinity<B, F>(#[values(channel_builder, workstealing_builder)] builder_factory: F)
where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let hive = builder_factory(false)
.thread_name("affinity example")
.num_threads(2)
.core_affinity(0..2)
.with_worker_default::<ThunkWorker<()>>()
.build();
hive.map_store((0..10).map(move |i| {
Thunk::from(move || {
if let Some(affininty) = core_affinity::get_core_ids() {
eprintln!("task {} on thread with affinity {:?}", i, affininty);
}
})
}));
}
#[rstest]
fn test_use_all_cores_builder<B, F>(
#[values(channel_builder, workstealing_builder)] builder_factory: F,
) where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let hive = builder_factory(false)
.thread_name("affinity example")
.with_thread_per_core()
.with_default_core_affinity()
.with_worker_default::<ThunkWorker<()>>()
.build();
hive.map_store((0..num_cpus::get()).map(move |i| {
Thunk::from(move || {
if let Some(affininty) = core_affinity::get_core_ids() {
eprintln!("task {} on thread with affinity {:?}", i, affininty);
}
})
}));
}
#[rstest]
fn test_grow_with_affinity<B, F>(
#[values(channel_builder, workstealing_builder)] builder_factory: F,
) where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let hive = builder_factory(false)
.thread_name("affinity example")
.with_default_core_affinity()
.with_worker_default::<ThunkWorker<usize>>()
.build();
let (tx, rx) = super::outcome_channel();
let _ = hive.apply_send(Thunk::from(|| 0), &tx);
thread::sleep(Duration::from_secs(1));
assert_eq!(hive.num_tasks().0, 1);
assert!(matches!(rx.try_recv_msg(), Message::ChannelEmpty));
assert!(matches!(hive.grow_with_affinity(0, vec![]), Ok(0)));
thread::sleep(Duration::from_secs(1));
assert_eq!(hive.num_tasks().0, 1);
assert!(matches!(hive.grow_with_affinity(1, vec![0]), Ok(1)));
thread::sleep(Duration::from_secs(1));
assert_eq!(hive.num_tasks().0, 0);
assert!(matches!(
rx.try_recv_msg(),
Message::Received(Outcome::Success { value: 0, .. })
));
}
#[rstest]
fn test_use_all_cores_hive() {
let hive = crate::hive::channel_builder(false)
.thread_name("affinity example")
.with_default_core_affinity()
.with_worker_default::<ThunkWorker<()>>()
.build();
let num_cores = num_cpus::get();
assert_eq!(hive.use_all_cores_with_affinity().unwrap(), num_cores);
hive.map_store((0..num_cpus::get()).map(move |i| {
Thunk::from(move || {
if let Some(affininty) = core_affinity::get_core_ids() {
eprintln!("task {} on thread with affinity {:?}", i, affininty);
}
})
}));
}
}
#[cfg(all(test, feature = "local-batch"))]
mod local_batch_tests {
use crate::barrier::IndexedBarrier;
use crate::bee::DefaultQueen;
use crate::bee::stock::{Thunk, ThunkWorker};
use crate::hive::{
Builder, Hive, OutcomeIteratorExt, OutcomeReceiver, OutcomeSender, TaskQueues,
TaskQueuesBuilder, channel_builder, workstealing_builder,
};
use rstest::*;
use std::collections::HashMap;
use std::thread::{self, ThreadId};
use std::time::Duration;
fn launch_tasks<T: TaskQueues<ThunkWorker<ThreadId>>>(
hive: &Hive<DefaultQueen<ThunkWorker<ThreadId>>, T>,
num_threads: usize,
num_tasks_per_thread: usize,
barrier: &IndexedBarrier,
tx: &OutcomeSender<ThunkWorker<ThreadId>>,
) -> Vec<usize> {
let total_tasks = num_threads * num_tasks_per_thread;
let init_task_ids: Vec<_> = (0..num_threads)
.map(|_| {
let barrier = barrier.clone();
let task_id = hive.apply_send(
Thunk::from(move || {
barrier.wait();
thread::sleep(Duration::from_millis(100));
thread::current().id()
}),
tx,
);
thread::sleep(Duration::from_millis(100));
task_id
})
.collect();
let rest_task_ids = hive.map_send(
(num_threads..total_tasks).map(|_| {
Thunk::from(move || {
thread::sleep(Duration::from_millis(1));
thread::current().id()
})
}),
tx,
);
init_task_ids.into_iter().chain(rest_task_ids).collect()
}
fn count_thread_ids(
rx: OutcomeReceiver<ThunkWorker<ThreadId>>,
task_ids: Vec<usize>,
) -> HashMap<ThreadId, usize> {
rx.select_unordered_outputs(task_ids)
.fold(HashMap::new(), |mut counter, id| {
*counter.entry(id).or_insert(0) += 1;
counter
})
}
fn run_test<T: TaskQueues<ThunkWorker<ThreadId>>>(
hive: &Hive<DefaultQueen<ThunkWorker<ThreadId>>, T>,
num_threads: usize,
batch_limit: usize,
assert_exact: bool,
) {
let tasks_per_thread = batch_limit + 2;
let (tx, rx) = crate::hive::outcome_channel();
let barrier = IndexedBarrier::new(num_threads);
let task_ids = launch_tasks(hive, num_threads, tasks_per_thread, &barrier, &tx);
barrier.wait();
hive.join();
let thread_counts = count_thread_ids(rx, task_ids);
assert_eq!(thread_counts.len(), num_threads);
assert_eq!(
thread_counts.values().sum::<usize>(),
tasks_per_thread * num_threads
);
if assert_exact {
assert!(
thread_counts
.values()
.all(|&count| count == tasks_per_thread)
);
} else {
assert!(thread_counts.values().all(|&count| count > 0));
}
}
#[rstest]
fn test_local_batch_channel() {
const NUM_THREADS: usize = 4;
const BATCH_LIMIT: usize = 24;
let hive = channel_builder(false)
.with_worker_default()
.num_threads(NUM_THREADS)
.batch_limit(BATCH_LIMIT)
.build();
run_test(&hive, NUM_THREADS, BATCH_LIMIT, true);
}
#[rstest]
fn test_local_batch_workstealing() {
const NUM_THREADS: usize = 4;
const BATCH_LIMIT: usize = 24;
let hive = workstealing_builder(false)
.with_worker_default()
.num_threads(NUM_THREADS)
.batch_limit(BATCH_LIMIT)
.build();
run_test(&hive, NUM_THREADS, BATCH_LIMIT, false);
}
#[rstest]
fn test_set_batch_limit_channel() {
const NUM_THREADS: usize = 4;
const BATCH_LIMIT_0: usize = 10;
const BATCH_LIMIT_1: usize = 50;
const BATCH_LIMIT_2: usize = 20;
let hive = channel_builder(false)
.with_worker_default()
.num_threads(NUM_THREADS)
.batch_limit(BATCH_LIMIT_0)
.build();
run_test(&hive, NUM_THREADS, BATCH_LIMIT_0, true);
hive.set_worker_batch_limit(BATCH_LIMIT_1);
run_test(&hive, NUM_THREADS, BATCH_LIMIT_1, true);
hive.set_worker_batch_limit(BATCH_LIMIT_2);
run_test(&hive, NUM_THREADS, BATCH_LIMIT_2, true);
}
#[rstest]
fn test_set_batch_limit_workstealing() {
const NUM_THREADS: usize = 4;
const BATCH_LIMIT_0: usize = 10;
const BATCH_LIMIT_1: usize = 50;
const BATCH_LIMIT_2: usize = 20;
let hive = workstealing_builder(false)
.with_worker_default()
.num_threads(NUM_THREADS)
.batch_limit(BATCH_LIMIT_0)
.build();
run_test(&hive, NUM_THREADS, BATCH_LIMIT_0, false);
hive.set_worker_batch_limit(BATCH_LIMIT_1);
run_test(&hive, NUM_THREADS, BATCH_LIMIT_1, false);
hive.set_worker_batch_limit(BATCH_LIMIT_2);
run_test(&hive, NUM_THREADS, BATCH_LIMIT_2, false);
}
#[rstest]
fn test_shrink_batch_limit() {
const NUM_THREADS: usize = 4;
const NUM_TASKS_PER_THREAD: usize = 125;
const BATCH_LIMIT_0: usize = 100;
const BATCH_LIMIT_1: usize = 10;
let hive = channel_builder(false)
.with_worker_default()
.num_threads(NUM_THREADS)
.batch_limit(BATCH_LIMIT_0)
.build();
let (tx, rx) = crate::hive::outcome_channel();
let barrier = IndexedBarrier::new(NUM_THREADS);
let task_ids = launch_tasks(&hive, NUM_THREADS, NUM_TASKS_PER_THREAD, &barrier, &tx);
let total_tasks = NUM_THREADS * NUM_TASKS_PER_THREAD;
assert_eq!(task_ids.len(), total_tasks);
barrier.wait();
hive.set_worker_batch_limit(BATCH_LIMIT_1);
hive.join();
let thread_counts = count_thread_ids(rx, task_ids);
assert!(thread_counts.values().all(|count| *count > BATCH_LIMIT_0));
assert_eq!(thread_counts.values().sum::<usize>(), total_tasks);
}
#[test]
fn test_change_channel_batch_limit_nonempty() {}
}
#[cfg(all(test, feature = "local-batch"))]
mod weighted_map_tests {
use crate::bee::stock::{RetryCaller, Thunk, ThunkWorker};
use crate::bee::{ApplyError, Context};
use crate::hive::{
Builder, Outcome, OutcomeIteratorExt, TaskQueuesBuilder, Weighted, WeightedIteratorExt,
channel_builder, workstealing_builder,
};
use rstest::*;
use std::collections::HashMap;
use std::thread;
use std::time::Duration;
#[rstest]
fn test_map_weighted<B, F>(#[values(channel_builder, workstealing_builder)] builder_factory: F)
where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
const NUM_THREADS: usize = 4;
const BATCH_LIMIT: usize = 24;
let hive = builder_factory(false)
.with_worker_default::<ThunkWorker<u8>>()
.num_threads(NUM_THREADS)
.batch_limit(BATCH_LIMIT)
.build();
let inputs = (0..10u8)
.map(|i| {
Thunk::from(move || {
thread::sleep(Duration::from_millis((10 - i as u64) * 100));
i
})
})
.map(|thunk| (thunk, 0))
.into_weighted();
let outputs: Vec<_> = hive.map(inputs).map(Outcome::unwrap).collect();
assert_eq!(outputs, (0..10).collect::<Vec<_>>())
}
#[rstest]
fn test_map_weighted_with_limit<B, F>(
#[values(channel_builder, workstealing_builder)] builder_factory: F,
) where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
const NUM_THREADS: usize = 4;
const NUM_TASKS_PER_THREAD: usize = 3;
const NUM_TASKS: usize = NUM_THREADS * NUM_TASKS_PER_THREAD;
const BATCH_LIMIT: usize = 10;
const WEIGHT: u32 = 25;
const WEIGHT_LIMIT: u64 = WEIGHT as u64 * NUM_TASKS_PER_THREAD as u64;
let hive = builder_factory(false)
.with_worker(RetryCaller::from(
|i: u8, ctx: &Context<u8>| -> Result<(u8, Option<usize>), ApplyError<u8, ()>> {
thread::sleep(Duration::from_millis(500));
Ok((i, ctx.thread_index()))
},
))
.batch_limit(BATCH_LIMIT)
.weight_limit(WEIGHT_LIMIT)
.build();
let inputs = (0..NUM_TASKS as u8).map(|i| (i, WEIGHT)).into_weighted();
let (tx, rx) = crate::hive::outcome_channel();
let task_ids = hive.map_send(inputs, tx);
thread::sleep(Duration::from_secs(1));
assert_eq!(hive.grow(NUM_THREADS).unwrap(), NUM_THREADS);
hive.join();
let (mut outputs, thread_indices) = rx
.into_iter()
.select_unordered_outputs(task_ids)
.unzip::<_, _, Vec<_>, Vec<_>>();
outputs.sort();
assert_eq!(outputs, (0..NUM_TASKS as u8).collect::<Vec<_>>());
let counts =
thread_indices
.into_iter()
.flatten()
.fold(HashMap::new(), |mut counts, index| {
counts
.entry(index)
.and_modify(|count| *count += 1)
.or_insert(1);
counts
});
assert!(counts.values().all(|&count| count == NUM_TASKS_PER_THREAD));
}
#[rstest]
fn test_overweight() {
const WEIGHT_LIMIT: u64 = 99;
let hive = channel_builder(false)
.with_worker_default::<ThunkWorker<u8>>()
.num_threads(1)
.weight_limit(WEIGHT_LIMIT)
.build();
let outcome = hive.apply(Weighted::new(Thunk::from(|| 0), 100));
assert!(matches!(
outcome,
Outcome::WeightLimitExceeded { weight: 100, .. }
))
}
#[rstest]
fn test_set_weight_limit() {
const WEIGHT_LIMIT: u64 = 99;
let hive = channel_builder(false)
.with_worker_default::<ThunkWorker<u8>>()
.num_threads(1)
.weight_limit(WEIGHT_LIMIT)
.build();
assert_eq!(WEIGHT_LIMIT, hive.worker_weight_limit());
let outcome = hive.apply(Weighted::new(Thunk::from(|| 0), WEIGHT_LIMIT + 1));
assert!(matches!(
outcome,
Outcome::WeightLimitExceeded { weight: 100, .. }
));
hive.set_worker_weight_limit(WEIGHT_LIMIT + 1);
assert_eq!(WEIGHT_LIMIT + 1, hive.worker_weight_limit());
let outcome = hive.apply(Weighted::new(Thunk::from(|| 0), WEIGHT_LIMIT + 1));
assert!(matches!(outcome, Outcome::Success { .. }));
}
}
#[cfg(all(test, feature = "local-batch"))]
mod weighted_swarm_tests {
use crate::bee::stock::{EchoWorker, Thunk, ThunkWorker};
use crate::hive::{
Builder, Outcome, TaskQueuesBuilder, WeightedIteratorExt, channel_builder,
workstealing_builder,
};
use rstest::*;
use std::thread;
use std::time::Duration;
#[rstest]
fn test_swarm_weighted<B, F>(
#[values(channel_builder, workstealing_builder)] builder_factory: F,
) where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
const NUM_THREADS: usize = 4;
const BATCH_LIMIT: usize = 24;
let hive = builder_factory(false)
.with_worker_default::<ThunkWorker<u8>>()
.num_threads(NUM_THREADS)
.batch_limit(BATCH_LIMIT)
.build();
let inputs = (0..10u8)
.map(|i| {
Thunk::from(move || {
thread::sleep(Duration::from_millis((10 - i as u64) * 100));
i
})
})
.map(|thunk| (thunk, 0))
.into_weighted_exact();
let outputs: Vec<_> = hive.swarm(inputs).map(Outcome::unwrap).collect();
assert_eq!(outputs, (0..10).collect::<Vec<_>>())
}
#[rstest]
fn test_swarm_default_weighted<B, F>(
#[values(channel_builder, workstealing_builder)] builder_factory: F,
) where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
const NUM_THREADS: usize = 4;
const BATCH_LIMIT: usize = 24;
let hive = builder_factory(false)
.with_worker_default::<ThunkWorker<u8>>()
.num_threads(NUM_THREADS)
.batch_limit(BATCH_LIMIT)
.build();
let inputs = (0..10u8)
.map(|i| {
Thunk::from(move || {
thread::sleep(Duration::from_millis((10 - i as u64) * 100));
i
})
})
.into_default_weighted_exact();
let outputs: Vec<_> = hive.swarm(inputs).map(Outcome::unwrap).collect();
assert_eq!(outputs, (0..10).collect::<Vec<_>>())
}
#[rstest]
fn test_swarm_const_weighted<B, F>(
#[values(channel_builder, workstealing_builder)] builder_factory: F,
) where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
const NUM_THREADS: usize = 4;
const BATCH_LIMIT: usize = 24;
let hive = builder_factory(false)
.with_worker_default::<ThunkWorker<u8>>()
.num_threads(NUM_THREADS)
.batch_limit(BATCH_LIMIT)
.build();
let inputs = (0..10u8)
.map(|i| {
Thunk::from(move || {
thread::sleep(Duration::from_millis((10 - i as u64) * 100));
i
})
})
.into_const_weighted_exact(0);
let outputs: Vec<_> = hive.swarm(inputs).map(Outcome::unwrap).collect();
assert_eq!(outputs, (0..10).collect::<Vec<_>>())
}
#[rstest]
fn test_swarm_identity_weighted<B, F>(
#[values(channel_builder, workstealing_builder)] builder_factory: F,
) where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
const NUM_THREADS: usize = 4;
const BATCH_LIMIT: usize = 24;
let hive = builder_factory(false)
.with_worker_default::<EchoWorker<u8>>()
.num_threads(NUM_THREADS)
.batch_limit(BATCH_LIMIT)
.build();
let inputs = (0..10u8).into_identity_weighted_exact();
let outputs: Vec<_> = hive.swarm(inputs).map(Outcome::unwrap).collect();
assert_eq!(outputs, (0..10).collect::<Vec<_>>())
}
}
#[cfg(all(test, feature = "retry"))]
mod retry_tests {
use crate::bee::stock::RetryCaller;
use crate::bee::{ApplyError, Context};
use crate::hive::{
Builder, Outcome, OutcomeIteratorExt, TaskQueuesBuilder, channel_builder,
workstealing_builder,
};
use rstest::*;
use std::time::{Duration, SystemTime};
fn echo_time(i: usize, ctx: &Context<usize>) -> Result<String, ApplyError<usize, String>> {
let attempt = ctx.attempt();
if attempt == 3 {
Ok("Success".into())
} else {
eprintln!("Task {} attempt {}: {:?}", i, attempt, SystemTime::now());
Err(ApplyError::Retryable {
input: i,
error: "Retryable".into(),
})
}
}
#[rstest]
fn test_retries<B, F>(#[values(channel_builder, workstealing_builder)] builder_factory: F)
where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let hive = builder_factory(false)
.with_worker(RetryCaller::from(echo_time))
.with_thread_per_core()
.max_retries(3)
.retry_factor(Duration::from_secs(1))
.build();
let v: Result<Vec<_>, _> = hive.swarm(0..10usize).into_results().collect();
assert_eq!(v.unwrap().len(), 10);
}
#[rstest]
fn test_retries_fail<B, F>(#[values(channel_builder, workstealing_builder)] builder_factory: F)
where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
fn sometimes_fail(
i: usize,
_: &Context<usize>,
) -> Result<String, ApplyError<usize, String>> {
match i % 3 {
0 => Ok("Success".into()),
1 => Err(ApplyError::Retryable {
input: i,
error: "Retryable".into(),
}),
2 => Err(ApplyError::Fatal {
input: Some(i),
error: "Fatal".into(),
}),
_ => unreachable!(),
}
}
let hive = builder_factory(false)
.with_worker(RetryCaller::from(sometimes_fail))
.with_thread_per_core()
.max_retries(3)
.build();
let (success, retry_failed, not_retried) = hive.swarm(0..10usize).fold(
(0, 0, 0),
|(success, retry_failed, not_retried), outcome| match outcome {
Outcome::Success { .. } => (success + 1, retry_failed, not_retried),
Outcome::MaxRetriesAttempted { .. } => (success, retry_failed + 1, not_retried),
Outcome::Failure { .. } => (success, retry_failed, not_retried + 1),
_ => unreachable!(),
},
);
assert_eq!(success, 4);
assert_eq!(retry_failed, 3);
assert_eq!(not_retried, 3);
}
#[rstest]
fn test_disable_retries<B, F>(
#[values(channel_builder, workstealing_builder)] builder_factory: F,
) where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let hive = builder_factory(false)
.with_worker(RetryCaller::from(echo_time))
.with_thread_per_core()
.with_no_retries()
.build();
let v: Result<Vec<_>, _> = hive.swarm(0..10usize).into_results().collect();
assert!(v.is_err());
}
#[rstest]
fn test_change_retry_limit<B, F>(
#[values(channel_builder, workstealing_builder)] builder_factory: F,
) where
B: TaskQueuesBuilder,
F: Fn(bool) -> B,
{
let hive = builder_factory(false)
.with_worker(RetryCaller::from(echo_time))
.with_thread_per_core()
.with_no_retries()
.build();
assert_eq!(hive.worker_retry_limit(), 0);
assert_eq!(hive.worker_retry_factor(), Duration::from_secs(0));
let v: Result<Vec<_>, _> = hive.swarm(0..10usize).into_results().collect();
assert!(v.is_err());
hive.set_worker_retry_limit(3);
hive.set_worker_retry_factor(Duration::from_secs(1));
assert_eq!(hive.worker_retry_limit(), 3);
assert_eq!(hive.worker_retry_factor(), Duration::from_secs(1));
let v: Result<Vec<_>, _> = hive.swarm(0..10usize).into_results().collect();
assert_eq!(v.unwrap().len(), 10);
}
}