1mod minima;
2mod particle;
3mod transient_model;
4
5pub mod swarm_config;
7
8pub use minima::GlobalRecord;
9
10use super::{Bounds, JobConfig, StopCondition};
11use minima::{OwnedRecord, RecordBeatResult};
12use particle::Particle;
13use rand::rngs::SmallRng;
14use rand::SeedableRng;
15use std::sync::Arc;
16use swarm_config::{GlobalBehavior, SwarmConfig, TransientBehavior};
17use transient_model::TransientModel;
18
19const DEFAULT_UPDATE_RATE: usize = 16;
20
21#[derive(Clone, Debug)]
22pub struct Swarm {
23 name: String,
24 particles: Vec<Particle>,
25 transient_model: TransientModel,
26 bound_state: BoundState,
27 stop_state: StopState,
28 global_updater: GlobalUpdater,
29 synergic: bool,
30 update_console: Option<usize>,
31 num_variables: usize,
32}
33
34#[allow(dead_code)]
35impl Swarm {
36 pub fn new(name: String, sc: SwarmConfig, jc: &JobConfig) -> Self {
37 let particles = Self::initialize_particles(&sc, &jc);
38
39 let motion_coefficients = match sc.global_behavior {
40 GlobalBehavior::Synergic(global_coeff, _) => {
41 vec![sc.local_motion, sc.tribal_motion, global_coeff]
42 }
43 GlobalBehavior::Solitary => vec![sc.local_motion, sc.tribal_motion],
44 };
45
46 let transient_model = TransientModel::new(
47 &sc.transient_behavior,
48 &jc.stop_condition,
49 jc.num_variables,
50 motion_coefficients,
51 sc.momentum,
52 );
53
54 let bound_state = BoundState::new(&sc, &jc);
55 let stop_state = StopState::new(&jc.stop_condition);
56 let (global_updater, synergic) = GlobalUpdater::new(&sc.global_behavior);
57 let update_console = jc.update_console;
58 let num_variables = jc.num_variables;
59
60 Self {
61 name,
62 particles,
63 transient_model,
64 bound_state,
65 stop_state,
66 global_updater,
67 synergic,
68 update_console,
69 num_variables,
70 }
71 }
72
73 pub fn minimize_fn_mut(
74 &mut self,
75 mut objective: impl FnMut(&[f64]) -> f64,
76 global_record: &mut GlobalRecord,
77 ) {
78 let mut tribal_record = OwnedRecord::new(self.num_variables);
79 let mut owned_global_record = OwnedRecord::new(self.num_variables);
80
81 let mut rng = SmallRng::from_entropy();
82
83 while let Some(iteration) = self.stop_state.next_iteration(&owned_global_record) {
84 let m = self.transient_model.momentum(iteration);
85 let motion_coeffs = self.transient_model.motion_coeffs(iteration);
86
87 let mut record_beat = false;
88 let mut min_cost = tribal_record.cost;
89 let mut min_particle_index = 0;
90
91 for (i, p) in self.particles.iter_mut().enumerate() {
92 let rand_vecs = self.transient_model.rand_vecs(iteration, &mut rng);
93
94 if let Some(new_min_cost) = p.cost_fn_mut(&mut objective) {
96 if new_min_cost < min_cost {
97 record_beat = true;
98 min_particle_index = i;
99 min_cost = new_min_cost;
100 }
101 }
102
103 if self.synergic {
105 p.update_syn(
106 m,
107 rand_vecs,
108 motion_coeffs,
109 tribal_record.pos(),
110 owned_global_record.pos(),
111 );
112 } else {
113 p.update_sol(m, rand_vecs, motion_coeffs, tribal_record.pos());
114 }
115
116 p.enforce_bounds(self.bound_state.values());
118 }
119
120 if record_beat {
122 tribal_record = self.particles[min_particle_index].get_record();
123 if tribal_record < owned_global_record {
124 owned_global_record = tribal_record.clone();
125 }
126 }
127
128 self.global_updater
130 .update_global(iteration, &mut owned_global_record, global_record);
131
132 self.print_progress(iteration, &owned_global_record, &record_beat);
133 }
134
135 self.global_updater
136 .update_global(0, &mut owned_global_record, global_record)
137 }
138
139 pub fn minimize_fn<T>(&mut self, objective: Arc<T>, global_record: &mut GlobalRecord)
140 where
141 T: Fn(&[f64]) -> f64,
142 {
143 let mut tribal_record = OwnedRecord::new(self.num_variables);
144 let mut owned_global_record = OwnedRecord::new(self.num_variables);
145
146 let mut rng = SmallRng::from_entropy();
147
148 while let Some(iteration) = self.stop_state.next_iteration(&owned_global_record) {
149 let m = self.transient_model.momentum(iteration);
150 let motion_coeffs = self.transient_model.motion_coeffs(iteration);
151
152 let mut record_beat = false;
153 let mut min_cost = tribal_record.cost;
154 let mut min_particle_index = 0;
155
156 for (i, p) in self.particles.iter_mut().enumerate() {
157 let rand_vecs = self.transient_model.rand_vecs(iteration, &mut rng);
158
159 if let Some(new_min_cost) = p.cost_fn(&objective) {
161 if new_min_cost < min_cost {
162 record_beat = true;
163 min_particle_index = i;
164 min_cost = new_min_cost;
165 }
166 }
167
168 if self.synergic {
170 p.update_syn(
171 m,
172 rand_vecs,
173 motion_coeffs,
174 tribal_record.pos(),
175 owned_global_record.pos(),
176 );
177 } else {
178 p.update_sol(m, rand_vecs, motion_coeffs, tribal_record.pos());
179 }
180
181 p.enforce_bounds(self.bound_state.values());
183 }
184
185 if record_beat {
187 tribal_record = self.particles[min_particle_index].get_record();
188 if tribal_record < owned_global_record {
189 owned_global_record = tribal_record.clone();
190 }
191 }
192
193 self.global_updater
195 .update_global(iteration, &mut owned_global_record, global_record);
196
197 self.print_progress(iteration, &owned_global_record, &record_beat);
198 }
199
200 self.global_updater
201 .update_global(0, &mut owned_global_record, global_record)
202 }
203
204 fn initialize_particles(sc: &SwarmConfig, jc: &JobConfig) -> Vec<Particle> {
205 let mut rng = rand::thread_rng();
206
207 let num_points = sc.num_particles * jc.num_variables;
209 let mut positions = Vec::with_capacity(num_points);
210 let mut velocities = Vec::with_capacity(num_points);
211
212 for _ in 0..num_points {
214 positions.extend_from_slice(&jc.variable_bounds.sample_bounds(&mut rng));
215 velocities.extend_from_slice(&jc.velocity_bounds.sample_bounds(&mut rng));
216 }
217
218 positions
220 .chunks_exact(jc.num_variables)
221 .zip(velocities.chunks_exact(jc.num_variables))
222 .map(|(pos, vel)| Particle::new(pos, vel))
223 .collect()
224 }
225
226 fn print_progress(&self, iteration: usize, record: &OwnedRecord, record_beat: &bool) {
227 if let Some(console_update_rate) = self.update_console {
228 if iteration % console_update_rate == 0 {
229 println!(
230 "{} {} \t iter {} \t {}",
231 self.name, record_beat, iteration, record
232 );
233 }
234 }
235 }
236}
237
238#[derive(Clone, Debug)]
239struct GlobalUpdater {
240 update_rate: usize,
241 retry: bool,
242}
243
244impl GlobalUpdater {
245 pub fn new(global_behavior: &GlobalBehavior) -> (Self, bool) {
246 let (update_rate, synergic) = match global_behavior {
247 GlobalBehavior::Synergic(_, rate) => (*rate, true),
248 GlobalBehavior::Solitary => (DEFAULT_UPDATE_RATE, false),
249 };
250
251 (
252 Self {
253 update_rate,
254 retry: false,
255 },
256 synergic,
257 )
258 }
259
260 pub fn update_global(
261 &mut self,
262 iteration: usize,
263 local_record: &mut OwnedRecord,
264 global_record: &mut GlobalRecord,
265 ) {
266 if iteration % self.update_rate == 0 || self.retry {
267 self.retry = false;
268
269 match global_record.beats_global(&local_record) {
270 RecordBeatResult::Won => {
271 if let RecordBeatResult::Retry = global_record.update_global(&local_record) {
272 self.retry = true;
273 }
274 }
275 RecordBeatResult::Lost(new_global_rec) => {
276 *local_record = new_global_rec;
277 }
278 RecordBeatResult::Retry => {
279 self.retry = true;
280 }
281 RecordBeatResult::Same => (),
282 }
283 }
284 }
285}
286
287#[derive(Clone, Debug)]
288struct BoundState {
289 pos_bounds: Vec<[f64; 2]>,
290 vel_bounds: Vec<[f64; 2]>,
291 wall_bounce: f64,
292}
293
294impl BoundState {
295 pub fn new(sc: &SwarmConfig, jc: &JobConfig) -> Self {
296 Self {
297 pos_bounds: jc.variable_bounds.bound_array(),
298 vel_bounds: jc.velocity_bounds.bound_array(),
299 wall_bounce: sc.wall_bounce_factor,
300 }
301 }
302
303 pub fn values(&self) -> (&[[f64; 2]], &[[f64; 2]], f64) {
304 (&self.pos_bounds, &self.vel_bounds, self.wall_bounce)
305 }
306}
307
308#[derive(Clone, Debug)]
309struct StopState {
310 iteration: usize,
311 stop_condition: StopCondition,
312}
313
314impl StopState {
315 pub fn new(stop_condition: &StopCondition) -> Self {
316 Self {
317 iteration: 0,
318 stop_condition: stop_condition.clone(),
319 }
320 }
321
322 pub fn next_iteration(&mut self, global_record: &OwnedRecord) -> Option<usize> {
323 match self.stop_condition {
324 StopCondition::Iterations(max_iter) => {
325 if self.iteration <= max_iter {
326 self.iteration += 1;
327 return Some(self.iteration);
328 }
329 }
330 StopCondition::Cost(min_cost) => {
331 if global_record.cost > min_cost {
332 self.iteration += 1;
333 return Some(self.iteration);
334 }
335 }
336 StopCondition::Both(max_iter, min_cost) => {
337 if self.iteration <= max_iter && global_record.cost > min_cost {
338 self.iteration += 1;
339 return Some(self.iteration);
340 }
341 }
342 }
343 None
344 }
345}