phop_core/
any_solution.rs1use crate::affine::AffineSolution;
13use crate::error::Result;
14use crate::solution::Solution;
15use scirs2_core::ndarray::{Array1, Array2};
16
17#[derive(Clone, Debug)]
19pub enum AnySolution {
20 Eml(Solution),
24 Affine(AffineSolution),
26}
27
28impl AnySolution {
29 #[must_use]
33 pub fn complexity(&self) -> usize {
34 match self {
35 Self::Eml(s) => s.complexity,
36 Self::Affine(s) => s.nodes,
37 }
38 }
39
40 #[must_use]
42 pub fn mse(&self) -> f64 {
43 match self {
44 Self::Eml(s) => s.mse,
45 Self::Affine(s) => s.mse,
46 }
47 }
48
49 #[must_use]
51 pub fn source(&self) -> &'static str {
52 match self {
53 Self::Eml(_) => "eml",
54 Self::Affine(_) => "affine",
55 }
56 }
57
58 #[must_use]
62 pub fn is_symbolic(&self) -> bool {
63 match self {
64 Self::Eml(_) => true,
65 Self::Affine(s) => s.symbolic,
66 }
67 }
68
69 #[must_use]
71 pub fn expr(&self) -> String {
72 match self {
73 Self::Eml(s) => s.pretty(),
74 Self::Affine(s) => s.expr.clone(),
75 }
76 }
77
78 #[must_use]
80 pub fn latex(&self) -> String {
81 match self {
82 Self::Eml(s) => s.latex(),
83 Self::Affine(s) => s.latex(),
84 }
85 }
86
87 pub fn predict(&self, x: &Array2<f64>) -> Result<Array1<f64>> {
92 match self {
93 Self::Eml(s) => s.predict(x),
94 Self::Affine(s) => Ok(s.predict(x)),
95 }
96 }
97
98 #[must_use]
101 pub fn as_eml(&self) -> Option<&Solution> {
102 match self {
103 Self::Eml(s) => Some(s),
104 Self::Affine(_) => None,
105 }
106 }
107
108 #[must_use]
110 pub fn dominates(&self, other: &Self) -> bool {
111 let (c0, m0) = (self.complexity(), self.mse());
112 let (c1, m1) = (other.complexity(), other.mse());
113 c0 <= c1 && m0 <= m1 && (c0 < c1 || m0 < m1)
114 }
115}
116
117#[must_use]
121pub fn merge_pareto(candidates: Vec<AnySolution>) -> Vec<AnySolution> {
122 let mut front: Vec<AnySolution> = Vec::new();
123 for cand in candidates {
124 if !cand.mse().is_finite() {
125 continue;
126 }
127 if front.iter().any(|s| s.dominates(&cand)) {
128 continue;
129 }
130 front.retain(|s| !cand.dominates(s));
131 if front
132 .iter()
133 .any(|s| s.complexity() == cand.complexity() && (s.mse() - cand.mse()).abs() < 1e-12)
134 {
135 continue;
136 }
137 front.push(cand);
138 }
139 front.sort_by(|a, b| {
140 a.mse()
141 .partial_cmp(&b.mse())
142 .unwrap_or(std::cmp::Ordering::Equal)
143 });
144 front
145}
146
147#[cfg(test)]
148mod tests {
149 use super::*;
150 use crate::affine::discover_affine_pareto;
151 use crate::solution::Solution;
152 use oxieml::EmlTree;
153 use scirs2_core::ndarray::{Array1, Array2};
154
155 #[test]
156 fn domination_uses_complexity_and_mse() {
157 let tree = oxieml::Canonical::exp(&EmlTree::var(0));
158 let simple_accurate = AnySolution::Eml(Solution {
159 tree: tree.clone(),
160 mse: 0.1,
161 complexity: 3,
162 });
163 let complex_worse = AnySolution::Eml(Solution {
164 tree,
165 mse: 0.2,
166 complexity: 5,
167 });
168 assert!(simple_accurate.dominates(&complex_worse));
169 assert!(!complex_worse.dominates(&simple_accurate));
170 }
171
172 #[test]
173 fn merge_keeps_non_dominated_and_sorts_by_mse() {
174 let t = oxieml::Canonical::exp(&EmlTree::var(0));
175 let a = AnySolution::Eml(Solution {
176 tree: t.clone(),
177 mse: 0.30,
178 complexity: 2,
179 });
180 let b = AnySolution::Eml(Solution {
181 tree: t.clone(),
182 mse: 0.10,
183 complexity: 4,
184 });
185 let dominated = AnySolution::Eml(Solution {
187 tree: t,
188 mse: 0.40,
189 complexity: 6,
190 });
191 let front = merge_pareto(vec![a, b, dominated]);
192 assert_eq!(front.len(), 2, "the dominated member must be removed");
193 assert!(front[0].mse() <= front[1].mse());
195 assert!((front[0].mse() - 0.10).abs() < 1e-12);
196 }
197
198 #[test]
199 fn merge_includes_affine_members() {
200 let n = 40usize;
203 let mut x = Array2::<f64>::zeros((n, 2));
204 let mut y = Array1::<f64>::zeros(n);
205 for i in 0..n {
206 let x0 = 1.0 + 0.05 * i as f64;
207 let x1 = 0.5 + 0.03 * i as f64;
208 x[[i, 0]] = x0;
209 x[[i, 1]] = x1;
210 y[i] = x0 * x0 * x1;
211 }
212 let affine = discover_affine_pareto(&x, &y, 2, 500);
213 let merged = merge_pareto(affine.into_iter().map(AnySolution::Affine).collect());
214 assert!(!merged.is_empty(), "affine engine should recover x0^2*x1");
215 let best = &merged[0];
216 assert_eq!(best.source(), "affine");
217 assert!(best.mse() < 1e-3, "best affine mse = {}", best.mse());
218 }
219}