use crate::solution::Solution;
use oxieml::EmlNode;
use scirs2_optimize::multiobjective::pareto::{crowding_distance, non_dominated_sort};
#[derive(Debug, Clone, Default)]
pub struct ParetoFront {
pub solutions: Vec<Solution>,
}
impl ParetoFront {
#[must_use]
pub fn from_candidates(candidates: Vec<Solution>) -> Self {
let mut front: Vec<Solution> = Vec::new();
for cand in candidates {
if !cand.mse.is_finite() {
continue;
}
if front.iter().any(|s| s.dominates(&cand)) {
continue;
}
front.retain(|s| !cand.dominates(s));
if front
.iter()
.any(|s| s.complexity == cand.complexity && (s.mse - cand.mse).abs() < 1e-12)
{
continue;
}
front.push(cand);
}
front.sort_by(|a, b| {
a.complexity.cmp(&b.complexity).then(
a.mse
.partial_cmp(&b.mse)
.unwrap_or(std::cmp::Ordering::Equal),
)
});
Self { solutions: front }
}
#[must_use]
pub fn pareto_top(&self, k: usize) -> Vec<&Solution> {
let mut by_mse: Vec<&Solution> = self.solutions.iter().collect();
by_mse.sort_by(|a, b| {
a.mse
.partial_cmp(&b.mse)
.unwrap_or(std::cmp::Ordering::Equal)
});
by_mse.into_iter().take(k).collect()
}
#[must_use]
pub fn best(&self) -> Option<&Solution> {
self.solutions.iter().min_by(|a, b| {
a.mse
.partial_cmp(&b.mse)
.unwrap_or(std::cmp::Ordering::Equal)
})
}
#[must_use]
pub fn len(&self) -> usize {
self.solutions.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.solutions.is_empty()
}
#[must_use]
pub fn rank_multiobjective(&self) -> Vec<MultiRank> {
if self.solutions.is_empty() {
return Vec::new();
}
let objs_struct: Vec<MultiObjective> = self.solutions.iter().map(objectives).collect();
let objs: Vec<Vec<f64>> = objs_struct
.iter()
.map(|o| vec![o.complexity, o.mse, -o.interpretability, -o.elegance])
.collect();
let fronts = non_dominated_sort(&objs);
let mut ranks: Vec<MultiRank> = Vec::with_capacity(self.solutions.len());
for (front_rank, front_idx) in fronts.iter().enumerate() {
let front_objs: Vec<Vec<f64>> = front_idx.iter().map(|&i| objs[i].clone()).collect();
let cd = crowding_distance(&front_objs);
for (slot, &i) in front_idx.iter().enumerate() {
ranks.push(MultiRank {
index: i,
front: front_rank,
crowding: cd.get(slot).copied().unwrap_or(0.0),
objectives: objs_struct[i],
});
}
}
ranks.sort_by(|a, b| {
a.front.cmp(&b.front).then(
b.crowding
.partial_cmp(&a.crowding)
.unwrap_or(std::cmp::Ordering::Equal),
)
});
ranks
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct MultiObjective {
pub complexity: f64,
pub mse: f64,
pub interpretability: f64,
pub elegance: f64,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct MultiRank {
pub index: usize,
pub front: usize,
pub crowding: f64,
pub objectives: MultiObjective,
}
#[must_use]
pub fn objectives(sol: &Solution) -> MultiObjective {
MultiObjective {
complexity: sol.complexity as f64,
mse: sol.mse,
interpretability: 1.0 / (1.0 + depth(&sol.tree.root) as f64),
elegance: elegance(sol),
}
}
fn depth(node: &EmlNode) -> usize {
match node {
EmlNode::One | EmlNode::Var(_) | EmlNode::Const(_) => 0,
EmlNode::Eml { left, right } => 1 + depth(left).max(depth(right)),
}
}
fn elegance(sol: &Solution) -> f64 {
let mut consts = Vec::new();
crate::fit::collect_consts(&sol.tree.root, &mut consts);
if consts.is_empty() {
return 1.0;
}
let simple = consts.iter().filter(|&&c| is_simple_constant(c)).count();
simple as f64 / consts.len() as f64
}
fn is_simple_constant(c: f64) -> bool {
let near_small_int = (c - c.round()).abs() < 1e-4 && c.round().abs() <= 12.0;
near_small_int || oxieml::symreg::snap_to_named_const(c).is_some()
}
#[cfg(test)]
mod tests {
use super::*;
use oxieml::{Canonical, EmlTree};
fn sol(mse: f64, complexity: usize) -> Solution {
Solution {
tree: Canonical::exp(&EmlTree::var(0)),
mse,
complexity,
}
}
#[test]
fn ranks_on_four_objectives() {
use oxieml::EmlTree;
let s_simple = Solution::new(EmlTree::eml(&EmlTree::var(0), &EmlTree::one()), 1e-9);
let messy = EmlTree::eml(
&EmlTree::eml(&EmlTree::var(0), &EmlTree::const_val(0.7234)),
&EmlTree::const_val(0.1119),
);
let s_messy = Solution::new(messy, 0.5);
let front = ParetoFront {
solutions: vec![s_simple, s_messy],
};
let ranks = front.rank_multiobjective();
assert_eq!(ranks.len(), 2, "one rank per solution");
assert_eq!(ranks[0].front, 0);
assert_eq!(ranks[0].index, 0);
let messy_rank = ranks
.iter()
.find(|r| r.index == 1)
.expect("messy solution ranked");
assert!(
messy_rank.front >= 1,
"dominated solution should be a later front"
);
assert!((ranks[0].objectives.elegance - 1.0).abs() < 1e-12);
assert!(messy_rank.objectives.elegance < 1e-12);
assert!(ranks[0].objectives.interpretability > messy_rank.objectives.interpretability);
}
#[test]
fn keeps_only_non_dominated() {
let cands = vec![
sol(0.5, 1), sol(0.1, 5), sol(0.6, 6), sol(0.3, 3), ];
let front = ParetoFront::from_candidates(cands);
assert_eq!(front.len(), 3);
assert!((front.best().unwrap().mse - 0.1).abs() < 1e-12);
assert!(front.solutions[0].complexity <= front.solutions[1].complexity);
}
}