const TARGET_WORK_PER_TASK: usize = 16_000_000;
const MIN_ROWS_PER_TASK: usize = 512;
const MAX_ROWS_PER_TASK: usize = 16_384;
const MAX_TASKS_PER_WORKER: usize = 4;
pub fn row_reduction_chunk_rows(
n_rows: usize,
row_work_units: usize,
reduction_cells: usize,
min_parallel_work: usize,
) -> Option<usize> {
if n_rows == 0 || row_work_units == 0 {
return None;
}
let workers = rayon::current_num_threads();
let total_work = n_rows.saturating_mul(row_work_units);
if workers <= 1 || total_work < min_parallel_work {
return None;
}
let min_rows_by_work = TARGET_WORK_PER_TASK
.div_ceil(row_work_units.max(1))
.clamp(MIN_ROWS_PER_TASK, MAX_ROWS_PER_TASK);
let tasks_by_rows = n_rows.div_ceil(min_rows_by_work).max(1);
if tasks_by_rows <= 1 {
return None;
}
let task_cap_by_workers = workers.saturating_mul(MAX_TASKS_PER_WORKER).max(1);
let task_cap_by_reduction = reduction_task_cap(reduction_cells);
let tasks = tasks_by_rows
.min(task_cap_by_workers)
.min(task_cap_by_reduction)
.max(1);
if tasks <= 1 {
return None;
}
Some(n_rows.div_ceil(tasks).max(1))
}
pub fn row_reduction_chunk_count(n_rows: usize, chunk_rows: usize) -> usize {
if n_rows == 0 {
0
} else {
n_rows.div_ceil(chunk_rows.max(1))
}
}
fn reduction_task_cap(reduction_cells: usize) -> usize {
let bytes = reduction_cells.saturating_mul(std::mem::size_of::<f64>());
if bytes <= 64 * 1024 {
usize::MAX
} else if bytes <= 1024 * 1024 {
128
} else if bytes <= 8 * 1024 * 1024 {
32
} else {
8
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn chunk_rows_zero_n_rows_returns_none() {
assert_eq!(row_reduction_chunk_rows(0, 100, 1, 1), None);
}
#[test]
fn chunk_rows_zero_work_units_returns_none() {
assert_eq!(row_reduction_chunk_rows(1000, 0, 1, 1), None);
}
#[test]
fn chunk_rows_below_min_parallel_work_returns_none() {
assert_eq!(row_reduction_chunk_rows(10, 5, 1, 10_000), None);
}
#[test]
fn chunk_count_zero_rows_is_zero() {
assert_eq!(row_reduction_chunk_count(0, 100), 0);
}
#[test]
fn chunk_count_exact_division() {
assert_eq!(row_reduction_chunk_count(9, 3), 3);
}
#[test]
fn chunk_count_ceiling_division() {
assert_eq!(row_reduction_chunk_count(10, 3), 4);
}
#[test]
fn chunk_count_zero_chunk_size_treated_as_one() {
assert_eq!(row_reduction_chunk_count(7, 0), 7);
}
#[test]
fn chunk_count_chunk_equals_n_rows() {
assert_eq!(row_reduction_chunk_count(5, 5), 1);
}
#[test]
fn chunk_count_chunk_larger_than_n_rows() {
assert_eq!(row_reduction_chunk_count(3, 100), 1);
}
}