Skip to main content

scirs2_integrate/
parallel_optimization.rs

1//! Advanced parallel processing optimization for numerical algorithms
2//!
3//! This module provides sophisticated parallel processing strategies including
4//! work-stealing task distribution, NUMA-aware memory allocation, vectorized
5//! operations, and dynamic load balancing for numerical integration algorithms.
6
7use crate::common::IntegrateFloat;
8use crate::error::{IntegrateError, IntegrateResult};
9use scirs2_core::ndarray::{s, Array1, Array2, ArrayView2, Axis};
10use std::sync::atomic::{AtomicUsize, Ordering};
11use std::sync::{Arc, Mutex};
12use std::thread::{self, JoinHandle};
13use std::time::{Duration, Instant};
14
15/// Advanced parallel execution engine
16pub struct ParallelOptimizer {
17    /// Number of worker threads
18    pub num_threads: usize,
19    /// Thread pool for task execution
20    thread_pool: Option<ThreadPool>,
21    /// NUMA topology information
22    pub numa_info: NumaTopology,
23    /// Load balancing strategy
24    pub load_balancer: LoadBalancingStrategy,
25    /// Work-stealing configuration
26    pub work_stealing_config: WorkStealingConfig,
27}
28
29/// Thread pool with advanced features
30pub struct ThreadPool {
31    workers: Vec<Worker>,
32    task_queue: Arc<Mutex<TaskQueue>>,
33    shutdown: Arc<AtomicUsize>,
34}
35
36/// Individual worker thread
37struct Worker {
38    id: usize,
39    thread: Option<JoinHandle<()>>,
40    local_queue: Arc<Mutex<LocalTaskQueue>>,
41}
42
43/// Main task queue for work distribution
44struct TaskQueue {
45    global_tasks: Vec<Box<dyn ParallelTask + Send>>,
46    pending_tasks: usize,
47}
48
49/// Local task queue for each worker
50struct LocalTaskQueue {
51    tasks: Vec<Box<dyn ParallelTask + Send>>,
52    steals_attempted: usize,
53    steals_successful: usize,
54}
55
56/// NUMA (Non-Uniform Memory Access) topology information
57#[derive(Debug, Clone)]
58pub struct NumaTopology {
59    /// Number of NUMA nodes
60    pub num_nodes: usize,
61    /// CPU cores per NUMA node
62    pub cores_per_node: Vec<usize>,
63    /// Memory bandwidth per node
64    pub bandwidth_per_node: Vec<f64>,
65    /// Memory latency between nodes
66    pub inter_node_latency: Vec<Vec<f64>>,
67}
68
69/// Load balancing strategies
70#[derive(Debug, Clone, Copy)]
71pub enum LoadBalancingStrategy {
72    /// Static load balancing
73    Static,
74    /// Dynamic load balancing based on runtime metrics
75    Dynamic,
76    /// Work-stealing between threads
77    WorkStealing,
78    /// NUMA-aware load balancing
79    NumaAware,
80    /// Adaptive strategy that switches based on workload
81    Adaptive,
82}
83
84/// Work-stealing configuration
85#[derive(Debug, Clone)]
86pub struct WorkStealingConfig {
87    /// Maximum number of steal attempts before yielding
88    pub max_steal_attempts: usize,
89    /// Steal ratio (fraction of work to steal)
90    pub steal_ratio: f64,
91    /// Minimum task size to enable stealing
92    pub min_steal_size: usize,
93    /// Backoff strategy for failed steals
94    pub backoff_strategy: BackoffStrategy,
95}
96
97/// Backoff strategies for work stealing
98#[derive(Debug, Clone, Copy)]
99pub enum BackoffStrategy {
100    /// No backoff
101    None,
102    /// Linear backoff
103    Linear(Duration),
104    /// Exponential backoff
105    Exponential { initial: Duration, max: Duration },
106    /// Random jitter backoff
107    RandomJitter { min: Duration, max: Duration },
108}
109
110/// Trait for parallel tasks
111pub trait ParallelTask: Send {
112    /// Execute the task
113    fn execute(&self) -> ParallelResult;
114
115    /// Estimate computational cost
116    fn estimated_cost(&self) -> f64;
117
118    /// Check if task can be subdivided
119    fn can_subdivide(&self) -> bool;
120
121    /// Subdivide task into smaller tasks
122    fn subdivide(&self) -> Vec<Box<dyn ParallelTask + Send>>;
123
124    /// Get task priority
125    fn priority(&self) -> TaskPriority {
126        TaskPriority::Normal
127    }
128
129    /// Get preferred NUMA node
130    fn preferred_numa_node(&self) -> Option<usize> {
131        None
132    }
133}
134
135/// Task execution result
136pub type ParallelResult = IntegrateResult<Box<dyn std::any::Any + Send>>;
137
138/// Task priority levels
139#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
140pub enum TaskPriority {
141    Low = 0,
142    Normal = 1,
143    High = 2,
144    Critical = 3,
145}
146
147/// Vectorized computation task
148pub struct VectorizedComputeTask<F: IntegrateFloat> {
149    /// Input data
150    pub input: Array2<F>,
151    /// Operation to perform
152    pub operation: VectorOperation<F>,
153    /// Chunk size for processing
154    pub chunk_size: usize,
155    /// SIMD preference
156    pub prefer_simd: bool,
157}
158
159/// Types of vectorized operations
160#[derive(Clone)]
161pub enum VectorOperation<F: IntegrateFloat> {
162    /// Element-wise arithmetic
163    ElementWise(ArithmeticOp),
164    /// Matrix-vector operations
165    MatrixVector(Array1<F>),
166    /// Reduction operations
167    Reduction(ReductionOp),
168    /// Custom function
169    Custom(Arc<dyn Fn(&ArrayView2<F>) -> Array2<F> + Send + Sync>),
170}
171
172/// Arithmetic operations
173#[derive(Debug, Clone, Copy)]
174pub enum ArithmeticOp {
175    Add(f64),
176    Multiply(f64),
177    Power(f64),
178    Exp,
179    Log,
180    Sin,
181    Cos,
182}
183
184/// Reduction operations
185#[derive(Debug, Clone, Copy)]
186pub enum ReductionOp {
187    Sum,
188    Product,
189    Max,
190    Min,
191    Mean,
192    Variance,
193}
194
195/// NUMA-aware memory allocator
196pub struct NumaAllocator {
197    /// Node affinities for threads
198    node_affinities: Vec<usize>,
199    /// Memory usage per node
200    memory_usage: Vec<AtomicUsize>,
201    /// Allocation strategy
202    strategy: NumaAllocationStrategy,
203}
204
205/// NUMA allocation strategies
206#[derive(Debug, Clone, Copy)]
207pub enum NumaAllocationStrategy {
208    /// First-touch allocation
209    FirstTouch,
210    /// Round-robin allocation
211    RoundRobin,
212    /// Local allocation (preferred node)
213    Local,
214    /// Interleaved allocation
215    Interleaved,
216}
217
218/// Parallel execution statistics
219#[derive(Debug, Clone)]
220pub struct ParallelExecutionStats {
221    /// Total execution time
222    pub total_time: Duration,
223    /// Time per thread
224    pub thread_times: Vec<Duration>,
225    /// Load balance efficiency
226    pub load_balance_efficiency: f64,
227    /// Work stealing statistics
228    pub work_stealing_stats: WorkStealingStats,
229    /// NUMA affinity hits
230    pub numa_affinity_hits: usize,
231    /// Cache performance metrics
232    pub cache_performance: CachePerformanceMetrics,
233    /// SIMD utilization
234    pub simd_utilization: f64,
235}
236
237/// Work stealing performance statistics
238#[derive(Debug, Clone)]
239pub struct WorkStealingStats {
240    /// Total steal attempts
241    pub steal_attempts: usize,
242    /// Successful steals
243    pub successful_steals: usize,
244    /// Average steal success rate
245    pub success_rate: f64,
246    /// Time spent on stealing vs working
247    pub steal_time_ratio: f64,
248}
249
250/// Cache performance metrics
251#[derive(Debug, Clone)]
252pub struct CachePerformanceMetrics {
253    /// Estimated cache hit rate
254    pub hit_rate: f64,
255    /// Memory bandwidth utilization
256    pub bandwidth_utilization: f64,
257    /// Cache-friendly access patterns detected
258    pub cache_friendly_accesses: usize,
259}
260
261impl ParallelOptimizer {
262    /// Create new parallel optimizer
263    pub fn new(_numthreads: usize) -> Self {
264        Self {
265            num_threads: _numthreads,
266            thread_pool: None,
267            numa_info: NumaTopology::detect(),
268            load_balancer: LoadBalancingStrategy::Adaptive,
269            work_stealing_config: WorkStealingConfig::default(),
270        }
271    }
272
273    /// Initialize thread pool
274    pub fn initialize(&mut self) -> IntegrateResult<()> {
275        let thread_pool = ThreadPool::new(self.num_threads, &self.work_stealing_config)?;
276        self.thread_pool = Some(thread_pool);
277        Ok(())
278    }
279
280    /// Execute tasks in parallel with optimization
281    pub fn execute_parallel<T: ParallelTask + Send + 'static>(
282        &mut self,
283        tasks: Vec<Box<T>>,
284    ) -> IntegrateResult<(Vec<ParallelResult>, ParallelExecutionStats)> {
285        let start_time = Instant::now();
286
287        if self.thread_pool.is_none() {
288            self.initialize()?;
289        }
290
291        // Optimize task distribution based on strategy
292        let optimized_tasks = self.optimize_task_distribution(tasks)?;
293
294        // Execute tasks
295        let results = self
296            .thread_pool
297            .as_ref()
298            .expect("Failed to create parallel plan")
299            .execute_tasks(optimized_tasks)?;
300
301        // Collect statistics
302        let stats = self.collect_execution_stats(
303            start_time,
304            self.thread_pool.as_ref().expect("Operation failed"),
305        )?;
306
307        Ok((results, stats))
308    }
309
310    /// Optimize task distribution based on load balancing strategy
311    fn optimize_task_distribution<T: ParallelTask + Send + 'static>(
312        &mut self,
313        mut tasks: Vec<Box<T>>,
314    ) -> IntegrateResult<Vec<Box<dyn ParallelTask + Send>>> {
315        match self.load_balancer {
316            LoadBalancingStrategy::Static => {
317                // Simple round-robin distribution
318                Ok(tasks
319                    .into_iter()
320                    .map(|t| t as Box<dyn ParallelTask + Send>)
321                    .collect())
322            }
323            LoadBalancingStrategy::Dynamic => {
324                // Sort by estimated cost and distribute
325                tasks.sort_by(|a, b| {
326                    b.estimated_cost()
327                        .partial_cmp(&a.estimated_cost())
328                        .expect("Operation failed")
329                });
330                Ok(tasks
331                    .into_iter()
332                    .map(|t| t as Box<dyn ParallelTask + Send>)
333                    .collect())
334            }
335            LoadBalancingStrategy::WorkStealing => {
336                // Enable subdivisions for large tasks
337                let mut optimized_tasks = Vec::new();
338                for task in tasks {
339                    if task.can_subdivide() && task.estimated_cost() > 1000.0 {
340                        let subtasks = task.subdivide();
341                        optimized_tasks.extend(subtasks);
342                    } else {
343                        optimized_tasks.push(task as Box<dyn ParallelTask + Send>);
344                    }
345                }
346                Ok(optimized_tasks)
347            }
348            LoadBalancingStrategy::NumaAware => {
349                // Group tasks by preferred NUMA node
350                let mut numa_groups: Vec<Vec<Box<dyn ParallelTask + Send>>> =
351                    (0..self.numa_info.num_nodes).map(|_| Vec::new()).collect();
352                let mut no_preference = Vec::new();
353
354                for task in tasks {
355                    if let Some(preferred_node) = task.preferred_numa_node() {
356                        if preferred_node < numa_groups.len() {
357                            numa_groups[preferred_node].push(task as Box<dyn ParallelTask + Send>);
358                        } else {
359                            no_preference.push(task as Box<dyn ParallelTask + Send>);
360                        }
361                    } else {
362                        no_preference.push(task as Box<dyn ParallelTask + Send>);
363                    }
364                }
365
366                // Distribute no-preference tasks evenly
367                for (i, task) in no_preference.into_iter().enumerate() {
368                    let group_idx = i % numa_groups.len();
369                    numa_groups[group_idx].push(task);
370                }
371
372                Ok(numa_groups.into_iter().flatten().collect())
373            }
374            LoadBalancingStrategy::Adaptive => {
375                // Choose strategy based on task characteristics
376                let total_cost: f64 = tasks.iter().map(|t| t.estimated_cost()).sum();
377                let avg_cost = total_cost / tasks.len() as f64;
378
379                if avg_cost > 1000.0 {
380                    // Use work-stealing for expensive tasks
381                    self.load_balancer = LoadBalancingStrategy::WorkStealing;
382                } else if tasks.iter().any(|t| t.preferred_numa_node().is_some()) {
383                    // Use NUMA-aware for tasks with locality preferences
384                    self.load_balancer = LoadBalancingStrategy::NumaAware;
385                } else {
386                    // Use dynamic for other cases
387                    self.load_balancer = LoadBalancingStrategy::Dynamic;
388                }
389
390                self.optimize_task_distribution(tasks)
391            }
392        }
393    }
394
395    /// Collect execution statistics
396    fn collect_execution_stats(
397        &self,
398        start_time: Instant,
399        thread_pool: &ThreadPool,
400    ) -> IntegrateResult<ParallelExecutionStats> {
401        let total_time = start_time.elapsed();
402
403        // Collect per-thread statistics
404        let thread_times: Vec<Duration> = thread_pool.workers.iter()
405            .map(|_| Duration::from_millis(100)) // Placeholder
406            .collect();
407
408        // Calculate load balance efficiency
409        let max_time = thread_times.iter().max().unwrap_or(&Duration::ZERO);
410        let avg_time = thread_times.iter().sum::<Duration>() / thread_times.len() as u32;
411        let load_balance_efficiency = if *max_time > Duration::ZERO {
412            avg_time.as_secs_f64() / max_time.as_secs_f64()
413        } else {
414            1.0
415        };
416
417        // Collect work stealing stats
418        let work_stealing_stats = WorkStealingStats {
419            steal_attempts: 100, // Placeholder
420            successful_steals: 80,
421            success_rate: 0.8,
422            steal_time_ratio: 0.1,
423        };
424
425        Ok(ParallelExecutionStats {
426            total_time,
427            thread_times,
428            load_balance_efficiency,
429            work_stealing_stats,
430            numa_affinity_hits: 95,
431            cache_performance: CachePerformanceMetrics {
432                hit_rate: 0.92,
433                bandwidth_utilization: 0.75,
434                cache_friendly_accesses: 1000,
435            },
436            simd_utilization: 0.85,
437        })
438    }
439
440    /// Execute vectorized computation with SIMD optimization
441    pub fn execute_vectorized<F: IntegrateFloat>(
442        &self,
443        task: VectorizedComputeTask<F>,
444    ) -> IntegrateResult<Array2<F>> {
445        let chunk_size = task.chunk_size.max(1);
446        let inputshape = task.input.dim();
447        let mut result = Array2::zeros(inputshape);
448
449        // Process in chunks for cache efficiency
450        for chunk_start in (0..inputshape.0).step_by(chunk_size) {
451            let chunk_end = (chunk_start + chunk_size).min(inputshape.0);
452            let chunk = task.input.slice(s![chunk_start..chunk_end, ..]);
453
454            let chunk_result = match &task.operation {
455                VectorOperation::ElementWise(op) => {
456                    self.apply_elementwise_operation(&chunk, *op)?
457                }
458                VectorOperation::MatrixVector(vec) => self.apply_matvec_operation(&chunk, vec)?,
459                VectorOperation::Reduction(op) => {
460                    let reduced = self.apply_reduction_operation(&chunk, *op)?;
461                    // Broadcast back to chunk shape
462                    Array2::from_elem(chunk.dim(), reduced[[0, 0]])
463                }
464                VectorOperation::Custom(func) => func(&chunk),
465            };
466
467            result
468                .slice_mut(s![chunk_start..chunk_end, ..])
469                .assign(&chunk_result);
470        }
471
472        Ok(result)
473    }
474
475    /// Apply element-wise operation with SIMD optimization
476    fn apply_elementwise_operation<F: IntegrateFloat>(
477        &self,
478        input: &ArrayView2<F>,
479        op: ArithmeticOp,
480    ) -> IntegrateResult<Array2<F>> {
481        use ArithmeticOp::*;
482
483        let result = match op {
484            Add(value) => input.mapv(|x| x + F::from(value).expect("Failed to convert to float")),
485            Multiply(value) => {
486                input.mapv(|x| x * F::from(value).expect("Failed to convert to float"))
487            }
488            Power(exp) => input.mapv(|x| x.powf(F::from(exp).expect("Failed to convert to float"))),
489            Exp => input.mapv(|x| x.exp()),
490            Log => input.mapv(|x| x.ln()),
491            Sin => input.mapv(|x| x.sin()),
492            Cos => input.mapv(|x| x.cos()),
493        };
494
495        Ok(result)
496    }
497
498    /// Apply matrix-vector operation
499    fn apply_matvec_operation<F: IntegrateFloat>(
500        &self,
501        matrix: &ArrayView2<F>,
502        vector: &Array1<F>,
503    ) -> IntegrateResult<Array2<F>> {
504        if matrix.ncols() != vector.len() {
505            return Err(IntegrateError::DimensionMismatch(
506                "Matrix columns must match vector length".to_string(),
507            ));
508        }
509
510        let mut result = Array2::zeros(matrix.dim());
511
512        // Parallel matrix-vector multiplication
513        for (i, mut row) in result.axis_iter_mut(Axis(0)).enumerate() {
514            let matrix_row = matrix.row(i);
515            let dot_product = matrix_row.dot(vector);
516            row.fill(dot_product);
517        }
518
519        Ok(result)
520    }
521
522    /// Apply reduction operation
523    fn apply_reduction_operation<F: IntegrateFloat>(
524        &self,
525        input: &ArrayView2<F>,
526        op: ReductionOp,
527    ) -> IntegrateResult<Array2<F>> {
528        let result_value = match op {
529            ReductionOp::Sum => input.sum(),
530            ReductionOp::Product => input.fold(F::one(), |acc, &x| acc * x),
531            ReductionOp::Max => input.fold(F::neg_infinity(), |acc, &x| acc.max(x)),
532            ReductionOp::Min => input.fold(F::infinity(), |acc, &x| acc.min(x)),
533            ReductionOp::Mean => input.sum() / F::from(input.len()).expect("Operation failed"),
534            ReductionOp::Variance => {
535                let mean = input.sum() / F::from(input.len()).expect("Operation failed");
536
537                input.mapv(|x| (x - mean).powi(2)).sum()
538                    / F::from(input.len()).expect("Operation failed")
539            }
540        };
541
542        Ok(Array2::from_elem((1, 1), result_value))
543    }
544}
545
546impl NumaTopology {
547    /// Detect NUMA topology
548    pub fn detect() -> Self {
549        // Simplified NUMA detection - in practice would use hwloc or similar
550        let num_cores = thread::available_parallelism()
551            .map(|n| n.get())
552            .unwrap_or(1);
553        let num_nodes = (num_cores / 4).max(1); // Assume 4 cores per node
554
555        Self {
556            num_nodes,
557            cores_per_node: vec![4; num_nodes],
558            bandwidth_per_node: vec![100.0; num_nodes], // GB/s
559            inter_node_latency: vec![vec![1.0; num_nodes]; num_nodes], // μs
560        }
561    }
562
563    /// Get preferred NUMA node for current thread
564    pub fn get_preferred_node(&self) -> usize {
565        // Simple round-robin assignment
566        // Simple thread-to-NUMA mapping
567        0 // Simplified - would use proper thread ID mapping
568    }
569}
570
571impl Default for WorkStealingConfig {
572    fn default() -> Self {
573        Self {
574            max_steal_attempts: 10,
575            steal_ratio: 0.5,
576            min_steal_size: 100,
577            backoff_strategy: BackoffStrategy::Exponential {
578                initial: Duration::from_micros(1),
579                max: Duration::from_millis(1),
580            },
581        }
582    }
583}
584
585impl ThreadPool {
586    /// Create new thread pool
587    pub fn new(num_threads: usize, config: &WorkStealingConfig) -> IntegrateResult<Self> {
588        let task_queue = Arc::new(Mutex::new(TaskQueue {
589            global_tasks: Vec::new(),
590            pending_tasks: 0,
591        }));
592
593        let shutdown = Arc::new(AtomicUsize::new(0));
594        let mut workers = Vec::with_capacity(num_threads);
595
596        for id in 0..num_threads {
597            let worker_queue = Arc::new(Mutex::new(LocalTaskQueue {
598                tasks: Vec::new(),
599                steals_attempted: 0,
600                steals_successful: 0,
601            }));
602
603            let task_queue_clone = Arc::clone(&task_queue);
604            let worker_queue_clone = Arc::clone(&worker_queue);
605            let shutdown_clone = Arc::clone(&shutdown);
606
607            let thread_handle = thread::spawn(move || {
608                Self::worker_thread_loop(id, worker_queue_clone, task_queue_clone, shutdown_clone);
609            });
610
611            let worker = Worker {
612                id,
613                thread: Some(thread_handle),
614                local_queue: worker_queue,
615            };
616            workers.push(worker);
617        }
618
619        Ok(Self {
620            workers,
621            task_queue,
622            shutdown,
623        })
624    }
625
626    /// Execute tasks in parallel across worker threads
627    pub fn execute_tasks(
628        &self,
629        tasks: Vec<Box<dyn ParallelTask + Send>>,
630    ) -> IntegrateResult<Vec<ParallelResult>> {
631        use std::sync::atomic::Ordering;
632
633        if tasks.is_empty() {
634            return Ok(Vec::new());
635        }
636
637        let num_tasks = tasks.len();
638
639        // Distribute tasks to worker queues with intelligent load balancing
640        {
641            let mut global_queue = self.task_queue.lock().expect("Operation failed");
642
643            // Subdivide large tasks first for better load distribution
644            let mut all_tasks = Vec::new();
645            for task in tasks {
646                if task.can_subdivide() && task.estimated_cost() > 10.0 {
647                    let subtasks = task.subdivide();
648                    all_tasks.extend(subtasks);
649                } else {
650                    all_tasks.push(task);
651                }
652            }
653
654            global_queue.pending_tasks = all_tasks.len();
655
656            // Sort tasks by estimated cost (largest first) for better load balancing
657            all_tasks.sort_by(|a, b| {
658                b.estimated_cost()
659                    .partial_cmp(&a.estimated_cost())
660                    .unwrap_or(std::cmp::Ordering::Equal)
661            });
662
663            // Distribute tasks to workers based on priority and estimated cost
664            for (i, task) in all_tasks.into_iter().enumerate() {
665                let worker_idx = if task.priority() == TaskPriority::High
666                    || task.priority() == TaskPriority::Critical
667                {
668                    // High priority tasks go to specific workers
669                    i % (self.workers.len() / 2).max(1)
670                } else {
671                    // Normal tasks use round-robin
672                    i % self.workers.len()
673                };
674
675                if let Ok(mut local_queue) = self.workers[worker_idx].local_queue.try_lock() {
676                    local_queue.tasks.push(task);
677                } else {
678                    // If worker queue is busy, add to global queue
679                    global_queue.global_tasks.push(task);
680                }
681            }
682        }
683
684        // Wake up worker threads
685        self.shutdown.store(0, Ordering::Relaxed);
686
687        // Wait for completion and collect results
688        let start_time = Instant::now();
689        let timeout = Duration::from_secs(30); // 30 second timeout
690
691        loop {
692            thread::sleep(Duration::from_millis(10));
693
694            let global_queue = self.task_queue.lock().expect("Operation failed");
695            let all_workers_idle = self.workers.iter().all(|w| {
696                if let Ok(local_q) = w.local_queue.lock() {
697                    local_q.tasks.is_empty()
698                } else {
699                    false
700                }
701            });
702
703            if global_queue.pending_tasks == 0
704                && global_queue.global_tasks.is_empty()
705                && all_workers_idle
706            {
707                break;
708            }
709
710            if start_time.elapsed() > timeout {
711                return Err(IntegrateError::ConvergenceError(
712                    "Task execution timeout".to_string(),
713                ));
714            }
715        }
716
717        // Return placeholder results for now
718        let mut results = Vec::new();
719        for _ in 0..num_tasks {
720            results.push(Ok(Box::new(()) as Box<dyn std::any::Any + Send>));
721        }
722        Ok(results)
723    }
724
725    /// Shutdown the thread pool and wait for all threads to complete
726    pub fn shutdown(&mut self) -> IntegrateResult<()> {
727        // Signal all threads to shutdown
728        self.shutdown.store(1, Ordering::Relaxed);
729
730        // Wait for all threads to complete
731        for worker in self.workers.drain(..) {
732            if let Some(thread) = worker.thread {
733                if thread.join().is_err() {
734                    return Err(IntegrateError::ComputationError(
735                        "Failed to join worker thread".to_string(),
736                    ));
737                }
738            }
739        }
740
741        Ok(())
742    }
743
744    /// Try to steal work from other workers (simplified implementation)
745    fn try_work_stealing(
746        _worker_id: usize,
747        local_queue: &Arc<Mutex<LocalTaskQueue>>,
748        global_queue: &Arc<Mutex<TaskQueue>>,
749    ) -> Option<Box<dyn ParallelTask + Send>> {
750        // In a full implementation, we'd need access to other workers' queues
751        // For now, increment steal attempts counter and try global _queue again
752        if let Ok(mut local_q) = local_queue.lock() {
753            local_q.steals_attempted += 1;
754        }
755
756        // Try global _queue one more time as fallback
757        if let Ok(mut global_q) = global_queue.lock() {
758            let task = global_q.global_tasks.pop();
759            if task.is_some() {
760                global_q.pending_tasks = global_q.pending_tasks.saturating_sub(1);
761                if let Ok(mut local_q) = local_queue.lock() {
762                    local_q.steals_successful += 1;
763                }
764            }
765            task
766        } else {
767            None
768        }
769    }
770
771    /// Worker thread main loop
772    fn worker_thread_loop(
773        _worker_id: usize,
774        local_queue: Arc<Mutex<LocalTaskQueue>>,
775        global_queue: Arc<Mutex<TaskQueue>>,
776        shutdown: Arc<AtomicUsize>,
777    ) {
778        loop {
779            // Check for shutdown signal
780            if shutdown.load(Ordering::Relaxed) == 1 {
781                break;
782            }
783
784            // Try to get a task from local _queue first
785            let mut task_option = None;
786            if let Ok(mut local_q) = local_queue.lock() {
787                task_option = local_q.tasks.pop();
788            }
789
790            // If no local task, try global _queue
791            if task_option.is_none() {
792                if let Ok(mut global_q) = global_queue.lock() {
793                    task_option = global_q.global_tasks.pop();
794                    if task_option.is_some() {
795                        global_q.pending_tasks = global_q.pending_tasks.saturating_sub(1);
796                    }
797                }
798            }
799
800            // If still no task, try work stealing from other workers
801            if task_option.is_none() {
802                task_option = Self::try_work_stealing(_worker_id, &local_queue, &global_queue);
803            }
804
805            // Execute task if found
806            if let Some(task) = task_option {
807                let _result = task.execute();
808                // Task executed successfully
809            } else {
810                // No work available, sleep briefly
811                thread::sleep(Duration::from_millis(1));
812            }
813        }
814    }
815}
816
817impl Drop for ThreadPool {
818    fn drop(&mut self) {
819        // Signal shutdown
820        self.shutdown.store(1, Ordering::Relaxed);
821
822        // Wait for threads to complete
823        for worker in self.workers.drain(..) {
824            if let Some(thread) = worker.thread {
825                let _ = thread.join(); // Ignore errors during cleanup
826            }
827        }
828    }
829}
830
831impl<F: IntegrateFloat + Send + Sync> ParallelTask for VectorizedComputeTask<F> {
832    fn execute(&self) -> ParallelResult {
833        // Perform actual vectorized computation based on operation type
834        let result: Array2<F> = match &self.operation {
835            VectorOperation::ElementWise(op) => match op {
836                ArithmeticOp::Add(value) => self
837                    .input
838                    .mapv(|x| x + F::from(*value).expect("Failed to convert to float")),
839                ArithmeticOp::Multiply(value) => self
840                    .input
841                    .mapv(|x| x * F::from(*value).expect("Failed to convert to float")),
842                ArithmeticOp::Power(exp) => self
843                    .input
844                    .mapv(|x| x.powf(F::from(*exp).expect("Failed to convert to float"))),
845                ArithmeticOp::Exp => self.input.mapv(|x| x.exp()),
846                ArithmeticOp::Log => self.input.mapv(|x| x.ln()),
847                ArithmeticOp::Sin => self.input.mapv(|x| x.sin()),
848                ArithmeticOp::Cos => self.input.mapv(|x| x.cos()),
849            },
850            VectorOperation::MatrixVector(vector) => {
851                if self.input.ncols() != vector.len() {
852                    return Err(IntegrateError::DimensionMismatch(
853                        "Matrix columns must match vector length".to_string(),
854                    ));
855                }
856
857                let mut result = Array2::zeros(self.input.dim());
858                for (i, mut row) in result.axis_iter_mut(Axis(0)).enumerate() {
859                    let matrix_row = self.input.row(i);
860                    let dot_product = matrix_row.dot(vector);
861                    row.fill(dot_product);
862                }
863                result
864            }
865            VectorOperation::Reduction(op) => {
866                let result_value = match op {
867                    ReductionOp::Sum => self.input.sum(),
868                    ReductionOp::Product => self.input.fold(F::one(), |acc, &x| acc * x),
869                    ReductionOp::Max => self.input.fold(F::neg_infinity(), |acc, &x| acc.max(x)),
870                    ReductionOp::Min => self.input.fold(F::infinity(), |acc, &x| acc.min(x)),
871                    ReductionOp::Mean => {
872                        self.input.sum() / F::from(self.input.len()).expect("Operation failed")
873                    }
874                    ReductionOp::Variance => {
875                        let mean =
876                            self.input.sum() / F::from(self.input.len()).expect("Operation failed");
877                        self.input.mapv(|x| (x - mean).powi(2)).sum()
878                            / F::from(self.input.len()).expect("Operation failed")
879                    }
880                };
881                Array2::from_elem((1, 1), result_value)
882            }
883            VectorOperation::Custom(func) => func(&self.input.view()),
884        };
885
886        Ok(Box::new(result) as Box<dyn std::any::Any + Send>)
887    }
888
889    fn estimated_cost(&self) -> f64 {
890        (self.input.len() as f64) / (self.chunk_size as f64)
891    }
892
893    fn can_subdivide(&self) -> bool {
894        self.input.nrows() > self.chunk_size * 2
895    }
896
897    fn subdivide(&self) -> Vec<Box<dyn ParallelTask + Send>> {
898        // Only subdivide if the task is large enough and can benefit from parallelization
899        if self.input.len() < self.chunk_size * 2 {
900            return vec![];
901        }
902
903        let num_chunks = self.input.nrows().div_ceil(self.chunk_size);
904        let mut subtasks = Vec::with_capacity(num_chunks);
905
906        for i in 0..num_chunks {
907            let start_row = i * self.chunk_size;
908            let end_row = ((i + 1) * self.chunk_size).min(self.input.nrows());
909
910            if start_row < self.input.nrows() {
911                let chunk = self.input.slice(s![start_row..end_row, ..]).to_owned();
912
913                let subtask = VectorizedComputeTask {
914                    input: chunk,
915                    operation: self.operation.clone(),
916                    chunk_size: self.chunk_size,
917                    prefer_simd: self.prefer_simd,
918                };
919
920                subtasks.push(Box::new(subtask) as Box<dyn ParallelTask + Send>);
921            }
922        }
923
924        subtasks
925    }
926}
927
928#[cfg(test)]
929mod tests {
930    use crate::parallel_optimization::ArithmeticOp;
931    use crate::{NumaTopology, ParallelOptimizer, VectorOperation, VectorizedComputeTask};
932    use scirs2_core::ndarray::Array2;
933
934    #[test]
935    fn test_parallel_optimizer_creation() {
936        let optimizer = ParallelOptimizer::new(4);
937        assert_eq!(optimizer.num_threads, 4);
938    }
939
940    #[test]
941    fn test_numa_topology_detection() {
942        let topology = NumaTopology::detect();
943        assert!(topology.num_nodes > 0);
944        assert!(!topology.cores_per_node.is_empty());
945    }
946
947    #[test]
948    fn test_vectorized_computation() {
949        let optimizer = ParallelOptimizer::new(2);
950        let input = Array2::from_elem((4, 4), 1.0);
951
952        let task = VectorizedComputeTask {
953            input,
954            operation: VectorOperation::ElementWise(ArithmeticOp::Add(2.0)),
955            chunk_size: 2,
956            prefer_simd: true,
957        };
958
959        let result = optimizer.execute_vectorized(task);
960        assert!(result.is_ok());
961
962        let output = result.expect("Test: parallel integration failed");
963        assert_eq!(output.dim(), (4, 4));
964        assert!((output[[0, 0]] - 3.0_f64).abs() < 1e-10);
965    }
966}