non_convex_opt/algorithms/multi_swarm/
mspo.rs

1use nalgebra::{allocator::Allocator, DefaultAllocator, Dim, OMatrix, OVector, U1};
2use rayon::prelude::*;
3
4use crate::algorithms::multi_swarm::information_exchange::InformationExchange;
5use crate::algorithms::multi_swarm::population::{
6    find_best_solution, get_population, update_population_state,
7};
8use crate::algorithms::multi_swarm::stagnation_monitor::StagnationMonitor;
9use crate::algorithms::multi_swarm::swarm::{initialize_swarms, Swarm};
10use crate::utils::config::MSPOConf;
11use crate::utils::opt_prob::{FloatNumber as FloatNum, OptProb, OptimizationAlgorithm, State};
12
13pub struct MSPO<T, N, D>
14where
15    T: FloatNum + Send + Sync,
16    D: Dim + Send + Sync,
17    N: Dim + Send + Sync,
18    OVector<T, D>: Send + Sync,
19    OMatrix<T, N, D>: Send + Sync,
20    DefaultAllocator:
21        Allocator<D> + Allocator<N, D> + Allocator<N> + Allocator<U1, D> + Allocator<U1>,
22{
23    pub conf: MSPOConf,
24    pub st: State<T, N, D>,
25    pub swarms: Vec<Swarm<T, D>>,
26    pub opt_prob: OptProb<T, D>,
27    stagnation_monitor: StagnationMonitor<T>,
28    information_exchange: InformationExchange<T, D>,
29}
30
31impl<T, N, D> MSPO<T, N, D>
32where
33    T: FloatNum + Send + Sync,
34    D: Dim + Send + Sync,
35    N: Dim + Send + Sync,
36    OVector<T, D>: Send + Sync,
37    OMatrix<T, N, D>: Send + Sync,
38    DefaultAllocator: Allocator<D>
39        + Allocator<N, D>
40        + Allocator<N>
41        + Allocator<U1, D>
42        + Allocator<D, D>
43        + Allocator<U1>,
44{
45    pub fn new(
46        conf: MSPOConf,
47        init_pop: OMatrix<T, N, D>,
48        opt_prob: OptProb<T, D>,
49        max_iter: usize,
50        seed: u64,
51    ) -> Self {
52        let dim = init_pop.ncols();
53        let total_particles = init_pop.nrows();
54        assert!(
55            total_particles >= conf.num_swarms * conf.swarm_size,
56            "Initial population size must be at least num_swarms * swarm_size"
57        );
58
59        let (best_x, best_fitness) = find_best_solution(&init_pop, &opt_prob);
60
61        let swarms = initialize_swarms(&conf, dim, &init_pop, &opt_prob, max_iter, seed);
62        let (fitness, constraints): (Vec<T>, Vec<bool>) = (0..init_pop.nrows())
63            .into_par_iter()
64            .map(|i| {
65                let x = init_pop.row(i).transpose();
66                let fit = opt_prob.evaluate(&x);
67                let constr = opt_prob.is_feasible(&x);
68                (fit, constr)
69            })
70            .unzip();
71
72        let fitness =
73            OVector::<T, N>::from_vec_generic(N::from_usize(init_pop.nrows()), U1, fitness);
74        let constraints =
75            OVector::<bool, N>::from_vec_generic(N::from_usize(init_pop.nrows()), U1, constraints);
76
77        let st = State {
78            best_x,
79            best_f: best_fitness,
80            pop: init_pop,
81            fitness,
82            constraints,
83            iter: 1,
84        };
85
86        let improvement_threshold = T::from_f64(conf.improvement_threshold).unwrap();
87        let stagnation_monitor = StagnationMonitor::new(improvement_threshold, best_fitness);
88        let information_exchange = InformationExchange::new(conf.clone(), opt_prob.clone());
89
90        Self {
91            conf,
92            st,
93            swarms,
94            opt_prob,
95            stagnation_monitor,
96            information_exchange,
97        }
98    }
99
100    pub fn stagnation_counter(&self) -> usize {
101        self.stagnation_monitor.stagnation_counter()
102    }
103
104    pub fn is_stagnated(&self) -> bool {
105        self.stagnation_monitor.is_stagnated()
106    }
107
108    pub fn get_swarm_diversity(&self) -> Vec<f64> {
109        self.swarms.iter().map(|s| s.current_diversity()).collect()
110    }
111
112    pub fn get_average_improvement(&self, window_size: usize) -> Vec<T> {
113        self.swarms
114            .iter()
115            .map(|s| s.average_improvement(window_size))
116            .collect()
117    }
118
119    pub fn get_performance_stats(&self) -> (f64, f64, f64) {
120        self.stagnation_monitor.get_performance_stats()
121    }
122
123    pub fn get_population(&self) -> OMatrix<T, N, D> {
124        get_population(&self.swarms)
125    }
126}
127
128impl<T, N, D> OptimizationAlgorithm<T, N, D> for MSPO<T, N, D>
129where
130    T: FloatNum + Send + Sync,
131    N: Dim + Send + Sync,
132    D: Dim + Send + Sync,
133    OVector<T, D>: Send + Sync,
134    OMatrix<T, N, D>: Send + Sync,
135    DefaultAllocator: Allocator<D>
136        + Allocator<N, D>
137        + Allocator<N>
138        + Allocator<U1, D>
139        + Allocator<D, D>
140        + Allocator<U1>,
141{
142    fn step(&mut self) {
143        let results: Vec<_> = self
144            .swarms
145            .par_iter_mut()
146            .map(|swarm| {
147                swarm.update(&self.opt_prob);
148                (
149                    swarm.global_best_position.clone(),
150                    swarm.global_best_fitness,
151                )
152            })
153            .collect();
154
155        for (pos, fitness) in results {
156            if fitness > self.st.best_f && self.opt_prob.is_feasible(&pos) {
157                self.st.best_f = fitness;
158                self.st.best_x = pos;
159            }
160        }
161
162        self.stagnation_monitor.check_stagnation(self.st.best_f);
163
164        let exchange_interval = if self.stagnation_monitor.stagnation_counter() > 10 {
165            self.conf.exchange_interval / 2 // More frequent exchange when stagnated
166        } else {
167            self.conf.exchange_interval
168        };
169
170        if self.st.iter % exchange_interval == 0 {
171            self.information_exchange.exchange_information(
172                &mut self.swarms,
173                self.stagnation_monitor.stagnation_counter(),
174            );
175        }
176
177        update_population_state(&mut self.st, &self.swarms, &self.opt_prob);
178        self.st.iter += 1;
179    }
180
181    fn state(&self) -> &State<T, N, D> {
182        &self.st
183    }
184}