use super::configure_workers;
use scirs2_core::parallel_ops::*;
use std::sync::{Arc, Mutex, Once};
static INIT: Once = Once::new();
static mut GLOBAL_POOL: Option<Arc<Mutex<ThreadPoolManager>>> = None;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ThreadPoolProfile {
Default,
CpuBound,
MemoryBound,
LatencySensitive,
Custom(usize),
}
impl ThreadPoolProfile {
pub fn num_threads(&self) -> usize {
match self {
ThreadPoolProfile::Default => std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(4),
ThreadPoolProfile::CpuBound => std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(4),
ThreadPoolProfile::MemoryBound => {
std::thread::available_parallelism()
.map(|n| std::cmp::max(1, n.get() / 2))
.unwrap_or(2)
}
ThreadPoolProfile::LatencySensitive => {
std::thread::available_parallelism()
.map(|n| n.get() + n.get() / 2)
.unwrap_or(6)
}
ThreadPoolProfile::Custom(n) => *n,
}
}
}
pub struct ThreadPoolManager {
profile: ThreadPoolProfile,
stacksize: Option<usize>,
thread_name_prefix: String,
cpu_affinity: bool,
}
impl ThreadPoolManager {
pub fn new() -> Self {
Self {
profile: ThreadPoolProfile::Default,
stacksize: None,
thread_name_prefix: "linalg-worker".to_string(),
cpu_affinity: false,
}
}
pub fn with_profile(mut self, profile: ThreadPoolProfile) -> Self {
self.profile = profile;
self
}
pub fn with_stacksize(mut self, size: usize) -> Self {
self.stacksize = Some(size);
self
}
pub fn with_thread_name_prefix(mut self, prefix: String) -> Self {
self.thread_name_prefix = prefix;
self
}
pub fn with_cpu_affinity(mut self, enabled: bool) -> Self {
self.cpu_affinity = enabled;
self
}
pub fn initialize(&self) -> Result<(), String> {
let num_threads = self.profile.num_threads();
let thread_prefix = self.thread_name_prefix.clone();
let mut pool_builder = ThreadPoolBuilder::new()
.num_threads(num_threads)
.thread_name(move |idx| format!("{thread_prefix}-{idx}"));
if let Some(stacksize) = self.stacksize {
pool_builder = pool_builder.stack_size(stacksize);
}
pool_builder
.build_global()
.map_err(|e| format!("Failed to initialize thread pool: {e}"))?;
std::env::set_var("OMP_NUM_THREADS", num_threads.to_string());
std::env::set_var("MKL_NUM_THREADS", num_threads.to_string());
Ok(())
}
pub fn statistics(&self) -> ThreadPoolStats {
ThreadPoolStats {
num_threads: self.profile.num_threads(),
current_parallelism: num_threads(),
profile: self.profile,
stacksize: self.stacksize,
}
}
}
impl Default for ThreadPoolManager {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct ThreadPoolStats {
pub num_threads: usize,
pub current_parallelism: usize,
pub profile: ThreadPoolProfile,
pub stacksize: Option<usize>,
}
pub fn global_pool() -> Arc<Mutex<ThreadPoolManager>> {
unsafe {
INIT.call_once(|| {
GLOBAL_POOL = Some(Arc::new(Mutex::new(ThreadPoolManager::new())));
});
#[allow(static_mut_refs)]
GLOBAL_POOL.as_ref().expect("Operation failed").clone()
}
}
pub fn initialize_global_pool(profile: ThreadPoolProfile) -> Result<(), String> {
let pool = global_pool();
let mut manager = pool.lock().expect("Operation failed");
manager.profile = profile;
manager.initialize()
}
pub struct AdaptiveThreadPool {
min_threads: usize,
max_threads: usize,
current_threads: Arc<Mutex<usize>>,
cpu_utilization: Arc<Mutex<f64>>,
}
impl AdaptiveThreadPool {
pub fn new(_min_threads: usize, maxthreads: usize) -> Self {
let current = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(4);
Self {
min_threads: _min_threads,
max_threads: maxthreads,
current_threads: Arc::new(Mutex::new(current)),
cpu_utilization: Arc::new(Mutex::new(0.0)),
}
}
pub fn adapt(&self, utilization: f64) {
let mut current = self.current_threads.lock().expect("Operation failed");
let mut cpu_util = self.cpu_utilization.lock().expect("Operation failed");
*cpu_util = utilization;
if utilization > 0.9 && *current < self.max_threads {
*current = std::cmp::min(*current + 1, self.max_threads);
self.apply_thread_count(*current);
} else if utilization < 0.5 && *current > self.min_threads {
*current = std::cmp::max(*current - 1, self.min_threads);
self.apply_thread_count(*current);
}
}
fn apply_thread_count(&self, count: usize) {
configure_workers(Some(count));
}
pub fn current_thread_count(&self) -> usize {
*self.current_threads.lock().expect("Operation failed")
}
}
pub mod benchmark {
use super::*;
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct BenchmarkResult {
pub profile: ThreadPoolProfile,
pub num_threads: usize,
pub execution_time: Duration,
pub throughput: f64,
}
pub fn benchmark_configurations<F>(
profiles: &[ThreadPoolProfile],
workload: F,
) -> Vec<BenchmarkResult>
where
F: Fn() -> f64 + Clone,
{
let mut results = Vec::new();
for &profile in profiles {
if let Err(e) = initialize_global_pool(profile) {
eprintln!("Failed to initialize pool for {profile:?}: {e}");
continue;
}
for _ in 0..3 {
workload();
}
let start = Instant::now();
let operations = 10;
let mut total_work = 0.0;
for _ in 0..operations {
total_work += workload();
}
let elapsed = start.elapsed();
let throughput = total_work / elapsed.as_secs_f64();
results.push(BenchmarkResult {
profile,
num_threads: profile.num_threads(),
execution_time: elapsed,
throughput,
});
}
results
}
pub fn find_optimal_configuration<F>(workload: F) -> ThreadPoolProfile
where
F: Fn() -> f64 + Clone,
{
let profiles = vec![
ThreadPoolProfile::CpuBound,
ThreadPoolProfile::MemoryBound,
ThreadPoolProfile::LatencySensitive,
];
let results = benchmark_configurations(&profiles, workload);
results
.into_iter()
.max_by(|a, b| {
a.throughput
.partial_cmp(&b.throughput)
.expect("Operation failed")
})
.map(|r| r.profile)
.unwrap_or(ThreadPoolProfile::Default)
}
}
pub struct EnhancedThreadPool {
#[allow(dead_code)]
base_pool: Arc<Mutex<ThreadPoolManager>>,
monitoring: Arc<Mutex<ThreadPoolMonitoring>>,
scaling_policy: ScalingPolicy,
load_balancer: LoadBalancer,
}
impl EnhancedThreadPool {
pub fn new(profile: ThreadPoolProfile) -> Self {
let base_pool = Arc::new(Mutex::new(ThreadPoolManager::new().with_profile(profile)));
Self {
base_pool,
monitoring: Arc::new(Mutex::new(ThreadPoolMonitoring::new())),
scaling_policy: ScalingPolicy::Conservative,
load_balancer: LoadBalancer::RoundRobin,
}
}
pub fn with_scaling_policy(mut self, policy: ScalingPolicy) -> Self {
self.scaling_policy = policy;
self
}
pub fn with_load_balancer(mut self, balancer: LoadBalancer) -> Self {
self.load_balancer = balancer;
self
}
pub fn get_metrics(&self) -> ThreadPoolMetrics {
let monitoring = self.monitoring.lock().expect("Operation failed");
monitoring.get_metrics()
}
pub fn execute_monitored<F, R>(&self, task: F) -> R
where
F: FnOnce() -> R + Send,
R: Send,
{
let start_time = std::time::Instant::now();
{
let mut monitoring = self.monitoring.lock().expect("Operation failed");
monitoring.record_task_start();
}
let result = task();
{
let mut monitoring = self.monitoring.lock().expect("Operation failed");
monitoring.record_task_completion(start_time.elapsed());
}
self.check_and_scale();
result
}
fn check_and_scale(&self) {
let metrics = self.get_metrics();
match self.scaling_policy {
ScalingPolicy::Conservative => {
if metrics.average_utilization > 0.9 && metrics.queue_length > 10 {
self.scale_up();
}
else if metrics.average_utilization < 0.3 && metrics.active_threads > 2 {
self.scale_down();
}
}
ScalingPolicy::Aggressive => {
if metrics.average_utilization > 0.7 {
self.scale_up();
}
else if metrics.average_utilization < 0.5 && metrics.active_threads > 1 {
self.scale_down();
}
}
ScalingPolicy::LatencyOptimized => {
if metrics.average_latency_ms > 10.0 {
self.scale_up();
} else if metrics.average_latency_ms < 2.0 && metrics.active_threads > 2 {
self.scale_down();
}
}
ScalingPolicy::Fixed => {
}
}
}
fn scale_up(&self) {
println!("Scaling up thread pool due to high utilization");
}
fn scale_down(&self) {
println!("Scaling down thread pool due to low utilization");
}
}
#[derive(Debug, Clone, Copy)]
pub enum ScalingPolicy {
Conservative,
Aggressive,
LatencyOptimized,
Fixed,
}
#[derive(Debug, Clone, Copy)]
pub enum LoadBalancer {
RoundRobin,
LeastLoaded,
WorkStealing,
NumaAware,
}
struct ThreadPoolMonitoring {
task_count: usize,
total_execution_time: std::time::Duration,
active_threads: usize,
queue_length: usize,
start_times: Vec<std::time::Instant>,
}
impl ThreadPoolMonitoring {
fn new() -> Self {
Self {
task_count: 0,
total_execution_time: std::time::Duration::ZERO,
active_threads: 0,
queue_length: 0,
start_times: Vec::new(),
}
}
fn record_task_start(&mut self) {
self.task_count += 1;
self.start_times.push(std::time::Instant::now());
self.queue_length += 1;
}
fn record_task_completion(&mut self, duration: std::time::Duration) {
self.total_execution_time += duration;
self.queue_length = self.queue_length.saturating_sub(1);
}
fn get_metrics(&self) -> ThreadPoolMetrics {
ThreadPoolMetrics {
active_threads: self.active_threads,
queue_length: self.queue_length,
total_tasks: self.task_count,
average_utilization: if self.active_threads > 0 {
self.queue_length as f64 / self.active_threads as f64
} else {
0.0
},
average_latency_ms: if self.task_count > 0 {
self.total_execution_time.as_millis() as f64 / self.task_count as f64
} else {
0.0
},
throughput_tasks_per_sec: if !self.total_execution_time.is_zero() {
self.task_count as f64 / self.total_execution_time.as_secs_f64()
} else {
0.0
},
}
}
}
#[derive(Debug, Clone)]
pub struct ThreadPoolMetrics {
pub active_threads: usize,
pub queue_length: usize,
pub total_tasks: usize,
pub average_utilization: f64,
pub average_latency_ms: f64,
pub throughput_tasks_per_sec: f64,
}