Skip to main content

scirs2_optimize/multiobjective/
nsga3.rs

1//! NSGA-III: Non-dominated Sorting Genetic Algorithm III
2//!
3//! Implements the many-objective evolutionary algorithm by Deb & Jain (2014).
4//! NSGA-III extends NSGA-II with structured reference points (Das-Dennis simplex
5//! lattice) instead of crowding distance for diversity preservation. It is
6//! particularly effective for problems with more than 3 objectives where
7//! NSGA-II's crowding distance degrades.
8//!
9//! # Algorithm outline
10//!
11//! 1. **Reference point generation**: Construct a set of structured reference
12//!    points on the unit hyperplane using the Das-Dennis method.
13//! 2. **Fast non-dominated sorting**: Same as NSGA-II — assign each individual
14//!    a Pareto-front rank.
15//! 3. **Reference-point-based niching**: For the critical (last-included) front,
16//!    use association + niche counting to maintain diversity with respect to
17//!    structured reference points.
18//! 4. **Elitist survivor selection**: Deterministically fill the next population
19//!    by adding complete fronts, then applying niching to the final partial front.
20//!
21//! # References
22//!
23//! - Deb, K., & Jain, H. (2014). An evolutionary many-objective optimization
24//!   algorithm using reference-point-based nondominated sorting approach, Part I:
25//!   Solving problems with box constraints. *IEEE TEC*, 18(4), 577–601.
26//! - Jain, H., & Deb, K. (2014). An evolutionary many-objective optimization
27//!   algorithm using reference-point based nondominated sorting approach, Part II:
28//!   Handling constraints and extending to an adaptive approach. *IEEE TEC*,
29//!   18(4), 602–622.
30
31use crate::error::{OptimizeError, OptimizeResult};
32use crate::multiobjective::indicators::{dominates, non_dominated_sort};
33use crate::multiobjective::nsga2::{Individual, Nsga2Config};
34use scirs2_core::random::rngs::StdRng;
35use scirs2_core::random::{Rng, SeedableRng};
36use scirs2_core::RngExt;
37
38// ─────────────────────────────────────────────────────────────────────────────
39// Public types
40// ─────────────────────────────────────────────────────────────────────────────
41
42/// Configuration for the NSGA-III algorithm.
43///
44/// Most fields are identical to [`Nsga2Config`]; the key difference is that
45/// NSGA-III accepts a `reference_point_divisions` parameter controlling the
46/// density of the Das-Dennis reference point lattice, and optionally a set
47/// of user-supplied adaptive reference points layered on top.
48#[derive(Debug, Clone)]
49pub struct Nsga3Config {
50    /// Population size (will be rounded to the nearest feasible size ≥ this
51    /// value to accommodate the reference point lattice).  Default 100.
52    pub population_size: usize,
53    /// Number of generations.  Default 200.
54    pub n_generations: usize,
55    /// Number of divisions on each objective axis for the primary reference
56    /// point lattice (Das-Dennis).  Default 12 for 2 objectives, 6 for 3-5,
57    /// 3 for 6+.  Setting to 0 uses the adaptive default.
58    pub n_divisions: usize,
59    /// Optional second-layer divisions for an inner lattice (used in
60    /// two-layer reference point generation for many-objective problems).
61    /// If `Some(d)`, an additional inner lattice with `d` divisions is
62    /// generated and merged with the outer lattice. Default `None`.
63    pub n_divisions_inner: Option<usize>,
64    /// Simulated binary crossover probability.  Default 0.9.
65    pub crossover_rate: f64,
66    /// Polynomial mutation probability per variable.  Default 1/n_vars.
67    pub mutation_rate: f64,
68    /// SBX distribution index η_c.  Default 20.
69    pub eta_c: f64,
70    /// Polynomial mutation distribution index η_m.  Default 20.
71    pub eta_m: f64,
72    /// RNG seed for reproducibility.  Default 12345.
73    pub seed: u64,
74}
75
76impl Default for Nsga3Config {
77    fn default() -> Self {
78        Self {
79            population_size: 100,
80            n_generations: 200,
81            n_divisions: 0, // auto-select
82            n_divisions_inner: None,
83            crossover_rate: 0.9,
84            mutation_rate: 0.0, // resolved at runtime
85            eta_c: 20.0,
86            eta_m: 20.0,
87            seed: 12345,
88        }
89    }
90}
91
92/// Result returned by [`nsga3`].
93#[derive(Debug)]
94pub struct Nsga3Result {
95    /// Individuals on the first (best) Pareto front after the final generation.
96    pub pareto_front: Vec<Individual>,
97    /// All fronts from the final generation (front 0 = Pareto optimal).
98    pub all_fronts: Vec<Vec<Individual>>,
99    /// Reference points used during the run.
100    pub reference_points: Vec<Vec<f64>>,
101    /// Number of generations executed.
102    pub n_generations: usize,
103    /// Total number of objective evaluations.
104    pub n_evaluations: usize,
105}
106
107// ─────────────────────────────────────────────────────────────────────────────
108// Main entry point
109// ─────────────────────────────────────────────────────────────────────────────
110
111/// Run NSGA-III on a many-objective optimisation problem.
112///
113/// # Arguments
114/// * `n_objectives` - Number of objectives (must be ≥ 2; designed for ≥ 4).
115/// * `bounds`       - Decision-variable bounds `[(lo, hi); n_vars]`.
116/// * `objectives`   - Closure mapping a gene vector to objective values
117///   (all minimised).
118/// * `config`       - Algorithm hyper-parameters.
119///
120/// # Errors
121/// Returns an error for empty bounds, degenerate bounds, or < 2 objectives.
122///
123/// # Examples
124/// ```
125/// use scirs2_optimize::multiobjective::nsga3::{nsga3, Nsga3Config};
126///
127/// // DTLZ2 benchmark: 3 variables, 3 objectives
128/// let bounds: Vec<(f64, f64)> = vec![(0.0, 1.0); 7];
129/// let mut cfg = Nsga3Config::default();
130/// cfg.population_size = 20;
131/// cfg.n_generations   = 10;
132///
133/// let result = nsga3(3, &bounds, |x| {
134///     let n = x.len();
135///     let k = n - 3 + 1;
136///     let g: f64 = x[n-k..].iter().map(|&xi| (xi - 0.5).powi(2)).sum();
137///     let f1 = (1.0 + g) * x[0].cos() * x[1].cos();
138///     let f2 = (1.0 + g) * x[0].cos() * x[1].sin();
139///     let f3 = (1.0 + g) * x[0].sin();
140///     vec![f1, f2, f3]
141/// }, cfg).expect("valid input");
142///
143/// assert!(!result.pareto_front.is_empty());
144/// ```
145pub fn nsga3<F>(
146    n_objectives: usize,
147    bounds: &[(f64, f64)],
148    objectives: F,
149    config: Nsga3Config,
150) -> OptimizeResult<Nsga3Result>
151where
152    F: Fn(&[f64]) -> Vec<f64>,
153{
154    if n_objectives < 2 {
155        return Err(OptimizeError::InvalidInput(
156            "n_objectives must be >= 2".to_string(),
157        ));
158    }
159    if bounds.is_empty() {
160        return Err(OptimizeError::InvalidInput(
161            "bounds must be non-empty".to_string(),
162        ));
163    }
164    for (i, &(lo, hi)) in bounds.iter().enumerate() {
165        if lo >= hi {
166            return Err(OptimizeError::InvalidInput(format!(
167                "bound[{i}]: lo ({lo}) must be < hi ({hi})"
168            )));
169        }
170    }
171
172    let n_vars = bounds.len();
173    let mutation_rate = if config.mutation_rate > 0.0 {
174        config.mutation_rate
175    } else {
176        1.0 / n_vars as f64
177    };
178
179    // ── Reference point generation ───────────────────────────────────────────
180    let n_div = if config.n_divisions > 0 {
181        config.n_divisions
182    } else {
183        // Auto-select number of divisions based on number of objectives
184        match n_objectives {
185            2..=3 => 12,
186            4..=5 => 6,
187            6..=8 => 4,
188            _ => 3,
189        }
190    };
191
192    let mut ref_points = generate_reference_points(n_objectives, n_div);
193
194    // Optional inner (second) lattice for two-layer reference points
195    if let Some(n_div_inner) = config.n_divisions_inner {
196        let inner = generate_reference_points_inner(n_objectives, n_div_inner);
197        ref_points.extend(inner);
198    }
199
200    // Determine population size: must be >= number of reference points for
201    // good coverage, and even for pairing in reproduction
202    let n_ref = ref_points.len();
203    let pop_size = {
204        let desired = config.population_size.max(n_ref);
205        if desired % 2 == 0 {
206            desired
207        } else {
208            desired + 1
209        }
210    };
211
212    let mut rng = StdRng::seed_from_u64(config.seed);
213    let mut n_evaluations = 0usize;
214
215    // ── Initialise population ────────────────────────────────────────────────
216    let mut population: Vec<Individual> = (0..pop_size)
217        .map(|_| {
218            let genes = random_genes(bounds, &mut rng);
219            let objs = objectives(&genes);
220            n_evaluations += 1;
221            Individual::new(genes, objs)
222        })
223        .collect();
224
225    assign_ranks(&mut population);
226
227    // ── Main evolutionary loop ───────────────────────────────────────────────
228    for _ in 0..config.n_generations {
229        // Generate offspring via SBX + polynomial mutation
230        let offspring: Vec<Individual> = (0..pop_size / 2)
231            .flat_map(|_| {
232                let p1 = tournament_select_by_rank(&population, &mut rng);
233                let p2 = tournament_select_by_rank(&population, &mut rng);
234
235                let (c1_genes, c2_genes) = if rng.random::<f64>() < config.crossover_rate {
236                    sbx_crossover(
237                        &population[p1].genes,
238                        &population[p2].genes,
239                        config.eta_c,
240                        bounds,
241                        &mut rng,
242                    )
243                } else {
244                    (population[p1].genes.clone(), population[p2].genes.clone())
245                };
246
247                let c1_genes =
248                    polynomial_mutation(c1_genes, mutation_rate, config.eta_m, bounds, &mut rng);
249                let c2_genes =
250                    polynomial_mutation(c2_genes, mutation_rate, config.eta_m, bounds, &mut rng);
251
252                let objs1 = objectives(&c1_genes);
253                let objs2 = objectives(&c2_genes);
254                n_evaluations += 2;
255
256                vec![
257                    Individual::new(c1_genes, objs1),
258                    Individual::new(c2_genes, objs2),
259                ]
260            })
261            .collect();
262
263        // Combine parent + offspring
264        let mut combined = population;
265        combined.extend(offspring);
266        assign_ranks(&mut combined);
267
268        // NSGA-III survivor selection using reference-point niching
269        population = nsga3_select(&mut combined, &ref_points, pop_size, &mut rng);
270    }
271
272    // ── Build result ─────────────────────────────────────────────────────────
273    assign_ranks(&mut population);
274    let obj_vecs: Vec<Vec<f64>> = population
275        .iter()
276        .map(|ind| ind.objectives.clone())
277        .collect();
278    let front_indices = non_dominated_sort(&obj_vecs);
279
280    let all_fronts: Vec<Vec<Individual>> = front_indices
281        .iter()
282        .map(|idx_vec| idx_vec.iter().map(|&i| population[i].clone()).collect())
283        .collect();
284
285    let pareto_front = if all_fronts.is_empty() {
286        population.clone()
287    } else {
288        all_fronts[0].clone()
289    };
290
291    Ok(Nsga3Result {
292        pareto_front,
293        all_fronts,
294        reference_points: ref_points,
295        n_generations: config.n_generations,
296        n_evaluations,
297    })
298}
299
300// ─────────────────────────────────────────────────────────────────────────────
301// Reference point generation (Das-Dennis simplex lattice)
302// ─────────────────────────────────────────────────────────────────────────────
303
304/// Generate structured reference points on the unit hyperplane using the
305/// Das-Dennis lattice (simplex lattice design).
306///
307/// For `n_obj` objectives and `n_divisions` H, generates all points
308/// (a_1/H, ..., a_M/H) where each a_i is a non-negative integer and their
309/// sum equals H.  The result lies on the M-dimensional unit simplex.
310///
311/// The total number of points is C(H + M - 1, M - 1).
312pub fn generate_reference_points(n_obj: usize, n_divisions: usize) -> Vec<Vec<f64>> {
313    let mut points: Vec<Vec<f64>> = Vec::new();
314    let mut current = vec![0.0f64; n_obj];
315    enumerate_simplex(
316        &mut points,
317        &mut current,
318        n_obj,
319        n_divisions,
320        0,
321        n_divisions,
322    );
323
324    // Normalise by dividing by n_divisions
325    for p in &mut points {
326        for x in p.iter_mut() {
327            *x /= n_divisions as f64;
328        }
329    }
330    points
331}
332
333/// Generate inner reference points for the two-layer reference point approach.
334///
335/// The inner lattice is scaled to lie inside the simplex, avoiding boundary
336/// degeneracy for certain problem types.  Points are shifted toward the
337/// centroid: p' = p * (1 - 1/M) + 1/M^2.
338pub fn generate_reference_points_inner(n_obj: usize, n_divisions: usize) -> Vec<Vec<f64>> {
339    let mut points = generate_reference_points(n_obj, n_divisions);
340    let scale = 1.0 - 1.0 / n_obj as f64;
341    let offset = 1.0 / (n_obj * n_obj) as f64;
342    for p in &mut points {
343        for x in p.iter_mut() {
344            *x = *x * scale + offset;
345        }
346    }
347    points
348}
349
350fn enumerate_simplex(
351    out: &mut Vec<Vec<f64>>,
352    current: &mut Vec<f64>,
353    n_obj: usize,
354    n_divisions: usize,
355    index: usize,
356    remaining: usize,
357) {
358    if index == n_obj - 1 {
359        current[index] = remaining as f64;
360        out.push(current.clone());
361        return;
362    }
363    for i in 0..=remaining {
364        current[index] = i as f64;
365        enumerate_simplex(out, current, n_obj, n_divisions, index + 1, remaining - i);
366    }
367}
368
369// ─────────────────────────────────────────────────────────────────────────────
370// Normalization utilities
371// ─────────────────────────────────────────────────────────────────────────────
372
373/// Compute the ideal point (component-wise minimum) of a population.
374fn ideal_point(population: &[Individual]) -> Vec<f64> {
375    if population.is_empty() {
376        return vec![];
377    }
378    let n_obj = population[0].objectives.len();
379    let mut ideal = vec![f64::INFINITY; n_obj];
380    for ind in population {
381        for (k, &v) in ind.objectives.iter().enumerate() {
382            if k < ideal.len() && v < ideal[k] {
383                ideal[k] = v;
384            }
385        }
386    }
387    ideal
388}
389
390/// Compute the nadir (approximate worst-boundary) point using extreme points
391/// from each objective axis.  Uses the achievement scalarization function (ASF)
392/// approach: for each objective k, find the individual minimizing
393/// max_i { (f_i - z_i*) / w_i } with weight vector e_k (axis vector).
394fn nadir_estimate(population: &[Individual], ideal: &[f64]) -> Vec<f64> {
395    let n_obj = ideal.len();
396    let mut nadir = vec![f64::NEG_INFINITY; n_obj];
397
398    for k in 0..n_obj {
399        // Construct weight vector: 1e-6 everywhere, 1.0 on axis k
400        let mut w = vec![1e-6f64; n_obj];
401        w[k] = 1.0;
402
403        // Find individual minimizing ASF (axis-aligned achievement)
404        let best_ind = population.iter().min_by(|a, b| {
405            let asf_a = asf(&a.objectives, ideal, &w);
406            let asf_b = asf(&b.objectives, ideal, &w);
407            asf_a
408                .partial_cmp(&asf_b)
409                .unwrap_or(std::cmp::Ordering::Equal)
410        });
411
412        if let Some(ind) = best_ind {
413            for (j, &v) in ind.objectives.iter().enumerate() {
414                if v > nadir[j] {
415                    nadir[j] = v;
416                }
417            }
418        }
419    }
420
421    // Fallback: ensure nadir > ideal everywhere
422    for k in 0..n_obj {
423        if nadir[k] <= ideal[k] {
424            nadir[k] = ideal[k] + 1.0;
425        }
426    }
427
428    nadir
429}
430
431/// Achievement scalarizing function (ASF):
432/// ASF(f | z*, w) = max_i { (f_i - z_i*) / w_i }
433fn asf(objectives: &[f64], ideal: &[f64], weights: &[f64]) -> f64 {
434    objectives
435        .iter()
436        .zip(ideal.iter())
437        .zip(weights.iter())
438        .map(|((f, z), w)| (f - z) / w)
439        .fold(f64::NEG_INFINITY, f64::max)
440}
441
442/// Normalize an objective vector given the ideal and nadir points.
443/// Returns the translated-and-scaled vector: (f - ideal) / (nadir - ideal).
444fn normalize_objectives(objectives: &[f64], ideal: &[f64], nadir: &[f64]) -> Vec<f64> {
445    objectives
446        .iter()
447        .zip(ideal.iter())
448        .zip(nadir.iter())
449        .map(|((f, z), n)| {
450            let denom = n - z;
451            if denom.abs() < 1e-10 {
452                0.0
453            } else {
454                (f - z) / denom
455            }
456        })
457        .collect()
458}
459
460// ─────────────────────────────────────────────────────────────────────────────
461// Reference-point association and niching
462// ─────────────────────────────────────────────────────────────────────────────
463
464/// Compute the perpendicular distance from a normalized objective vector to a
465/// reference line (direction vector from origin through reference point).
466///
467/// dist(f, r) = ||f - (f·r / ||r||²) * r||
468///
469/// Both `f_norm` and `ref_point` must have the same length.
470pub fn reference_line_distance(f_norm: &[f64], ref_point: &[f64]) -> f64 {
471    let dot: f64 = f_norm
472        .iter()
473        .zip(ref_point.iter())
474        .map(|(a, b)| a * b)
475        .sum();
476    let r_sq: f64 = ref_point.iter().map(|r| r * r).sum();
477
478    if r_sq < 1e-14 {
479        // Reference point at origin — fall back to Euclidean distance
480        return f_norm.iter().map(|x| x * x).sum::<f64>().sqrt();
481    }
482
483    let proj = dot / r_sq;
484
485    // perpendicular distance: ||f - proj * r||
486    f_norm
487        .iter()
488        .zip(ref_point.iter())
489        .map(|(f, r)| (f - proj * r).powi(2))
490        .sum::<f64>()
491        .sqrt()
492}
493
494/// For each individual, find the nearest reference point and compute the
495/// distance to its reference line.
496///
497/// Returns a vector of `(ref_idx, distance)` tuples, one per individual.
498pub fn associate_to_reference_points(
499    population: &[Individual],
500    ref_points: &[Vec<f64>],
501    ideal: &[f64],
502    nadir: &[f64],
503) -> Vec<(usize, f64)> {
504    population
505        .iter()
506        .map(|ind| {
507            let f_norm = normalize_objectives(&ind.objectives, ideal, nadir);
508
509            let mut best_ref = 0usize;
510            let mut best_dist = f64::INFINITY;
511
512            for (r_idx, rp) in ref_points.iter().enumerate() {
513                let d = reference_line_distance(&f_norm, rp);
514                if d < best_dist {
515                    best_dist = d;
516                    best_ref = r_idx;
517                }
518            }
519
520            (best_ref, best_dist)
521        })
522        .collect()
523}
524
525/// NSGA-III survivor selection using reference-point-based niching.
526///
527/// Selects `target_size` survivors from `combined` (size ≈ 2N):
528/// 1. Greedily fill with front 0, front 1, ... until adding the next front
529///    would exceed `target_size`.
530/// 2. From the "critical" (last partial) front, use niche preservation:
531///    count how many individuals from already-selected fronts are associated
532///    with each reference point, then repeatedly pick the individual from the
533///    reference point with the smallest niche count (breaking ties by distance).
534fn nsga3_select(
535    combined: &mut Vec<Individual>,
536    ref_points: &[Vec<f64>],
537    target_size: usize,
538    rng: &mut StdRng,
539) -> Vec<Individual> {
540    let obj_vecs: Vec<Vec<f64>> = combined.iter().map(|ind| ind.objectives.clone()).collect();
541    let fronts = non_dominated_sort(&obj_vecs);
542
543    // Compute ideal + nadir over the combined population
544    let ideal = ideal_point(combined);
545    let nadir = nadir_estimate(combined, &ideal);
546
547    // Association: for each individual, find nearest ref point and distance
548    let assoc = associate_to_reference_points(combined, ref_points, &ideal, &nadir);
549
550    // Greedily fill complete fronts
551    let mut survivors: Vec<usize> = Vec::with_capacity(target_size);
552    let mut critical_front: &[usize] = &[];
553
554    for front in &fronts {
555        if survivors.len() + front.len() <= target_size {
556            survivors.extend_from_slice(front);
557        } else {
558            critical_front = front;
559            break;
560        }
561    }
562
563    let remaining = target_size - survivors.len();
564
565    if remaining == 0 || critical_front.is_empty() {
566        // Selection complete without niching
567        return survivors.iter().map(|&i| combined[i].clone()).collect();
568    }
569
570    // Niche counting: count how many survivors are associated with each ref point
571    let n_ref = ref_points.len();
572    let mut niche_count = vec![0usize; n_ref];
573    for &s in &survivors {
574        let (ref_idx, _) = assoc[s];
575        if ref_idx < niche_count.len() {
576            niche_count[ref_idx] += 1;
577        }
578    }
579
580    // Add `remaining` individuals from `critical_front` using niche preservation
581    let mut available: Vec<usize> = critical_front.to_vec();
582    let mut selected_from_critical: Vec<usize> = Vec::with_capacity(remaining);
583
584    for _ in 0..remaining {
585        if available.is_empty() {
586            break;
587        }
588
589        // Find the minimum niche count among reference points that have candidates
590        let min_niche = available
591            .iter()
592            .filter_map(|&idx| {
593                let (ref_idx, _) = assoc[idx];
594                Some(niche_count[ref_idx])
595            })
596            .min()
597            .unwrap_or(0);
598
599        // Collect all candidates associated with reference points at min_niche
600        let candidates: Vec<usize> = available
601            .iter()
602            .copied()
603            .filter(|&idx| {
604                let (ref_idx, _) = assoc[idx];
605                niche_count[ref_idx] == min_niche
606            })
607            .collect();
608
609        if candidates.is_empty() {
610            break;
611        }
612
613        // Among tied candidates, if niche_count == 0, pick the closest to the
614        // reference line; otherwise pick a random one (uniform random selection)
615        let chosen = if min_niche == 0 {
616            // Pick the candidate with the smallest distance to its reference line
617            *candidates
618                .iter()
619                .min_by(|&&a, &&b| {
620                    let da = assoc[a].1;
621                    let db = assoc[b].1;
622                    da.partial_cmp(&db).unwrap_or(std::cmp::Ordering::Equal)
623                })
624                .unwrap_or(&candidates[0])
625        } else {
626            // Random selection among tied candidates
627            candidates[rng.random_range(0..candidates.len())]
628        };
629
630        // Update niche count for the chosen individual's reference point
631        let (chosen_ref, _) = assoc[chosen];
632        if chosen_ref < niche_count.len() {
633            niche_count[chosen_ref] += 1;
634        }
635
636        selected_from_critical.push(chosen);
637
638        // Remove chosen from available
639        available.retain(|&x| x != chosen);
640    }
641
642    // Build final survivor list
643    survivors.extend(selected_from_critical);
644    survivors.iter().map(|&i| combined[i].clone()).collect()
645}
646
647// ─────────────────────────────────────────────────────────────────────────────
648// Rank assignment (without crowding distance — NSGA-III uses reference points)
649// ─────────────────────────────────────────────────────────────────────────────
650
651fn assign_ranks(population: &mut Vec<Individual>) {
652    if population.is_empty() {
653        return;
654    }
655    let obj_vecs: Vec<Vec<f64>> = population
656        .iter()
657        .map(|ind| ind.objectives.clone())
658        .collect();
659    let fronts = non_dominated_sort(&obj_vecs);
660
661    for (rank, front_idx) in fronts.iter().enumerate() {
662        for &i in front_idx {
663            population[i].rank = rank;
664        }
665    }
666}
667
668// ─────────────────────────────────────────────────────────────────────────────
669// Tournament selection (rank only, for NSGA-III)
670// ─────────────────────────────────────────────────────────────────────────────
671
672fn tournament_select_by_rank(population: &[Individual], rng: &mut StdRng) -> usize {
673    let n = population.len();
674    let a = rng.random_range(0..n);
675    let mut b = rng.random_range(0..n);
676    if b == a && n > 1 {
677        b = (a + 1) % n;
678    }
679
680    if population[a].rank <= population[b].rank {
681        a
682    } else {
683        b
684    }
685}
686
687// ─────────────────────────────────────────────────────────────────────────────
688// Genetic operators (SBX + polynomial mutation, same as NSGA-II)
689// ─────────────────────────────────────────────────────────────────────────────
690
691fn sbx_crossover(
692    parent1: &[f64],
693    parent2: &[f64],
694    eta_c: f64,
695    bounds: &[(f64, f64)],
696    rng: &mut StdRng,
697) -> (Vec<f64>, Vec<f64>) {
698    let n = parent1.len();
699    let mut child1 = parent1.to_vec();
700    let mut child2 = parent2.to_vec();
701
702    for i in 0..n {
703        if rng.random::<f64>() > 0.5 {
704            continue;
705        }
706
707        let (lo, hi) = bounds[i];
708        let x1 = parent1[i].min(parent2[i]);
709        let x2 = parent1[i].max(parent2[i]);
710
711        if (x2 - x1).abs() < 1e-14 {
712            continue;
713        }
714
715        let u: f64 = rng.random();
716
717        let beta_q = if u <= 0.5 {
718            let alpha = 2.0 - (1.0 / sbx_beta(x1, x2, lo, eta_c)).powf(eta_c + 1.0);
719            let alpha = alpha.max(0.0);
720            (2.0 * u * alpha).powf(1.0 / (eta_c + 1.0))
721        } else {
722            let alpha = 2.0 - (1.0 / sbx_beta(x1, x2, hi - x2 + x1, eta_c)).powf(eta_c + 1.0);
723            let alpha_inv = 2.0 * (1.0 - u) * alpha.max(0.0);
724            if alpha_inv < f64::EPSILON {
725                1.0
726            } else {
727                (1.0 / alpha_inv).powf(1.0 / (eta_c + 1.0))
728            }
729        };
730
731        let mid = 0.5 * (x1 + x2);
732        let half_diff = 0.5 * (x2 - x1);
733
734        let c1 = (mid - beta_q * half_diff).clamp(lo, hi);
735        let c2 = (mid + beta_q * half_diff).clamp(lo, hi);
736
737        if parent1[i] < parent2[i] {
738            child1[i] = c1;
739            child2[i] = c2;
740        } else {
741            child1[i] = c2;
742            child2[i] = c1;
743        }
744    }
745
746    (child1, child2)
747}
748
749fn sbx_beta(x1: f64, x2: f64, bound: f64, eta: f64) -> f64 {
750    let diff = (x2 - x1).abs().max(1e-14);
751    let dist = (bound - x1).abs().max(1e-14);
752    (1.0 + 2.0 * dist / diff).powf(eta + 1.0)
753}
754
755fn polynomial_mutation(
756    mut genes: Vec<f64>,
757    mutation_rate: f64,
758    eta_m: f64,
759    bounds: &[(f64, f64)],
760    rng: &mut StdRng,
761) -> Vec<f64> {
762    for (i, gene) in genes.iter_mut().enumerate() {
763        if rng.random::<f64>() >= mutation_rate {
764            continue;
765        }
766
767        let (lo, hi) = bounds[i];
768        let delta = hi - lo;
769        if delta < f64::EPSILON {
770            continue;
771        }
772
773        let u: f64 = rng.random();
774        let delta_q = if u < 0.5 {
775            let delta_l = (*gene - lo) / delta;
776            let base = 2.0 * u + (1.0 - 2.0 * u) * (1.0 - delta_l).powf(eta_m + 1.0);
777            base.powf(1.0 / (eta_m + 1.0)) - 1.0
778        } else {
779            let delta_r = (hi - *gene) / delta;
780            let base = 2.0 * (1.0 - u) + 2.0 * (u - 0.5) * (1.0 - delta_r).powf(eta_m + 1.0);
781            1.0 - base.powf(1.0 / (eta_m + 1.0))
782        };
783
784        *gene = (*gene + delta_q * delta).clamp(lo, hi);
785    }
786    genes
787}
788
789// ─────────────────────────────────────────────────────────────────────────────
790// Random initialisation
791// ─────────────────────────────────────────────────────────────────────────────
792
793fn random_genes(bounds: &[(f64, f64)], rng: &mut StdRng) -> Vec<f64> {
794    bounds
795        .iter()
796        .map(|&(lo, hi)| lo + rng.random::<f64>() * (hi - lo))
797        .collect()
798}
799
800// ─────────────────────────────────────────────────────────────────────────────
801// Adaptive reference points
802// ─────────────────────────────────────────────────────────────────────────────
803
804/// Adaptively update reference points based on the current Pareto front.
805///
806/// This implements the adaptive reference point mechanism from the A-NSGA-III
807/// variant: after observing the current approximation, reference points that
808/// have no associated solutions are moved toward the centroid of the front.
809///
810/// # Arguments
811/// * `ref_points`    - Current reference points (modified in place).
812/// * `pareto_front`  - Current Pareto front solutions' normalized objectives.
813/// * `learning_rate` - Step size for reference point update (typically 0.1).
814pub fn adapt_reference_points(
815    ref_points: &mut Vec<Vec<f64>>,
816    pareto_front_norm: &[Vec<f64>],
817    learning_rate: f64,
818) {
819    if pareto_front_norm.is_empty() || ref_points.is_empty() {
820        return;
821    }
822
823    let n_obj = ref_points[0].len();
824
825    // Compute centroid of the normalized Pareto front
826    let mut centroid = vec![0.0f64; n_obj];
827    for pt in pareto_front_norm {
828        for (k, &v) in pt.iter().enumerate() {
829            if k < n_obj {
830                centroid[k] += v;
831            }
832        }
833    }
834    let n = pareto_front_norm.len() as f64;
835    for c in &mut centroid {
836        *c /= n;
837    }
838
839    // For each reference point, check if it has any associated solution
840    for rp in ref_points.iter_mut() {
841        let has_association = pareto_front_norm.iter().any(|pt| {
842            let d = reference_line_distance(pt, rp);
843            d < 0.1 // threshold for "close enough"
844        });
845
846        if !has_association {
847            // Move reference point toward centroid
848            for k in 0..n_obj {
849                rp[k] += learning_rate * (centroid[k] - rp[k]);
850            }
851
852            // Re-normalise to unit simplex: project back
853            let sum: f64 = rp.iter().sum();
854            if sum > 1e-10 {
855                for x in rp.iter_mut() {
856                    *x /= sum;
857                }
858            }
859        }
860    }
861}
862
863// ─────────────────────────────────────────────────────────────────────────────
864// Tests
865// ─────────────────────────────────────────────────────────────────────────────
866
867#[cfg(test)]
868mod tests {
869    use super::*;
870
871    // DTLZ2 benchmark: Pareto front is on sphere surface in M-dimensional space
872    fn dtlz2(x: &[f64], n_obj: usize) -> Vec<f64> {
873        let n = x.len();
874        let k = n - n_obj + 1;
875        let g: f64 = x[n - k..].iter().map(|&xi| (xi - 0.5).powi(2)).sum();
876
877        let mut f = vec![0.0f64; n_obj];
878        for i in 0..n_obj {
879            let mut val = 1.0 + g;
880            for j in 0..n_obj - 1 - i {
881                val *= (x[j] * std::f64::consts::FRAC_PI_2).cos();
882            }
883            if i > 0 {
884                val *= (x[n_obj - 1 - i] * std::f64::consts::FRAC_PI_2).sin();
885            }
886            f[i] = val;
887        }
888        f
889    }
890
891    // ── Reference point generation ───────────────────────────────────────────
892
893    #[test]
894    fn test_reference_points_sum_to_one() {
895        let pts = generate_reference_points(3, 4);
896        for p in &pts {
897            let s: f64 = p.iter().sum();
898            assert!((s - 1.0).abs() < 1e-10, "sum = {s}");
899            assert_eq!(p.len(), 3);
900        }
901    }
902
903    #[test]
904    fn test_reference_points_count() {
905        // C(H + M - 1, M - 1) reference points
906        // For M=3, H=4: C(6, 2) = 15
907        let pts = generate_reference_points(3, 4);
908        assert_eq!(pts.len(), 15, "Expected 15 reference points for M=3, H=4");
909
910        // For M=2, H=5: C(6, 1) = 6
911        let pts2 = generate_reference_points(2, 5);
912        assert_eq!(pts2.len(), 6);
913    }
914
915    #[test]
916    fn test_reference_points_non_negative() {
917        let pts = generate_reference_points(4, 3);
918        for p in &pts {
919            for &v in p {
920                assert!(v >= 0.0, "Reference point component {v} is negative");
921            }
922        }
923    }
924
925    #[test]
926    fn test_inner_reference_points_inside_simplex() {
927        let pts = generate_reference_points_inner(3, 3);
928        for p in &pts {
929            let s: f64 = p.iter().sum();
930            assert!((s - 1.0).abs() < 0.01, "inner sum = {s}");
931            for &v in p {
932                assert!(v >= 0.0, "negative inner component {v}");
933            }
934        }
935    }
936
937    // ── reference_line_distance ──────────────────────────────────────────────
938
939    #[test]
940    fn test_ref_line_distance_on_line() {
941        // A point along the reference direction should have distance 0
942        let f_norm = vec![0.5, 0.5];
943        let ref_pt = vec![1.0, 1.0]; // direction (1,1)
944        let d = reference_line_distance(&f_norm, &ref_pt);
945        assert!(d < 1e-10, "point on ref line should have d≈0, got {d}");
946    }
947
948    #[test]
949    fn test_ref_line_distance_perpendicular() {
950        // (1, 0) has distance 1/sqrt(2) from direction (1, 1)/sqrt(2)
951        let f_norm = vec![1.0, 0.0];
952        let ref_pt = vec![1.0, 1.0];
953        let d = reference_line_distance(&f_norm, &ref_pt);
954        let expected = (0.5f64).sqrt();
955        assert!((d - expected).abs() < 1e-10, "expected {expected}, got {d}");
956    }
957
958    // ── associate_to_reference_points ─────────────────────────────────────────
959
960    #[test]
961    fn test_association_nearest_ref() {
962        // Two reference points: (1,0) and (0,1); individual at (0.9, 0.1) → ref 0
963        let ref_points = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
964        let pop = vec![Individual::new(vec![0.0], vec![0.9, 0.1])];
965        let ideal = vec![0.0, 0.0];
966        let nadir = vec![1.0, 1.0];
967
968        let assoc = associate_to_reference_points(&pop, &ref_points, &ideal, &nadir);
969        assert_eq!(assoc.len(), 1);
970        assert_eq!(assoc[0].0, 0, "Should be associated with reference point 0");
971    }
972
973    // ── nsga3 on DTLZ2 ───────────────────────────────────────────────────────
974
975    #[test]
976    fn test_nsga3_returns_pareto_front() {
977        let n_obj = 3;
978        let bounds: Vec<(f64, f64)> = vec![(0.0, 1.0); n_obj + 3];
979        let mut cfg = Nsga3Config::default();
980        cfg.population_size = 20;
981        cfg.n_generations = 10;
982        cfg.n_divisions = 3;
983
984        let result = nsga3(n_obj, &bounds, |x| dtlz2(x, n_obj), cfg).expect("nsga3 should succeed");
985
986        assert!(!result.pareto_front.is_empty());
987        assert!(!result.reference_points.is_empty());
988    }
989
990    #[test]
991    fn test_nsga3_pareto_front_non_dominated() {
992        let n_obj = 3;
993        let bounds: Vec<(f64, f64)> = vec![(0.0, 1.0); n_obj + 3];
994        let mut cfg = Nsga3Config::default();
995        cfg.population_size = 20;
996        cfg.n_generations = 15;
997        cfg.n_divisions = 3;
998        cfg.seed = 77;
999
1000        let result =
1001            nsga3(n_obj, &bounds, |x| dtlz2(x, n_obj), cfg).expect("failed to create result");
1002        let front = &result.pareto_front;
1003
1004        for i in 0..front.len() {
1005            for j in 0..front.len() {
1006                if i != j {
1007                    assert!(
1008                        !dominates(&front[i].objectives, &front[j].objectives),
1009                        "front[{i}] dominates front[{j}]"
1010                    );
1011                }
1012            }
1013        }
1014    }
1015
1016    #[test]
1017    fn test_nsga3_four_objectives() {
1018        // Many-objective: 4 objectives where NSGA-II degrades
1019        let n_obj = 4;
1020        let bounds: Vec<(f64, f64)> = vec![(0.0, 1.0); n_obj + 3];
1021        let mut cfg = Nsga3Config::default();
1022        cfg.population_size = 30;
1023        cfg.n_generations = 10;
1024        cfg.n_divisions = 3;
1025
1026        let result =
1027            nsga3(n_obj, &bounds, |x| dtlz2(x, n_obj), cfg).expect("failed to create result");
1028        assert!(!result.pareto_front.is_empty());
1029    }
1030
1031    #[test]
1032    fn test_nsga3_two_layer_reference_points() {
1033        let n_obj = 3;
1034        let bounds: Vec<(f64, f64)> = vec![(0.0, 1.0); n_obj + 2];
1035        let mut cfg = Nsga3Config::default();
1036        cfg.population_size = 30;
1037        cfg.n_generations = 10;
1038        cfg.n_divisions = 3;
1039        cfg.n_divisions_inner = Some(2);
1040
1041        let result =
1042            nsga3(n_obj, &bounds, |x| dtlz2(x, n_obj), cfg).expect("failed to create result");
1043        // Two-layer should have more reference points
1044        assert!(result.reference_points.len() > 10);
1045        assert!(!result.pareto_front.is_empty());
1046    }
1047
1048    #[test]
1049    fn test_nsga3_bounds_respected() {
1050        let bounds = vec![(0.2, 0.8); 4];
1051        let mut cfg = Nsga3Config::default();
1052        cfg.population_size = 20;
1053        cfg.n_generations = 10;
1054        cfg.n_divisions = 3;
1055
1056        let result =
1057            nsga3(3, &bounds, |x| vec![x[0], x[1], x[2]], cfg).expect("failed to create result");
1058
1059        for ind in &result.pareto_front {
1060            for (i, &g) in ind.genes.iter().enumerate() {
1061                assert!(
1062                    g >= bounds[i].0 - 1e-9 && g <= bounds[i].1 + 1e-9,
1063                    "gene[{i}]={g} outside bounds"
1064                );
1065            }
1066        }
1067    }
1068
1069    #[test]
1070    fn test_nsga3_invalid_input() {
1071        // Empty bounds
1072        let result = nsga3(3, &[], |x| vec![x[0]], Nsga3Config::default());
1073        assert!(result.is_err());
1074
1075        // Bad bound
1076        let result = nsga3(3, &[(1.0, 0.0)], |x| vec![x[0]], Nsga3Config::default());
1077        assert!(result.is_err());
1078
1079        // Too few objectives
1080        let result = nsga3(1, &[(0.0, 1.0)], |x| vec![x[0]], Nsga3Config::default());
1081        assert!(result.is_err());
1082    }
1083
1084    #[test]
1085    fn test_nsga3_reference_point_coverage() {
1086        let ref_pts = generate_reference_points(3, 4);
1087        // All reference points should be on the unit simplex
1088        for p in &ref_pts {
1089            let sum: f64 = p.iter().sum();
1090            assert!((sum - 1.0).abs() < 1e-10);
1091            for &v in p {
1092                assert!(v >= 0.0 && v <= 1.0);
1093            }
1094        }
1095    }
1096
1097    #[test]
1098    fn test_adapt_reference_points() {
1099        let mut ref_pts = generate_reference_points(3, 3);
1100        let initial_count = ref_pts.len();
1101
1102        // Simulate a Pareto front concentrated in one corner
1103        let fake_front: Vec<Vec<f64>> = vec![vec![0.9, 0.05, 0.05], vec![0.85, 0.1, 0.05]];
1104
1105        adapt_reference_points(&mut ref_pts, &fake_front, 0.1);
1106
1107        // Count should be unchanged
1108        assert_eq!(ref_pts.len(), initial_count);
1109
1110        // Reference points should still approximately sum to 1
1111        for p in &ref_pts {
1112            let sum: f64 = p.iter().sum();
1113            assert!((sum - 1.0).abs() < 0.01, "adapted ref point sum = {sum}");
1114        }
1115    }
1116}