use super::super::WorkerConfig;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
pub struct WorkStealingScheduler {
pub(crate) num_workers: usize,
chunksize: usize,
adaptive_chunking: bool,
}
impl WorkStealingScheduler {
pub fn new(config: &WorkerConfig) -> Self {
let num_workers = config.workers.unwrap_or_else(|| {
std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(4)
});
Self {
num_workers,
chunksize: config.chunksize,
adaptive_chunking: true,
}
}
pub fn with_adaptive_chunking(mut self, adaptive: bool) -> Self {
self.adaptive_chunking = adaptive;
self
}
pub fn execute<T, R, F>(&self, items: &[T], f: F) -> Vec<R>
where
T: Send + Sync,
R: Send + Default + Clone,
F: Fn(&T) -> R + Send + Sync,
{
let n = items.len();
if n == 0 {
return Vec::new();
}
let chunksize = if self.adaptive_chunking {
self.adaptive_chunksize(n)
} else {
self.chunksize
};
let work_counter = Arc::new(AtomicUsize::new(0));
let results = Arc::new(Mutex::new(vec![R::default(); n]));
std::thread::scope(|s| {
let handles: Vec<_> = (0..self.num_workers)
.map(|_| {
let work_counter = work_counter.clone();
let results = results.clone();
let items_ref = items;
let f_ref = &f;
s.spawn(move || {
loop {
let start = work_counter.fetch_add(chunksize, Ordering::SeqCst);
if start >= n {
break;
}
let end = std::cmp::min(start + chunksize, n);
for i in start..end {
let result = f_ref(&items_ref[i]);
let mut results_guard = results.lock().expect("Operation failed");
results_guard[i] = result;
}
}
})
})
.collect();
for handle in handles {
handle.join().expect("Operation failed");
}
});
Arc::try_unwrap(results)
.unwrap_or_else(|_| panic!("Failed to extract results"))
.into_inner()
.unwrap_or_else(|_| panic!("Failed to extract mutex inner value"))
}
fn adaptive_chunksize(&self, totalitems: usize) -> usize {
let items_per_worker = totalitems / self.num_workers;
if items_per_worker < 100 {
std::cmp::max(1, items_per_worker / 4)
} else if items_per_worker < 1000 {
items_per_worker / 8
} else {
std::cmp::min(self.chunksize, items_per_worker / 16)
}
}
pub fn execute_matrix<R, F>(&self, rows: usize, cols: usize, f: F) -> scirs2_core::ndarray::Array2<R>
where
R: Send + Default + Clone,
F: Fn(usize, usize) -> R + Send + Sync,
{
let blocksize = 64; let work_items: Vec<(usize, usize)> = (0..rows)
.step_by(blocksize)
.flat_map(|i| (0..cols).step_by(blocksize).map(move |j| (i, j)))
.collect();
let work_counter = Arc::new(AtomicUsize::new(0));
let results_vec = Arc::new(Mutex::new(Vec::new()));
std::thread::scope(|s| {
let handles: Vec<_> = (0..self.num_workers)
.map(|_| {
let work_counter = work_counter.clone();
let results_vec = results_vec.clone();
let work_items_ref = &work_items;
let f_ref = &f;
s.spawn(move || {
let mut local_results = Vec::new();
loop {
let idx = work_counter.fetch_add(1, Ordering::SeqCst);
if idx >= work_items_ref.len() {
break;
}
let (block_i, block_j) = work_items_ref[idx];
let i_end = std::cmp::min(block_i + blocksize, rows);
let j_end = std::cmp::min(block_j + blocksize, cols);
for i in block_i..i_end {
for j in block_j..j_end {
local_results.push((i, j, f_ref(i, j)));
}
}
}
if !local_results.is_empty() {
let mut global_results = results_vec.lock().expect("Operation failed");
global_results.extend(local_results);
}
})
})
.collect();
for handle in handles {
handle.join().expect("Operation failed");
}
});
let mut result = scirs2_core::ndarray::Array2::default((rows, cols));
let results = Arc::try_unwrap(results_vec)
.unwrap_or_else(|_| panic!("Failed to extract results"))
.into_inner()
.unwrap_or_else(|_| panic!("Failed to extract mutex inner value"));
for (i, j, val) in results {
result[[i, j]] = val;
}
result
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_work_stealing_scheduler() {
let config = WorkerConfig::default();
let scheduler = WorkStealingScheduler::new(&config);
let items = vec![1, 2, 3, 4, 5];
let results = scheduler.execute(&items, |x| x * 2);
assert_eq!(results, vec![2, 4, 6, 8, 10]);
}
#[test]
fn test_adaptive_chunking() {
let config = WorkerConfig::default();
let scheduler = WorkStealingScheduler::new(&config).with_adaptive_chunking(true);
assert!(scheduler.adaptive_chunksize(10) >= 1);
assert!(scheduler.adaptive_chunksize(10000) > 10);
}
#[test]
fn test_matrix_execution() {
let config = WorkerConfig::default();
let scheduler = WorkStealingScheduler::new(&config);
let result = scheduler.execute_matrix(3, 3, |i, j| i + j);
assert_eq!(result.shape(), &[3, 3]);
assert_eq!(result[[0, 0]], 0);
assert_eq!(result[[1, 1]], 2);
assert_eq!(result[[2, 2]], 4);
}
}