Skip to main content

phop_core/
optimize.rs

1//! Library-backed constant refinement via `scirs2-optimize`.
2//!
3//! phop's default polish ([`crate::polish::polish_constants`]) is a small, panic-free, well-tested
4//! Levenberg–Marquardt with named-constant snapping. This module offers two **alternative**
5//! refinement backends drawn from the cool-japan `scirs2-optimize` crate so the constant fit can be
6//! swapped without touching the discovery pipeline:
7//!
8//! - [`ScirsPolish::Lm`] — `scirs2_optimize::least_squares` (Levenberg–Marquardt), and
9//! - [`ScirsPolish::Lbfgs`] — `scirs2_optimize::unconstrained::minimize_lbfgs` minimizing the sum
10//!   of squared residuals.
11//!
12//! They are offered as opt-in backends rather than as a drop-in replacement: phop's hand-rolled LM
13//! makes no contiguity assumptions and never panics, whereas `scirs2`'s least-squares uses internal
14//! `expect`s on `as_slice()`. We feed it only contiguous [`Array1`]s built with `from_vec`, and treat
15//! any solver error or non-finite step as a no-op (returning the starting constants), so this module
16//! upholds phop's no-panic / no-`unwrap` policy.
17
18use crate::dataset::DataSet;
19use crate::fit::{collect_consts, mse, substitute_consts};
20use crate::forest::eval_tree;
21use oxieml::EmlTree;
22use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
23use scirs2_optimize::least_squares::{least_squares, Method as LsqMethod, Options as LsqOptions};
24use scirs2_optimize::unconstrained::{minimize_lbfgs, Options as UncOptions};
25
26/// A large but finite penalty substituted for non-finite residuals/objectives so the solver keeps
27/// making progress instead of seeing `NaN`/`inf`.
28const PENALTY: f64 = 1e12;
29
30/// Which `scirs2-optimize` algorithm to use for the polish.
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum ScirsPolish {
33    /// Levenberg–Marquardt least-squares (`scirs2_optimize::least_squares`).
34    Lm,
35    /// L-BFGS quasi-Newton minimization of the sum of squared residuals.
36    Lbfgs,
37}
38
39/// Rebuild a tree from a flat constant vector (pre-order), mirroring [`crate::polish`].
40fn tree_with(template: &EmlTree, consts: &[f64]) -> EmlTree {
41    let mut idx = 0;
42    EmlTree::from_node(substitute_consts(&template.root, consts, &mut idx))
43}
44
45/// Refine the constant leaves of `tree` against `ds` using a `scirs2-optimize` backend.
46///
47/// `budget` bounds solver work (max function evaluations for LM, max iterations for L-BFGS). Returns
48/// the refined tree and its MSE; if the tree has no constants, or the solver fails / does not
49/// improve on the starting fit, the original tree (and its MSE) is returned unchanged.
50#[must_use]
51pub fn polish_constants_scirs(
52    tree: &EmlTree,
53    ds: &DataSet,
54    budget: usize,
55    method: ScirsPolish,
56) -> (EmlTree, f64) {
57    let mut theta0 = Vec::new();
58    collect_consts(&tree.root, &mut theta0);
59
60    let base_mse = eval_tree(tree, &ds.x).map_or(f64::INFINITY, |p| mse(&p, &ds.y));
61    if theta0.is_empty() {
62        return (tree.clone(), base_mse);
63    }
64
65    let refined = match method {
66        ScirsPolish::Lm => refine_lm(tree, ds, &theta0, budget),
67        ScirsPolish::Lbfgs => refine_lbfgs(tree, ds, &theta0, budget),
68    };
69    let Some(consts) = refined else {
70        return (tree.clone(), base_mse);
71    };
72    if consts.iter().any(|c| !c.is_finite()) {
73        return (tree.clone(), base_mse);
74    }
75
76    let candidate = tree_with(tree, &consts);
77    match eval_tree(&candidate, &ds.x) {
78        Ok(pred) => {
79            let m = mse(&pred, &ds.y);
80            // Accept only a finite, non-worse fit; otherwise keep the input.
81            if m.is_finite() && m <= base_mse {
82                (candidate, m)
83            } else {
84                (tree.clone(), base_mse)
85            }
86        }
87        Err(_) => (tree.clone(), base_mse),
88    }
89}
90
91/// Levenberg–Marquardt least-squares refinement. The residual vector is `pred(θ) − y`.
92fn refine_lm(tree: &EmlTree, ds: &DataSet, theta0: &[f64], max_nfev: usize) -> Option<Vec<f64>> {
93    let x = &ds.x;
94    let residual = |params: &[f64], data: &[f64]| -> Array1<f64> {
95        let t = tree_with(tree, params);
96        match eval_tree(&t, x) {
97            Ok(pred) => Array1::from_iter(pred.iter().zip(data.iter()).map(|(p, d)| {
98                let v = p - d;
99                if v.is_finite() {
100                    v
101                } else {
102                    PENALTY
103                }
104            })),
105            Err(_) => Array1::from_elem(data.len(), PENALTY),
106        }
107    };
108
109    let x0 = Array1::from_vec(theta0.to_vec());
110    let data = Array1::from_vec(ds.y.to_vec());
111    let opts = LsqOptions {
112        max_nfev: Some(max_nfev),
113        xtol: Some(1e-12),
114        ftol: Some(1e-12),
115        gtol: Some(1e-12),
116        ..Default::default()
117    };
118    // No analytic Jacobian: pass `None` (turbofish fixes the unused generic), LM finite-differences.
119    let res = least_squares(
120        residual,
121        &x0,
122        LsqMethod::LevenbergMarquardt,
123        None::<fn(&[f64], &[f64]) -> Array2<f64>>,
124        &data,
125        Some(opts),
126    )
127    .ok()?;
128    Some(res.x.to_vec())
129}
130
131/// Sum of squared residuals of `tree` (with constants `theta`) against `ds`, or [`PENALTY`] if the
132/// forward pass fails or is non-finite.
133fn ssr(tree: &EmlTree, ds: &DataSet, theta: &[f64]) -> f64 {
134    let t = tree_with(tree, theta);
135    match eval_tree(&t, &ds.x) {
136        Ok(pred) => {
137            let s: f64 = pred
138                .iter()
139                .zip(ds.y.iter())
140                .map(|(p, d)| (p - d) * (p - d))
141                .sum();
142            if s.is_finite() {
143                s
144            } else {
145                PENALTY
146            }
147        }
148        Err(_) => PENALTY,
149    }
150}
151
152/// L-BFGS refinement minimizing the sum of squared residuals over the constant vector.
153///
154/// We supply an explicit forward-difference gradient with a parameter-scaled step (`h = 1e-6·(|θ|+1)`,
155/// as in the hand-rolled LM) rather than relying on `scirs2`'s tiny default ε, which converges
156/// poorly on the ill-scaled constants that sit inside `exp`/`ln`.
157fn refine_lbfgs(tree: &EmlTree, ds: &DataSet, theta0: &[f64], max_iter: usize) -> Option<Vec<f64>> {
158    let objective = |params: &ArrayView1<f64>| -> f64 {
159        let theta: Vec<f64> = params.iter().copied().collect();
160        ssr(tree, ds, &theta)
161    };
162    let gradient = |params: &ArrayView1<f64>| -> Array1<f64> {
163        let theta: Vec<f64> = params.iter().copied().collect();
164        let f0 = ssr(tree, ds, &theta);
165        let mut g = Array1::zeros(theta.len());
166        for (k, gk) in g.iter_mut().enumerate() {
167            let h = 1e-6 * (theta[k].abs() + 1.0);
168            let mut tp = theta.clone();
169            tp[k] += h;
170            *gk = (ssr(tree, ds, &tp) - f0) / h;
171        }
172        g
173    };
174
175    let opts = UncOptions {
176        max_iter,
177        gtol: 1e-12,
178        ftol: 1e-14,
179        use_gpu: false,
180        ..Default::default()
181    };
182    let res = minimize_lbfgs(
183        objective,
184        Some(gradient),
185        Array1::from_vec(theta0.to_vec()),
186        &opts,
187    )
188    .ok()?;
189    Some(res.x.to_vec())
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195    use crate::fit::collect_consts;
196    use scirs2_core::ndarray::Array2;
197
198    fn ds_from(xs: &[f64], ys: &[f64]) -> DataSet {
199        let x = Array2::from_shape_vec((xs.len(), 1), xs.to_vec()).expect("shape");
200        DataSet::from_arrays(x, Array1::from(ys.to_vec())).expect("dataset")
201    }
202
203    fn recovered_c(method: ScirsPolish, start_c: f64) -> (f64, f64) {
204        // y = exp(x) - ln(c), true c = 3.0.
205        let true_c: f64 = 3.0;
206        let xs: Vec<f64> = (1..=20).map(|i| f64::from(i) * 0.1).collect();
207        let ys: Vec<f64> = xs.iter().map(|&x| x.exp() - true_c.ln()).collect();
208        let ds = ds_from(&xs, &ys);
209        let start = EmlTree::eml(&EmlTree::var(0), &EmlTree::const_val(start_c));
210        let (refined, m) = polish_constants_scirs(&start, &ds, 200, method);
211        let mut consts = Vec::new();
212        collect_consts(&refined.root, &mut consts);
213        (consts[0], m)
214    }
215
216    #[test]
217    fn scirs_lm_recovers_constant() {
218        // LM is robust enough to cold-start far from the truth.
219        let (c, m) = recovered_c(ScirsPolish::Lm, 1.0);
220        assert!((c - 3.0).abs() < 1e-3, "LM c = {c}, mse = {m}");
221        assert!(m < 1e-6, "LM mse not tight: {m}");
222    }
223
224    #[test]
225    fn scirs_lbfgs_polishes_a_coarse_fit() {
226        // L-BFGS is offered as a *polish* — it refines an already-good (e.g. post-Adam) fit. From a
227        // coarse constant it sharpens to the truth; it is not a cold-start solver.
228        let (c, m) = recovered_c(ScirsPolish::Lbfgs, 2.7);
229        assert!((c - 3.0).abs() < 1e-2, "L-BFGS c = {c}, mse = {m}");
230        assert!(m < 1e-4, "L-BFGS mse not tight: {m}");
231    }
232
233    #[test]
234    fn never_worsens_the_fit() {
235        // A constant-free tree is returned unchanged with its base MSE.
236        let xs: Vec<f64> = (1..=10).map(|i| f64::from(i) * 0.1).collect();
237        let ys: Vec<f64> = xs.iter().map(|&x| x.exp()).collect();
238        let ds = ds_from(&xs, &ys);
239        let tree = EmlTree::eml(&EmlTree::var(0), &EmlTree::one());
240        let (_, m) = polish_constants_scirs(&tree, &ds, 50, ScirsPolish::Lm);
241        assert!(m < 1e-9, "constant-free exp recovered exactly: mse = {m}");
242    }
243}