pub mod load_balancer;
pub mod parallel_algorithms;
pub mod parallel_allocator;
pub mod scheduler;
pub mod thread_pool;
pub mod work_stealing;
pub use load_balancer::{
BalancingStrategy, LoadBalancer, LoadBalancingAdvisor, LoadBalancingAnalysis, WorkloadMetrics,
};
pub use parallel_algorithms::{
parallel_prefix_sum, parallel_scan, ParallelArrayOps, ParallelConfig, ParallelFFT,
ParallelMatrixOps, ParallelPipeline, ParallelQuickSort, ScanMode,
};
pub use parallel_allocator::{ParallelAllocator, ParallelAllocatorConfig, ThreadLocalAllocator};
pub use scheduler::{ParallelScheduler, SchedulerConfig, SchedulerStats, TaskPriority};
pub use thread_pool::{Priority, ThreadPool, ThreadPoolConfig, ThreadPoolStats};
pub use work_stealing::{task, PoolStats, Task, TaskResult, WorkStealingConfig, WorkStealingPool};
use crate::error::{NumRs2Error, Result};
use std::sync::Arc;
use std::time::Duration;
pub struct ParallelContext {
scheduler: Arc<ParallelScheduler>,
load_balancer: Arc<LoadBalancer>,
work_stealing_pool: Arc<WorkStealingPool>,
}
impl ParallelContext {
pub fn new() -> Result<Self> {
let num_cores = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(4);
let scheduler_config = SchedulerConfig::optimal_for_cores(num_cores);
let scheduler = Arc::new(ParallelScheduler::new(scheduler_config)?);
let load_balancer = Arc::new(LoadBalancer::new(BalancingStrategy::Adaptive, num_cores)?);
let work_stealing_pool = Arc::new(WorkStealingPool::new(num_cores)?);
Ok(Self {
scheduler,
load_balancer,
work_stealing_pool,
})
}
pub fn with_config(
scheduler_config: SchedulerConfig,
balancing_strategy: BalancingStrategy,
num_threads: usize,
) -> Result<Self> {
let scheduler = Arc::new(ParallelScheduler::new(scheduler_config)?);
let load_balancer = Arc::new(LoadBalancer::new(balancing_strategy, num_threads)?);
let work_stealing_pool = Arc::new(WorkStealingPool::new(num_threads)?);
Ok(Self {
scheduler,
load_balancer,
work_stealing_pool,
})
}
pub fn scheduler(&self) -> &Arc<ParallelScheduler> {
&self.scheduler
}
pub fn load_balancer(&self) -> &Arc<LoadBalancer> {
&self.load_balancer
}
pub fn work_stealing_pool(&self) -> &Arc<WorkStealingPool> {
&self.work_stealing_pool
}
pub fn shutdown(&self) -> Result<()> {
self.work_stealing_pool.shutdown()?;
self.scheduler.shutdown()?;
Ok(())
}
pub fn workload_stats(&self) -> WorkloadMetrics {
self.load_balancer.current_metrics()
}
}
impl Default for ParallelContext {
fn default() -> Self {
Self::new().unwrap_or_else(|_| {
let num_cores = 1; let scheduler_config = SchedulerConfig::optimal_for_cores(num_cores);
let scheduler = ParallelScheduler::new(scheduler_config)
.ok()
.map(Arc::new)
.unwrap_or_else(|| {
Arc::new(
ParallelScheduler::new(SchedulerConfig {
num_threads: 1,
max_queue_size: 100,
enable_thread_affinity: false,
enable_adaptive_scheduling: false,
time_slice_ms: 10,
work_stealing_threshold: 5,
cache_aware_scheduling: false,
})
.unwrap_or_else(|_| {
panic!("Cannot create even minimal ParallelScheduler - system unusable")
}),
)
});
let load_balancer = LoadBalancer::new(BalancingStrategy::Adaptive, num_cores)
.ok()
.map(Arc::new)
.unwrap_or_else(|| panic!("Cannot create LoadBalancer - system unusable"));
let work_stealing_pool = WorkStealingPool::new(num_cores)
.ok()
.map(Arc::new)
.unwrap_or_else(|| panic!("Cannot create WorkStealingPool - system unusable"));
Self {
scheduler,
load_balancer,
work_stealing_pool,
}
})
}
}
lazy_static::lazy_static! {
static ref GLOBAL_PARALLEL_CONTEXT: std::sync::Mutex<Option<Arc<ParallelContext>>> =
std::sync::Mutex::new(None);
static ref GLOBAL_THREAD_POOL: Arc<ThreadPool> = {
let num_threads = std::env::var("NUMRS2_THREAD_COUNT")
.ok()
.and_then(|s| s.parse().ok())
.or_else(|| std::thread::available_parallelism().ok().map(|n| n.get()))
.unwrap_or(2);
let config = ThreadPoolConfig {
num_threads: Some(num_threads),
enable_thread_pinning: false,
adaptive_threads: false,
min_threads: 1,
max_threads: num_threads,
queue_capacity: 10000,
steal_interval: Duration::from_micros(100),
idle_timeout: Duration::from_millis(100),
};
Arc::new(ThreadPool::with_config(config).unwrap_or_else(|_| {
ThreadPool::with_config(ThreadPoolConfig {
num_threads: Some(1),
..Default::default()
}).expect("Failed to create fallback thread pool")
}))
};
}
pub fn initialize_parallel_context() -> Result<()> {
let context = Arc::new(ParallelContext::new()?);
let mut global = GLOBAL_PARALLEL_CONTEXT.lock().map_err(|e| {
NumRs2Error::RuntimeError(format!("Failed to acquire global context lock: {}", e))
})?;
*global = Some(context);
Ok(())
}
pub fn global_parallel_context() -> Result<Arc<ParallelContext>> {
let global = GLOBAL_PARALLEL_CONTEXT.lock().map_err(|e| {
NumRs2Error::RuntimeError(format!("Failed to acquire global context lock: {}", e))
})?;
global.clone().ok_or_else(|| {
NumRs2Error::RuntimeError("Global parallel context not initialized".to_string())
})
}
pub fn shutdown_parallel_context() -> Result<()> {
let mut global = GLOBAL_PARALLEL_CONTEXT.lock().map_err(|e| {
NumRs2Error::RuntimeError(format!("Failed to acquire global context lock: {}", e))
})?;
if let Some(context) = global.take() {
context.shutdown()?;
}
Ok(())
}
pub fn global_thread_pool() -> Arc<ThreadPool> {
Arc::clone(&GLOBAL_THREAD_POOL)
}
pub fn global_thread_count() -> usize {
GLOBAL_THREAD_POOL.num_threads()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parallel_context_creation() {
let context =
ParallelContext::new().expect("ParallelContext creation should succeed in test");
assert!(context.scheduler.num_threads() > 0);
assert!(context.load_balancer.num_workers() > 0);
}
#[test]
fn test_global_context_initialization() {
initialize_parallel_context().expect("initialize_parallel_context should succeed in test");
let context =
global_parallel_context().expect("global_parallel_context should succeed in test");
assert!(context.scheduler.num_threads() > 0);
shutdown_parallel_context().expect("shutdown_parallel_context should succeed in test");
}
#[test]
fn test_workload_stats() {
let context =
ParallelContext::new().expect("ParallelContext creation should succeed in test");
let stats = context.workload_stats();
assert_eq!(stats.active_tasks, 0);
assert!(stats.total_throughput >= 0.0);
}
}