1use ndarray::Array1;
9use rand::rngs::StdRng;
10use rand::{Rng, SeedableRng};
11
12use crate::error::{DEError, Result};
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum NsgaVariant {
17 Nsga2,
19 Nsga3,
21}
22
23#[derive(Debug, Clone)]
25pub struct NsgaConfig {
26 pub bounds: Vec<(f64, f64)>,
28 pub x0: Option<Array1<f64>>,
30 pub population_size: usize,
32 pub maxeval: usize,
34 pub crossover_prob: f64,
36 pub mutation_prob: Option<f64>,
38 pub eta_c: f64,
40 pub eta_m: f64,
42 pub variant: NsgaVariant,
44 pub reference_partitions: usize,
48 pub seed: Option<u64>,
50}
51
52impl Default for NsgaConfig {
53 fn default() -> Self {
54 Self {
55 bounds: Vec::new(),
56 x0: None,
57 population_size: 0,
58 maxeval: 10_000,
59 crossover_prob: 0.9,
60 mutation_prob: None,
61 eta_c: 15.0,
62 eta_m: 20.0,
63 variant: NsgaVariant::Nsga2,
64 reference_partitions: 0,
65 seed: None,
66 }
67 }
68}
69
70#[derive(Debug, Clone)]
72pub struct ParetoSolution {
73 pub x: Array1<f64>,
75 pub objectives: Vec<f64>,
77 pub rank: usize,
79 pub crowding_distance: f64,
81}
82
83#[derive(Debug, Clone)]
85pub struct NsgaReport {
86 pub pareto_front: Vec<ParetoSolution>,
88 pub population: Vec<ParetoSolution>,
90 pub nfev: usize,
92 pub nit: usize,
94 pub success: bool,
96 pub message: String,
98}
99
100#[derive(Clone)]
101struct Individual {
102 x: Array1<f64>,
103 objectives: Vec<f64>,
104 rank: usize,
105 crowding_distance: f64,
106}
107
108pub fn nsga2<F>(f: &F, mut config: NsgaConfig) -> Result<NsgaReport>
110where
111 F: Fn(&Array1<f64>) -> Vec<f64> + Sync,
112{
113 config.variant = NsgaVariant::Nsga2;
114 nsga(f, config)
115}
116
117pub fn nsga3<F>(f: &F, mut config: NsgaConfig) -> Result<NsgaReport>
119where
120 F: Fn(&Array1<f64>) -> Vec<f64> + Sync,
121{
122 config.variant = NsgaVariant::Nsga3;
123 nsga(f, config)
124}
125
126pub fn nsga<F>(f: &F, config: NsgaConfig) -> Result<NsgaReport>
128where
129 F: Fn(&Array1<f64>) -> Vec<f64> + Sync,
130{
131 validate_config(&config)?;
132 let n = config.bounds.len();
133 let pop_size = if config.population_size == 0 {
134 (10 * n).max(40)
135 } else {
136 config.population_size
137 }
138 .max(4);
139 let mutation_prob = config
140 .mutation_prob
141 .unwrap_or(1.0 / n as f64)
142 .clamp(0.0, 1.0);
143
144 let mut rng: StdRng = match config.seed {
145 Some(s) => StdRng::seed_from_u64(s),
146 None => {
147 let mut thread_rng = rand::rng();
148 StdRng::from_rng(&mut thread_rng)
149 }
150 };
151
152 let mut population = initial_population(f, &config, pop_size, &mut rng);
153 let mut nfev = population.len();
154 assign_rank_and_crowding(&mut population);
155
156 let mut nit = 0usize;
157 while nfev < config.maxeval {
158 let mut offspring: Vec<Individual> = Vec::with_capacity(pop_size);
159 while offspring.len() < pop_size && nfev < config.maxeval {
160 let p1 = tournament_select(&population, &mut rng);
161 let p2 = tournament_select(&population, &mut rng);
162 let (mut c1, mut c2) = if rng.random::<f64>() < config.crossover_prob {
163 sbx_crossover(
164 &p1.x,
165 &p2.x,
166 &config.bounds,
167 config.eta_c.max(1.0),
168 &mut rng,
169 )
170 } else {
171 (p1.x.clone(), p2.x.clone())
172 };
173 polynomial_mutation(
174 &mut c1,
175 &config.bounds,
176 mutation_prob,
177 config.eta_m.max(1.0),
178 &mut rng,
179 );
180 polynomial_mutation(
181 &mut c2,
182 &config.bounds,
183 mutation_prob,
184 config.eta_m.max(1.0),
185 &mut rng,
186 );
187
188 offspring.push(evaluate(f, c1));
189 nfev += 1;
190 if offspring.len() < pop_size && nfev < config.maxeval {
191 offspring.push(evaluate(f, c2));
192 nfev += 1;
193 }
194 }
195
196 if offspring.is_empty() {
197 break;
198 }
199 population.extend(offspring);
200 population = environmental_selection(population, pop_size, &config);
201 nit += 1;
202 }
203
204 assign_rank_and_crowding(&mut population);
205 let mut final_solutions = to_solutions(&population);
206 final_solutions.sort_by(compare_solutions);
207 let pareto_front = final_solutions
208 .iter()
209 .filter(|s| s.rank == 0)
210 .cloned()
211 .collect::<Vec<_>>();
212
213 Ok(NsgaReport {
214 pareto_front,
215 population: final_solutions,
216 nfev,
217 nit,
218 success: nit > 0,
219 message: if nit > 0 {
220 String::from("evaluation budget reached")
221 } else {
222 String::from("initial population evaluated")
223 },
224 })
225}
226
227fn validate_config(config: &NsgaConfig) -> Result<()> {
228 let n = config.bounds.len();
229 if n == 0 {
230 return Err(DEError::BoundsMismatch {
231 lower_len: 0,
232 upper_len: 0,
233 });
234 }
235 for (i, (lo, hi)) in config.bounds.iter().enumerate() {
236 if lo > hi {
237 return Err(DEError::InvalidBounds {
238 index: i,
239 lower: *lo,
240 upper: *hi,
241 });
242 }
243 }
244 if let Some(ref x0) = config.x0
245 && x0.len() != n
246 {
247 return Err(DEError::X0DimensionMismatch {
248 expected: n,
249 got: x0.len(),
250 });
251 }
252 Ok(())
253}
254
255fn initial_population<F>(
256 f: &F,
257 config: &NsgaConfig,
258 pop_size: usize,
259 rng: &mut StdRng,
260) -> Vec<Individual>
261where
262 F: Fn(&Array1<f64>) -> Vec<f64> + Sync,
263{
264 let mut population = Vec::with_capacity(pop_size);
265 if let Some(ref x0) = config.x0 {
266 let x = clamp_to_bounds(x0.clone(), &config.bounds);
267 population.push(evaluate(f, x));
268 }
269 while population.len() < pop_size {
270 let x = random_in_bounds(&config.bounds, rng);
271 population.push(evaluate(f, x));
272 }
273 population
274}
275
276fn evaluate<F>(f: &F, x: Array1<f64>) -> Individual
277where
278 F: Fn(&Array1<f64>) -> Vec<f64> + Sync,
279{
280 let objectives = f(&x)
281 .into_iter()
282 .map(|v| if v.is_finite() { v } else { f64::INFINITY })
283 .collect();
284 Individual {
285 x,
286 objectives,
287 rank: usize::MAX,
288 crowding_distance: 0.0,
289 }
290}
291
292fn random_in_bounds(bounds: &[(f64, f64)], rng: &mut StdRng) -> Array1<f64> {
293 Array1::from(
294 bounds
295 .iter()
296 .map(|(lo, hi)| {
297 if hi > lo {
298 rng.random_range(*lo..=*hi)
299 } else {
300 *lo
301 }
302 })
303 .collect::<Vec<_>>(),
304 )
305}
306
307fn clamp_to_bounds(mut x: Array1<f64>, bounds: &[(f64, f64)]) -> Array1<f64> {
308 for (xi, (lo, hi)) in x.iter_mut().zip(bounds.iter()) {
309 *xi = xi.clamp(*lo, *hi);
310 }
311 x
312}
313
314fn tournament_select<'a>(population: &'a [Individual], rng: &mut StdRng) -> &'a Individual {
315 let a = rng.random_range(0..population.len());
316 let b = rng.random_range(0..population.len());
317 let ia = &population[a];
318 let ib = &population[b];
319 if ia.rank < ib.rank || (ia.rank == ib.rank && ia.crowding_distance > ib.crowding_distance) {
320 ia
321 } else {
322 ib
323 }
324}
325
326fn sbx_crossover(
327 p1: &Array1<f64>,
328 p2: &Array1<f64>,
329 bounds: &[(f64, f64)],
330 eta_c: f64,
331 rng: &mut StdRng,
332) -> (Array1<f64>, Array1<f64>) {
333 let mut c1 = p1.clone();
334 let mut c2 = p2.clone();
335 for i in 0..p1.len() {
336 if rng.random::<f64>() > 0.5 || (p1[i] - p2[i]).abs() <= 1e-14 {
337 continue;
338 }
339 let u = rng
340 .random::<f64>()
341 .clamp(f64::MIN_POSITIVE, 1.0 - f64::EPSILON);
342 let beta = if u <= 0.5 {
343 (2.0 * u).powf(1.0 / (eta_c + 1.0))
344 } else {
345 (1.0 / (2.0 * (1.0 - u))).powf(1.0 / (eta_c + 1.0))
346 };
347 c1[i] = 0.5 * ((1.0 + beta) * p1[i] + (1.0 - beta) * p2[i]);
348 c2[i] = 0.5 * ((1.0 - beta) * p1[i] + (1.0 + beta) * p2[i]);
349 c1[i] = c1[i].clamp(bounds[i].0, bounds[i].1);
350 c2[i] = c2[i].clamp(bounds[i].0, bounds[i].1);
351 }
352 (c1, c2)
353}
354
355fn polynomial_mutation(
356 x: &mut Array1<f64>,
357 bounds: &[(f64, f64)],
358 mutation_prob: f64,
359 eta_m: f64,
360 rng: &mut StdRng,
361) {
362 for i in 0..x.len() {
363 if rng.random::<f64>() > mutation_prob {
364 continue;
365 }
366 let (lo, hi) = bounds[i];
367 let span = hi - lo;
368 if span <= 0.0 {
369 x[i] = lo;
370 continue;
371 }
372 let y = x[i].clamp(lo, hi);
373 let delta1 = (y - lo) / span;
374 let delta2 = (hi - y) / span;
375 let rnd = rng.random::<f64>();
376 let mut_pow = 1.0 / (eta_m + 1.0);
377 let deltaq = if rnd <= 0.5 {
378 let xy = 1.0 - delta1;
379 let val = 2.0 * rnd + (1.0 - 2.0 * rnd) * xy.powf(eta_m + 1.0);
380 val.powf(mut_pow) - 1.0
381 } else {
382 let xy = 1.0 - delta2;
383 let val = 2.0 * (1.0 - rnd) + 2.0 * (rnd - 0.5) * xy.powf(eta_m + 1.0);
384 1.0 - val.powf(mut_pow)
385 };
386 x[i] = (y + deltaq * span).clamp(lo, hi);
387 }
388}
389
390fn environmental_selection(
391 mut combined: Vec<Individual>,
392 pop_size: usize,
393 config: &NsgaConfig,
394) -> Vec<Individual> {
395 let fronts = assign_rank_and_crowding(&mut combined);
396 let mut next = Vec::with_capacity(pop_size);
397
398 for front in fronts {
399 if next.len() + front.len() <= pop_size {
400 next.extend(front.into_iter().map(|idx| combined[idx].clone()));
401 continue;
402 }
403
404 let remaining = pop_size - next.len();
405 let splitting_front = front
406 .into_iter()
407 .map(|idx| combined[idx].clone())
408 .collect::<Vec<_>>();
409 match config.variant {
410 NsgaVariant::Nsga2 => {
411 let mut partial = splitting_front;
412 let partial_indices = (0..partial.len()).collect::<Vec<_>>();
413 assign_crowding_distance_for_indices(&mut partial, &partial_indices);
414 partial.sort_by(compare_individuals);
415 next.extend(partial.into_iter().take(remaining));
416 }
417 NsgaVariant::Nsga3 => {
418 let chosen = nsga3_niching(&next, splitting_front, remaining, config);
419 next.extend(chosen);
420 }
421 }
422 break;
423 }
424
425 assign_rank_and_crowding(&mut next);
426 next
427}
428
429fn assign_rank_and_crowding(population: &mut [Individual]) -> Vec<Vec<usize>> {
430 let fronts = non_dominated_sort(population);
431 for front in &fronts {
432 assign_crowding_distance_for_indices(population, front);
433 }
434 fronts
435}
436
437fn non_dominated_sort(population: &mut [Individual]) -> Vec<Vec<usize>> {
438 let n = population.len();
439 let mut dominates_set = vec![Vec::<usize>::new(); n];
440 let mut domination_count = vec![0usize; n];
441 let mut fronts: Vec<Vec<usize>> = Vec::new();
442 let mut first = Vec::new();
443
444 for p in 0..n {
445 for q in 0..n {
446 if p == q {
447 continue;
448 }
449 if dominates(&population[p].objectives, &population[q].objectives) {
450 dominates_set[p].push(q);
451 } else if dominates(&population[q].objectives, &population[p].objectives) {
452 domination_count[p] += 1;
453 }
454 }
455 if domination_count[p] == 0 {
456 population[p].rank = 0;
457 first.push(p);
458 }
459 }
460 fronts.push(first);
461
462 let mut i = 0usize;
463 while i < fronts.len() {
464 let mut next = Vec::new();
465 for &p in &fronts[i] {
466 for &q in &dominates_set[p] {
467 domination_count[q] = domination_count[q].saturating_sub(1);
468 if domination_count[q] == 0 {
469 population[q].rank = i + 1;
470 next.push(q);
471 }
472 }
473 }
474 if !next.is_empty() {
475 fronts.push(next);
476 }
477 i += 1;
478 }
479
480 fronts
481}
482
483fn dominates(a: &[f64], b: &[f64]) -> bool {
484 if a.len() != b.len() || a.is_empty() {
485 return false;
486 }
487 let mut strictly_better = false;
488 for (&ai, &bi) in a.iter().zip(b.iter()) {
489 if ai > bi {
490 return false;
491 }
492 if ai < bi {
493 strictly_better = true;
494 }
495 }
496 strictly_better
497}
498
499fn assign_crowding_distance_for_indices(population: &mut [Individual], front: &[usize]) {
500 if front.is_empty() {
501 return;
502 }
503 for &idx in front {
504 population[idx].crowding_distance = 0.0;
505 }
506 if front.len() <= 2 {
507 for &idx in front {
508 population[idx].crowding_distance = f64::INFINITY;
509 }
510 return;
511 }
512 let m = population[front[0]].objectives.len();
513 for obj in 0..m {
514 let mut sorted = front.to_vec();
515 sorted.sort_by(|&a, &b| {
516 population[a].objectives[obj].total_cmp(&population[b].objectives[obj])
517 });
518 let first = sorted[0];
519 let last = sorted[sorted.len() - 1];
520 population[first].crowding_distance = f64::INFINITY;
521 population[last].crowding_distance = f64::INFINITY;
522 let min_obj = population[first].objectives[obj];
523 let max_obj = population[last].objectives[obj];
524 let span = max_obj - min_obj;
525 if span <= 0.0 || !span.is_finite() {
526 continue;
527 }
528 for window in sorted.windows(3) {
529 let mid = window[1];
530 if population[mid].crowding_distance.is_infinite() {
531 continue;
532 }
533 let prev = population[window[0]].objectives[obj];
534 let next = population[window[2]].objectives[obj];
535 population[mid].crowding_distance += (next - prev) / span;
536 }
537 }
538}
539
540fn nsga3_niching(
541 selected: &[Individual],
542 splitting_front: Vec<Individual>,
543 slots: usize,
544 config: &NsgaConfig,
545) -> Vec<Individual> {
546 if slots == 0 || splitting_front.is_empty() {
547 return Vec::new();
548 }
549 let m = splitting_front[0].objectives.len();
550 if m <= 1 {
551 let mut sorted = splitting_front;
552 sorted.sort_by(compare_individuals);
553 return sorted.into_iter().take(slots).collect();
554 }
555 let partitions = if config.reference_partitions > 0 {
556 config.reference_partitions
557 } else {
558 default_reference_partitions(m)
559 };
560 let references = reference_directions(m, partitions);
561 if references.is_empty() {
562 let mut sorted = splitting_front;
563 sorted.sort_by(compare_individuals);
564 return sorted.into_iter().take(slots).collect();
565 }
566
567 let mut all = selected.to_vec();
568 all.extend(splitting_front.iter().cloned());
569 let (ideal, nadir) = objective_ranges(&all, m);
570 let mut niche_counts = vec![0usize; references.len()];
571 for ind in selected {
572 let norm = normalise_objectives(&ind.objectives, &ideal, &nadir);
573 let (niche, _) = nearest_reference(&norm, &references);
574 niche_counts[niche] += 1;
575 }
576
577 let mut candidates: Vec<(Individual, usize, f64)> = splitting_front
578 .into_iter()
579 .map(|ind| {
580 let norm = normalise_objectives(&ind.objectives, &ideal, &nadir);
581 let (niche, dist) = nearest_reference(&norm, &references);
582 (ind, niche, dist)
583 })
584 .collect();
585
586 let mut chosen = Vec::with_capacity(slots);
587 while chosen.len() < slots && !candidates.is_empty() {
588 let Some(target_niche) = references
589 .iter()
590 .enumerate()
591 .filter(|(idx, _)| candidates.iter().any(|(_, niche, _)| niche == idx))
592 .min_by_key(|(idx, _)| niche_counts[*idx])
593 .map(|(idx, _)| idx)
594 else {
595 break;
596 };
597
598 let best_idx = candidates
599 .iter()
600 .enumerate()
601 .filter(|(_, (_, niche, _))| *niche == target_niche)
602 .min_by(|(_, a), (_, b)| a.2.total_cmp(&b.2))
603 .map(|(idx, _)| idx)
604 .unwrap_or(0);
605 let (ind, niche, _) = candidates.swap_remove(best_idx);
606 niche_counts[niche] += 1;
607 chosen.push(ind);
608 }
609 chosen
610}
611
612fn objective_ranges(population: &[Individual], m: usize) -> (Vec<f64>, Vec<f64>) {
613 let mut ideal = vec![f64::INFINITY; m];
614 let mut nadir = vec![f64::NEG_INFINITY; m];
615 for ind in population {
616 for j in 0..m {
617 ideal[j] = ideal[j].min(ind.objectives[j]);
618 nadir[j] = nadir[j].max(ind.objectives[j]);
619 }
620 }
621 (ideal, nadir)
622}
623
624fn normalise_objectives(objectives: &[f64], ideal: &[f64], nadir: &[f64]) -> Vec<f64> {
625 objectives
626 .iter()
627 .enumerate()
628 .map(|(i, &v)| {
629 let span = nadir[i] - ideal[i];
630 if span > 0.0 && span.is_finite() {
631 ((v - ideal[i]) / span).max(0.0)
632 } else {
633 0.0
634 }
635 })
636 .collect()
637}
638
639fn nearest_reference(point: &[f64], references: &[Vec<f64>]) -> (usize, f64) {
640 references
641 .iter()
642 .enumerate()
643 .map(|(idx, reference)| (idx, perpendicular_distance(point, reference)))
644 .min_by(|a, b| a.1.total_cmp(&b.1))
645 .unwrap_or((0, 0.0))
646}
647
648fn perpendicular_distance(point: &[f64], reference: &[f64]) -> f64 {
649 let ref_norm_sq = reference.iter().map(|r| r * r).sum::<f64>();
650 if ref_norm_sq <= 0.0 {
651 return point.iter().map(|v| v * v).sum::<f64>().sqrt();
652 }
653 let dot = point
654 .iter()
655 .zip(reference.iter())
656 .map(|(p, r)| p * r)
657 .sum::<f64>();
658 let scale = dot / ref_norm_sq;
659 point
660 .iter()
661 .zip(reference.iter())
662 .map(|(p, r)| {
663 let d = p - scale * r;
664 d * d
665 })
666 .sum::<f64>()
667 .sqrt()
668}
669
670fn default_reference_partitions(m: usize) -> usize {
671 match m {
672 0 | 1 => 1,
673 2 => 24,
674 3 => 12,
675 4 => 6,
676 _ => 3,
677 }
678}
679
680fn reference_directions(m: usize, partitions: usize) -> Vec<Vec<f64>> {
681 if m == 0 {
682 return Vec::new();
683 }
684 if m == 1 {
685 return vec![vec![1.0]];
686 }
687 let h = partitions.max(1);
688 let mut out = Vec::new();
689 let mut current = vec![0usize; m];
690 reference_directions_rec(m, h, h, 0, &mut current, &mut out);
691 out
692}
693
694fn reference_directions_rec(
695 m: usize,
696 h: usize,
697 left: usize,
698 depth: usize,
699 current: &mut [usize],
700 out: &mut Vec<Vec<f64>>,
701) {
702 if depth == m - 1 {
703 current[depth] = left;
704 out.push(current.iter().map(|&v| v as f64 / h as f64).collect());
705 return;
706 }
707 for value in 0..=left {
708 current[depth] = value;
709 reference_directions_rec(m, h, left - value, depth + 1, current, out);
710 }
711}
712
713fn to_solutions(population: &[Individual]) -> Vec<ParetoSolution> {
714 population
715 .iter()
716 .map(|ind| ParetoSolution {
717 x: ind.x.clone(),
718 objectives: ind.objectives.clone(),
719 rank: ind.rank,
720 crowding_distance: ind.crowding_distance,
721 })
722 .collect()
723}
724
725fn compare_individuals(a: &Individual, b: &Individual) -> std::cmp::Ordering {
726 a.rank
727 .cmp(&b.rank)
728 .then_with(|| b.crowding_distance.total_cmp(&a.crowding_distance))
729 .then_with(|| sum_objectives(&a.objectives).total_cmp(&sum_objectives(&b.objectives)))
730}
731
732fn compare_solutions(a: &ParetoSolution, b: &ParetoSolution) -> std::cmp::Ordering {
733 a.rank
734 .cmp(&b.rank)
735 .then_with(|| b.crowding_distance.total_cmp(&a.crowding_distance))
736 .then_with(|| sum_objectives(&a.objectives).total_cmp(&sum_objectives(&b.objectives)))
737}
738
739fn sum_objectives(objectives: &[f64]) -> f64 {
740 objectives.iter().sum::<f64>()
741}
742
743#[cfg(test)]
744mod tests {
745 use super::*;
746
747 fn two_objective_tradeoff(x: &Array1<f64>) -> Vec<f64> {
748 vec![x[0] * x[0], (x[0] - 2.0).powi(2)]
749 }
750
751 #[test]
752 fn nsga2_finds_both_ends_of_simple_front() {
753 let report = nsga2(
754 &two_objective_tradeoff,
755 NsgaConfig {
756 bounds: vec![(0.0, 2.0)],
757 population_size: 64,
758 maxeval: 1_024,
759 seed: Some(11),
760 ..Default::default()
761 },
762 )
763 .expect("NSGA-II should run");
764
765 assert!(report.pareto_front.len() > 20);
766 assert!(
767 report.pareto_front.iter().any(|s| s.objectives[0] < 1e-3),
768 "front should include the f1 optimum"
769 );
770 assert!(
771 report.pareto_front.iter().any(|s| s.objectives[1] < 1e-3),
772 "front should include the f2 optimum"
773 );
774 }
775
776 #[test]
777 fn nsga3_handles_three_objective_front() {
778 let f = |x: &Array1<f64>| vec![x[0].powi(2), (x[0] - 1.0).powi(2), (x[0] - 2.0).powi(2)];
779 let report = nsga3(
780 &f,
781 NsgaConfig {
782 bounds: vec![(0.0, 2.0)],
783 population_size: 72,
784 maxeval: 1_200,
785 reference_partitions: 8,
786 seed: Some(19),
787 ..Default::default()
788 },
789 )
790 .expect("NSGA-III should run");
791
792 assert!(report.pareto_front.len() > 10);
793 assert!(report.pareto_front.iter().any(|s| s.x[0] < 0.1));
794 assert!(report.pareto_front.iter().any(|s| s.x[0] > 1.9));
795 }
796
797 #[test]
798 fn dominance_requires_strict_improvement() {
799 assert!(dominates(&[1.0, 2.0], &[1.0, 3.0]));
800 assert!(!dominates(&[1.0, 2.0], &[1.0, 2.0]));
801 assert!(!dominates(&[2.0, 1.0], &[1.0, 2.0]));
802 }
803}