1use crate::model::*;
2use indicatif::{ProgressBar, ProgressStyle};
3use rand::{thread_rng, Rng};
4
5use std::error::Error;
6use std::fs::File;
7use std::io::Write;
8
9pub struct PSO {
13 chi: f64,
14 v_max: f64,
15 pub model: Model,
16 neighborhoods: Vec<Vec<usize>>,
17 velocities: Population,
18 pub neigh_population: Population,
19 pub best_f_values: Vec<f64>,
20 pub best_f_trajectory: Vec<f64>,
21 pub best_x_trajectory: Vec<Particle>,
22}
23
24impl PSO {
25 pub fn new(model: Model) -> PSO {
27 let phi = model.config.c1 + model.config.c2;
28 let phi_squared = phi.powf(2.0);
29 let tmp = phi_squared - (4.0 * phi);
30 let tmp = tmp.sqrt();
31 let chi = 2.0 / (2.0 - phi - tmp).abs();
32 let v_max = model.config.alpha * 5.0;
33 let neighborhoods = PSO::create_neighborhoods(&model);
34
35 let mut rng = thread_rng();
37 let mut velocities = vec![];
38 for _ in 0..model.config.population_size {
39 let mut tmp = vec![];
40 for _ in 0..model.flat_dim {
41 tmp.push(rng.gen_range(v_max * -1.0..v_max * 1.0));
42 }
43 velocities.push(tmp);
44 }
45
46 let best_f_values = model.population_f_scores.clone();
47 let neigh_population = model.population.clone();
48 let best_f_trajectory = vec![model.f_best];
49 let best_x_trajectory = vec![model.x_best.clone()];
50
51 PSO {
52 chi,
53 v_max,
54 model,
55 neighborhoods,
56 velocities,
57 best_f_values,
58 neigh_population,
59 best_f_trajectory,
60 best_x_trajectory,
61 }
62 }
63
64 pub fn run(&mut self, terminate: fn(f64) -> bool) -> usize {
70 let mut bar: Option<ProgressBar> = None;
71 if self.model.config.progress_bar {
72 bar = Some(ProgressBar::new(self.model.config.t_max as u64));
73 match bar {
74 Some(ref bar) => {
75 bar.set_style(ProgressStyle::default_bar().template(
76 "{msg} [{elapsed}] {bar:20.cyan/blue} {pos:>7}/{len:7} ETA: {eta}",
77 ));
78 }
79 None => {}
80 }
81 }
82 let mut k = 0;
83 let pop_size = self.model.config.population_size;
84 loop {
85 self.update_velocity_and_pos();
87
88 self.model.get_f_values();
90 self.update_best_positions();
91
92 self.model.population = self.model.population.clone();
93 k += pop_size;
94 match bar {
95 Some(ref bar) => {
96 bar.inc(pop_size as u64);
97 bar.set_message(format!("{:.6}", self.model.f_best));
98 }
99 None => {}
100 }
101 if k > self.model.config.t_max || terminate(self.model.f_best) {
102 break;
103 }
104 }
105 match bar {
106 Some(ref bar) => {
107 bar.finish_and_clear();
108 }
109 None => {}
110 }
111 k
112 }
113
114 fn update_velocity_and_pos(&mut self) {
116 let mut rng = thread_rng();
117
118 for i in 0..self.model.config.population_size {
119 let lbest = &self.neigh_population[self.local_best(i)];
120 for j in 0..self.model.flat_dim {
121 let r1 = rng.gen_range(-1.0..1.0);
122 let r2 = rng.gen_range(-1.0..1.0);
123 let cog = self.model.config.c1
124 * r1
125 * (self.neigh_population[i][j] - self.model.population[i][j]);
126
127 let soc = self.model.config.c2 * r2 * (lbest[j] - self.model.population[i][j]);
128 let v = self.chi * (self.velocities[i][j] + cog + soc);
129
130 self.velocities[i][j] = if v.abs() > self.v_max {
132 v.signum() * self.v_max
133 } else {
134 v
135 };
136
137 let x = self.model.population[i][j] + self.model.config.lr * self.velocities[i][j];
138
139 let bound_index =
140 j % self.model.config.dimensions[self.model.config.dimensions.len() - 1];
141 let (lower_bound, upper_bound) = self.model.config.bounds[bound_index];
142 if x > upper_bound {
144 self.model.population[i][j] = upper_bound;
145 } else if x < lower_bound {
146 self.model.population[i][j] = lower_bound;
147 } else {
148 self.model.population[i][j] = x;
149 }
150 if x.is_nan() {
151 panic!("A coefficient became NaN!");
152 }
153 }
154 }
155 }
156
157 fn update_best_positions(&mut self) {
159 for i in 0..self.best_f_values.len() {
160 let new = self.model.population_f_scores[i];
161 let old = self.best_f_values[i];
162
163 if new < old {
164 self.best_f_values[i] = new;
165 self.neigh_population[i] = self.model.population[i].clone();
166 }
167 }
168 self.best_f_trajectory.push(self.model.f_best);
169 self.best_x_trajectory.push(self.model.x_best.clone());
170 }
171
172 fn local_best(&self, i: usize) -> usize {
174 let best = PSO::argsort(&self.best_f_values);
175 for b in best {
176 if self.neighborhoods[i].iter().any(|&n| n == b) {
177 return b;
178 }
179 }
180 0
181 }
182
183 fn create_neighborhoods(model: &Model) -> Vec<Vec<usize>> {
185 let mut neighborhoods;
186 match model.config.neighborhood_type {
187 NeighborhoodType::Lbest => {
188 neighborhoods = vec![];
189 for i in 0..model.config.population_size {
190 let mut neighbor = vec![];
191 let first_neighbor = i as i32 - model.config.rho as i32;
192 let last_neighbor = i as i32 + model.config.rho as i32;
193
194 for neighbor_i in first_neighbor..last_neighbor {
195 neighbor.push(if neighbor_i < 0 {
196 (model.config.population_size as i32 - neighbor_i) as usize
197 } else {
198 neighbor_i as usize
199 });
200 }
201 neighborhoods.push(neighbor)
202 }
203 }
204 NeighborhoodType::Gbest => {
205 neighborhoods = vec![];
206 for _ in 0..model.config.population_size {
207 let mut tmp = vec![];
208 for j in 0..model.config.population_size {
209 tmp.push(j);
210 }
211 neighborhoods.push(tmp);
212 }
213 }
214 }
215 neighborhoods
216 }
217
218 fn argsort(v: &Vec<f64>) -> Vec<usize> {
220 let mut idx = (0..v.len()).collect::<Vec<_>>();
221 idx.sort_by(|&i, &j| v[i].partial_cmp(&v[j]).expect("NaN"));
222 idx
223 }
224
225 pub fn write_f_to_file(&self, filepath: &str) -> Result<(), Box<dyn Error>> {
227 let best_f_str: Vec<String> = self
228 .best_f_trajectory
229 .iter()
230 .map(|n| n.to_string())
231 .collect();
232
233 let mut file = File::create(filepath)?;
234 writeln!(file, "{}", best_f_str.join("\n"))?;
235
236 Ok(())
237 }
238
239 pub fn write_x_to_file(&self, filepath: &str) -> Result<(), Box<dyn Error>> {
243 let best_x_str: Vec<String> = self
244 .best_x_trajectory
245 .iter()
246 .map(|x| {
247 x.iter()
248 .map(|coef: &f64| coef.to_string())
249 .collect::<Vec<String>>()
250 .join(", ")
251 })
252 .collect();
253
254 let mut file = File::create(filepath)?;
255 writeln!(file, "{}", best_x_str.join("\n"))?;
256
257 Ok(())
258 }
259}