use crate::affine::AffineSolution;
use crate::error::Result;
use crate::solution::Solution;
use scirs2_core::ndarray::{Array1, Array2};
#[derive(Clone, Debug)]
pub enum AnySolution {
Eml(Solution),
Affine(AffineSolution),
}
impl AnySolution {
#[must_use]
pub fn complexity(&self) -> usize {
match self {
Self::Eml(s) => s.complexity,
Self::Affine(s) => s.nodes,
}
}
#[must_use]
pub fn mse(&self) -> f64 {
match self {
Self::Eml(s) => s.mse,
Self::Affine(s) => s.mse,
}
}
#[must_use]
pub fn source(&self) -> &'static str {
match self {
Self::Eml(_) => "eml",
Self::Affine(_) => "affine",
}
}
#[must_use]
pub fn is_symbolic(&self) -> bool {
match self {
Self::Eml(_) => true,
Self::Affine(s) => s.symbolic,
}
}
#[must_use]
pub fn expr(&self) -> String {
match self {
Self::Eml(s) => s.pretty(),
Self::Affine(s) => s.expr.clone(),
}
}
#[must_use]
pub fn latex(&self) -> String {
match self {
Self::Eml(s) => s.latex(),
Self::Affine(s) => s.latex(),
}
}
pub fn predict(&self, x: &Array2<f64>) -> Result<Array1<f64>> {
match self {
Self::Eml(s) => s.predict(x),
Self::Affine(s) => Ok(s.predict(x)),
}
}
#[must_use]
pub fn as_eml(&self) -> Option<&Solution> {
match self {
Self::Eml(s) => Some(s),
Self::Affine(_) => None,
}
}
#[must_use]
pub fn dominates(&self, other: &Self) -> bool {
let (c0, m0) = (self.complexity(), self.mse());
let (c1, m1) = (other.complexity(), other.mse());
c0 <= c1 && m0 <= m1 && (c0 < c1 || m0 < m1)
}
}
#[must_use]
pub fn merge_pareto(candidates: Vec<AnySolution>) -> Vec<AnySolution> {
let mut front: Vec<AnySolution> = Vec::new();
for cand in candidates {
if !cand.mse().is_finite() {
continue;
}
if front.iter().any(|s| s.dominates(&cand)) {
continue;
}
front.retain(|s| !cand.dominates(s));
if front
.iter()
.any(|s| s.complexity() == cand.complexity() && (s.mse() - cand.mse()).abs() < 1e-12)
{
continue;
}
front.push(cand);
}
front.sort_by(|a, b| {
a.mse()
.partial_cmp(&b.mse())
.unwrap_or(std::cmp::Ordering::Equal)
});
front
}
#[cfg(test)]
mod tests {
use super::*;
use crate::affine::discover_affine_pareto;
use crate::solution::Solution;
use oxieml::EmlTree;
use scirs2_core::ndarray::{Array1, Array2};
#[test]
fn domination_uses_complexity_and_mse() {
let tree = oxieml::Canonical::exp(&EmlTree::var(0));
let simple_accurate = AnySolution::Eml(Solution {
tree: tree.clone(),
mse: 0.1,
complexity: 3,
});
let complex_worse = AnySolution::Eml(Solution {
tree,
mse: 0.2,
complexity: 5,
});
assert!(simple_accurate.dominates(&complex_worse));
assert!(!complex_worse.dominates(&simple_accurate));
}
#[test]
fn merge_keeps_non_dominated_and_sorts_by_mse() {
let t = oxieml::Canonical::exp(&EmlTree::var(0));
let a = AnySolution::Eml(Solution {
tree: t.clone(),
mse: 0.30,
complexity: 2,
});
let b = AnySolution::Eml(Solution {
tree: t.clone(),
mse: 0.10,
complexity: 4,
});
let dominated = AnySolution::Eml(Solution {
tree: t,
mse: 0.40,
complexity: 6,
});
let front = merge_pareto(vec![a, b, dominated]);
assert_eq!(front.len(), 2, "the dominated member must be removed");
assert!(front[0].mse() <= front[1].mse());
assert!((front[0].mse() - 0.10).abs() < 1e-12);
}
#[test]
fn merge_includes_affine_members() {
let n = 40usize;
let mut x = Array2::<f64>::zeros((n, 2));
let mut y = Array1::<f64>::zeros(n);
for i in 0..n {
let x0 = 1.0 + 0.05 * i as f64;
let x1 = 0.5 + 0.03 * i as f64;
x[[i, 0]] = x0;
x[[i, 1]] = x1;
y[i] = x0 * x0 * x1;
}
let affine = discover_affine_pareto(&x, &y, 2, 500);
let merged = merge_pareto(affine.into_iter().map(AnySolution::Affine).collect());
assert!(!merged.is_empty(), "affine engine should recover x0^2*x1");
let best = &merged[0];
assert_eq!(best.source(), "affine");
assert!(best.mse() < 1e-3, "best affine mse = {}", best.mse());
}
}