use super::super::strategies::work_stealing::WorkStealingScheduler;
use super::super::WorkerConfig;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
pub struct AdvancedWorkStealingScheduler {
base_scheduler: WorkStealingScheduler,
numa_aware: bool,
cache_linesize: usize,
#[allow(dead_code)]
work_queue_per_thread: bool,
}
impl AdvancedWorkStealingScheduler {
pub fn new(config: &WorkerConfig) -> Self {
Self {
base_scheduler: WorkStealingScheduler::new(config),
numa_aware: true,
cache_linesize: 64, work_queue_per_thread: true,
}
}
pub fn with_numa_aware(mut self, enabled: bool) -> Self {
self.numa_aware = enabled;
self
}
pub fn with_cache_linesize(mut self, size: usize) -> Self {
self.cache_linesize = size;
self
}
pub fn execute_optimized<T, R, F>(&self, items: &[T], f: F) -> Vec<R>
where
T: Send + Sync,
R: Send + Default + Clone,
F: Fn(&T) -> R + Send + Sync,
{
if items.is_empty() {
return Vec::new();
}
let n = items.len();
let workload_type = self.analyze_workload(n);
let chunk_config = match workload_type {
WorkloadType::MemoryBound => ChunkConfig {
size: self.cache_linesize / std::mem::size_of::<T>(),
strategy: ChunkStrategy::Sequential,
},
WorkloadType::CpuBound => ChunkConfig {
size: n / (self.base_scheduler.num_workers * 4),
strategy: ChunkStrategy::Interleaved,
},
WorkloadType::Mixed => ChunkConfig {
size: self.adaptive_chunksize_enhanced(n),
strategy: ChunkStrategy::Dynamic,
},
};
self.execute_with_strategy(items, f, chunk_config)
}
fn analyze_workload(&self, size: usize) -> WorkloadType {
let memory_footprint = size * std::mem::size_of::<usize>();
let cachesize = 8 * 1024 * 1024;
if memory_footprint > cachesize {
WorkloadType::MemoryBound
} else if size < 1000 {
WorkloadType::CpuBound
} else {
WorkloadType::Mixed
}
}
fn adaptive_chunksize_enhanced(&self, totalitems: usize) -> usize {
let num_workers = self.base_scheduler.num_workers;
let items_per_worker = totalitems / num_workers;
let cache_optimal_size = self.cache_linesize / std::mem::size_of::<usize>();
let load_balance_size = std::cmp::max(1, items_per_worker / 8);
if cache_optimal_size > 0 && cache_optimal_size < load_balance_size * 2 {
cache_optimal_size
} else {
load_balance_size
}
}
fn execute_with_strategy<T, R, F>(&self, items: &[T], f: F, config: ChunkConfig) -> Vec<R>
where
T: Send + Sync,
R: Send + Default + Clone,
F: Fn(&T) -> R + Send + Sync,
{
match config.strategy {
ChunkStrategy::Sequential => self.execute_sequential_chunks(items, f, config.size),
ChunkStrategy::Interleaved => {
self.execute_interleaved_chunks(items, f, config.size)
}
ChunkStrategy::Dynamic => self.execute_dynamic_chunks(items, f, config.size),
}
}
fn execute_sequential_chunks<T, R, F>(&self, items: &[T], f: F, _chunksize: usize) -> Vec<R>
where
T: Send + Sync,
R: Send + Default + Clone,
F: Fn(&T) -> R + Send + Sync,
{
self.base_scheduler.execute(items, f)
}
fn execute_interleaved_chunks<T, R, F>(&self, items: &[T], f: F, chunksize: usize) -> Vec<R>
where
T: Send + Sync,
R: Send + Default + Clone,
F: Fn(&T) -> R + Send + Sync,
{
let n = items.len();
let chunksize = chunksize.max(1);
let results = Arc::new(Mutex::new(vec![R::default(); n]));
let work_counter = Arc::new(AtomicUsize::new(0));
std::thread::scope(|s| {
let handles: Vec<_> = (0..self.base_scheduler.num_workers)
.map(|_worker_id| {
let items_ref = items;
let f_ref = &f;
let results = results.clone();
let work_counter = work_counter.clone();
s.spawn(move || {
loop {
let chunk_id = work_counter.fetch_add(1, Ordering::SeqCst);
let start = chunk_id * chunksize;
if start >= n {
break;
}
let end = std::cmp::min(start + chunksize, n);
for i in start..end {
let interleaved_idx = (i % self.base_scheduler.num_workers)
* (n / self.base_scheduler.num_workers)
+ (i / self.base_scheduler.num_workers);
if interleaved_idx < n {
let result = f_ref(&items_ref[interleaved_idx]);
let mut results_guard = results.lock().expect("Operation failed");
results_guard[interleaved_idx] = result;
}
}
}
})
})
.collect();
for handle in handles {
handle.join().expect("Operation failed");
}
});
Arc::try_unwrap(results)
.unwrap_or_else(|_| panic!("Failed to extract results"))
.into_inner()
.unwrap_or_else(|_| panic!("Failed to extract mutex inner value"))
}
fn execute_dynamic_chunks<T, R, F>(
&self,
items: &[T],
f: F,
_initial_chunksize: usize,
) -> Vec<R>
where
T: Send + Sync,
R: Send + Default + Clone,
F: Fn(&T) -> R + Send + Sync,
{
self.base_scheduler.execute(items, f)
}
pub fn get_numa_topology(&self) -> NumaTopology {
NumaTopology::detect()
}
pub fn execute_numa_aware<T, R, F>(&self, items: &[T], f: F) -> Vec<R>
where
T: Send + Sync,
R: Send + Default + Clone,
F: Fn(&T) -> R + Send + Sync,
{
if !self.numa_aware {
return self.base_scheduler.execute(items, f);
}
let topology = self.get_numa_topology();
if topology.num_nodes <= 1 {
return self.base_scheduler.execute(items, f);
}
self.execute_with_numa_distribution(items, f, &topology)
}
fn execute_with_numa_distribution<T, R, F>(
&self,
items: &[T],
f: F,
topology: &NumaTopology,
) -> Vec<R>
where
T: Send + Sync,
R: Send + Default + Clone,
F: Fn(&T) -> R + Send + Sync,
{
let n = items.len();
if n == 0 {
return Vec::new();
}
let results = Arc::new(Mutex::new(vec![R::default(); n]));
let work_counter = Arc::new(AtomicUsize::new(0));
let workers_per_node = self.base_scheduler.num_workers / topology.num_nodes;
let chunk_size = std::cmp::max(1, n / (topology.num_nodes * workers_per_node));
std::thread::scope(|s| {
let handles: Vec<_> = (0..self.base_scheduler.num_workers)
.map(|worker_id| {
let work_counter = work_counter.clone();
let results = results.clone();
let items_ref = items;
let f_ref = &f;
let numa_node = worker_id / workers_per_node;
s.spawn(move || {
loop {
let start = work_counter.fetch_add(chunk_size, Ordering::SeqCst);
if start >= n {
break;
}
let end = std::cmp::min(start + chunk_size, n);
for i in start..end {
let result = f_ref(&items_ref[i]);
let mut results_guard = results.lock().expect("Operation failed");
results_guard[i] = result;
}
}
})
})
.collect();
for handle in handles {
handle.join().expect("Operation failed");
}
});
Arc::try_unwrap(results)
.unwrap_or_else(|_| panic!("Failed to extract results"))
.into_inner()
.unwrap_or_else(|_| panic!("Failed to extract mutex inner value"))
}
}
#[derive(Debug, Clone, Copy)]
enum WorkloadType {
MemoryBound,
CpuBound,
Mixed,
}
#[derive(Debug, Clone)]
struct ChunkConfig {
size: usize,
strategy: ChunkStrategy,
}
#[derive(Debug, Clone, Copy)]
enum ChunkStrategy {
Sequential,
Interleaved,
Dynamic,
}
#[derive(Debug, Clone)]
pub struct NumaTopology {
pub num_nodes: usize,
pub cores_per_node: Vec<usize>,
pub memory_per_node: Vec<usize>, }
impl NumaTopology {
pub fn detect() -> Self {
let num_cpus = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1);
Self {
num_nodes: 1,
cores_per_node: vec![num_cpus],
memory_per_node: vec![8192], }
}
pub fn is_numa_system(&self) -> bool {
self.num_nodes > 1
}
pub fn optimal_worker_distribution(&self, total_workers: usize) -> Vec<usize> {
if self.num_nodes <= 1 {
return vec![total_workers];
}
let workers_per_node = total_workers / self.num_nodes;
let remainder = total_workers % self.num_nodes;
let mut distribution = vec![workers_per_node; self.num_nodes];
for i in 0..remainder {
distribution[i] += 1;
}
distribution
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_advanced_scheduler() {
let config = WorkerConfig::default();
let scheduler = AdvancedWorkStealingScheduler::new(&config);
let items = vec![1, 2, 3, 4, 5];
let results = scheduler.execute_optimized(&items, |x| x * 2);
assert_eq!(results, vec![2, 4, 6, 8, 10]);
}
#[test]
fn test_numa_topology() {
let topology = NumaTopology::detect();
assert!(topology.num_nodes >= 1);
assert!(!topology.cores_per_node.is_empty());
}
#[test]
fn test_worker_distribution() {
let topology = NumaTopology {
num_nodes: 2,
cores_per_node: vec![4, 4],
memory_per_node: vec![8192, 8192],
};
let distribution = topology.optimal_worker_distribution(8);
assert_eq!(distribution, vec![4, 4]);
let distribution = topology.optimal_worker_distribution(9);
assert_eq!(distribution, vec![5, 4]);
}
#[test]
fn test_cache_linesize() {
let config = WorkerConfig::default();
let scheduler = AdvancedWorkStealingScheduler::new(&config)
.with_cache_linesize(128);
assert_eq!(scheduler.cache_linesize, 128);
}
}