use crate::error::{CoreError, CoreResult, ErrorContext, ErrorLocation};
use std::collections::VecDeque;
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
use std::sync::{Arc, Mutex, RwLock};
use std::thread::{self, JoinHandle};
use std::time::{Duration, Instant};
#[cfg(feature = "parallel")]
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Default)]
pub enum TaskPriority {
Low = 0,
#[default]
Normal = 1,
High = 2,
Critical = 3,
}
#[derive(Debug)]
pub struct NumaNode {
pub id: usize,
pub cpu_cores: Vec<usize>,
pub memory_size: usize,
pub memory_used: AtomicUsize,
}
impl Clone for NumaNode {
fn clone(&self) -> Self {
Self {
id: self.id,
cpu_cores: self.cpu_cores.clone(),
memory_size: self.memory_size,
memory_used: AtomicUsize::new(
self.memory_used.load(std::sync::atomic::Ordering::Relaxed),
),
}
}
}
impl NumaNode {
pub fn new(id: usize, cpu_cores: Vec<usize>, memory_size: usize) -> Self {
Self {
id,
cpu_cores,
memory_size,
memory_used: AtomicUsize::new(0),
}
}
pub fn memory_utilization(&self) -> f64 {
if self.memory_size == 0 {
0.0
} else {
self.memory_used.load(Ordering::Relaxed) as f64 / self.memory_size as f64
}
}
pub fn has_available_memory(&self, required: usize) -> bool {
self.memory_used.load(Ordering::Relaxed) + required <= self.memory_size
}
}
#[derive(Debug, Clone)]
pub struct WorkStealingConfig {
pub num_workers: Option<usize>,
pub max_queue_size: usize,
pub steal_timeout_ms: u64,
pub max_steal_attempts: usize,
pub numa_aware: bool,
pub priority_scheduling: bool,
pub thread_affinity: Option<Vec<usize>>,
pub max_memory_per_worker: Option<usize>,
pub enable_monitoring: bool,
pub stats_interval: Duration,
pub adaptive_balancing: bool,
pub load_balance_threshold: f64,
}
impl Default for WorkStealingConfig {
fn default() -> Self {
Self {
num_workers: None,
max_queue_size: 10000,
steal_timeout_ms: 1,
max_steal_attempts: 3,
numa_aware: false,
priority_scheduling: true,
thread_affinity: None,
max_memory_per_worker: None,
enable_monitoring: true,
stats_interval: Duration::from_secs(1),
adaptive_balancing: true,
load_balance_threshold: 0.8,
}
}
}
pub trait WorkStealingTask: Send + 'static {
type Output: Send + 'static;
fn execute(self) -> Self::Output;
fn estimated_duration(&self) -> Option<Duration> {
None
}
fn estimated_memory(&self) -> Option<usize> {
None
}
fn can_split(&self) -> bool {
false
}
fn split(self) -> Vec<Box<dyn WorkStealingTask<Output = Self::Output>>>
where
Self: Sized,
{
vec![Box::new(self)]
}
}
struct FunctionTask<F, R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
func: Option<F>,
estimated_duration: Option<Duration>,
estimated_memory: Option<usize>,
}
impl<F, R> FunctionTask<F, R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
fn new(func: F) -> Self {
Self {
func: Some(func),
estimated_duration: None,
estimated_memory: None,
}
}
fn with_estimates(func: F, duration: Option<Duration>, memory: Option<usize>) -> Self {
Self {
func: Some(func),
estimated_duration: duration,
estimated_memory: memory,
}
}
}
impl<F, R> WorkStealingTask for FunctionTask<F, R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
type Output = R;
fn execute(mut self) -> Self::Output {
let func = self.func.take().expect("Function already executed");
func()
}
fn estimated_duration(&self) -> Option<Duration> {
self.estimated_duration
}
fn estimated_memory(&self) -> Option<usize> {
self.estimated_memory
}
}
struct PrioritizedTask {
task: Option<Box<dyn FnOnce() -> Box<dyn std::any::Any + Send> + Send>>,
priority: TaskPriority,
submitted_at: Instant,
numa_hint: Option<usize>,
}
impl std::fmt::Debug for PrioritizedTask {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PrioritizedTask")
.field("priority", &self.priority)
.field("submitted_at", &self.submitted_at)
.field("numa_hint", &self.numa_hint)
.field("task", &self.task.is_some())
.finish()
}
}
impl PrioritizedTask {
fn new<T: WorkStealingTask>(task: T, priority: TaskPriority, numa_hint: Option<usize>) -> Self
where
T::Output: 'static,
{
Self {
task: Some(Box::new(move || Box::new(task.execute()))),
priority,
submitted_at: Instant::now(),
numa_hint,
}
}
fn execute(mut self) -> Box<dyn std::any::Any + Send> {
let task = self.task.take().expect("Task already executed");
task()
}
}
#[derive(Debug)]
struct PriorityTaskQueue {
queues: [VecDeque<PrioritizedTask>; 4], total_size: usize,
maxsize: usize,
}
impl PriorityTaskQueue {
fn new(max_size: usize) -> Self {
Self {
queues: [
VecDeque::new(), VecDeque::new(), VecDeque::new(), VecDeque::new(), ],
total_size: 0,
maxsize: max_size,
}
}
fn push(&mut self, task: PrioritizedTask) -> Result<(), PrioritizedTask> {
if self.total_size >= self.maxsize {
return Err(task);
}
let priority_idx = task.priority as usize;
self.queues[priority_idx].push_back(task);
self.total_size += 1;
Ok(())
}
fn pop(&mut self) -> Option<PrioritizedTask> {
for queue in self.queues.iter_mut().rev() {
if let Some(task) = queue.pop_front() {
self.total_size -= 1;
return Some(task);
}
}
None
}
fn steal(&mut self) -> Option<PrioritizedTask> {
for queue in &mut self.queues {
if let Some(task) = queue.pop_back() {
self.total_size -= 1;
return Some(task);
}
}
None
}
fn len(&self) -> usize {
self.total_size
}
#[allow(dead_code)]
fn is_empty(&self) -> bool {
self.total_size == 0
}
#[allow(dead_code)]
fn is_full(&self) -> bool {
self.total_size >= self.maxsize
}
}
struct Worker {
#[allow(dead_code)]
id: usize,
local_queue: Arc<Mutex<PriorityTaskQueue>>,
global_queue: Arc<Mutex<PriorityTaskQueue>>,
other_workers: Vec<Arc<Mutex<PriorityTaskQueue>>>,
#[allow(dead_code)]
numa_node: Option<usize>,
shutdown: Arc<AtomicBool>,
stats: Arc<WorkerStats>,
config: WorkStealingConfig,
}
#[derive(Debug, Default)]
struct WorkerStats {
tasks_executed: AtomicU64,
tasks_stolen: AtomicU64,
#[allow(dead_code)]
tasks_provided: AtomicU64,
idle_time: AtomicU64,
active_time: AtomicU64,
last_activity: AtomicU64,
}
impl Worker {
fn new(
id: usize,
global_queue: Arc<Mutex<PriorityTaskQueue>>,
numa_node: Option<usize>,
shutdown: Arc<AtomicBool>,
config: WorkStealingConfig,
) -> Self {
let local_queue = Arc::new(Mutex::new(PriorityTaskQueue::new(
config.max_queue_size / 4, )));
Self {
id,
local_queue,
global_queue,
other_workers: Vec::new(),
numa_node,
shutdown,
stats: Arc::new(WorkerStats::default()),
config,
}
}
fn add_other_worker(&mut self, worker_queue: Arc<Mutex<PriorityTaskQueue>>) {
self.other_workers.push(worker_queue);
}
fn run(self, result_sender: crossbeam::channel::Sender<Box<dyn std::any::Any + Send>>) {
let mut consecutive_steals = 0;
let mut last_steal_attempt = Instant::now();
while !self.shutdown.load(Ordering::Relaxed) {
let task_start = Instant::now();
if let Some(task) = self.get_task() {
let result = task.execute();
if result_sender.send(result).is_err() {
break;
}
self.stats.tasks_executed.fetch_add(1, Ordering::Relaxed);
self.stats
.active_time
.fetch_add(task_start.elapsed().as_micros() as u64, Ordering::Relaxed);
self.stats
.last_activity
.store(task_start.elapsed().as_secs(), Ordering::Relaxed);
consecutive_steals = 0;
} else {
let idle_start = Instant::now();
if last_steal_attempt.elapsed()
>= Duration::from_millis(self.config.steal_timeout_ms)
{
if self.try_steal_work() {
consecutive_steals += 1;
self.stats.tasks_stolen.fetch_add(1, Ordering::Relaxed);
}
last_steal_attempt = Instant::now();
}
self.stats
.idle_time
.fetch_add(idle_start.elapsed().as_micros() as u64, Ordering::Relaxed);
let backoff_duration = if consecutive_steals > 5 {
Duration::from_millis(10) } else {
Duration::from_micros(100) };
thread::sleep(backoff_duration);
}
}
}
fn get_task(&self) -> Option<PrioritizedTask> {
if let Ok(mut local) = self.local_queue.try_lock() {
if let Some(task) = local.pop() {
return Some(task);
}
}
if let Ok(mut global) = self.global_queue.try_lock() {
if let Some(task) = global.pop() {
return Some(task);
}
}
None
}
fn try_steal_work(&self) -> bool {
let mut attempts = 0;
let mut workers = self.other_workers.clone();
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
self.id.hash(&mut hasher);
let seed = hasher.finish() as usize;
for i in 0..workers.len() {
let j = (seed + i) % workers.len();
workers.swap(i, j);
}
for worker_queue in workers {
if attempts >= self.config.max_steal_attempts {
break;
}
if let Ok(mut queue) = worker_queue.try_lock() {
if let Some(task) = queue.steal() {
if let Ok(mut local) = self.local_queue.lock() {
if local.push(task).is_ok() {
return true;
}
}
}
}
attempts += 1;
}
false
}
}
#[derive(Debug, Clone, Default)]
pub struct SchedulerStats {
pub tasks_submitted: u64,
pub tasks_completed: u64,
pub tasks_pending: u64,
pub total_steals: u64,
pub avg_execution_time_us: f64,
pub worker_utilization: f64,
pub memory_usage_per_worker: Vec<usize>,
pub load_balance_operations: u64,
pub throughput: f64,
}
pub struct WorkStealingScheduler {
config: WorkStealingConfig,
#[allow(dead_code)]
workers: Vec<Worker>,
worker_handles: Vec<JoinHandle<()>>,
global_queue: Arc<Mutex<PriorityTaskQueue>>,
result_receiver: crossbeam::channel::Receiver<Box<dyn std::any::Any + Send>>,
result_sender: crossbeam::channel::Sender<Box<dyn std::any::Any + Send>>,
shutdown: Arc<AtomicBool>,
#[allow(dead_code)]
numa_nodes: Vec<NumaNode>,
stats: Arc<RwLock<SchedulerStats>>,
start_time: Option<Instant>,
}
impl WorkStealingScheduler {
pub fn new(config: WorkStealingConfig) -> CoreResult<Self> {
let num_workers = config.num_workers.unwrap_or_else(|| {
std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(4)
});
let global_queue = Arc::new(Mutex::new(PriorityTaskQueue::new(config.max_queue_size)));
let (result_sender, result_receiver) = crossbeam::channel::unbounded();
let shutdown = Arc::new(AtomicBool::new(false));
let numa_nodes = if config.numa_aware {
Self::detect_numa_topology(num_workers)
} else {
vec![NumaNode::new(0, (0..num_workers).collect(), 0)]
};
let mut workers = Vec::with_capacity(num_workers);
for i in 0..num_workers {
let numa_node = if config.numa_aware {
Some(i % numa_nodes.len())
} else {
None
};
let worker = Worker::new(
i,
global_queue.clone(),
numa_node,
shutdown.clone(),
config.clone(),
);
workers.push(worker);
}
let local_queues: Vec<_> = workers.iter().map(|w| w.local_queue.clone()).collect();
for (i, worker) in workers.iter_mut().enumerate() {
for (j, queue) in local_queues.iter().enumerate() {
if i != j {
worker.add_other_worker(queue.clone());
}
}
}
Ok(Self {
config,
workers,
worker_handles: Vec::new(),
global_queue,
result_receiver,
result_sender,
shutdown,
numa_nodes,
stats: Arc::new(RwLock::new(SchedulerStats::default())),
start_time: None,
})
}
pub fn start(&mut self) -> CoreResult<()> {
if !self.worker_handles.is_empty() {
return Err(CoreError::StreamError(
ErrorContext::new("Scheduler already started".to_string())
.with_location(ErrorLocation::new(file!(), line!())),
));
}
self.start_time = Some(Instant::now());
let workers = std::mem::take(&mut self.workers);
for worker in workers {
let worker_id = worker.id;
let result_sender = self.result_sender.clone();
let handle = thread::Builder::new()
.name(format!("worker-{worker_id}"))
.spawn(move || {
worker.run(result_sender);
})
.map_err(|e| {
CoreError::StreamError(
ErrorContext::new(format!("{e}"))
.with_location(ErrorLocation::new(file!(), line!())),
)
})?;
self.worker_handles.push(handle);
}
if self.config.enable_monitoring {
self.start_monitoring();
}
Ok(())
}
pub fn submit<F, R>(&self, priority: TaskPriority, func: F) -> CoreResult<()>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
let task = FunctionTask::new(func);
self.submit_task(priority, task, None)
}
pub fn submit_with_estimates<F, R>(
&self,
priority: TaskPriority,
func: F,
duration_estimate: Option<Duration>,
memory_estimate: Option<usize>,
) -> CoreResult<()>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
let task = FunctionTask::with_estimates(func, duration_estimate, memory_estimate);
self.submit_task(priority, task, None)
}
pub fn submit_to_numa<F, R>(
&self,
priority: TaskPriority,
numa_node: usize,
func: F,
) -> CoreResult<()>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
let task = FunctionTask::new(func);
self.submit_task(priority, task, Some(numa_node))
}
fn submit_task<T>(
&self,
priority: TaskPriority,
task: T,
numa_hint: Option<usize>,
) -> CoreResult<()>
where
T: WorkStealingTask,
T::Output: 'static,
{
if let Some(numa_node) = numa_hint {
if numa_node < self.workers.len() {
if let Ok(mut local_queue) = self.workers[numa_node].local_queue.try_lock() {
let prioritized_task = PrioritizedTask::new(task, priority, numa_hint);
local_queue.push(prioritized_task).map_err(|_| {
CoreError::StreamError(
ErrorContext::new("Local task queue is full".to_string())
.with_location(ErrorLocation::new(file!(), line!())),
)
})?;
self.update_submit_stats();
return Ok(());
}
}
}
let prioritized_task = PrioritizedTask::new(task, priority, numa_hint);
let mut global_queue = self.global_queue.lock().expect("Operation failed");
global_queue.push(prioritized_task).map_err(|_| {
CoreError::StreamError(
ErrorContext::new("Global task queue is full".to_string())
.with_location(ErrorLocation::new(file!(), line!())),
)
})?;
self.update_submit_stats();
Ok(())
}
pub fn try_recv<T: 'static>(&self) -> Option<T> {
if let Ok(result) = self.result_receiver.try_recv() {
self.update_completion_stats();
if let Ok(typed_result) = result.downcast::<T>() {
Some(*typed_result)
} else {
None
}
} else {
None
}
}
pub fn recv_timeout<T: 'static>(&self, timeout: Duration) -> CoreResult<T> {
let result = self.result_receiver.recv_timeout(timeout).map_err(|_| {
CoreError::TimeoutError(
ErrorContext::new("Timeout waiting for task result".to_string())
.with_location(ErrorLocation::new(file!(), line!())),
)
})?;
self.update_completion_stats();
result.downcast::<T>().map(|r| *r).map_err(|_| {
CoreError::ValidationError(
ErrorContext::new("Task result type mismatch".to_string())
.with_location(ErrorLocation::new(file!(), line!())),
)
})
}
pub fn stats(&self) -> SchedulerStats {
self.stats.read().expect("Operation failed").clone()
}
pub fn stop(&mut self) -> CoreResult<()> {
self.shutdown.store(true, Ordering::Relaxed);
for handle in self.worker_handles.drain(..) {
handle.join().map_err(|_| {
CoreError::StreamError(
ErrorContext::new("Failed to join worker thread".to_string())
.with_location(ErrorLocation::new(file!(), line!())),
)
})?;
}
Ok(())
}
pub fn pending_tasks(&self) -> usize {
let global_pending = self.global_queue.lock().expect("Operation failed").len();
let local_pending: usize = self
.workers
.iter()
.map(|w| w.local_queue.lock().expect("Operation failed").len())
.sum();
global_pending + local_pending
}
fn detect_numa_topology(num_workers: usize) -> Vec<NumaNode> {
vec![NumaNode::new(0, (0..num_workers).collect(), 0)]
}
fn update_submit_stats(&self) {
if let Ok(mut stats) = self.stats.write() {
stats.tasks_submitted += 1;
stats.tasks_pending += 1;
}
}
fn update_completion_stats(&self) {
if let Ok(mut stats) = self.stats.write() {
stats.tasks_completed += 1;
if stats.tasks_pending > 0 {
stats.tasks_pending -= 1;
}
}
}
fn start_monitoring(&self) {
let stats = self.stats.clone();
let shutdown = self.shutdown.clone();
let interval = self.config.stats_interval;
let start_time = self.start_time;
thread::spawn(move || {
while !shutdown.load(Ordering::Relaxed) {
thread::sleep(interval);
if let Ok(mut stats_guard) = stats.write() {
if let Some(start) = start_time {
let elapsed = start.elapsed().as_secs_f64();
if elapsed > 0.0 {
stats_guard.throughput = stats_guard.tasks_completed as f64 / elapsed;
}
}
}
}
});
}
}
impl Drop for WorkStealingScheduler {
fn drop(&mut self) {
let _ = self.stop();
}
}
#[derive(Debug, Clone)]
pub struct WorkStealingConfigBuilder {
config: WorkStealingConfig,
}
impl WorkStealingConfigBuilder {
pub fn new() -> Self {
Self {
config: WorkStealingConfig::default(),
}
}
pub fn num_workers(mut self, workers: usize) -> Self {
self.config.num_workers = Some(workers);
self
}
pub const fn max_queue_size(mut self, size: usize) -> Self {
self.config.max_queue_size = size;
self
}
pub const fn numa_aware(mut self, enable: bool) -> Self {
self.config.numa_aware = enable;
self
}
pub const fn priority_scheduling(mut self, enable: bool) -> Self {
self.config.priority_scheduling = enable;
self
}
pub fn thread_affinity(mut self, affinity: Vec<usize>) -> Self {
self.config.thread_affinity = Some(affinity);
self
}
pub fn max_memory_per_worker(mut self, memory: usize) -> Self {
self.config.max_memory_per_worker = Some(memory);
self
}
pub const fn enable_monitoring(mut self, enable: bool) -> Self {
self.config.enable_monitoring = enable;
self
}
pub const fn adaptive_balancing(mut self, enable: bool) -> Self {
self.config.adaptive_balancing = enable;
self
}
pub fn build(self) -> WorkStealingConfig {
self.config
}
}
impl Default for WorkStealingConfigBuilder {
fn default() -> Self {
Self::new()
}
}
#[allow(dead_code)]
pub fn create_work_stealing_scheduler() -> CoreResult<WorkStealingScheduler> {
WorkStealingScheduler::new(WorkStealingConfig::default())
}
#[allow(dead_code)]
pub fn create_cpu_intensive_scheduler() -> CoreResult<WorkStealingScheduler> {
let config = WorkStealingConfigBuilder::new()
.num_workers(
std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(4),
)
.priority_scheduling(true)
.adaptive_balancing(true)
.enable_monitoring(true)
.build();
WorkStealingScheduler::new(config)
}
#[allow(dead_code)]
pub fn create_io_intensive_scheduler() -> CoreResult<WorkStealingScheduler> {
let num_workers = std::thread::available_parallelism()
.map(|n| n.get() * 2) .unwrap_or(8);
let config = WorkStealingConfigBuilder::new()
.num_workers(num_workers)
.max_queue_size(50000) .priority_scheduling(true)
.adaptive_balancing(false) .enable_monitoring(true)
.build();
WorkStealingScheduler::new(config)
}
#[cfg(all(test, feature = "parallel"))]
mod tests {
use super::*;
use std::sync::atomic::AtomicU32;
#[test]
fn test_work_stealing_scheduler_creation() {
let scheduler = create_work_stealing_scheduler();
assert!(scheduler.is_ok());
}
#[test]
fn test_task_submission_and_execution() {
let mut scheduler = create_work_stealing_scheduler().expect("Operation failed");
scheduler.start().expect("Operation failed");
scheduler
.submit(TaskPriority::Normal, || 42)
.expect("Operation failed");
std::thread::sleep(Duration::from_millis(100));
if let Some(result) = scheduler.try_recv::<i32>() {
assert_eq!(result, 42);
}
scheduler.stop().expect("Operation failed");
}
#[test]
fn test_priority_scheduling() {
let mut scheduler = create_work_stealing_scheduler().expect("Operation failed");
scheduler.start().expect("Operation failed");
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = counter.clone();
scheduler
.submit(TaskPriority::Low, move || {
std::thread::sleep(Duration::from_millis(50));
counter_clone.store(1, Ordering::Relaxed);
})
.expect("Operation failed");
let counter_clone = counter.clone();
scheduler
.submit(TaskPriority::High, move || {
counter_clone.store(2, Ordering::Relaxed);
})
.expect("Operation failed");
std::thread::sleep(Duration::from_millis(200));
scheduler.stop().expect("Operation failed");
}
#[test]
fn test_scheduler_stats() {
let mut scheduler = create_work_stealing_scheduler().expect("Operation failed");
scheduler.start().expect("Operation failed");
for i in 0..10 {
scheduler
.submit(TaskPriority::Normal, move || i * 2)
.expect("Operation failed");
}
std::thread::sleep(Duration::from_millis(100));
let stats = scheduler.stats();
assert!(stats.tasks_submitted >= 10);
scheduler.stop().expect("Operation failed");
}
#[test]
fn test_config_builder() {
let config = WorkStealingConfigBuilder::new()
.num_workers(8)
.max_queue_size(5000)
.numa_aware(true)
.priority_scheduling(false)
.adaptive_balancing(true)
.build();
assert_eq!(config.num_workers, Some(8));
assert_eq!(config.max_queue_size, 5000);
assert!(config.numa_aware);
assert!(!config.priority_scheduling);
assert!(config.adaptive_balancing);
}
}