use core::ops::Range;
#[derive(Clone, Debug)]
pub struct BatchingStrategy {
pub batch_size_limits: Range<usize>,
pub batches_per_thread: usize,
}
impl Default for BatchingStrategy {
fn default() -> Self {
Self::new()
}
}
impl BatchingStrategy {
pub const fn new() -> Self {
Self {
batch_size_limits: 1..usize::MAX,
batches_per_thread: 1,
}
}
pub const fn fixed(batch_size: usize) -> Self {
Self {
batch_size_limits: batch_size..batch_size,
batches_per_thread: 1,
}
}
pub const fn min_batch_size(mut self, batch_size: usize) -> Self {
self.batch_size_limits.start = batch_size;
self
}
pub const fn max_batch_size(mut self, batch_size: usize) -> Self {
self.batch_size_limits.end = batch_size;
self
}
pub fn batches_per_thread(mut self, batches_per_thread: usize) -> Self {
assert!(
batches_per_thread > 0,
"The number of batches per thread must be non-zero."
);
self.batches_per_thread = batches_per_thread;
self
}
#[inline]
pub fn calc_batch_size(&self, max_items: impl FnOnce() -> usize, thread_count: usize) -> usize {
if self.batch_size_limits.is_empty() {
return self.batch_size_limits.start;
}
assert!(
thread_count > 0,
"Attempted to run parallel iteration with an empty TaskPool"
);
let batches = thread_count * self.batches_per_thread;
let batch_size = max_items().div_ceil(batches);
batch_size.clamp(self.batch_size_limits.start, self.batch_size_limits.end)
}
}