use scirs2_core::parallel_ops::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::{Duration, Instant};
static GLOBAL_THRESHOLD: AtomicUsize = AtomicUsize::new(1000);
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ParallelizationThreshold {
Fixed(usize),
Adaptive,
Benchmarked,
AutoTuning,
}
pub fn set_global_threshold(threshold: usize) {
GLOBAL_THRESHOLD.store(threshold, Ordering::SeqCst);
}
pub fn get_global_threshold() -> usize {
GLOBAL_THRESHOLD.load(Ordering::SeqCst)
}
pub fn adaptive_threshold(array_size: usize, element_cost: f64) -> usize {
let core_count = num_cpus::get();
let base_threshold = match core_count {
1 => usize::MAX, 2 => 5000, 3..=4 => 2000, 5..=8 => 1000, _ => 500, };
let cost_factor = (5.0 / element_cost).clamp(0.1, 10.0);
let size_factor = if array_size > 1_000_000 {
0.8 } else if array_size > 100_000 {
0.9 } else {
1.0 };
let calculated = (base_threshold as f64 * cost_factor * size_factor) as usize;
calculated.clamp(100, 50_000)
}
pub fn benchmark_threshold<F>(element_cost: F) -> usize
where
F: Fn(usize) -> f64 + Sync + Send,
{
let benchmark = |size: usize| -> (Duration, Duration) {
let data: Vec<usize> = (0..size).collect();
let start = Instant::now();
let _sequential: f64 = data.iter().map(|&i| element_cost(i)).sum();
let sequential_time = start.elapsed();
let start = Instant::now();
let _parallel: f64 = data.par_iter().map(|&i| element_cost(i)).sum();
let parallel_time = start.elapsed();
(sequential_time, parallel_time)
};
let sizes = [100, 500, 1000, 2000, 5000, 10000, 20000];
let mut threshold = 1000;
for &size in &sizes {
let (seq_time, par_time) = benchmark(size);
if par_time < seq_time {
return threshold;
}
threshold = size;
}
threshold
}
pub fn get_optimal_threshold(
threshold_type: ParallelizationThreshold,
array_size: usize,
element_cost: f64,
) -> usize {
match threshold_type {
ParallelizationThreshold::Fixed(value) => value,
ParallelizationThreshold::Adaptive => adaptive_threshold(array_size, element_cost),
ParallelizationThreshold::Benchmarked => {
benchmark_threshold(|i| {
let mut result = i as f64;
for _ in 0..(element_cost as usize).max(1) {
result = result.sin().cos();
}
result
})
}
ParallelizationThreshold::AutoTuning => {
get_global_threshold()
}
}
}
pub fn auto_tune_threshold(array_size: usize, element_cost: f64) -> usize {
let current = get_global_threshold();
let new_threshold = if array_size > current && element_cost > 1.0 {
(current as f64 * 0.9) as usize
} else if array_size < current / 2 {
(current as f64 * 1.1) as usize
} else {
current
};
let tuned = new_threshold.clamp(500, 50_000);
set_global_threshold(tuned);
tuned
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_adaptive_threshold() {
let threshold1 = adaptive_threshold(10000, 1.0);
let threshold2 = adaptive_threshold(10000, 5.0);
assert!(threshold2 <= threshold1);
let threshold3 = adaptive_threshold(10000, 1.0);
let threshold4 = adaptive_threshold(1000000, 1.0);
assert!(threshold4 <= threshold3);
}
#[test]
fn test_global_threshold() {
set_global_threshold(5000);
let threshold = get_global_threshold();
assert_eq!(threshold, 5000);
}
}