use crate::error::{NumRs2Error, Result};
use std::cmp::Ordering;
use std::collections::BinaryHeap;
use std::sync::{Arc, Condvar, Mutex};
use std::thread::{self, JoinHandle};
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum TaskPriority {
Low = 0,
Normal = 1,
High = 2,
Critical = 3,
}
#[derive(Debug, Clone)]
pub struct SchedulerConfig {
pub num_threads: usize,
pub max_queue_size: usize,
pub enable_thread_affinity: bool,
pub enable_adaptive_scheduling: bool,
pub time_slice_ms: u64,
pub work_stealing_threshold: usize,
pub cache_aware_scheduling: bool,
}
impl SchedulerConfig {
pub fn optimal_for_cores(num_cores: usize) -> Self {
Self {
num_threads: num_cores,
max_queue_size: 1000,
enable_thread_affinity: true,
enable_adaptive_scheduling: true,
time_slice_ms: 10,
work_stealing_threshold: 5,
cache_aware_scheduling: true,
}
}
pub fn throughput_optimized(num_cores: usize) -> Self {
Self {
num_threads: num_cores,
max_queue_size: 2000,
enable_thread_affinity: false,
enable_adaptive_scheduling: true,
time_slice_ms: 5,
work_stealing_threshold: 3,
cache_aware_scheduling: false,
}
}
pub fn latency_optimized(num_cores: usize) -> Self {
Self {
num_threads: num_cores,
max_queue_size: 500,
enable_thread_affinity: true,
enable_adaptive_scheduling: false,
time_slice_ms: 2,
work_stealing_threshold: 8,
cache_aware_scheduling: true,
}
}
}
pub struct ScheduledTask {
pub id: u64,
pub priority: TaskPriority,
pub submitted_at: Instant,
pub estimated_duration: Option<Duration>,
pub thread_affinity: Option<usize>,
pub task: Box<dyn FnOnce() -> TaskResult + Send + 'static>,
}
impl std::fmt::Debug for ScheduledTask {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ScheduledTask")
.field("id", &self.id)
.field("priority", &self.priority)
.field("submitted_at", &self.submitted_at)
.field("estimated_duration", &self.estimated_duration)
.field("thread_affinity", &self.thread_affinity)
.field("task", &"<closure>")
.finish()
}
}
impl PartialEq for ScheduledTask {
fn eq(&self, other: &Self) -> bool {
self.priority == other.priority && self.submitted_at == other.submitted_at
}
}
impl Eq for ScheduledTask {}
impl PartialOrd for ScheduledTask {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for ScheduledTask {
fn cmp(&self, other: &Self) -> Ordering {
match self.priority.cmp(&other.priority) {
Ordering::Equal => other.submitted_at.cmp(&self.submitted_at), other => other, }
}
}
#[derive(Debug)]
pub enum TaskResult {
Success,
Error(String),
Cancelled,
}
#[derive(Debug)]
#[repr(align(64))]
struct ThreadState {
#[allow(dead_code)]
id: usize,
local_queue: BinaryHeap<ScheduledTask>,
tasks_executed: u64,
total_execution_time: Duration,
idle_time: Duration,
last_steal_time: Instant,
_padding: [u8; 0],
}
impl ThreadState {
fn new(id: usize) -> Self {
Self {
id,
local_queue: BinaryHeap::new(),
tasks_executed: 0,
total_execution_time: Duration::ZERO,
idle_time: Duration::ZERO,
last_steal_time: Instant::now(),
_padding: [],
}
}
fn efficiency(&self) -> f64 {
let total_time = self.total_execution_time + self.idle_time;
if total_time.is_zero() {
0.0
} else {
self.total_execution_time.as_secs_f64() / total_time.as_secs_f64()
}
}
}
pub struct ParallelScheduler {
config: SchedulerConfig,
global_queue: Arc<Mutex<BinaryHeap<ScheduledTask>>>,
#[allow(dead_code)]
worker_threads: Vec<JoinHandle<()>>,
shutdown_signal: Arc<(Mutex<bool>, Condvar)>,
thread_states: Arc<Mutex<Vec<ThreadState>>>,
next_task_id: Arc<Mutex<u64>>,
scheduler_stats: Arc<Mutex<SchedulerStats>>,
}
#[derive(Debug, Clone, Default)]
pub struct SchedulerStats {
pub tasks_submitted: u64,
pub tasks_completed: u64,
pub tasks_failed: u64,
pub average_queue_time: Duration,
pub average_execution_time: Duration,
pub thread_efficiency: Vec<f64>,
pub work_steals: u64,
pub queue_overflows: u64,
}
impl ParallelScheduler {
pub fn new(config: SchedulerConfig) -> Result<Self> {
let global_queue = Arc::new(Mutex::new(BinaryHeap::new()));
let shutdown_signal = Arc::new((Mutex::new(false), Condvar::new()));
let thread_states = Arc::new(Mutex::new(Vec::new()));
let next_task_id = Arc::new(Mutex::new(0));
let scheduler_stats = Arc::new(Mutex::new(SchedulerStats::default()));
{
let mut states = thread_states.lock().expect("lock should not be poisoned");
for i in 0..config.num_threads {
states.push(ThreadState::new(i));
}
}
let mut worker_threads = Vec::new();
for thread_id in 0..config.num_threads {
let global_queue = Arc::clone(&global_queue);
let shutdown_signal = Arc::clone(&shutdown_signal);
let thread_states = Arc::clone(&thread_states);
let scheduler_stats = Arc::clone(&scheduler_stats);
let config = config.clone();
let handle = thread::spawn(move || {
Self::worker_thread_main(
thread_id,
global_queue,
shutdown_signal,
thread_states,
scheduler_stats,
config,
);
});
worker_threads.push(handle);
}
Ok(Self {
config,
global_queue,
worker_threads,
shutdown_signal,
thread_states,
next_task_id,
scheduler_stats,
})
}
pub fn submit_task<F>(
&self,
task: F,
priority: TaskPriority,
estimated_duration: Option<Duration>,
thread_affinity: Option<usize>,
) -> Result<u64>
where
F: FnOnce() -> TaskResult + Send + 'static,
{
let task_id = {
let mut id = self
.next_task_id
.lock()
.expect("lock should not be poisoned");
*id += 1;
*id
};
let scheduled_task = ScheduledTask {
id: task_id,
priority,
submitted_at: Instant::now(),
estimated_duration,
thread_affinity,
task: Box::new(task),
};
let mut queue = self
.global_queue
.lock()
.expect("lock should not be poisoned");
if queue.len() >= self.config.max_queue_size {
let mut stats = self
.scheduler_stats
.lock()
.expect("lock should not be poisoned");
stats.queue_overflows += 1;
drop(stats);
drop(queue);
return Err(NumRs2Error::RuntimeError("Task queue full".to_string()));
}
queue.push(scheduled_task);
{
let mut stats = self
.scheduler_stats
.lock()
.expect("lock should not be poisoned");
stats.tasks_submitted += 1;
}
let (_, condvar) = &*self.shutdown_signal;
condvar.notify_one();
Ok(task_id)
}
pub fn submit_urgent_task<F>(&self, task: F) -> Result<u64>
where
F: FnOnce() -> TaskResult + Send + 'static,
{
self.submit_task(task, TaskPriority::Critical, None, None)
}
pub fn statistics(&self) -> SchedulerStats {
let mut stats = self
.scheduler_stats
.lock()
.expect("lock should not be poisoned");
if let Ok(thread_states) = self.thread_states.try_lock() {
stats.thread_efficiency = thread_states
.iter()
.map(|state| state.efficiency())
.collect();
}
stats.clone()
}
pub fn num_threads(&self) -> usize {
self.config.num_threads
}
pub fn queue_length(&self) -> usize {
self.global_queue
.lock()
.expect("lock should not be poisoned")
.len()
}
pub fn shutdown(&self) -> Result<()> {
{
let (shutdown_flag, condvar) = &*self.shutdown_signal;
let mut flag = shutdown_flag.lock().expect("lock should not be poisoned");
*flag = true;
condvar.notify_all();
}
Ok(())
}
fn worker_thread_main(
thread_id: usize,
global_queue: Arc<Mutex<BinaryHeap<ScheduledTask>>>,
shutdown_signal: Arc<(Mutex<bool>, Condvar)>,
thread_states: Arc<Mutex<Vec<ThreadState>>>,
scheduler_stats: Arc<Mutex<SchedulerStats>>,
config: SchedulerConfig,
) {
let (shutdown_flag, condvar) = &*shutdown_signal;
loop {
{
let flag = shutdown_flag.lock().expect("lock should not be poisoned");
if *flag {
break;
}
}
let task = {
let mut queue = global_queue.lock().expect("lock should not be poisoned");
queue.pop()
};
match task {
Some(scheduled_task) => {
let start_time = Instant::now();
let result = (scheduled_task.task)();
let execution_time = start_time.elapsed();
{
let mut states = thread_states.lock().expect("lock should not be poisoned");
if let Some(state) = states.get_mut(thread_id) {
state.tasks_executed += 1;
state.total_execution_time += execution_time;
}
}
{
let mut stats =
scheduler_stats.lock().expect("lock should not be poisoned");
match result {
TaskResult::Success => stats.tasks_completed += 1,
TaskResult::Error(_) | TaskResult::Cancelled => stats.tasks_failed += 1,
}
let total_tasks = stats.tasks_completed + stats.tasks_failed;
let old_time_nanos = stats.average_execution_time.as_nanos() as u64;
let new_time_nanos = execution_time.as_nanos() as u64;
if let Some(avg_nanos) = (old_time_nanos * (total_tasks - 1)
+ new_time_nanos)
.checked_div(total_tasks)
{
stats.average_execution_time = Duration::from_nanos(avg_nanos);
}
}
}
None => {
let idle_start = Instant::now();
if config.enable_adaptive_scheduling {
if Self::try_work_stealing(thread_id, &thread_states, &config) {
continue;
}
}
let flag = shutdown_flag.lock().expect("lock should not be poisoned");
if !*flag {
let _ =
condvar.wait_timeout(flag, Duration::from_millis(config.time_slice_ms));
}
let idle_time = idle_start.elapsed();
{
let mut states = thread_states.lock().expect("lock should not be poisoned");
if let Some(state) = states.get_mut(thread_id) {
state.idle_time += idle_time;
}
}
}
}
}
}
fn try_work_stealing(
thread_id: usize,
thread_states: &Arc<Mutex<Vec<ThreadState>>>,
config: &SchedulerConfig,
) -> bool {
if let Ok(mut states) = thread_states.try_lock() {
let current_time = Instant::now();
if let Some(current_state) = states.get_mut(thread_id) {
if current_time.duration_since(current_state.last_steal_time)
< Duration::from_millis(config.time_slice_ms)
{
return false;
}
current_state.last_steal_time = current_time;
}
let mut best_victim = None;
let mut max_tasks = config.work_stealing_threshold;
for (i, state) in states.iter().enumerate() {
if i != thread_id && state.local_queue.len() > max_tasks {
max_tasks = state.local_queue.len();
best_victim = Some(i);
}
}
if let Some(victim_id) = best_victim {
if let Some(victim_state) = states.get_mut(victim_id) {
if let Some(stolen_task) = victim_state.local_queue.pop() {
if let Some(current_state) = states.get_mut(thread_id) {
current_state.local_queue.push(stolen_task);
return true;
}
}
}
}
}
false
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
#[test]
fn test_scheduler_creation() {
let config = SchedulerConfig::optimal_for_cores(2);
let scheduler =
ParallelScheduler::new(config).expect("failed to create parallel scheduler");
assert_eq!(scheduler.num_threads(), 2);
assert_eq!(scheduler.queue_length(), 0);
}
#[test]
fn test_task_submission() {
let config = SchedulerConfig::optimal_for_cores(2);
let scheduler =
ParallelScheduler::new(config).expect("failed to create parallel scheduler");
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = Arc::clone(&counter);
let task_id = scheduler
.submit_task(
move || {
counter_clone.fetch_add(1, Ordering::SeqCst);
TaskResult::Success
},
TaskPriority::Normal,
None,
None,
)
.expect("failed to submit task");
assert!(task_id > 0);
std::thread::sleep(Duration::from_millis(100));
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[test]
fn test_priority_scheduling() {
let config = SchedulerConfig::optimal_for_cores(1); let scheduler =
ParallelScheduler::new(config).expect("failed to create parallel scheduler");
let execution_order = Arc::new(Mutex::new(Vec::new()));
let priorities = [
TaskPriority::Low,
TaskPriority::Normal,
TaskPriority::High,
TaskPriority::Critical,
];
let blocker = Arc::new(AtomicU32::new(0));
let blocker_clone = Arc::clone(&blocker);
let _ = scheduler
.submit_task(
move || {
while blocker_clone.load(Ordering::SeqCst) == 0 {
std::thread::sleep(Duration::from_millis(10));
}
TaskResult::Success
},
TaskPriority::Low,
None,
None,
)
.expect("failed to submit blocker task");
std::thread::sleep(Duration::from_millis(50));
for priority in priorities {
let order_clone = Arc::clone(&execution_order);
let _ = scheduler
.submit_task(
move || {
order_clone
.lock()
.expect("lock should not be poisoned")
.push(priority);
std::thread::sleep(Duration::from_millis(10));
TaskResult::Success
},
priority,
None,
None,
)
.expect("failed to submit priority task");
}
std::thread::sleep(Duration::from_millis(50));
blocker.store(1, Ordering::SeqCst);
std::thread::sleep(Duration::from_millis(300));
let order = execution_order.lock().expect("lock should not be poisoned");
assert_eq!(
order.len(),
4,
"Expected 4 tasks to complete, got {}",
order.len()
);
assert_eq!(
*order,
vec![
TaskPriority::Critical,
TaskPriority::High,
TaskPriority::Normal,
TaskPriority::Low
],
"Expected strict priority order, got {:?}",
*order
);
}
#[test]
fn test_scheduler_statistics() {
let config = SchedulerConfig::optimal_for_cores(2);
let scheduler =
ParallelScheduler::new(config).expect("failed to create parallel scheduler");
for _ in 0..5 {
let _ = scheduler
.submit_task(
|| {
std::thread::sleep(Duration::from_millis(10));
TaskResult::Success
},
TaskPriority::Normal,
None,
None,
)
.expect("failed to submit statistics test task");
}
std::thread::sleep(Duration::from_millis(500));
let stats = scheduler.statistics();
assert_eq!(stats.tasks_submitted, 5);
assert!(stats.tasks_completed > 0);
assert!(stats.thread_efficiency.len() == 2);
}
#[test]
fn test_urgent_task_submission() {
let config = SchedulerConfig::optimal_for_cores(1);
let scheduler =
ParallelScheduler::new(config).expect("failed to create parallel scheduler");
let executed = Arc::new(AtomicU32::new(0));
let executed_clone = Arc::clone(&executed);
let task_id = scheduler
.submit_urgent_task(move || {
executed_clone.store(1, Ordering::SeqCst);
TaskResult::Success
})
.expect("failed to submit urgent task");
assert!(task_id > 0);
std::thread::sleep(Duration::from_millis(100));
assert_eq!(executed.load(Ordering::SeqCst), 1);
}
}