use serde::{Deserialize, Serialize};
use std::collections::{HashMap, VecDeque};
use std::sync::{Arc, Mutex};
use thiserror::Error;
#[derive(Error, Debug, Clone, PartialEq)]
pub enum ParallelError {
#[error("Task queue is full")]
QueueFull,
#[error("Task dependency cycle detected")]
DependencyCycle,
#[error("Task {0} not found")]
TaskNotFound(String),
#[error("Invalid worker count: {0}")]
InvalidWorkerCount(usize),
#[error("Parallel execution failed: {0}")]
ExecutionFailed(String),
#[error("NUMA allocation failed: {0}")]
NumaAllocationFailed(String),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum StealStrategy {
Random,
MaxLoad,
LRU,
RoundRobin,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct NumaNode(pub usize);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum NumaStrategy {
None,
LocalPreferred,
LocalStrict,
Interleave,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ParallelConfig {
pub num_workers: usize,
pub steal_strategy: StealStrategy,
pub numa_strategy: NumaStrategy,
pub enable_priority: bool,
pub enable_stats: bool,
pub max_queue_size: usize,
pub cache_line_padding: bool,
}
impl Default for ParallelConfig {
fn default() -> Self {
Self {
num_workers: num_cpus::get(),
steal_strategy: StealStrategy::Random,
numa_strategy: NumaStrategy::None,
enable_priority: false,
enable_stats: true,
max_queue_size: 10000,
cache_line_padding: true,
}
}
}
impl ParallelConfig {
pub fn new(num_workers: usize) -> Result<Self, ParallelError> {
if num_workers == 0 {
return Err(ParallelError::InvalidWorkerCount(num_workers));
}
Ok(Self {
num_workers,
..Default::default()
})
}
pub fn with_num_workers(mut self, num_workers: usize) -> Self {
self.num_workers = num_workers;
self
}
pub fn with_steal_strategy(mut self, strategy: StealStrategy) -> Self {
self.steal_strategy = strategy;
self
}
pub fn with_numa_strategy(mut self, strategy: NumaStrategy) -> Self {
self.numa_strategy = strategy;
self
}
pub fn with_priority(mut self, enabled: bool) -> Self {
self.enable_priority = enabled;
self
}
pub fn with_stats(mut self, enabled: bool) -> Self {
self.enable_stats = enabled;
self
}
}
pub type TaskId = String;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub enum TaskPriority {
Low = 0,
Normal = 1,
High = 2,
Critical = 3,
}
#[derive(Debug, Clone)]
pub struct Task {
pub id: TaskId,
pub priority: TaskPriority,
pub dependencies: Vec<TaskId>,
pub numa_node: Option<NumaNode>,
pub estimated_time_us: Option<u64>,
}
impl Task {
pub fn new(id: TaskId) -> Self {
Self {
id,
priority: TaskPriority::Normal,
dependencies: Vec::new(),
numa_node: None,
estimated_time_us: None,
}
}
pub fn with_priority(mut self, priority: TaskPriority) -> Self {
self.priority = priority;
self
}
pub fn with_dependency(mut self, dep: TaskId) -> Self {
self.dependencies.push(dep);
self
}
pub fn with_numa_node(mut self, node: NumaNode) -> Self {
self.numa_node = Some(node);
self
}
pub fn with_estimated_time(mut self, time_us: u64) -> Self {
self.estimated_time_us = Some(time_us);
self
}
}
#[repr(align(64))] struct WorkerQueue {
queue: VecDeque<Task>,
steal_count: usize,
tasks_executed: usize,
total_execution_time_us: u64,
}
impl WorkerQueue {
fn new() -> Self {
Self {
queue: VecDeque::new(),
steal_count: 0,
tasks_executed: 0,
total_execution_time_us: 0,
}
}
fn push(&mut self, task: Task) {
self.queue.push_back(task);
}
fn pop(&mut self) -> Option<Task> {
self.queue.pop_front()
}
fn steal(&mut self) -> Option<Task> {
self.steal_count += 1;
self.queue.pop_back()
}
fn len(&self) -> usize {
self.queue.len()
}
}
pub struct WorkStealingScheduler {
config: ParallelConfig,
workers: Vec<Arc<Mutex<WorkerQueue>>>,
completed_tasks: Arc<Mutex<HashMap<TaskId, u64>>>, stats: Arc<Mutex<SchedulerStats>>,
}
impl WorkStealingScheduler {
pub fn new(config: ParallelConfig) -> Self {
let mut workers = Vec::with_capacity(config.num_workers);
for _ in 0..config.num_workers {
workers.push(Arc::new(Mutex::new(WorkerQueue::new())));
}
Self {
config,
workers,
completed_tasks: Arc::new(Mutex::new(HashMap::new())),
stats: Arc::new(Mutex::new(SchedulerStats::default())),
}
}
pub fn submit(&self, task: Task) -> Result<(), ParallelError> {
self.validate_dependencies(&task)?;
let worker_idx = self.select_worker(&task);
let mut worker = self.workers[worker_idx]
.lock()
.expect("lock should not be poisoned");
if worker.len() >= self.config.max_queue_size {
return Err(ParallelError::QueueFull);
}
worker.push(task);
Ok(())
}
pub fn submit_batch(&self, tasks: Vec<Task>) -> Result<(), ParallelError> {
for task in tasks {
self.submit(task)?;
}
Ok(())
}
pub fn execute_all(&self) -> Result<Vec<TaskId>, ParallelError> {
let mut completed = Vec::new();
for worker in &self.workers {
let mut worker = worker.lock().expect("lock should not be poisoned");
while let Some(task) = worker.pop() {
if self.dependencies_satisfied(&task)? {
let execution_time = task.estimated_time_us.unwrap_or(1000);
worker.tasks_executed += 1;
worker.total_execution_time_us += execution_time;
self.completed_tasks
.lock()
.expect("lock should not be poisoned")
.insert(task.id.clone(), execution_time);
completed.push(task.id);
if self.config.enable_stats {
let mut stats = self.stats.lock().expect("lock should not be poisoned");
stats.tasks_executed += 1;
stats.total_execution_time_us += execution_time;
}
} else {
worker.push(task);
}
}
}
Ok(completed)
}
pub fn stats(&self) -> SchedulerStats {
self.stats
.lock()
.expect("lock should not be poisoned")
.clone()
}
pub fn reset(&self) {
for worker in &self.workers {
let mut worker = worker.lock().expect("lock should not be poisoned");
worker.queue.clear();
worker.steal_count = 0;
worker.tasks_executed = 0;
worker.total_execution_time_us = 0;
}
self.completed_tasks
.lock()
.expect("lock should not be poisoned")
.clear();
*self.stats.lock().expect("lock should not be poisoned") = SchedulerStats::default();
}
fn validate_dependencies(&self, task: &Task) -> Result<(), ParallelError> {
let mut visited = std::collections::HashSet::new();
self.check_cycle(&task.id, &task.dependencies, &mut visited)
}
fn check_cycle(
&self,
current: &TaskId,
dependencies: &[TaskId],
visited: &mut std::collections::HashSet<TaskId>,
) -> Result<(), ParallelError> {
if visited.contains(current) {
return Err(ParallelError::DependencyCycle);
}
visited.insert(current.clone());
for _dep in dependencies {
}
Ok(())
}
fn dependencies_satisfied(&self, task: &Task) -> Result<bool, ParallelError> {
let completed = self
.completed_tasks
.lock()
.expect("lock should not be poisoned");
Ok(task
.dependencies
.iter()
.all(|dep| completed.contains_key(dep)))
}
fn select_worker(&self, task: &Task) -> usize {
if let Some(numa_node) = task.numa_node {
return self.numa_node_to_worker(numa_node);
}
let mut min_load = usize::MAX;
let mut selected = 0;
for (idx, worker) in self.workers.iter().enumerate() {
let worker = worker.lock().expect("lock should not be poisoned");
let load = worker.len();
if load < min_load {
min_load = load;
selected = idx;
}
}
selected
}
fn numa_node_to_worker(&self, node: NumaNode) -> usize {
node.0 % self.config.num_workers
}
pub fn try_steal(&self, thief_idx: usize) -> Option<Task> {
let victim_idx = self.select_victim(thief_idx);
if victim_idx == thief_idx {
return None;
}
let mut victim = self.workers[victim_idx]
.lock()
.expect("lock should not be poisoned");
let stolen = victim.steal();
if stolen.is_some() && self.config.enable_stats {
let mut stats = self.stats.lock().expect("lock should not be poisoned");
stats.steal_count += 1;
}
stolen
}
fn select_victim(&self, thief_idx: usize) -> usize {
match self.config.steal_strategy {
StealStrategy::Random => {
(thief_idx + 1) % self.config.num_workers
}
StealStrategy::MaxLoad => {
let mut max_load = 0;
let mut victim = thief_idx;
for (idx, worker) in self.workers.iter().enumerate() {
if idx == thief_idx {
continue;
}
let worker = worker.lock().expect("lock should not be poisoned");
let load = worker.len();
if load > max_load {
max_load = load;
victim = idx;
}
}
victim
}
StealStrategy::LRU | StealStrategy::RoundRobin => {
(thief_idx + 1) % self.config.num_workers
}
}
}
pub fn load_balance_stats(&self) -> LoadBalanceStats {
let mut worker_loads = Vec::new();
let mut total_tasks = 0;
for worker in &self.workers {
let worker = worker.lock().expect("lock should not be poisoned");
let load = worker.tasks_executed;
worker_loads.push(load);
total_tasks += load;
}
let avg_load = total_tasks as f64 / self.config.num_workers as f64;
let variance = worker_loads
.iter()
.map(|&load| (load as f64 - avg_load).powi(2))
.sum::<f64>()
/ self.config.num_workers as f64;
let std_dev = variance.sqrt();
let cv = if avg_load > 0.0 {
std_dev / avg_load
} else {
0.0
};
let max_load = *worker_loads.iter().max().unwrap_or(&0);
LoadBalanceStats {
worker_loads,
avg_load,
std_dev,
coefficient_of_variation: cv,
imbalance_ratio: if avg_load > 0.0 {
max_load as f64 / avg_load
} else {
1.0
},
}
}
}
#[derive(Debug, Clone, Default, PartialEq)]
pub struct SchedulerStats {
pub tasks_executed: usize,
pub total_execution_time_us: u64,
pub steal_count: usize,
pub failed_steals: usize,
}
impl SchedulerStats {
pub fn avg_execution_time_us(&self) -> f64 {
if self.tasks_executed > 0 {
self.total_execution_time_us as f64 / self.tasks_executed as f64
} else {
0.0
}
}
pub fn steal_success_rate(&self) -> f64 {
let total_attempts = self.steal_count + self.failed_steals;
if total_attempts > 0 {
self.steal_count as f64 / total_attempts as f64
} else {
0.0
}
}
}
impl std::fmt::Display for SchedulerStats {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "Scheduler Statistics")?;
writeln!(f, "====================")?;
writeln!(f, "Tasks executed: {}", self.tasks_executed)?;
writeln!(
f,
"Total time: {:.2} ms",
self.total_execution_time_us as f64 / 1000.0
)?;
writeln!(
f,
"Avg time/task: {:.2} µs",
self.avg_execution_time_us()
)?;
writeln!(f, "Steal count: {}", self.steal_count)?;
writeln!(f, "Failed steals: {}", self.failed_steals)?;
writeln!(
f,
"Steal success rate: {:.2}%",
self.steal_success_rate() * 100.0
)?;
Ok(())
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct LoadBalanceStats {
pub worker_loads: Vec<usize>,
pub avg_load: f64,
pub std_dev: f64,
pub coefficient_of_variation: f64,
pub imbalance_ratio: f64,
}
impl LoadBalanceStats {
pub fn is_well_balanced(&self) -> bool {
self.coefficient_of_variation < 0.2
}
}
impl std::fmt::Display for LoadBalanceStats {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "Load Balance Statistics")?;
writeln!(f, "=======================")?;
writeln!(f, "Worker loads: {:?}", self.worker_loads)?;
writeln!(f, "Average load: {:.2}", self.avg_load)?;
writeln!(f, "Std deviation: {:.2}", self.std_dev)?;
writeln!(f, "CV: {:.4}", self.coefficient_of_variation)?;
writeln!(f, "Imbalance: {:.2}x", self.imbalance_ratio)?;
writeln!(
f,
"Well balanced: {}",
if self.is_well_balanced() { "Yes" } else { "No" }
)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parallel_config_default() {
let config = ParallelConfig::default();
assert!(config.num_workers > 0);
assert_eq!(config.steal_strategy, StealStrategy::Random);
assert!(config.enable_stats);
}
#[test]
fn test_parallel_config_builder() {
let config = ParallelConfig::new(4)
.expect("unwrap")
.with_steal_strategy(StealStrategy::MaxLoad)
.with_numa_strategy(NumaStrategy::LocalPreferred)
.with_priority(true);
assert_eq!(config.num_workers, 4);
assert_eq!(config.steal_strategy, StealStrategy::MaxLoad);
assert_eq!(config.numa_strategy, NumaStrategy::LocalPreferred);
assert!(config.enable_priority);
}
#[test]
fn test_task_creation() {
let task = Task::new("task1".to_string())
.with_priority(TaskPriority::High)
.with_dependency("task0".to_string())
.with_estimated_time(1000);
assert_eq!(task.id, "task1");
assert_eq!(task.priority, TaskPriority::High);
assert_eq!(task.dependencies.len(), 1);
assert_eq!(task.estimated_time_us, Some(1000));
}
#[test]
fn test_scheduler_creation() {
let config = ParallelConfig::new(4).expect("unwrap");
let scheduler = WorkStealingScheduler::new(config);
assert_eq!(scheduler.workers.len(), 4);
}
#[test]
fn test_scheduler_submit() {
let config = ParallelConfig::new(2).expect("unwrap");
let scheduler = WorkStealingScheduler::new(config);
let task = Task::new("task1".to_string());
assert!(scheduler.submit(task).is_ok());
}
#[test]
fn test_scheduler_execute_simple() {
let config = ParallelConfig::new(2).expect("unwrap");
let scheduler = WorkStealingScheduler::new(config);
let task1 = Task::new("task1".to_string()).with_estimated_time(100);
let task2 = Task::new("task2".to_string()).with_estimated_time(200);
scheduler.submit(task1).expect("unwrap");
scheduler.submit(task2).expect("unwrap");
let completed = scheduler.execute_all().expect("unwrap");
assert_eq!(completed.len(), 2);
}
#[test]
fn test_scheduler_dependencies() {
let config = ParallelConfig::new(2).expect("unwrap");
let scheduler = WorkStealingScheduler::new(config);
let task1 = Task::new("task1".to_string());
let task2 = Task::new("task2".to_string()).with_dependency("task1".to_string());
scheduler.submit(task1).expect("unwrap");
scheduler.submit(task2).expect("unwrap");
let completed = scheduler.execute_all().expect("unwrap");
assert!(completed.contains(&"task1".to_string()));
}
#[test]
fn test_scheduler_stats() {
let config = ParallelConfig::new(2).expect("unwrap");
let scheduler = WorkStealingScheduler::new(config);
let task1 = Task::new("task1".to_string()).with_estimated_time(1000);
let task2 = Task::new("task2".to_string()).with_estimated_time(2000);
scheduler.submit(task1).expect("unwrap");
scheduler.submit(task2).expect("unwrap");
scheduler.execute_all().expect("unwrap");
let stats = scheduler.stats();
assert_eq!(stats.tasks_executed, 2);
assert_eq!(stats.total_execution_time_us, 3000);
}
#[test]
fn test_load_balance_stats() {
let config = ParallelConfig::new(4).expect("unwrap");
let scheduler = WorkStealingScheduler::new(config);
for i in 0..8 {
let task = Task::new(format!("task{}", i)).with_estimated_time(100);
scheduler.submit(task).expect("unwrap");
}
scheduler.execute_all().expect("unwrap");
let stats = scheduler.load_balance_stats();
assert!((stats.avg_load - 2.0).abs() < 0.1); }
#[test]
fn test_scheduler_reset() {
let config = ParallelConfig::new(2).expect("unwrap");
let scheduler = WorkStealingScheduler::new(config);
let task = Task::new("task1".to_string());
scheduler.submit(task).expect("unwrap");
scheduler.execute_all().expect("unwrap");
let stats_before = scheduler.stats();
assert_eq!(stats_before.tasks_executed, 1);
scheduler.reset();
let stats_after = scheduler.stats();
assert_eq!(stats_after.tasks_executed, 0);
}
#[test]
fn test_task_priority() {
assert!(TaskPriority::Critical > TaskPriority::High);
assert!(TaskPriority::High > TaskPriority::Normal);
assert!(TaskPriority::Normal > TaskPriority::Low);
}
#[test]
fn test_numa_node() {
let node = NumaNode(0);
assert_eq!(node.0, 0);
}
#[test]
fn test_steal_strategy() {
let s1 = StealStrategy::Random;
let s2 = s1;
let s3 = s1; assert_eq!(s2, s3);
}
}