1use crate::solution::Solution;
14use oxieml::EmlNode;
15use scirs2_optimize::multiobjective::pareto::{crowding_distance, non_dominated_sort};
16
17#[derive(Debug, Clone, Default)]
19pub struct ParetoFront {
20 pub solutions: Vec<Solution>,
22}
23
24impl ParetoFront {
25 #[must_use]
29 pub fn from_candidates(candidates: Vec<Solution>) -> Self {
30 let mut front: Vec<Solution> = Vec::new();
31 for cand in candidates {
32 if !cand.mse.is_finite() {
33 continue;
34 }
35 if front.iter().any(|s| s.dominates(&cand)) {
37 continue;
38 }
39 front.retain(|s| !cand.dominates(s));
41 if front
43 .iter()
44 .any(|s| s.complexity == cand.complexity && (s.mse - cand.mse).abs() < 1e-12)
45 {
46 continue;
47 }
48 front.push(cand);
49 }
50 front.sort_by(|a, b| {
51 a.complexity.cmp(&b.complexity).then(
52 a.mse
53 .partial_cmp(&b.mse)
54 .unwrap_or(std::cmp::Ordering::Equal),
55 )
56 });
57 Self { solutions: front }
58 }
59
60 #[must_use]
62 pub fn pareto_top(&self, k: usize) -> Vec<&Solution> {
63 let mut by_mse: Vec<&Solution> = self.solutions.iter().collect();
64 by_mse.sort_by(|a, b| {
65 a.mse
66 .partial_cmp(&b.mse)
67 .unwrap_or(std::cmp::Ordering::Equal)
68 });
69 by_mse.into_iter().take(k).collect()
70 }
71
72 #[must_use]
74 pub fn best(&self) -> Option<&Solution> {
75 self.solutions.iter().min_by(|a, b| {
76 a.mse
77 .partial_cmp(&b.mse)
78 .unwrap_or(std::cmp::Ordering::Equal)
79 })
80 }
81
82 #[must_use]
84 pub fn len(&self) -> usize {
85 self.solutions.len()
86 }
87
88 #[must_use]
90 pub fn is_empty(&self) -> bool {
91 self.solutions.is_empty()
92 }
93
94 #[must_use]
102 pub fn rank_multiobjective(&self) -> Vec<MultiRank> {
103 if self.solutions.is_empty() {
104 return Vec::new();
105 }
106 let objs_struct: Vec<MultiObjective> = self.solutions.iter().map(objectives).collect();
107 let objs: Vec<Vec<f64>> = objs_struct
109 .iter()
110 .map(|o| vec![o.complexity, o.mse, -o.interpretability, -o.elegance])
111 .collect();
112
113 let fronts = non_dominated_sort(&objs);
114 let mut ranks: Vec<MultiRank> = Vec::with_capacity(self.solutions.len());
115 for (front_rank, front_idx) in fronts.iter().enumerate() {
116 let front_objs: Vec<Vec<f64>> = front_idx.iter().map(|&i| objs[i].clone()).collect();
117 let cd = crowding_distance(&front_objs);
118 for (slot, &i) in front_idx.iter().enumerate() {
119 ranks.push(MultiRank {
120 index: i,
121 front: front_rank,
122 crowding: cd.get(slot).copied().unwrap_or(0.0),
123 objectives: objs_struct[i],
124 });
125 }
126 }
127 ranks.sort_by(|a, b| {
128 a.front.cmp(&b.front).then(
129 b.crowding
130 .partial_cmp(&a.crowding)
131 .unwrap_or(std::cmp::Ordering::Equal),
132 )
133 });
134 ranks
135 }
136}
137
138#[derive(Debug, Clone, Copy, PartialEq)]
141pub struct MultiObjective {
142 pub complexity: f64,
144 pub mse: f64,
146 pub interpretability: f64,
148 pub elegance: f64,
150}
151
152#[derive(Debug, Clone, Copy, PartialEq)]
154pub struct MultiRank {
155 pub index: usize,
157 pub front: usize,
159 pub crowding: f64,
161 pub objectives: MultiObjective,
163}
164
165#[must_use]
167pub fn objectives(sol: &Solution) -> MultiObjective {
168 MultiObjective {
169 complexity: sol.complexity as f64,
170 mse: sol.mse,
171 interpretability: 1.0 / (1.0 + depth(&sol.tree.root) as f64),
172 elegance: elegance(sol),
173 }
174}
175
176fn depth(node: &EmlNode) -> usize {
178 match node {
179 EmlNode::One | EmlNode::Var(_) | EmlNode::Const(_) => 0,
180 EmlNode::Eml { left, right } => 1 + depth(left).max(depth(right)),
181 }
182}
183
184fn elegance(sol: &Solution) -> f64 {
187 let mut consts = Vec::new();
188 crate::fit::collect_consts(&sol.tree.root, &mut consts);
189 if consts.is_empty() {
190 return 1.0;
191 }
192 let simple = consts.iter().filter(|&&c| is_simple_constant(c)).count();
193 simple as f64 / consts.len() as f64
194}
195
196fn is_simple_constant(c: f64) -> bool {
198 let near_small_int = (c - c.round()).abs() < 1e-4 && c.round().abs() <= 12.0;
199 near_small_int || oxieml::symreg::snap_to_named_const(c).is_some()
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205 use oxieml::{Canonical, EmlTree};
206
207 fn sol(mse: f64, complexity: usize) -> Solution {
208 Solution {
209 tree: Canonical::exp(&EmlTree::var(0)),
210 mse,
211 complexity,
212 }
213 }
214
215 #[test]
216 fn ranks_on_four_objectives() {
217 use oxieml::EmlTree;
218 let s_simple = Solution::new(EmlTree::eml(&EmlTree::var(0), &EmlTree::one()), 1e-9);
220 let messy = EmlTree::eml(
223 &EmlTree::eml(&EmlTree::var(0), &EmlTree::const_val(0.7234)),
224 &EmlTree::const_val(0.1119),
225 );
226 let s_messy = Solution::new(messy, 0.5);
227
228 let front = ParetoFront {
229 solutions: vec![s_simple, s_messy],
230 };
231 let ranks = front.rank_multiobjective();
232 assert_eq!(ranks.len(), 2, "one rank per solution");
233
234 assert_eq!(ranks[0].front, 0);
237 assert_eq!(ranks[0].index, 0);
238 let messy_rank = ranks
239 .iter()
240 .find(|r| r.index == 1)
241 .expect("messy solution ranked");
242 assert!(
243 messy_rank.front >= 1,
244 "dominated solution should be a later front"
245 );
246
247 assert!((ranks[0].objectives.elegance - 1.0).abs() < 1e-12);
250 assert!(messy_rank.objectives.elegance < 1e-12);
251 assert!(ranks[0].objectives.interpretability > messy_rank.objectives.interpretability);
253 }
254
255 #[test]
256 fn keeps_only_non_dominated() {
257 let cands = vec![
258 sol(0.5, 1), sol(0.1, 5), sol(0.6, 6), sol(0.3, 3), ];
263 let front = ParetoFront::from_candidates(cands);
264 assert_eq!(front.len(), 3);
265 assert!((front.best().unwrap().mse - 0.1).abs() < 1e-12);
266 assert!(front.solutions[0].complexity <= front.solutions[1].complexity);
268 }
269}