autoeq_de/
parallel_eval.rs

1use ndarray::{Array1, Array2};
2use rayon::prelude::*;
3use std::sync::Arc;
4
5/// Parallel evaluation configuration
6#[derive(Debug, Clone)]
7pub struct ParallelConfig {
8    /// Enable parallel evaluation
9    pub enabled: bool,
10    /// Number of threads to use (None = use rayon default)
11    pub num_threads: Option<usize>,
12}
13
14impl Default for ParallelConfig {
15    fn default() -> Self {
16        Self {
17            enabled: true,
18            num_threads: None, // Use rayon's default (typically num_cpus)
19        }
20    }
21}
22
23/// Evaluate a population in parallel
24///
25/// # Arguments
26/// * `population` - 2D array where each row is an individual
27/// * `eval_fn` - Function to evaluate each individual
28/// * `config` - Parallel configuration
29///
30/// # Returns
31/// Array of fitness values for each individual
32pub fn evaluate_population_parallel<F>(
33    population: &Array2<f64>,
34    eval_fn: Arc<F>,
35    config: &ParallelConfig,
36) -> Array1<f64>
37where
38    F: Fn(&Array1<f64>) -> f64 + Send + Sync,
39{
40    let npop = population.nrows();
41
42    if !config.enabled || npop < 4 {
43        // Sequential evaluation for small populations or when disabled
44        let mut energies = Array1::zeros(npop);
45        for i in 0..npop {
46            let individual = population.row(i).to_owned();
47            energies[i] = eval_fn(&individual);
48        }
49        return energies;
50    }
51
52    // Always use global thread pool (configured once in solver)
53    let results = (0..npop)
54        .into_par_iter()
55        .map(|i| {
56            let individual = population.row(i).to_owned();
57            eval_fn(&individual)
58        })
59        .collect::<Vec<f64>>();
60
61    Array1::from_vec(results)
62}
63
64/// Evaluate trials in parallel for differential evolution
65///
66/// This function evaluates multiple trial vectors in parallel, which is useful
67/// during the main DE loop where we generate and evaluate one trial per individual.
68///
69/// # Arguments
70/// * `trials` - Vector of trial vectors to evaluate
71/// * `eval_fn` - Function to evaluate each trial
72/// * `config` - Parallel configuration
73///
74/// # Returns
75/// Vector of fitness values for each trial
76pub fn evaluate_trials_parallel<F>(
77    trials: Vec<Array1<f64>>,
78    eval_fn: Arc<F>,
79    config: &ParallelConfig,
80) -> Vec<f64>
81where
82    F: Fn(&Array1<f64>) -> f64 + Send + Sync,
83{
84    if !config.enabled || trials.len() < 4 {
85        // Sequential evaluation for small batches or when disabled
86        return trials.iter().map(|trial| eval_fn(trial)).collect();
87    }
88
89    // Always use global thread pool (configured once in solver)
90    trials.par_iter().map(|trial| eval_fn(trial)).collect()
91}
92
93/// Structure to batch evaluate individuals with their indices
94pub struct IndexedEvaluation {
95    pub index: usize,
96    pub individual: Array1<f64>,
97    pub fitness: f64,
98}
99
100/// Evaluate population with indices preserved for tracking
101pub fn evaluate_population_indexed<F>(
102    population: &Array2<f64>,
103    eval_fn: Arc<F>,
104    config: &ParallelConfig,
105) -> Vec<IndexedEvaluation>
106where
107    F: Fn(&Array1<f64>) -> f64 + Send + Sync,
108{
109    let npop = population.nrows();
110
111    if !config.enabled || npop < 4 {
112        // Sequential evaluation
113        let mut results = Vec::with_capacity(npop);
114        for i in 0..npop {
115            let individual = population.row(i).to_owned();
116            let fitness = eval_fn(&individual);
117            results.push(IndexedEvaluation {
118                index: i,
119                individual,
120                fitness,
121            });
122        }
123        return results;
124    }
125
126    // Parallel evaluation (global thread pool)
127    (0..npop)
128        .into_par_iter()
129        .map(|i| {
130            let individual = population.row(i).to_owned();
131            let fitness = eval_fn(&individual);
132            IndexedEvaluation {
133                index: i,
134                individual,
135                fitness,
136            }
137        })
138        .collect()
139}
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144
145    #[test]
146    fn test_parallel_evaluation() {
147        // Simple quadratic function
148        let eval_fn = Arc::new(|x: &Array1<f64>| -> f64 { x.iter().map(|&xi| xi * xi).sum() });
149
150        // Create a small population
151        let mut population = Array2::zeros((10, 3));
152        for i in 0..10 {
153            for j in 0..3 {
154                population[[i, j]] = (i as f64) * 0.1 + (j as f64) * 0.01;
155            }
156        }
157
158        // Test with parallel enabled
159        let config = ParallelConfig {
160            enabled: true,
161            num_threads: Some(2),
162        };
163        let energies = evaluate_population_parallel(&population, eval_fn.clone(), &config);
164
165        // Verify results
166        assert_eq!(energies.len(), 10);
167        for i in 0..10 {
168            let expected = population.row(i).iter().map(|&x| x * x).sum::<f64>();
169            assert!((energies[i] - expected).abs() < 1e-10);
170        }
171
172        // Test with parallel disabled
173        let config_seq = ParallelConfig {
174            enabled: false,
175            num_threads: None,
176        };
177        let energies_seq = evaluate_population_parallel(&population, eval_fn, &config_seq);
178
179        // Results should be identical
180        for i in 0..10 {
181            assert_eq!(energies[i], energies_seq[i]);
182        }
183    }
184
185    #[test]
186    fn test_indexed_evaluation() {
187        let eval_fn = Arc::new(|x: &Array1<f64>| -> f64 { x.iter().sum() });
188
189        let mut population = Array2::zeros((5, 2));
190        for i in 0..5 {
191            population[[i, 0]] = i as f64;
192            population[[i, 1]] = (i * 2) as f64;
193        }
194
195        let config = ParallelConfig::default();
196        let results = evaluate_population_indexed(&population, eval_fn, &config);
197
198        assert_eq!(results.len(), 5);
199        for result in results {
200            let expected = population.row(result.index).sum();
201            assert_eq!(result.fitness, expected);
202        }
203    }
204}