use crate::dataset::DataSet;
use crate::fit::{collect_consts, mse, substitute_consts};
use crate::forest::eval_tree;
use crate::loss::RobustLoss;
use oxieml::symreg::snap_to_named_const;
use oxieml::EmlTree;
use scirs2_core::ndarray::Array1;
fn tree_with_consts(template: &EmlTree, consts: &[f64]) -> EmlTree {
let mut idx = 0;
EmlTree::from_node(substitute_consts(&template.root, consts, &mut idx))
}
fn residuals(pred: &Array1<f64>, y: &Array1<f64>) -> Option<Vec<f64>> {
let mut r = Vec::with_capacity(pred.len());
for (p, t) in pred.iter().zip(y.iter()) {
if !p.is_finite() {
return None;
}
r.push(p - t);
}
Some(r)
}
pub(crate) fn solve_dense(mut a: Vec<Vec<f64>>, mut b: Vec<f64>) -> Option<Vec<f64>> {
let n = b.len();
for col in 0..n {
let mut piv = col;
for r in (col + 1)..n {
if a[r][col].abs() > a[piv][col].abs() {
piv = r;
}
}
if a[piv][col].abs() < 1e-14 {
return None;
}
a.swap(col, piv);
b.swap(col, piv);
let pivot_row = a[col].clone();
let pivot_b = b[col];
for r in (col + 1)..n {
let f = a[r][col] / pivot_row[col];
for (rc, pc) in a[r].iter_mut().zip(pivot_row.iter()).skip(col) {
*rc -= f * pc;
}
b[r] -= f * pivot_b;
}
}
let mut x = vec![0.0; n];
for i in (0..n).rev() {
let mut s = b[i];
for j in (i + 1)..n {
s -= a[i][j] * x[j];
}
x[i] = s / a[i][i];
}
if x.iter().all(|v| v.is_finite()) {
Some(x)
} else {
None
}
}
#[must_use]
pub fn polish_constants(tree: &EmlTree, ds: &DataSet, iters: usize) -> (EmlTree, f64) {
lm_refine(tree, ds, iters, RobustLoss::Mse)
}
#[must_use]
pub fn polish_constants_robust(
tree: &EmlTree,
ds: &DataSet,
iters: usize,
loss: RobustLoss,
) -> (EmlTree, f64) {
lm_refine(tree, ds, iters, loss)
}
fn lm_refine(tree: &EmlTree, ds: &DataSet, iters: usize, loss: RobustLoss) -> (EmlTree, f64) {
let mut theta = Vec::new();
collect_consts(&tree.root, &mut theta);
let p = theta.len();
let y = &ds.y;
let eval = |th: &[f64]| -> Option<Array1<f64>> {
let t = tree_with_consts(tree, th);
eval_tree(&t, &ds.x).ok()
};
if p == 0 {
let m = eval(&theta).map_or(f64::INFINITY, |pred| mse(&pred, y));
return (tree.clone(), m);
}
let n = y.len();
let mut pred = match eval(&theta) {
Some(p) => p,
None => return (tree.clone(), f64::INFINITY),
};
let mut r = match residuals(&pred, y) {
Some(r) => r,
None => return (tree.clone(), f64::INFINITY),
};
let mut lambda = 1e-3_f64;
for _ in 0..iters {
let w: Vec<f64> = r.iter().map(|&ri| loss.irls_weight(ri, &r)).collect();
let wcost: f64 = r.iter().zip(&w).map(|(ri, wi)| (wi * ri) * (wi * ri)).sum();
let mut jac: Vec<Vec<f64>> = vec![vec![0.0; p]; n];
let mut ok = true;
for j in 0..p {
let h = 1e-6 * (theta[j].abs() + 1.0);
let mut th = theta.clone();
th[j] += h;
let Some(pj) = eval(&th) else {
ok = false;
break;
};
for i in 0..n {
jac[i][j] = (pj[i] - pred[i]) / h;
}
}
if !ok {
break;
}
let mut a = vec![vec![0.0; p]; p];
let mut grad = vec![0.0; p];
for col in 0..p {
for row in 0..n {
let w2 = w[row] * w[row];
grad[col] += w2 * jac[row][col] * r[row];
}
for col2 in col..p {
let mut s = 0.0;
for (row, jrow) in jac.iter().enumerate() {
s += w[row] * w[row] * jrow[col] * jrow[col2];
}
a[col][col2] = s;
a[col2][col] = s;
}
}
let mut accepted = false;
for _ in 0..12 {
let mut a_damped = a.clone();
for d in 0..p {
a_damped[d][d] += lambda * a[d][d].max(1e-12);
}
let rhs: Vec<f64> = grad.iter().map(|g| -g).collect();
let Some(delta) = solve_dense(a_damped, rhs) else {
lambda *= 4.0;
continue;
};
let cand: Vec<f64> = theta.iter().zip(&delta).map(|(t, d)| t + d).collect();
if let Some(p_new) = eval(&cand) {
if let Some(r_new) = residuals(&p_new, y) {
let wcost_new: f64 = r_new
.iter()
.zip(&w)
.map(|(ri, wi)| (wi * ri) * (wi * ri))
.sum();
if wcost_new < wcost {
theta = cand;
pred = p_new;
r = r_new;
lambda = (lambda * 0.5).max(1e-12);
accepted = true;
break;
}
}
}
lambda *= 4.0;
}
if !accepted {
break; }
}
let refined = tree_with_consts(tree, &theta);
let m = mse(&pred, y);
(refined, m)
}
const SNAP_CAND_EPS: f64 = 1e-3;
fn snap_candidates(c: f64) -> Vec<f64> {
let mut out = Vec::new();
let r = c.round();
if (c - r).abs() < SNAP_CAND_EPS && r.abs() <= 1e6 {
out.push(r);
}
if let Some(nc) = snap_to_named_const(c) {
out.push(nc.value());
}
for d in [2.0, 3.0, 4.0, 6.0] {
let q = (c * d).round() / d;
if (c - q).abs() < SNAP_CAND_EPS && q.abs() <= 12.0 {
out.push(q);
}
}
out
}
#[must_use]
pub fn snap_constants(tree: &EmlTree, ds: &DataSet, rel_tol: f64) -> (EmlTree, f64) {
let mut theta = Vec::new();
collect_consts(&tree.root, &mut theta);
let y = &ds.y;
let base = match eval_tree(tree, &ds.x) {
Ok(pred) => mse(&pred, y),
Err(_) => return (tree.clone(), f64::INFINITY),
};
if theta.is_empty() {
return (tree.clone(), base);
}
let tol = base * (1.0 + rel_tol) + 1e-12;
let mut changed = false;
for j in 0..theta.len() {
for cv in snap_candidates(theta[j]) {
if cv == theta[j] {
continue; }
let mut cand = theta.clone();
cand[j] = cv;
let t = tree_with_consts(tree, &cand);
if let Ok(pred) = eval_tree(&t, &ds.x) {
if mse(&pred, y) <= tol {
theta = cand; changed = true;
break;
}
}
}
}
if !changed {
return (tree.clone(), base);
}
let snapped = tree_with_consts(tree, &theta);
let m = eval_tree(&snapped, &ds.x).map_or(base, |pred| mse(&pred, y));
(snapped, m)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
fn ds_from(xs: &[f64], ys: &[f64]) -> DataSet {
let x = Array2::from_shape_vec((xs.len(), 1), xs.to_vec()).unwrap();
DataSet::from_arrays(x, Array1::from(ys.to_vec())).unwrap()
}
#[test]
fn lm_recovers_constant_precisely() {
let true_c: f64 = 3.0;
let xs: Vec<f64> = (1..=20).map(|i| f64::from(i) * 0.1).collect();
let ys: Vec<f64> = xs.iter().map(|&x| x.exp() - true_c.ln()).collect();
let ds = ds_from(&xs, &ys);
let start = EmlTree::eml(&EmlTree::var(0), &EmlTree::const_val(1.0));
let (refined, m) = polish_constants(&start, &ds, 50);
let mut consts = Vec::new();
collect_consts(&refined.root, &mut consts);
assert!(
(consts[0] - true_c).abs() < 1e-4,
"c = {} (want {true_c}), mse = {m}",
consts[0]
);
assert!(m < 1e-8, "mse not tight: {m}");
}
#[test]
fn robust_polish_resists_outliers() {
let true_c: f64 = 3.0;
let xs: Vec<f64> = (1..=30).map(|i| f64::from(i) * 0.1).collect();
let mut ys: Vec<f64> = xs.iter().map(|&x| x.exp() - true_c.ln()).collect();
for &k in &[3usize, 11, 19, 27] {
ys[k] += 50.0;
}
let ds = ds_from(&xs, &ys);
let start = EmlTree::eml(&EmlTree::var(0), &EmlTree::const_val(1.0));
let consts_of = |t: &EmlTree| {
let mut v = Vec::new();
collect_consts(&t.root, &mut v);
v[0]
};
let mse_c = consts_of(&polish_constants(&start, &ds, 80).0);
let hub_c = consts_of(
&polish_constants_robust(&start, &ds, 80, RobustLoss::Huber { delta: 0.2 }).0,
);
let trim_c = consts_of(
&polish_constants_robust(&start, &ds, 80, RobustLoss::Trimmed { alpha: 0.2 }).0,
);
let (mse_err, hub_err, trim_err) = (
(mse_c - true_c).abs(),
(hub_c - true_c).abs(),
(trim_c - true_c).abs(),
);
assert!(
hub_err < mse_err,
"Huber ({hub_c}) did not beat MSE ({mse_c})"
);
assert!(
trim_err < mse_err,
"Trimmed ({trim_c}) did not beat MSE ({mse_c})"
);
assert!(
trim_err < 0.1,
"Trimmed did not recover c: {trim_c} (want {true_c})"
);
}
#[test]
fn snaps_near_one_removes_ln_residue() {
let xs: Vec<f64> = (0..20).map(|i| f64::from(i) * 0.15).collect();
let ys: Vec<f64> = xs.iter().map(|&x| x.exp()).collect();
let ds = ds_from(&xs, &ys);
let start = EmlTree::eml(&EmlTree::var(0), &EmlTree::const_val(1.000_000_043_8));
let (snapped, m) = snap_constants(&start, &ds, 0.02);
let mut consts = Vec::new();
collect_consts(&snapped.root, &mut consts);
assert!(
(consts[0] - 1.0).abs() < 1e-12,
"did not snap to exactly 1: {}",
consts[0]
);
assert!(m < 1e-20, "snapped mse should be machine-zero, got {m}");
}
#[test]
fn snaps_near_integer_constant() {
let xs: Vec<f64> = (1..=20).map(|i| f64::from(i) * 0.1).collect();
let ys: Vec<f64> = xs.iter().map(|&x| x.exp() - 2.0_f64.ln()).collect();
let ds = ds_from(&xs, &ys);
let start = EmlTree::eml(&EmlTree::var(0), &EmlTree::const_val(2.000_000_3));
let (snapped, _) = snap_constants(&start, &ds, 0.02);
let mut consts = Vec::new();
collect_consts(&snapped.root, &mut consts);
assert!(
(consts[0] - 2.0).abs() < 1e-12,
"did not snap to 2: {}",
consts[0]
);
}
#[test]
fn does_not_snap_a_genuinely_fractional_constant() {
let xs: Vec<f64> = (1..=20).map(|i| f64::from(i) * 0.1).collect();
let ys: Vec<f64> = xs.iter().map(|&x| x.exp() - 1.37_f64.ln()).collect();
let ds = ds_from(&xs, &ys);
let start = EmlTree::eml(&EmlTree::var(0), &EmlTree::const_val(1.37));
let (snapped, _) = snap_constants(&start, &ds, 0.02);
let mut consts = Vec::new();
collect_consts(&snapped.root, &mut consts);
assert!(
(consts[0] - 1.37).abs() < 1e-9,
"wrongly snapped 1.37 to {}",
consts[0]
);
}
#[test]
fn snaps_to_pi() {
let xs: Vec<f64> = (1..=20).map(|i| f64::from(i) * 0.1).collect();
let ys: Vec<f64> = xs
.iter()
.map(|&x| x.exp() - std::f64::consts::PI.ln())
.collect();
let ds = ds_from(&xs, &ys);
let start = EmlTree::eml(&EmlTree::var(0), &EmlTree::const_val(3.0));
let (refined, _) = polish_constants(&start, &ds, 50);
let (snapped, _) = snap_constants(&refined, &ds, 0.05);
let mut consts = Vec::new();
collect_consts(&snapped.root, &mut consts);
assert!(
(consts[0] - std::f64::consts::PI).abs() < 1e-12,
"did not snap to pi: {}",
consts[0]
);
}
}