use super::scheduling::SchedulingStrategy;
use scirs2_core::parallel_ops::*;
use std::ops::Range;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WorkloadPartitioning {
EqualChunks,
VariableChunks,
PowerOfTwoChunks,
CacheOptimizedChunks,
DynamicPartitioning,
}
pub fn partition_workload(
array_size: usize,
partitioning: WorkloadPartitioning,
num_threads: usize,
) -> Vec<Range<usize>> {
let n_threads = if num_threads == 0 {
scirs2_core::parallel_ops::num_threads()
} else {
num_threads
};
match partitioning {
WorkloadPartitioning::EqualChunks => equal_chunks(array_size, n_threads),
WorkloadPartitioning::VariableChunks => variable_chunks(array_size, n_threads),
WorkloadPartitioning::PowerOfTwoChunks => power_of_two_chunks(array_size, n_threads),
WorkloadPartitioning::CacheOptimizedChunks => cache_optimized_chunks(array_size, n_threads),
WorkloadPartitioning::DynamicPartitioning => dynamic_partitioning(array_size, n_threads),
}
}
fn equal_chunks(array_size: usize, num_threads: usize) -> Vec<Range<usize>> {
let actual_threads = num_threads.min((array_size / 100).max(1));
let chunk_size = array_size.div_ceil(actual_threads);
(0..array_size)
.step_by(chunk_size)
.map(|start| {
let end = (start + chunk_size).min(array_size);
start..end
})
.collect()
}
fn variable_chunks(array_size: usize, num_threads: usize) -> Vec<Range<usize>> {
if array_size <= num_threads {
return (0..array_size).map(|i| i..i + 1).collect();
}
let mut chunks = Vec::with_capacity(num_threads);
let mut remaining = array_size;
let mut start = 0;
for thread_idx in 0..num_threads {
let scaling_factor = 1.0 - (thread_idx as f64 / num_threads as f64).powf(0.75);
let ideal_portion = (remaining as f64 / (num_threads - thread_idx) as f64) * scaling_factor;
let chunk_size = (ideal_portion as usize).max(1);
let actual_size = chunk_size.min(remaining);
chunks.push(start..start + actual_size);
start += actual_size;
remaining -= actual_size;
if remaining == 0 {
break;
}
}
if remaining > 0 && !chunks.is_empty() {
let last_idx = chunks.len() - 1;
chunks[last_idx].end += remaining;
}
chunks
}
fn power_of_two_chunks(array_size: usize, num_threads: usize) -> Vec<Range<usize>> {
let mut power = 1;
while power * 2 <= num_threads {
power *= 2;
}
let chunk_size = array_size.div_ceil(power);
let mut pot_chunk_size = 1;
while pot_chunk_size < chunk_size {
pot_chunk_size *= 2;
}
let mut start = 0;
let mut chunks = Vec::new();
while start < array_size {
let end = (start + pot_chunk_size).min(array_size);
chunks.push(start..end);
start = end;
}
chunks
}
fn cache_optimized_chunks(array_size: usize, num_threads: usize) -> Vec<Range<usize>> {
let elements_per_cache_line = 8; let cache_lines_per_chunk = 64;
let ideal_chunk_size = elements_per_cache_line * cache_lines_per_chunk;
let adjusted_chunk_size = if array_size > ideal_chunk_size * num_threads {
ideal_chunk_size
} else {
array_size.div_ceil(num_threads)
};
let cache_aligned_chunk_size =
adjusted_chunk_size.div_ceil(elements_per_cache_line) * elements_per_cache_line;
let mut start = 0;
let mut chunks = Vec::new();
while start < array_size {
let end = (start + cache_aligned_chunk_size).min(array_size);
chunks.push(start..end);
start = end;
}
chunks
}
fn dynamic_partitioning(array_size: usize, num_threads: usize) -> Vec<Range<usize>> {
variable_chunks(array_size, num_threads)
}
pub fn parallel_execute<F, R>(
array_size: usize,
partitioning: WorkloadPartitioning,
scheduling: SchedulingStrategy,
op: F,
) -> Vec<R>
where
F: Fn(Range<usize>) -> R + Send + Sync,
R: Send,
{
let chunks = partition_workload(array_size, partitioning, 0);
match scheduling {
SchedulingStrategy::Static => {
chunks.par_iter().map(|chunk| op(chunk.clone())).collect()
}
SchedulingStrategy::Dynamic => {
chunks
.par_iter()
.with_min_len(1)
.map(|chunk| op(chunk.clone()))
.collect()
}
SchedulingStrategy::Guided => {
chunks
.par_iter()
.with_min_len(chunks.len() / 10 + 1)
.map(|chunk| op(chunk.clone()))
.collect()
}
SchedulingStrategy::WorkStealing => {
chunks.par_iter().map(|chunk| op(chunk.clone())).collect()
}
SchedulingStrategy::Adaptive => {
if chunks.len() < 10 {
chunks.par_iter().map(|chunk| op(chunk.clone())).collect()
} else if chunks.iter().map(|c| c.len()).max().unwrap_or(0)
> chunks.iter().map(|c| c.len()).min().unwrap_or(0) * 2
{
chunks
.par_iter()
.with_min_len(1)
.map(|chunk| op(chunk.clone()))
.collect()
} else {
chunks
.par_iter()
.with_min_len(chunks.len() / 20 + 1)
.map(|chunk| op(chunk.clone()))
.collect()
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_equal_chunks() {
let chunks = equal_chunks(10, 3);
assert!(chunks.len() <= 3);
let total_elements: usize = chunks.iter().map(|r| r.len()).sum();
assert_eq!(total_elements, 10);
let chunks_small = equal_chunks(3, 10);
assert!(chunks_small.len() <= 3); }
#[test]
fn test_variable_chunks() {
let chunks = variable_chunks(100, 4);
assert!(chunks.len() <= 4);
if chunks.len() > 1 {
}
let total_elements: usize = chunks.iter().map(|r| r.len()).sum();
assert_eq!(total_elements, 100); }
#[test]
fn test_power_of_two_chunks() {
let chunks = power_of_two_chunks(100, 4);
for (i, chunk) in chunks.iter().enumerate() {
let size = chunk.len();
if i < chunks.len() - 1 {
assert_eq!(size & (size - 1), 0);
}
}
let total_elements: usize = chunks.iter().map(|r| r.len()).sum();
assert_eq!(total_elements, 100); }
#[test]
fn test_cache_optimized_chunks() {
let chunks = cache_optimized_chunks(1000, 4);
for (i, chunk) in chunks.iter().enumerate() {
if i < chunks.len() - 1 {
assert_eq!(chunk.len() % 8, 0);
}
}
let total_elements: usize = chunks.iter().map(|r| r.len()).sum();
assert_eq!(total_elements, 1000); }
}