scirs2_optimize/multiobjective/
moead.rs1use crate::error::OptimizeResult;
24use crate::multiobjective::indicators::{dominates, non_dominated_sort};
25use crate::multiobjective::nsga2::Individual;
26use scirs2_core::random::rngs::StdRng;
27use scirs2_core::random::{Rng, SeedableRng};
28use scirs2_core::RngExt;
29
30#[derive(Debug, Clone)]
36pub struct MoeadConfig {
37 pub population_size: usize,
40 pub n_generations: usize,
42 pub n_neighbors: usize,
45 pub delta: f64,
48 pub n_objectives: usize,
50 pub seed: u64,
52}
53
54impl Default for MoeadConfig {
55 fn default() -> Self {
56 Self {
57 population_size: 100,
58 n_generations: 200,
59 n_neighbors: 20,
60 delta: 0.9,
61 n_objectives: 2,
62 seed: 12345,
63 }
64 }
65}
66
67#[derive(Debug)]
69pub struct MoeadResult {
70 pub pareto_front: Vec<Individual>,
72 pub weight_vectors: Vec<Vec<f64>>,
74 pub n_generations: usize,
76}
77
78pub fn moead<F>(
113 bounds: &[(f64, f64)],
114 objectives: F,
115 config: MoeadConfig,
116) -> OptimizeResult<MoeadResult>
117where
118 F: Fn(&[f64]) -> Vec<f64>,
119{
120 use crate::error::OptimizeError;
121
122 if config.n_objectives < 2 {
123 return Err(OptimizeError::InvalidInput(
124 "n_objectives must be >= 2".to_string(),
125 ));
126 }
127 if bounds.is_empty() {
128 return Err(OptimizeError::InvalidInput(
129 "bounds must be non-empty".to_string(),
130 ));
131 }
132 for (i, &(lo, hi)) in bounds.iter().enumerate() {
133 if lo >= hi {
134 return Err(OptimizeError::InvalidInput(format!(
135 "bound[{i}]: lo ({lo}) must be < hi ({hi})"
136 )));
137 }
138 }
139
140 let n_vars = bounds.len();
141 let pop_size = config.population_size.max(4);
142 let n_obj = config.n_objectives;
143 let t = config.n_neighbors.min(pop_size).max(2);
144
145 let mut rng = StdRng::seed_from_u64(config.seed);
146
147 let weight_vectors = generate_weight_vectors(n_obj, pop_size, &mut rng);
149 let actual_pop_size = weight_vectors.len();
150
151 let neighborhoods = build_neighborhood(&weight_vectors, t);
153
154 let mut population: Vec<Individual> = (0..actual_pop_size)
156 .map(|_| {
157 let genes = random_genes(bounds, &mut rng);
158 let objs = objectives(&genes);
159 Individual::new(genes, objs)
160 })
161 .collect();
162
163 let mut ideal_point: Vec<f64> = vec![f64::INFINITY; n_obj];
165 for ind in &population {
166 for (k, &v) in ind.objectives.iter().enumerate() {
167 if v < ideal_point[k] {
168 ideal_point[k] = v;
169 }
170 }
171 }
172
173 for _ in 0..config.n_generations {
175 for i in 0..actual_pop_size {
176 let use_neighborhood = rng.random::<f64>() < config.delta;
178 let mating_pool = if use_neighborhood {
179 &neighborhoods[i]
180 } else {
181 &neighborhoods[i] };
184
185 let offspring_genes = de_offspring(&population, mating_pool, bounds, &mut rng);
187 let offspring_objs = objectives(&offspring_genes);
188 let offspring = Individual::new(offspring_genes, offspring_objs);
189
190 for (k, &v) in offspring.objectives.iter().enumerate() {
192 if v < ideal_point[k] {
193 ideal_point[k] = v;
194 }
195 }
196
197 for &j in mating_pool {
199 let w = &weight_vectors[j];
200 let g_offspring = tchebycheff_scalarization(&offspring.objectives, w, &ideal_point);
201 let g_current =
202 tchebycheff_scalarization(&population[j].objectives, w, &ideal_point);
203 if g_offspring <= g_current {
204 population[j] = offspring.clone();
205 }
206 }
207 }
208 }
209
210 let obj_vecs: Vec<Vec<f64>> = population
212 .iter()
213 .map(|ind| ind.objectives.clone())
214 .collect();
215 let fronts = non_dominated_sort(&obj_vecs);
216
217 let pareto_front: Vec<Individual> = if fronts.is_empty() {
218 population.clone()
219 } else {
220 fronts[0].iter().map(|&i| population[i].clone()).collect()
221 };
222
223 Ok(MoeadResult {
224 pareto_front,
225 weight_vectors,
226 n_generations: config.n_generations,
227 })
228}
229
230pub fn tchebycheff_scalarization(objectives: &[f64], weights: &[f64], ideal_point: &[f64]) -> f64 {
241 objectives
242 .iter()
243 .zip(weights.iter())
244 .zip(ideal_point.iter())
245 .map(|((f, w), z)| w * (f - z).abs())
246 .fold(f64::NEG_INFINITY, f64::max)
247}
248
249pub fn generate_weight_vectors(n_obj: usize, target_n: usize, rng: &mut StdRng) -> Vec<Vec<f64>> {
265 if n_obj == 1 {
266 return vec![vec![1.0]];
267 }
268
269 let mut h = 1usize;
271 while combinations(h + n_obj, n_obj - 1) <= target_n {
272 h += 1;
273 }
274 h -= 1;
275 if h == 0 {
276 h = 1;
277 }
278
279 let mut vectors: Vec<Vec<f64>> = Vec::new();
280 enumerate_simplex(&mut vectors, n_obj, h);
281
282 for v in &mut vectors {
284 for x in v.iter_mut() {
285 *x /= h as f64;
286 }
287 }
288
289 while vectors.len() < 2 {
291 vectors.push(random_simplex_point(n_obj, rng));
292 }
293
294 vectors
295}
296
297fn enumerate_simplex(out: &mut Vec<Vec<f64>>, n_obj: usize, h: usize) {
299 let mut current = vec![0.0f64; n_obj];
300 enumerate_simplex_rec(out, &mut current, n_obj, h, 0, h);
301}
302
303fn enumerate_simplex_rec(
304 out: &mut Vec<Vec<f64>>,
305 current: &mut Vec<f64>,
306 n_obj: usize,
307 h: usize,
308 index: usize,
309 remaining: usize,
310) {
311 if index == n_obj - 1 {
312 current[index] = remaining as f64;
313 out.push(current.clone());
314 return;
315 }
316 for i in 0..=remaining {
317 current[index] = i as f64;
318 enumerate_simplex_rec(out, current, n_obj, h, index + 1, remaining - i);
319 }
320}
321
322fn combinations(n: usize, k: usize) -> usize {
324 if k == 0 || k == n {
325 return 1;
326 }
327 if k > n {
328 return 0;
329 }
330 let k = k.min(n - k);
331 let mut result = 1usize;
332 for i in 0..k {
333 result = result.saturating_mul(n - i) / (i + 1);
334 }
335 result
336}
337
338fn random_simplex_point(n: usize, rng: &mut StdRng) -> Vec<f64> {
340 let exps: Vec<f64> = (0..n)
342 .map(|_| -rng.random::<f64>().ln().max(1e-15))
343 .collect();
344 let sum: f64 = exps.iter().sum();
345 exps.iter().map(|&e| e / sum).collect()
346}
347
348pub fn build_neighborhood(weight_vectors: &[Vec<f64>], t: usize) -> Vec<Vec<usize>> {
357 let n = weight_vectors.len();
358 let t = t.min(n);
359
360 weight_vectors
361 .iter()
362 .map(|wi| {
363 let mut dist_idx: Vec<(f64, usize)> = weight_vectors
365 .iter()
366 .enumerate()
367 .map(|(j, wj)| {
368 let d = wi
369 .iter()
370 .zip(wj.iter())
371 .map(|(a, b)| (a - b).powi(2))
372 .sum::<f64>()
373 .sqrt();
374 (d, j)
375 })
376 .collect();
377
378 dist_idx.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
379 dist_idx.iter().take(t).map(|&(_, j)| j).collect()
380 })
381 .collect()
382}
383
384fn de_offspring(
394 population: &[Individual],
395 pool: &[usize],
396 bounds: &[(f64, f64)],
397 rng: &mut StdRng,
398) -> Vec<f64> {
399 let n_pool = pool.len();
400 let n_vars = bounds.len();
401
402 let r1 = pool[rng.random_range(0..n_pool)];
404 let mut r2 = pool[rng.random_range(0..n_pool)];
405 let mut r3 = pool[rng.random_range(0..n_pool)];
406
407 for _ in 0..3 {
409 if r2 != r1 {
410 break;
411 }
412 r2 = pool[rng.random_range(0..n_pool)];
413 }
414 for _ in 0..3 {
415 if r3 != r1 && r3 != r2 {
416 break;
417 }
418 r3 = pool[rng.random_range(0..n_pool)];
419 }
420
421 let x1 = &population[r1].genes;
422 let x2 = &population[r2].genes;
423 let x3 = &population[r3].genes;
424
425 let f_scale = 0.5 + rng.random::<f64>() * 0.4;
427 let cr = 0.5;
429
430 let j_rand = rng.random_range(0..n_vars);
432
433 let (lo_v, hi_v): (Vec<f64>, Vec<f64>) = bounds.iter().map(|&(lo, hi)| (lo, hi)).unzip();
434
435 (0..n_vars)
436 .map(|j| {
437 if j == j_rand || rng.random::<f64>() < cr {
438 (x1[j] + f_scale * (x2[j] - x3[j])).clamp(lo_v[j], hi_v[j])
439 } else {
440 x1[j]
441 }
442 })
443 .collect()
444}
445
446fn random_genes(bounds: &[(f64, f64)], rng: &mut StdRng) -> Vec<f64> {
451 bounds
452 .iter()
453 .map(|&(lo, hi)| lo + rng.random::<f64>() * (hi - lo))
454 .collect()
455}
456
457#[cfg(test)]
462mod tests {
463 use super::*;
464
465 fn zdt1(x: &[f64]) -> Vec<f64> {
466 let f1 = x[0];
467 let g = 1.0 + 9.0 * x[1..].iter().sum::<f64>() / (x.len() - 1) as f64;
468 let f2 = g * (1.0 - (f1 / g).sqrt());
469 vec![f1, f2]
470 }
471
472 #[test]
475 fn test_weight_vectors_sum_to_one() {
476 let mut rng = StdRng::seed_from_u64(0);
477 let wvs = generate_weight_vectors(2, 10, &mut rng);
478 for w in &wvs {
479 let s: f64 = w.iter().sum();
480 assert!((s - 1.0).abs() < 1e-10, "weight vector sum = {s}");
481 assert_eq!(w.len(), 2);
482 }
483 }
484
485 #[test]
486 fn test_weight_vectors_3obj() {
487 let mut rng = StdRng::seed_from_u64(1);
488 let wvs = generate_weight_vectors(3, 15, &mut rng);
489 for w in &wvs {
490 let s: f64 = w.iter().sum();
491 assert!((s - 1.0).abs() < 1e-10);
492 assert_eq!(w.len(), 3);
493 for &x in w {
494 assert!(x >= 0.0 && x <= 1.0);
495 }
496 }
497 }
498
499 #[test]
502 fn test_neighborhood_includes_self() {
503 let mut rng = StdRng::seed_from_u64(0);
504 let wvs = generate_weight_vectors(2, 10, &mut rng);
505 let nb = build_neighborhood(&wvs, 3);
506 for (i, n) in nb.iter().enumerate() {
507 assert!(n.contains(&i), "neighbourhood of {i} should contain itself");
508 }
509 }
510
511 #[test]
512 fn test_neighborhood_size() {
513 let mut rng = StdRng::seed_from_u64(0);
514 let wvs = generate_weight_vectors(2, 10, &mut rng);
515 let t = 4;
516 let nb = build_neighborhood(&wvs, t);
517 for n in &nb {
518 assert_eq!(n.len(), t);
519 }
520 }
521
522 #[test]
525 fn test_tchebycheff_basic() {
526 let f = vec![1.0, 2.0];
527 let w = vec![0.5, 0.5];
528 let z = vec![0.0, 0.0];
529 let val = tchebycheff_scalarization(&f, &w, &z);
530 assert!((val - 1.0).abs() < 1e-10, "Expected 1.0, got {val}");
532 }
533
534 #[test]
535 fn test_tchebycheff_with_ideal_shift() {
536 let f = vec![3.0, 3.0];
537 let w = vec![0.5, 0.5];
538 let z = vec![1.0, 1.0];
539 let val = tchebycheff_scalarization(&f, &w, &z);
540 assert!((val - 1.0).abs() < 1e-10);
542 }
543
544 #[test]
547 fn test_moead_returns_pareto_front() {
548 let bounds: Vec<(f64, f64)> = vec![(0.0, 1.0); 5];
549 let mut cfg = MoeadConfig::default();
550 cfg.population_size = 20;
551 cfg.n_generations = 10;
552 cfg.n_objectives = 2;
553 cfg.seed = 1;
554
555 let result = moead(&bounds, zdt1, cfg).expect("moead should succeed");
556 assert!(!result.pareto_front.is_empty());
557 }
558
559 #[test]
560 fn test_moead_weight_vectors_returned() {
561 let bounds: Vec<(f64, f64)> = vec![(0.0, 1.0); 3];
562 let mut cfg = MoeadConfig::default();
563 cfg.population_size = 10;
564 cfg.n_generations = 5;
565 cfg.n_objectives = 2;
566
567 let result = moead(&bounds, zdt1, cfg).expect("failed to create result");
568 assert!(!result.weight_vectors.is_empty());
569 for w in &result.weight_vectors {
570 assert_eq!(w.len(), 2);
571 let s: f64 = w.iter().sum();
572 assert!((s - 1.0).abs() < 1e-10, "weight sum = {s}");
573 }
574 }
575
576 #[test]
577 fn test_moead_pareto_front_non_dominated() {
578 let bounds: Vec<(f64, f64)> = vec![(0.0, 1.0); 5];
579 let mut cfg = MoeadConfig::default();
580 cfg.population_size = 20;
581 cfg.n_generations = 20;
582 cfg.n_objectives = 2;
583 cfg.seed = 42;
584
585 let result = moead(&bounds, zdt1, cfg).expect("failed to create result");
586
587 let front = &result.pareto_front;
588 for i in 0..front.len() {
589 for j in 0..front.len() {
590 if i != j {
591 assert!(
592 !dominates(&front[i].objectives, &front[j].objectives),
593 "front[{i}] dominates front[{j}] in MOEA/D result"
594 );
595 }
596 }
597 }
598 }
599
600 #[test]
601 fn test_moead_bounds_respected() {
602 let bounds = vec![(0.2, 0.8); 3];
603 let mut cfg = MoeadConfig::default();
604 cfg.population_size = 10;
605 cfg.n_generations = 5;
606 cfg.n_objectives = 2;
607
608 let result =
609 moead(&bounds, |x| vec![x[0], 1.0 - x[0]], cfg).expect("failed to create result");
610
611 for ind in &result.pareto_front {
612 for (i, &g) in ind.genes.iter().enumerate() {
613 assert!(
614 g >= bounds[i].0 - 1e-9 && g <= bounds[i].1 + 1e-9,
615 "gene[{i}]={g} outside bounds"
616 );
617 }
618 }
619 }
620
621 #[test]
622 fn test_moead_invalid_input() {
623 let result = moead(&[], |x| vec![x[0]], MoeadConfig::default());
625 assert!(result.is_err());
626
627 let result = moead(&[(1.0, 0.0)], |x| vec![x[0]], MoeadConfig::default());
629 assert!(result.is_err());
630
631 let mut cfg = MoeadConfig::default();
633 cfg.n_objectives = 1;
634 let result = moead(&[(0.0, 1.0)], |x| vec![x[0]], cfg);
635 assert!(result.is_err());
636 }
637
638 #[test]
639 fn test_moead_generation_count() {
640 let bounds = vec![(0.0, 1.0); 3];
641 let mut cfg = MoeadConfig::default();
642 cfg.population_size = 10;
643 cfg.n_generations = 7;
644 cfg.n_objectives = 2;
645
646 let result = moead(&bounds, zdt1, cfg).expect("failed to create result");
647 assert_eq!(result.n_generations, 7);
648 }
649
650 #[test]
651 fn test_moead_diverse_objectives() {
652 let bounds: Vec<(f64, f64)> = vec![(0.0, 1.0); 6];
654 let mut cfg = MoeadConfig::default();
655 cfg.population_size = 30;
656 cfg.n_generations = 30;
657 cfg.n_objectives = 2;
658 cfg.seed = 7;
659
660 let result = moead(&bounds, zdt1, cfg).expect("failed to create result");
661
662 assert!(result.pareto_front.len() >= 2);
664
665 for ind in &result.pareto_front {
667 assert!(ind.objectives[0] >= 0.0, "f1 must be >= 0");
668 assert!(ind.objectives[1] >= 0.0, "f2 must be >= 0");
669 }
670 }
671}