Skip to main content

pso/optimizer/
swarm.rs

1mod minima;
2mod particle;
3mod transient_model;
4
5/// Configure behavior of Particle Swarms
6pub mod swarm_config;
7
8pub use minima::GlobalRecord;
9
10use super::{Bounds, JobConfig, StopCondition};
11use minima::{OwnedRecord, RecordBeatResult};
12use particle::Particle;
13use rand::rngs::SmallRng;
14use rand::SeedableRng;
15use std::sync::Arc;
16use swarm_config::{GlobalBehavior, SwarmConfig, TransientBehavior};
17use transient_model::TransientModel;
18
19const DEFAULT_UPDATE_RATE: usize = 16;
20
21#[derive(Clone, Debug)]
22pub struct Swarm {
23    name: String,
24    particles: Vec<Particle>,
25    transient_model: TransientModel,
26    bound_state: BoundState,
27    stop_state: StopState,
28    global_updater: GlobalUpdater,
29    synergic: bool,
30    update_console: Option<usize>,
31    num_variables: usize,
32}
33
34#[allow(dead_code)]
35impl Swarm {
36    pub fn new(name: String, sc: SwarmConfig, jc: &JobConfig) -> Self {
37        let particles = Self::initialize_particles(&sc, &jc);
38
39        let motion_coefficients = match sc.global_behavior {
40            GlobalBehavior::Synergic(global_coeff, _) => {
41                vec![sc.local_motion, sc.tribal_motion, global_coeff]
42            }
43            GlobalBehavior::Solitary => vec![sc.local_motion, sc.tribal_motion],
44        };
45
46        let transient_model = TransientModel::new(
47            &sc.transient_behavior,
48            &jc.stop_condition,
49            jc.num_variables,
50            motion_coefficients,
51            sc.momentum,
52        );
53
54        let bound_state = BoundState::new(&sc, &jc);
55        let stop_state = StopState::new(&jc.stop_condition);
56        let (global_updater, synergic) = GlobalUpdater::new(&sc.global_behavior);
57        let update_console = jc.update_console;
58        let num_variables = jc.num_variables;
59
60        Self {
61            name,
62            particles,
63            transient_model,
64            bound_state,
65            stop_state,
66            global_updater,
67            synergic,
68            update_console,
69            num_variables,
70        }
71    }
72
73    pub fn minimize_fn_mut(
74        &mut self,
75        mut objective: impl FnMut(&[f64]) -> f64,
76        global_record: &mut GlobalRecord,
77    ) {
78        let mut tribal_record = OwnedRecord::new(self.num_variables);
79        let mut owned_global_record = OwnedRecord::new(self.num_variables);
80
81        let mut rng = SmallRng::from_entropy();
82
83        while let Some(iteration) = self.stop_state.next_iteration(&owned_global_record) {
84            let m = self.transient_model.momentum(iteration);
85            let motion_coeffs = self.transient_model.motion_coeffs(iteration);
86
87            let mut record_beat = false;
88            let mut min_cost = tribal_record.cost;
89            let mut min_particle_index = 0;
90
91            for (i, p) in self.particles.iter_mut().enumerate() {
92                let rand_vecs = self.transient_model.rand_vecs(iteration, &mut rng);
93
94                // test cost function at current location
95                if let Some(new_min_cost) = p.cost_fn_mut(&mut objective) {
96                    if new_min_cost < min_cost {
97                        record_beat = true;
98                        min_particle_index = i;
99                        min_cost = new_min_cost;
100                    }
101                }
102
103                // update position and velocity
104                if self.synergic {
105                    p.update_syn(
106                        m,
107                        rand_vecs,
108                        motion_coeffs,
109                        tribal_record.pos(),
110                        owned_global_record.pos(),
111                    );
112                } else {
113                    p.update_sol(m, rand_vecs, motion_coeffs, tribal_record.pos());
114                }
115
116                // enforce search space and velocity bounds
117                p.enforce_bounds(self.bound_state.values());
118            }
119
120            // bubble new records up the chain to tribal and global
121            if record_beat {
122                tribal_record = self.particles[min_particle_index].get_record();
123                if tribal_record < owned_global_record {
124                    owned_global_record = tribal_record.clone();
125                }
126            }
127
128            // update or download global record as necessary
129            self.global_updater
130                .update_global(iteration, &mut owned_global_record, global_record);
131
132            self.print_progress(iteration, &owned_global_record, &record_beat);
133        }
134
135        self.global_updater
136            .update_global(0, &mut owned_global_record, global_record)
137    }
138
139    pub fn minimize_fn<T>(&mut self, objective: Arc<T>, global_record: &mut GlobalRecord)
140    where
141        T: Fn(&[f64]) -> f64,
142    {
143        let mut tribal_record = OwnedRecord::new(self.num_variables);
144        let mut owned_global_record = OwnedRecord::new(self.num_variables);
145
146        let mut rng = SmallRng::from_entropy();
147
148        while let Some(iteration) = self.stop_state.next_iteration(&owned_global_record) {
149            let m = self.transient_model.momentum(iteration);
150            let motion_coeffs = self.transient_model.motion_coeffs(iteration);
151
152            let mut record_beat = false;
153            let mut min_cost = tribal_record.cost;
154            let mut min_particle_index = 0;
155
156            for (i, p) in self.particles.iter_mut().enumerate() {
157                let rand_vecs = self.transient_model.rand_vecs(iteration, &mut rng);
158
159                // test cost function at current location
160                if let Some(new_min_cost) = p.cost_fn(&objective) {
161                    if new_min_cost < min_cost {
162                        record_beat = true;
163                        min_particle_index = i;
164                        min_cost = new_min_cost;
165                    }
166                }
167
168                // update position and velocity
169                if self.synergic {
170                    p.update_syn(
171                        m,
172                        rand_vecs,
173                        motion_coeffs,
174                        tribal_record.pos(),
175                        owned_global_record.pos(),
176                    );
177                } else {
178                    p.update_sol(m, rand_vecs, motion_coeffs, tribal_record.pos());
179                }
180
181                // enforce search space and velocity bounds
182                p.enforce_bounds(self.bound_state.values());
183            }
184
185            // bubble new records up the chain to tribal and global
186            if record_beat {
187                tribal_record = self.particles[min_particle_index].get_record();
188                if tribal_record < owned_global_record {
189                    owned_global_record = tribal_record.clone();
190                }
191            }
192
193            // update or download global record as necessary
194            self.global_updater
195                .update_global(iteration, &mut owned_global_record, global_record);
196
197            self.print_progress(iteration, &owned_global_record, &record_beat);
198        }
199
200        self.global_updater
201            .update_global(0, &mut owned_global_record, global_record)
202    }
203
204    fn initialize_particles(sc: &SwarmConfig, jc: &JobConfig) -> Vec<Particle> {
205        let mut rng = rand::thread_rng();
206
207        // position and velocity vectors
208        let num_points = sc.num_particles * jc.num_variables;
209        let mut positions = Vec::with_capacity(num_points);
210        let mut velocities = Vec::with_capacity(num_points);
211
212        // fill vectors randomly
213        for _ in 0..num_points {
214            positions.extend_from_slice(&jc.variable_bounds.sample_bounds(&mut rng));
215            velocities.extend_from_slice(&jc.velocity_bounds.sample_bounds(&mut rng));
216        }
217
218        // partition random values into individual particles
219        positions
220            .chunks_exact(jc.num_variables)
221            .zip(velocities.chunks_exact(jc.num_variables))
222            .map(|(pos, vel)| Particle::new(pos, vel))
223            .collect()
224    }
225
226    fn print_progress(&self, iteration: usize, record: &OwnedRecord, record_beat: &bool) {
227        if let Some(console_update_rate) = self.update_console {
228            if iteration % console_update_rate == 0 {
229                println!(
230                    "{} {} \t iter {} \t {}",
231                    self.name, record_beat, iteration, record
232                );
233            }
234        }
235    }
236}
237
238#[derive(Clone, Debug)]
239struct GlobalUpdater {
240    update_rate: usize,
241    retry: bool,
242}
243
244impl GlobalUpdater {
245    pub fn new(global_behavior: &GlobalBehavior) -> (Self, bool) {
246        let (update_rate, synergic) = match global_behavior {
247            GlobalBehavior::Synergic(_, rate) => (*rate, true),
248            GlobalBehavior::Solitary => (DEFAULT_UPDATE_RATE, false),
249        };
250
251        (
252            Self {
253                update_rate,
254                retry: false,
255            },
256            synergic,
257        )
258    }
259
260    pub fn update_global(
261        &mut self,
262        iteration: usize,
263        local_record: &mut OwnedRecord,
264        global_record: &mut GlobalRecord,
265    ) {
266        if iteration % self.update_rate == 0 || self.retry {
267            self.retry = false;
268
269            match global_record.beats_global(&local_record) {
270                RecordBeatResult::Won => {
271                    if let RecordBeatResult::Retry = global_record.update_global(&local_record) {
272                        self.retry = true;
273                    }
274                }
275                RecordBeatResult::Lost(new_global_rec) => {
276                    *local_record = new_global_rec;
277                }
278                RecordBeatResult::Retry => {
279                    self.retry = true;
280                }
281                RecordBeatResult::Same => (),
282            }
283        }
284    }
285}
286
287#[derive(Clone, Debug)]
288struct BoundState {
289    pos_bounds: Vec<[f64; 2]>,
290    vel_bounds: Vec<[f64; 2]>,
291    wall_bounce: f64,
292}
293
294impl BoundState {
295    pub fn new(sc: &SwarmConfig, jc: &JobConfig) -> Self {
296        Self {
297            pos_bounds: jc.variable_bounds.bound_array(),
298            vel_bounds: jc.velocity_bounds.bound_array(),
299            wall_bounce: sc.wall_bounce_factor,
300        }
301    }
302
303    pub fn values(&self) -> (&[[f64; 2]], &[[f64; 2]], f64) {
304        (&self.pos_bounds, &self.vel_bounds, self.wall_bounce)
305    }
306}
307
308#[derive(Clone, Debug)]
309struct StopState {
310    iteration: usize,
311    stop_condition: StopCondition,
312}
313
314impl StopState {
315    pub fn new(stop_condition: &StopCondition) -> Self {
316        Self {
317            iteration: 0,
318            stop_condition: stop_condition.clone(),
319        }
320    }
321
322    pub fn next_iteration(&mut self, global_record: &OwnedRecord) -> Option<usize> {
323        match self.stop_condition {
324            StopCondition::Iterations(max_iter) => {
325                if self.iteration <= max_iter {
326                    self.iteration += 1;
327                    return Some(self.iteration);
328                }
329            }
330            StopCondition::Cost(min_cost) => {
331                if global_record.cost > min_cost {
332                    self.iteration += 1;
333                    return Some(self.iteration);
334                }
335            }
336            StopCondition::Both(max_iter, min_cost) => {
337                if self.iteration <= max_iter && global_record.cost > min_cost {
338                    self.iteration += 1;
339                    return Some(self.iteration);
340                }
341            }
342        }
343        None
344    }
345}