veda-rs 1.0.0

High-performance parallel runtime for Rust with work-stealing and adaptive scheduling
Documentation
use super::{LoadStatistics, WorkerId, WorkerState};
use crate::config::Config;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Instant;
use parking_lot::RwLock;

pub struct AdaptiveScheduler {
    worker_states: Vec<Arc<RwLock<WorkerState>>>,
    load_estimator: LoadEstimator,
    config: SchedulerConfig,
    last_rebalance: RwLock<Instant>,
}

#[derive(Debug, Clone)]
pub struct SchedulerConfig {
    pub imbalance_threshold: f64,
    pub target_latency_ns: u64,
    pub min_workers: usize,
    pub max_workers: usize,
}

impl From<&Config> for SchedulerConfig {
    fn from(config: &Config) -> Self {
        let num_threads = config.worker_threads();
        Self {
            imbalance_threshold: config.imbalance_threshold,
            target_latency_ns: 1_000_000, // 1ms default
            min_workers: 1,
            max_workers: num_threads,
        }
    }
}

impl AdaptiveScheduler {
    pub fn new(config: SchedulerConfig) -> Self {
        let num_workers = config.max_workers;
        let worker_states = (0..num_workers)
            .map(|i| Arc::new(RwLock::new(WorkerState::new(WorkerId(i)))))
            .collect();
            
        Self {
            worker_states,
            load_estimator: LoadEstimator::new(),
            config,
            last_rebalance: RwLock::new(Instant::now()),
        }
    }
    
    pub fn maybe_rebalance(&self) -> bool {
        let stats = self.collect_statistics();
        
        if self.detect_imbalance(&stats) {
            self.rebalance(&stats);
            *self.last_rebalance.write() = Instant::now();
            true
        } else {
            false
        }
    }
    
    pub fn collect_statistics(&self) -> LoadStatistics {
        let worker_loads: Vec<f64> = self.worker_states
            .iter()
            .map(|state| {
                let s = state.read();
                s.tasks_executed as f64
            })
            .collect();
            
        let mean = if worker_loads.is_empty() {
            0.0
        } else {
            worker_loads.iter().sum::<f64>() / worker_loads.len() as f64
        };
        
        let variance = if worker_loads.is_empty() {
            0.0
        } else {
            worker_loads.iter()
                .map(|&load| {
                    let diff = load - mean;
                    diff * diff
                })
                .sum::<f64>() / worker_loads.len() as f64
        };
        
        let std_dev = variance.sqrt();
        
        let avg_utilization = if self.worker_states.is_empty() {
            0.0
        } else {
            self.worker_states.iter()
                .map(|state| state.read().utilization())
                .sum::<f64>() / self.worker_states.len() as f64
        };
        
        LoadStatistics {
            mean_load: mean,
            std_dev,
            avg_utilization,
            task_arrival_rate: self.load_estimator.arrival_rate(),
            avg_queue_wait_time_ns: 0,
            timestamp: Instant::now(),
        }
    }
    
    fn detect_imbalance(&self, stats: &LoadStatistics) -> bool {
        let cv = stats.coefficient_of_variation();
        cv > self.config.imbalance_threshold
    }
    
    fn rebalance(&self, stats: &LoadStatistics) {
        // Calculate load per worker
        let mut worker_loads: Vec<(usize, u64, f64)> = self.worker_states
            .iter()
            .enumerate()
            .map(|(i, state)| {
                let s = state.read();
                let total_time = s.idle_time_ns + s.busy_time_ns;
                let utilization = if total_time > 0 {
                    s.busy_time_ns as f64 / total_time as f64
                } else {
                    0.0
                };
                (i, s.tasks_executed, utilization)
            })
            .collect();
        
        // Sort by utilization (descending)
        worker_loads.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
        
        // Find overloaded and underloaded workers
        let mean_util = stats.avg_utilization;
        let threshold_high = mean_util * (1.0 + self.config.imbalance_threshold);
        let threshold_low = mean_util * (1.0 - self.config.imbalance_threshold);
        
        let overloaded: Vec<usize> = worker_loads.iter()
            .filter(|(_, _, util)| *util > threshold_high && *util > 0.7)
            .map(|(id, _, _)| *id)
            .collect();
        
        let underloaded: Vec<usize> = worker_loads.iter()
            .filter(|(_, _, util)| *util < threshold_low && *util < 0.5)
            .map(|(id, _, _)| *id)
            .collect();
        
        // Record rebalancing event
        if !overloaded.is_empty() && !underloaded.is_empty() {
            let rebalance_count = overloaded.len().min(underloaded.len());
            
            // Update worker states to reflect rebalancing intent
            // The actual task migration happens via work-stealing in the executor
            for i in 0..rebalance_count {
                if let Some(from_state) = self.worker_states.get(overloaded[i]) {
                    let mut state = from_state.write();
                    // Mark that this worker had tasks stolen (indirectly encouraged)
                    state.tasks_stolen = state.tasks_stolen.saturating_add(1);
                }
            }
            
            if cfg!(debug_assertions) {
                eprintln!("[VEDA Rebalance] {} overloaded -> {} underloaded workers (mean util: {:.2}%)", 
                    overloaded.len(), underloaded.len(), mean_util * 100.0);
            }
        }
    }
    
