use crate::{Module, ModuleResult};
use rayon::prelude::*;
use std::sync::{Arc, Mutex};
#[derive(Debug, Clone)]
pub struct ParallelConfig {
pub num_threads: usize,
pub parallel_modules: bool,
pub parallel_optimization: bool,
pub min_modules_for_parallel: usize,
}
impl Default for ParallelConfig {
fn default() -> Self {
Self {
num_threads: 0, parallel_modules: true,
parallel_optimization: true,
min_modules_for_parallel: 2,
}
}
}
impl ParallelConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_threads(mut self, num_threads: usize) -> Self {
self.num_threads = num_threads;
self
}
pub fn with_parallel_modules(mut self, enabled: bool) -> Self {
self.parallel_modules = enabled;
self
}
pub fn with_parallel_optimization(mut self, enabled: bool) -> Self {
self.parallel_optimization = enabled;
self
}
pub fn init_thread_pool(&self) -> Result<(), String> {
if self.num_threads > 0 {
rayon::ThreadPoolBuilder::new()
.num_threads(self.num_threads)
.build_global()
.map_err(|e| format!("Failed to initialize thread pool: {}", e))?;
}
Ok(())
}
}
pub struct ParallelCompiler {
config: ParallelConfig,
}
impl ParallelCompiler {
pub fn new(config: ParallelConfig) -> Self {
Self { config }
}
pub fn compile_modules_parallel<F>(
&self,
module_names: Vec<String>,
compile_fn: F,
) -> Vec<ModuleResult<Module>>
where
F: Fn(&str) -> ModuleResult<Module> + Sync + Send,
{
if !self.config.parallel_modules
|| module_names.len() < self.config.min_modules_for_parallel
{
return module_names.iter().map(|name| compile_fn(name)).collect();
}
module_names
.par_iter()
.map(|name| compile_fn(name))
.collect()
}
pub fn optimize_parallel<T, F>(&self, items: Vec<T>, optimize_fn: F) -> Vec<T>
where
T: Send,
F: Fn(T) -> T + Sync + Send,
{
if !self.config.parallel_optimization || items.len() < 2 {
return items.into_iter().map(optimize_fn).collect();
}
items.into_par_iter().map(optimize_fn).collect()
}
pub fn load_ppu_files_parallel<F>(
&self,
unit_names: Vec<String>,
load_fn: F,
) -> Vec<ModuleResult<crate::ast::Unit>>
where
F: Fn(&str) -> ModuleResult<crate::ast::Unit> + Sync + Send,
{
if unit_names.len() < self.config.min_modules_for_parallel {
return unit_names.iter().map(|name| load_fn(name)).collect();
}
unit_names.par_iter().map(|name| load_fn(name)).collect()
}
}
pub struct ProgressTracker {
total: usize,
completed: Arc<Mutex<usize>>,
errors: Arc<Mutex<Vec<String>>>,
}
impl ProgressTracker {
pub fn new(total: usize) -> Self {
Self {
total,
completed: Arc::new(Mutex::new(0)),
errors: Arc::new(Mutex::new(Vec::new())),
}
}
pub fn complete_one(&self) {
let mut completed = self.completed.lock().unwrap();
*completed += 1;
}
pub fn add_error(&self, error: String) {
let mut errors = self.errors.lock().unwrap();
errors.push(error);
}
pub fn progress(&self) -> f64 {
let completed = *self.completed.lock().unwrap();
if self.total == 0 {
1.0
} else {
completed as f64 / self.total as f64
}
}
pub fn completed(&self) -> usize {
*self.completed.lock().unwrap()
}
pub fn errors(&self) -> Vec<String> {
self.errors.lock().unwrap().clone()
}
pub fn is_complete(&self) -> bool {
*self.completed.lock().unwrap() >= self.total
}
}
#[derive(Debug)]
pub struct CompilationWorker {
worker_id: usize,
}
impl CompilationWorker {
pub fn new(worker_id: usize) -> Self {
Self { worker_id }
}
pub fn process_job<F>(&self, unit_name: &str, compile_fn: F) -> (String, ModuleResult<Module>)
where
F: FnOnce(&str) -> ModuleResult<Module>,
{
let result = compile_fn(unit_name);
(unit_name.to_string(), result)
}
pub fn worker_id(&self) -> usize {
self.worker_id
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ModuleError;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::thread;
use std::time::Duration;
fn create_test_module(name: &str) -> Module {
Module {
name: name.to_string(),
unit: crate::ast::Unit {
name: name.to_string(),
uses: vec![],
interface: crate::ast::UnitInterface {
uses: vec![],
types: vec![],
constants: vec![],
variables: vec![],
procedures: vec![],
functions: vec![],
classes: vec![],
interfaces: vec![],
},
implementation: crate::ast::UnitImplementation {
uses: vec![],
types: vec![],
constants: vec![],
variables: vec![],
procedures: vec![],
functions: vec![],
classes: vec![],
interfaces: vec![],
initialization: None,
finalization: None,
},
},
dependencies: vec![],
}
}
#[test]
fn test_parallel_config_default() {
let config = ParallelConfig::default();
assert_eq!(config.num_threads, 0); assert!(config.parallel_modules);
assert!(config.parallel_optimization);
assert_eq!(config.min_modules_for_parallel, 2);
}
#[test]
fn test_parallel_config_builder() {
let config = ParallelConfig::new()
.with_threads(4)
.with_parallel_modules(true)
.with_parallel_optimization(false);
assert_eq!(config.num_threads, 4);
assert!(config.parallel_modules);
assert!(!config.parallel_optimization);
}
#[test]
fn test_parallel_config_chaining() {
let config = ParallelConfig::new()
.with_threads(8)
.with_parallel_modules(false)
.with_parallel_optimization(true);
assert_eq!(config.num_threads, 8);
assert!(!config.parallel_modules);
assert!(config.parallel_optimization);
}
#[test]
fn test_progress_tracker_basic() {
let tracker = ProgressTracker::new(10);
assert_eq!(tracker.completed(), 0);
assert_eq!(tracker.progress(), 0.0);
assert!(!tracker.is_complete());
tracker.complete_one();
assert_eq!(tracker.completed(), 1);
assert_eq!(tracker.progress(), 0.1);
tracker.add_error("Test error".to_string());
assert_eq!(tracker.errors().len(), 1);
}
#[test]
fn test_progress_tracker_completion() {
let tracker = ProgressTracker::new(5);
for _ in 0..5 {
tracker.complete_one();
}
assert_eq!(tracker.completed(), 5);
assert_eq!(tracker.progress(), 1.0);
assert!(tracker.is_complete());
}
#[test]
fn test_progress_tracker_zero_total() {
let tracker = ProgressTracker::new(0);
assert_eq!(tracker.progress(), 1.0);
assert!(tracker.is_complete());
}
#[test]
fn test_progress_tracker_multiple_errors() {
let tracker = ProgressTracker::new(10);
tracker.add_error("Error 1".to_string());
tracker.add_error("Error 2".to_string());
tracker.add_error("Error 3".to_string());
let errors = tracker.errors();
assert_eq!(errors.len(), 3);
assert_eq!(errors[0], "Error 1");
assert_eq!(errors[1], "Error 2");
assert_eq!(errors[2], "Error 3");
}
#[test]
fn test_progress_tracker_thread_safety() {
let tracker = ProgressTracker::new(100);
let mut handles = vec![];
for _ in 0..10 {
let tracker_clone = ProgressTracker {
total: tracker.total,
completed: Arc::clone(&tracker.completed),
errors: Arc::clone(&tracker.errors),
};
let handle = thread::spawn(move || {
for _ in 0..10 {
tracker_clone.complete_one();
thread::sleep(Duration::from_micros(1));
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
assert_eq!(tracker.completed(), 100);
assert_eq!(tracker.progress(), 1.0);
}
#[test]
fn test_parallel_compiler_basic() {
let config = ParallelConfig::new();
let compiler = ParallelCompiler::new(config);
let items = vec![1, 2, 3, 4, 5];
let results = compiler.optimize_parallel(items, |x| x * 2);
assert_eq!(results, vec![2, 4, 6, 8, 10]);
}
#[test]
fn test_parallel_compiler_large_dataset() {
let config = ParallelConfig::new();
let compiler = ParallelCompiler::new(config);
let items: Vec<i32> = (1..=1000).collect();
let results = compiler.optimize_parallel(items, |x| x * x);
assert_eq!(results.len(), 1000);
assert_eq!(results[0], 1);
assert_eq!(results[999], 1000000);
}
#[test]
fn test_parallel_compiler_sequential_fallback() {
let config = ParallelConfig::new().with_parallel_optimization(false);
let compiler = ParallelCompiler::new(config);
let items = vec![1, 2, 3];
let results = compiler.optimize_parallel(items, |x| x + 1);
assert_eq!(results, vec![2, 3, 4]);
}
#[test]
fn test_parallel_compiler_small_workload_fallback() {
let config = ParallelConfig::new();
let compiler = ParallelCompiler::new(config);
let items = vec![42];
let results = compiler.optimize_parallel(items, |x| x * 2);
assert_eq!(results, vec![84]);
}
#[test]
fn test_compile_modules_parallel_success() {
let config = ParallelConfig::new();
let compiler = ParallelCompiler::new(config);
let modules = vec![
"Module1".to_string(),
"Module2".to_string(),
"Module3".to_string(),
];
let results = compiler.compile_modules_parallel(modules, |name| {
Ok(create_test_module(name))
});
assert_eq!(results.len(), 3);
assert!(results.iter().all(|r| r.is_ok()));
}
#[test]
fn test_compile_modules_parallel_with_errors() {
let config = ParallelConfig::new();
let compiler = ParallelCompiler::new(config);
let modules = vec!["Good1".to_string(), "Bad".to_string(), "Good2".to_string()];
let results = compiler.compile_modules_parallel(modules, |name| {
if name == "Bad" {
Err(ModuleError::LoadError(
name.to_string(),
"Simulated error".to_string(),
))
} else {
Ok(create_test_module(name))
}
});
assert_eq!(results.len(), 3);
assert!(results[0].is_ok());
assert!(results[1].is_err());
assert!(results[2].is_ok());
}
#[test]
fn test_compile_modules_sequential_for_small_workload() {
let config = ParallelConfig::new();
let compiler = ParallelCompiler::new(config);
let modules = vec!["SingleModule".to_string()];
let call_count = Arc::new(AtomicUsize::new(0));
let call_count_clone = Arc::clone(&call_count);
let results = compiler.compile_modules_parallel(modules, move |name| {
call_count_clone.fetch_add(1, Ordering::SeqCst);
Ok(create_test_module(name))
});
assert_eq!(results.len(), 1);
assert_eq!(call_count.load(Ordering::SeqCst), 1);
}
#[test]
fn test_parallel_optimization_preserves_order() {
let config = ParallelConfig::new();
let compiler = ParallelCompiler::new(config);
let items: Vec<usize> = (0..100).collect();
let results = compiler.optimize_parallel(items, |x| x);
for (i, &val) in results.iter().enumerate() {
assert_eq!(i, val);
}
}
#[test]
fn test_parallel_compiler_with_complex_computation() {
let config = ParallelConfig::new();
let compiler = ParallelCompiler::new(config);
let items = vec![10, 20, 30, 40, 50];
let results = compiler.optimize_parallel(items, |x| {
let mut sum = 0;
for i in 0..x {
sum += i;
}
sum
});
assert_eq!(results.len(), 5);
assert!(results.iter().all(|&x| x >= 0));
}
}