non_convex_opt/algorithms/cem/
cem_opt.rs

1use nalgebra::{allocator::Allocator, DefaultAllocator, Dim, OMatrix, OVector, RealField, U1};
2use num_traits::Float;
3use rand::{self, rngs::StdRng, Rng, SeedableRng};
4use rand_distr::{Distribution, Normal};
5use rayon::prelude::*;
6use std::iter::Sum;
7
8use crate::utils::config::CEMConf;
9use crate::utils::opt_prob::{FloatNumber as FloatNum, OptProb, OptimizationAlgorithm, State};
10
11pub struct CEM<T, N, D>
12where
13    T: FloatNum + Send + Sync + nalgebra::ComplexField,
14    N: Dim + Send + Sync,
15    D: Dim + Send + Sync,
16    OVector<T, D>: Send + Sync,
17    OVector<T, N>: Send + Sync,
18    OVector<bool, N>: Send + Sync,
19    OMatrix<T, N, D>: Send + Sync,
20    DefaultAllocator: Allocator<D> + Allocator<N, D> + Allocator<N> + Allocator<D, D>,
21{
22    pub conf: CEMConf,
23    pub opt_prob: OptProb<T, D>,
24    pub st: State<T, N, D>,
25
26    pub mean: OVector<T, D>,
27    pub covariance: OMatrix<T, D, D>,
28    pub std_dev: OVector<T, D>,
29    pub cached_cholesky: Option<nalgebra::Cholesky<T, D>>,
30    pub covariance_changed: bool,
31
32    pub improvement_history: Vec<T>,
33    pub diversity_history: Vec<T>,
34    pub stagnation_counter: usize,
35    pub last_improvement: T,
36    pub last_improvement_iter: usize,
37    pub restart_counter: usize,
38    pub last_restart_iter: usize,
39    pub stagnation_window: usize,
40
41    rng: StdRng,
42}
43
44impl<T, N, D> CEM<T, N, D>
45where
46    T: FloatNum + RealField + Send + Sync + Sum,
47    N: Dim + Send + Sync,
48    D: Dim + Send + Sync,
49    OVector<T, D>: Send + Sync,
50    OVector<T, N>: Send + Sync,
51    OVector<bool, N>: Send + Sync,
52    OMatrix<T, N, D>: Send + Sync,
53    DefaultAllocator:
54        Allocator<D> + Allocator<N, D> + Allocator<N> + Allocator<U1, D> + Allocator<D, D>,
55{
56    pub fn new(
57        conf: CEMConf,
58        init_pop: OMatrix<T, N, D>,
59        opt_prob: OptProb<T, D>,
60        stagnation_window: usize,
61        seed: u64,
62    ) -> Self {
63        let n = init_pop.ncols();
64        let population_size = init_pop.nrows();
65
66        let mean = if population_size > 0 {
67            let mut mean_vec = OVector::<T, D>::zeros_generic(D::from_usize(n), U1);
68            for i in 0..n {
69                let sum: T = (0..population_size).map(|j| init_pop[(j, i)]).sum();
70                mean_vec[i] = sum / T::from_usize(population_size).unwrap();
71            }
72            mean_vec
73        } else {
74            OVector::<T, D>::zeros_generic(D::from_usize(n), U1)
75        };
76
77        let initial_std = T::from_f64(conf.common.initial_std).unwrap();
78        let std_dev = OVector::<T, D>::from_element_generic(D::from_usize(n), U1, initial_std);
79        let mut covariance = OMatrix::<T, D, D>::zeros_generic(D::from_usize(n), D::from_usize(n));
80
81        // Init cov as diagonal matrix
82        for i in 0..n {
83            covariance[(i, i)] = std_dev[i] * std_dev[i];
84        }
85
86        let mut st = State {
87            best_x: mean.clone(),
88            best_f: T::neg_infinity(),
89            pop: init_pop.clone(),
90            fitness: OVector::<T, N>::zeros_generic(N::from_usize(population_size), U1),
91            constraints: OVector::<bool, N>::from_element_generic(
92                N::from_usize(population_size),
93                U1,
94                true,
95            ),
96            iter: 0,
97        };
98
99        let (fitness, constraints): (Vec<T>, Vec<bool>) = (0..population_size)
100            .into_par_iter()
101            .map(|i| {
102                let x = init_pop.row(i).transpose();
103                let fit = opt_prob.evaluate(&x);
104                let constr = opt_prob.is_feasible(&x);
105                (fit, constr)
106            })
107            .unzip();
108
109        st.fitness = OVector::<T, N>::from_vec_generic(N::from_usize(population_size), U1, fitness);
110        st.constraints =
111            OVector::<bool, N>::from_vec_generic(N::from_usize(population_size), U1, constraints);
112
113        if let Some((best_idx, _)) = st
114            .fitness
115            .iter()
116            .enumerate()
117            .filter(|(i, _)| st.constraints[*i])
118            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
119        {
120            st.best_x = st.pop.row(best_idx).transpose();
121            st.best_f = st.fitness[best_idx];
122        }
123
124        Self {
125            conf,
126            opt_prob,
127            st,
128            mean,
129            covariance,
130            std_dev,
131            cached_cholesky: None,
132            covariance_changed: true,
133            improvement_history: Vec::new(),
134            diversity_history: Vec::new(),
135            stagnation_counter: 0,
136            last_improvement: T::neg_infinity(),
137            last_improvement_iter: 0,
138            restart_counter: 0,
139            last_restart_iter: 0,
140            stagnation_window,
141            rng: StdRng::seed_from_u64(seed),
142        }
143    }
144
145    fn get_bounds(&self, candidate: &OVector<T, D>) -> (OVector<T, D>, OVector<T, D>) {
146        let lower_bounds = self
147            .opt_prob
148            .objective
149            .x_lower_bound(candidate)
150            .unwrap_or_else(|| {
151                OVector::<T, D>::from_element_generic(
152                    D::from_usize(candidate.len()),
153                    U1,
154                    T::from_f64(-10.0).unwrap(),
155                )
156            });
157        let upper_bounds = self
158            .opt_prob
159            .objective
160            .x_upper_bound(candidate)
161            .unwrap_or_else(|| {
162                OVector::<T, D>::from_element_generic(
163                    D::from_usize(candidate.len()),
164                    U1,
165                    T::from_f64(10.0).unwrap(),
166                )
167            });
168
169        (lower_bounds, upper_bounds)
170    }
171
172    fn should_restart(&self) -> bool {
173        if !self.conf.advanced.use_restart_strategy {
174            return false;
175        }
176
177        let restart_freq = self.conf.advanced.restart_frequency;
178        let stagnation_threshold = restart_freq / 2;
179
180        // Frequency-based
181        let basic_restart = self.stagnation_counter > stagnation_threshold
182            || (self.st.iter - self.last_restart_iter) > restart_freq;
183
184        // Based on diversity and convergence
185        if self.st.iter > 20 {
186            let recent_diversity = self
187                .diversity_history
188                .iter()
189                .rev()
190                .take(5)
191                .fold(T::zero(), |acc, &x| acc + x)
192                / T::from_f64(5.0).unwrap();
193
194            let diversity_threshold = T::from_f64(1e-4).unwrap();
195            let diversity_restart = recent_diversity < diversity_threshold;
196
197            return basic_restart || diversity_restart;
198        }
199
200        basic_restart
201    }
202
203    fn should_early_stop(&self) -> bool {
204        if !self.conf.advanced.use_restart_strategy {
205            return false;
206        }
207
208        if self.stagnation_counter > self.stagnation_window * 2 {
209            return true;
210        }
211
212        let threshold_window = self.conf.advanced.improvement_threshold_window;
213        if self.improvement_history.len() >= threshold_window {
214            let recent_improvements =
215                &self.improvement_history[self.improvement_history.len() - threshold_window..];
216            let avg_improvement: T = recent_improvements.iter().cloned().sum::<T>()
217                / T::from_usize(recent_improvements.len()).unwrap();
218
219            if avg_improvement < T::from_f64(1e-8).unwrap() {
220                return true;
221            }
222        }
223
224        false
225    }
226
227    fn perform_restart(&mut self) {
228        let n = self.mean.len();
229
230        let mean_copy = self.mean.clone();
231        let (lb, ub) = self.get_bounds(&mean_copy);
232        for i in 0..n {
233            let range = ub[i] - lb[i];
234            self.mean[i] = lb[i] + T::from_f64(self.rng.random::<f64>()).unwrap() * range;
235        }
236
237        let initial_std = T::from_f64(self.conf.common.initial_std).unwrap();
238        self.std_dev = OVector::<T, D>::from_element_generic(D::from_usize(n), U1, initial_std);
239
240        // Inject diversity into std
241        for i in 0..n {
242            let noise = T::from_f64(self.rng.random_range(0.5..1.5)).unwrap();
243            self.std_dev[i] *= noise;
244        }
245
246        for i in 0..n {
247            for j in 0..n {
248                if i == j {
249                    self.covariance[(i, j)] = self.std_dev[i] * self.std_dev[i];
250                } else {
251                    self.covariance[(i, j)] = T::zero();
252                }
253            }
254        }
255        self.covariance_changed = true;
256        self.stagnation_counter = 0;
257        self.last_restart_iter = self.st.iter;
258        self.restart_counter += 1;
259
260        if self.improvement_history.len() > 10 {
261            self.improvement_history
262                .drain(0..self.improvement_history.len() - 10);
263        }
264        if self.diversity_history.len() > 10 {
265            self.diversity_history
266                .drain(0..self.diversity_history.len() - 10);
267        }
268
269        eprintln!(
270            "CEM restart triggered after {} iterations without improvement (restart #{})",
271            self.st.iter - self.last_improvement_iter,
272            self.restart_counter
273        );
274    }
275
276    fn sample_population(&mut self) -> Vec<OVector<T, D>> {
277        let n = self.mean.len();
278        let population_size = self.conf.common.population_size;
279
280        let antithetic_ratio = if self.conf.sampling.use_antithetic {
281            self.conf.sampling.antithetic_ratio
282        } else {
283            0.0
284        };
285
286        let regular_sample_count = (population_size as f64 / (1.0 + antithetic_ratio)) as usize;
287        let antithetic_count = population_size - regular_sample_count;
288
289        let mut population = Vec::with_capacity(population_size);
290
291        for _ in 0..regular_sample_count {
292            let sample = self.sample_multivariate_normal();
293            population.push(sample);
294        }
295
296        for sample in &mut population {
297            let (lb, ub) = self.get_bounds(sample);
298            for i in 0..n {
299                sample[i] = Float::min(Float::max(sample[i], lb[i]), ub[i]);
300            }
301        }
302
303        // This a variance reduction technique: generated negatively correlated pairs
304        if self.conf.sampling.use_antithetic && antithetic_count > 0 {
305            for i in 0..antithetic_count.min(regular_sample_count) {
306                let antithetic = &self.mean - (&population[i] - &self.mean); // Reflect about mean
307
308                let (lb, ub) = self.get_bounds(&antithetic);
309                let mut bounded_antithetic = antithetic;
310                for j in 0..n {
311                    bounded_antithetic[j] =
312                        Float::min(Float::max(bounded_antithetic[j], lb[j]), ub[j]);
313                }
314
315                population.push(bounded_antithetic);
316            }
317        }
318
319        population
320    }
321
322    fn sample_multivariate_normal(&mut self) -> OVector<T, D> {
323        let n = self.mean.len();
324
325        let mut z = OVector::<T, D>::zeros_generic(D::from_usize(n), U1);
326        let normal = Normal::new(0.0, 1.0).unwrap();
327        for i in 0..n {
328            z[i] = T::from_f64(normal.sample(&mut self.rng)).unwrap();
329        }
330
331        if self.cached_cholesky.is_none() || self.covariance_changed {
332            let mut cov_matrix = self.covariance.clone();
333            let reg_factor = T::from_f64(1e-6).unwrap();
334            for i in 0..n {
335                cov_matrix[(i, i)] += reg_factor;
336            }
337
338            let new_cholesky = cov_matrix
339                .cholesky()
340                .expect("Covariance matrix should be positive definite after regularization");
341            self.cached_cholesky = Some(new_cholesky);
342            self.covariance_changed = false;
343        }
344
345        // Transform standard normal: x = μ + L * z
346        let l_matrix = self.cached_cholesky.as_ref().unwrap();
347        let transformed = l_matrix.l() * z;
348        &self.mean + &transformed
349    }
350
351    fn update_distribution(&mut self, elite_samples: &[OVector<T, D>]) {
352        if elite_samples.is_empty() {
353            return;
354        }
355
356        let n = self.mean.len();
357        let elite_size = elite_samples.len();
358
359        // Update mean with elite samples
360        let mut new_mean = OVector::<T, D>::zeros_generic(D::from_usize(n), U1);
361        for sample in elite_samples {
362            new_mean += sample;
363        }
364        new_mean /= T::from_usize(elite_size).unwrap();
365
366        // Update covariance with elite samples
367        let mut new_covariance =
368            OMatrix::<T, D, D>::zeros_generic(D::from_usize(n), D::from_usize(n));
369
370        for sample in elite_samples {
371            let diff = sample - &new_mean;
372            // Rank-1 update: cov += diff * diff^T
373            for i in 0..n {
374                for j in 0..n {
375                    new_covariance[(i, j)] += diff[i] * diff[j];
376                }
377            }
378        }
379        new_covariance /= T::from_usize(elite_size).unwrap();
380
381        // Ensure positive definiteness in cov
382        if self.conf.advanced.use_covariance_adaptation {
383            let reg = T::from_f64(self.conf.advanced.covariance_regularization).unwrap();
384            for i in 0..n {
385                new_covariance[(i, i)] += reg;
386            }
387
388            // Additional numerical stability: ensure minimum eigenvalue
389            let min_diag = new_covariance
390                .diagonal()
391                .iter()
392                .fold(T::infinity(), |acc, &x| Float::min(acc, x));
393            if min_diag < reg {
394                let additional_reg = reg - min_diag;
395                for i in 0..n {
396                    new_covariance[(i, i)] += additional_reg;
397                }
398            }
399        }
400
401        // Smooth updates with EMA
402        let alpha = T::from_f64(self.conf.adaptation.smoothing_factor).unwrap();
403
404        // Update mean with EMA
405        let one_minus_alpha = T::one() - alpha;
406        for i in 0..n {
407            self.mean[i] = alpha * new_mean[i] + one_minus_alpha * self.mean[i];
408        }
409
410        // Update covariance with EMA
411        for i in 0..n {
412            for j in 0..n {
413                self.covariance[(i, j)] =
414                    alpha * new_covariance[(i, j)] + one_minus_alpha * self.covariance[(i, j)];
415            }
416        }
417        self.covariance_changed = true;
418
419        // Update std from diagonal of cov
420        for i in 0..n {
421            let new_std = Float::sqrt(self.covariance[(i, i)]);
422            let min_std = T::from_f64(self.conf.common.min_std).unwrap();
423            let max_std = T::from_f64(self.conf.common.max_std).unwrap();
424            self.std_dev[i] = Float::min(Float::max(new_std, min_std), max_std);
425        }
426    }
427
428    fn compute_diversity(&self) -> T {
429        let n = self.mean.len();
430        let population_size = self.st.pop.nrows();
431
432        if population_size == 0 {
433            return T::zero();
434        }
435
436        // Vectorized diversity computation
437        let mut total_variance = T::zero();
438        for i in 0..n {
439            let mean_val = self.mean[i];
440            let variance: T = (0..population_size)
441                .map(|j| {
442                    let diff = self.st.pop[(j, i)] - mean_val;
443                    diff * diff
444                })
445                .sum();
446            total_variance += variance;
447        }
448
449        Float::sqrt(total_variance) / T::from_usize(n).unwrap()
450    }
451}
452
453impl<T, N, D> OptimizationAlgorithm<T, N, D> for CEM<T, N, D>
454where
455    T: FloatNum + RealField + Send + Sync + Sum,
456    N: Dim + Send + Sync,
457    D: Dim + Send + Sync,
458    OVector<T, D>: Send + Sync,
459    OVector<T, N>: Send + Sync,
460    OVector<bool, N>: Send + Sync,
461    OMatrix<T, N, D>: Send + Sync,
462    DefaultAllocator:
463        Allocator<D> + Allocator<N, D> + Allocator<N> + Allocator<U1, D> + Allocator<D, D>,
464{
465    fn step(&mut self) {
466        if self.should_early_stop() {
467            eprintln!("CEM early stopping triggered due to stagnation");
468            return;
469        }
470
471        if self.should_restart() {
472            self.perform_restart();
473            return;
474        }
475
476        let population = self.sample_population();
477        let population_size = population.len();
478
479        let (fitness, constraints): (Vec<T>, Vec<bool>) = population
480            .par_iter()
481            .map(|x| {
482                let fit = self.opt_prob.evaluate(x);
483                let constr = self.opt_prob.is_feasible(x);
484                (fit, constr)
485            })
486            .unzip();
487
488        let sample_dimension = if !population.is_empty() {
489            population[0].len()
490        } else {
491            0
492        };
493
494        let mut new_pop = OMatrix::<T, N, D>::zeros_generic(
495            N::from_usize(population_size),
496            D::from_usize(sample_dimension),
497        );
498
499        for (i, sample) in population.iter().enumerate() {
500            for (j, &val) in sample.iter().enumerate() {
501                new_pop[(i, j)] = val;
502            }
503        }
504
505        self.st.pop = new_pop;
506
507        self.st.fitness =
508            OVector::<T, N>::from_vec_generic(N::from_usize(population_size), U1, fitness.clone());
509
510        self.st.constraints = OVector::<bool, N>::from_vec_generic(
511            N::from_usize(population_size),
512            U1,
513            constraints.clone(),
514        );
515
516        let mut best_fitness = T::neg_infinity();
517        let mut best_idx = 0;
518
519        for (i, (fit, constr)) in fitness.iter().zip(constraints.iter()).enumerate() {
520            if *constr && *fit > best_fitness {
521                best_fitness = *fit;
522                best_idx = i;
523            }
524        }
525
526        if best_fitness > self.st.best_f {
527            self.st.best_f = best_fitness;
528            self.st.best_x = population[best_idx].clone();
529            self.last_improvement = best_fitness;
530            self.last_improvement_iter = self.st.iter;
531            self.stagnation_counter = 0;
532        } else {
533            self.stagnation_counter += 1;
534        }
535
536        // Elite samples: top ρ% of feasible solutions
537        let mut elite_indices: Vec<usize> =
538            (0..population_size).filter(|&i| constraints[i]).collect();
539
540        elite_indices.sort_by(|&a, &b| fitness[b].partial_cmp(&fitness[a]).unwrap());
541        let elite_size = self.conf.common.elite_size.min(elite_indices.len());
542        let elite_samples: Vec<OVector<T, D>> = elite_indices[..elite_size]
543            .iter()
544            .map(|&i| population[i].clone())
545            .collect();
546
547        self.update_distribution(&elite_samples);
548        self.improvement_history.push(best_fitness);
549        self.diversity_history.push(self.compute_diversity());
550
551        if best_fitness > self.last_improvement {
552            self.last_improvement = best_fitness;
553            self.last_improvement_iter = self.st.iter;
554            self.stagnation_counter = 0;
555        } else {
556            self.stagnation_counter += 1;
557        }
558
559        let max_history = self.conf.advanced.improvement_history_size;
560        if self.improvement_history.len() > max_history {
561            self.improvement_history
562                .drain(0..self.improvement_history.len() - max_history);
563            self.diversity_history
564                .drain(0..self.diversity_history.len() - max_history);
565        }
566
567        self.st.iter += 1;
568    }
569
570    fn state(&self) -> &State<T, N, D> {
571        &self.st
572    }
573}