    pub fn compute_optimal_workers(&self, stats: &LoadStatistics) -> usize {
        if stats.avg_utilization < 0.5 {
            self.worker_states.len().max(self.config.min_workers)
        } else if stats.avg_utilization > 0.9 {
            (self.worker_states.len() * 5 / 4).min(self.config.max_workers)
        } else {
            self.worker_states.len()
        }
    }
    
    pub fn worker_state(&self, id: WorkerId) -> Option<Arc<RwLock<WorkerState>>> {
        self.worker_states.get(id.0).cloned()
    }
    
    pub fn num_workers(&self) -> usize {
        self.worker_states.len()
    }
}

pub struct LoadEstimator {
    estimate: AtomicU64,
    last_update: RwLock<Instant>,
    task_count: AtomicU64,
}

impl LoadEstimator {
    pub fn new() -> Self {
        Self {
            estimate: AtomicU64::new(0),
            last_update: RwLock::new(Instant::now()),
            task_count: AtomicU64::new(0),
        }
    }
    
    pub fn update(&self, current_load: f64) {
        let alpha = 0.3;
        let old_estimate = f64::from_bits(self.estimate.load(Ordering::Relaxed));
        let new_estimate = alpha * current_load + (1.0 - alpha) * old_estimate;
        self.estimate.store(new_estimate.to_bits(), Ordering::Relaxed);
        *self.last_update.write() = Instant::now();
    }
    
    pub fn estimate(&self) -> f64 {
        f64::from_bits(self.estimate.load(Ordering::Relaxed))
    }
    
    pub fn record_task(&self) {
        self.task_count.fetch_add(1, Ordering::Relaxed);
    }
    
    pub fn arrival_rate(&self) -> f64 {
        let elapsed = self.last_update.read().elapsed().as_secs_f64();
        if elapsed == 0.0 {
            return 0.0;
        }
        let count = self.task_count.load(Ordering::Relaxed);
        count as f64 / elapsed
    }
}

impl Default for LoadEstimator {
    fn default() -> Self {
        Self::new()
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    
    #[test]
    fn test_load_estimator() {
        let estimator = LoadEstimator::new();
        estimator.update(1.0);
        assert!(estimator.estimate() > 0.0);
        
        estimator.update(0.5);
        let estimate = estimator.estimate();
        assert!(estimate > 0.0 && estimate < 1.0);
    }
    
    #[test]
    fn test_adaptive_scheduler_creation() {
        let config = SchedulerConfig {
            imbalance_threshold: 0.2,
            target_latency_ns: 1_000_000,
            min_workers: 1,
            max_workers: 4,
        };
        
        let scheduler = AdaptiveScheduler::new(config);
        assert_eq!(scheduler.num_workers(), 4);
    }
    
    #[test]
    fn test_collect_statistics() {
        let config = SchedulerConfig {
            imbalance_threshold: 0.2,
            target_latency_ns: 1_000_000,
            min_workers: 1,
            max_workers: 4,
        };
        
        let scheduler = AdaptiveScheduler::new(config);
        let stats = scheduler.collect_statistics();
        
        assert_eq!(stats.mean_load, 0.0);
        assert!(stats.avg_utilization >= 0.0 && stats.avg_utilization <= 1.0);
    }
}