Skip to main content

graphmind_optimization/algorithms/
nsga2.rs

1use crate::common::{
2    MultiObjectiveIndividual, MultiObjectiveProblem, MultiObjectiveResult, SolverConfig,
3};
4use ndarray::Array1;
5use rand::prelude::*;
6
7pub struct NSGA2Solver {
8    pub config: SolverConfig,
9    pub mutation_rate: f64,
10}
11
12impl NSGA2Solver {
13    pub fn new(config: SolverConfig) -> Self {
14        Self {
15            config,
16            mutation_rate: 0.1,
17        }
18    }
19
20    pub fn solve<P: MultiObjectiveProblem>(&self, problem: &P) -> MultiObjectiveResult {
21        let mut rng = thread_rng();
22        let dim = problem.dim();
23        let (lower, upper) = problem.bounds();
24        let pop_size = self.config.population_size;
25
26        // 1. Initialize Population
27        let mut population: Vec<MultiObjectiveIndividual> = (0..pop_size)
28            .map(|_| {
29                let mut vars = Array1::zeros(dim);
30                for i in 0..dim {
31                    vars[i] = rng.gen_range(lower[i]..upper[i]);
32                }
33                let fitness = problem.objectives(&vars);
34                MultiObjectiveIndividual::new(vars, fitness)
35            })
36            .collect();
37
38        self.evaluate_population(&mut population);
39
40        let mut history = Vec::with_capacity(self.config.max_iterations);
41
42        for _iter in 0..self.config.max_iterations {
43            // 2. Create Offspring (Crossover + Mutation)
44            let mut offspring = Vec::with_capacity(pop_size);
45            while offspring.len() < pop_size {
46                let p1 = self.tournament_select(&population);
47                let p2 = self.tournament_select(&population);
48
49                let (mut c1_vars, mut c2_vars) = self.crossover(&p1.variables, &p2.variables);
50                self.mutate(&mut c1_vars, &lower, &upper);
51                self.mutate(&mut c2_vars, &lower, &upper);
52
53                offspring.push(MultiObjectiveIndividual::new(
54                    c1_vars.clone(),
55                    problem.objectives(&c1_vars),
56                ));
57                if offspring.len() < pop_size {
58                    offspring.push(MultiObjectiveIndividual::new(
59                        c2_vars.clone(),
60                        problem.objectives(&c2_vars),
61                    ));
62                }
63            }
64
65            // 3. Merge Population and Offspring (2N)
66            let mut combined = population;
67            combined.extend(offspring);
68
69            // 4. Non-dominated Sort + Crowding Distance
70            self.evaluate_population(&mut combined);
71
72            // 5. Select Best N
73            combined.sort_by(|a, b| {
74                if a.rank != b.rank {
75                    a.rank.cmp(&b.rank)
76                } else {
77                    // Larger crowding distance is better
78                    b.crowding_distance
79                        .partial_cmp(&a.crowding_distance)
80                        .unwrap()
81                }
82            });
83
84            combined.truncate(pop_size);
85            population = combined;
86
87            history.push(population[0].fitness[0]); // Track first objective of best-ranked
88        }
89
90        MultiObjectiveResult {
91            pareto_front: population.into_iter().filter(|ind| ind.rank == 0).collect(),
92            history,
93        }
94    }
95
96    fn evaluate_population(&self, population: &mut [MultiObjectiveIndividual]) {
97        // Fast Non-dominated Sort
98        self.non_dominated_sort(population);
99
100        // Crowding Distance per rank
101        let mut rank = 0;
102        loop {
103            let current_rank_indices: Vec<usize> = population
104                .iter()
105                .enumerate()
106                .filter(|(_, ind)| ind.rank == rank)
107                .map(|(i, _)| i)
108                .collect();
109
110            if current_rank_indices.is_empty() {
111                break;
112            }
113
114            self.calculate_crowding_distance(population, &current_rank_indices);
115            rank += 1;
116        }
117    }
118
119    fn non_dominated_sort(&self, population: &mut [MultiObjectiveIndividual]) {
120        let n = population.len();
121        let mut dominance_counts = vec![0; n];
122        let mut dominated_sets = vec![Vec::new(); n];
123        let mut fronts = vec![Vec::new()];
124
125        for i in 0..n {
126            for j in 0..n {
127                if i == j {
128                    continue;
129                }
130                if self.dominates(&population[i].fitness, &population[j].fitness) {
131                    dominated_sets[i].push(j);
132                } else if self.dominates(&population[j].fitness, &population[i].fitness) {
133                    dominance_counts[i] += 1;
134                }
135            }
136            if dominance_counts[i] == 0 {
137                population[i].rank = 0;
138                fronts[0].push(i);
139            }
140        }
141
142        let mut i = 0;
143        while !fronts[i].is_empty() {
144            let mut next_front = Vec::new();
145            for &p in &fronts[i] {
146                for &q in &dominated_sets[p] {
147                    dominance_counts[q] -= 1;
148                    if dominance_counts[q] == 0 {
149                        population[q].rank = i + 1;
150                        next_front.push(q);
151                    }
152                }
153            }
154            i += 1;
155            fronts.push(next_front);
156        }
157    }
158
159    fn dominates(&self, f1: &[f64], f2: &[f64]) -> bool {
160        let mut better = false;
161        for i in 0..f1.len() {
162            if f1[i] > f2[i] {
163                return false;
164            }
165            if f1[i] < f2[i] {
166                better = true;
167            }
168        }
169        better
170    }
171
172    fn calculate_crowding_distance(
173        &self,
174        population: &mut [MultiObjectiveIndividual],
175        indices: &[usize],
176    ) {
177        let num_objectives = population[0].fitness.len();
178        for &idx in indices {
179            population[idx].crowding_distance = 0.0;
180        }
181
182        for m in 0..num_objectives {
183            let mut sorted_indices = indices.to_vec();
184            sorted_indices.sort_by(|&a, &b| {
185                population[a].fitness[m]
186                    .partial_cmp(&population[b].fitness[m])
187                    .unwrap_or(std::cmp::Ordering::Equal)
188            });
189
190            let min_val = population[*sorted_indices.first().unwrap()].fitness[m];
191            let max_val = population[*sorted_indices.last().unwrap()].fitness[m];
192            let range = max_val - min_val;
193
194            population[*sorted_indices.first().unwrap()].crowding_distance = f64::INFINITY;
195            population[*sorted_indices.last().unwrap()].crowding_distance = f64::INFINITY;
196
197            if range > 1e-9 {
198                for i in 1..(sorted_indices.len() - 1) {
199                    let prev = population[sorted_indices[i - 1]].fitness[m];
200                    let next = population[sorted_indices[i + 1]].fitness[m];
201                    population[sorted_indices[i]].crowding_distance += (next - prev) / range;
202                }
203            }
204        }
205    }
206
207    fn tournament_select<'a>(
208        &self,
209        population: &'a [MultiObjectiveIndividual],
210    ) -> &'a MultiObjectiveIndividual {
211        let mut rng = thread_rng();
212        let i1 = rng.gen_range(0..population.len());
213        let i2 = rng.gen_range(0..population.len());
214
215        let p1 = &population[i1];
216        let p2 = &population[i2];
217
218        if p1.rank < p2.rank {
219            p1
220        } else if p2.rank < p1.rank {
221            p2
222        } else if p1.crowding_distance > p2.crowding_distance {
223            p1
224        } else {
225            p2
226        }
227    }
228
229    fn crossover(&self, p1: &Array1<f64>, p2: &Array1<f64>) -> (Array1<f64>, Array1<f64>) {
230        let mut rng = thread_rng();
231        let dim = p1.len();
232        let mut c1 = p1.clone();
233        let mut c2 = p2.clone();
234
235        // BLX-alpha crossover or similar for continuous
236        for i in 0..dim {
237            if rng.gen_bool(0.5) {
238                let alpha = 0.5;
239                let min = p1[i].min(p2[i]);
240                let max = p1[i].max(p2[i]);
241                let range = max - min;
242
243                let lower = min - alpha * range;
244                let upper = max + alpha * range;
245
246                if (upper - lower).abs() > 1e-9 {
247                    c1[i] = rng.gen_range(lower..upper);
248                    c2[i] = rng.gen_range(lower..upper);
249                }
250            }
251        }
252        (c1, c2)
253    }
254
255    fn mutate(&self, vars: &mut Array1<f64>, lower: &Array1<f64>, upper: &Array1<f64>) {
256        let mut rng = thread_rng();
257        for i in 0..vars.len() {
258            if rng.gen::<f64>() < self.mutation_rate {
259                let range = upper[i] - lower[i];
260                vars[i] =
261                    (vars[i] + (rng.gen::<f64>() - 0.5) * range * 0.1).clamp(lower[i], upper[i]);
262            }
263        }
264    }
265}