1use crate::dataset::DataSet;
19use crate::fit::{collect_consts, mse, substitute_consts};
20use crate::forest::eval_tree;
21use oxieml::EmlTree;
22use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
23use scirs2_optimize::least_squares::{least_squares, Method as LsqMethod, Options as LsqOptions};
24use scirs2_optimize::unconstrained::{minimize_lbfgs, Options as UncOptions};
25
26const PENALTY: f64 = 1e12;
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum ScirsPolish {
33 Lm,
35 Lbfgs,
37}
38
39fn tree_with(template: &EmlTree, consts: &[f64]) -> EmlTree {
41 let mut idx = 0;
42 EmlTree::from_node(substitute_consts(&template.root, consts, &mut idx))
43}
44
45#[must_use]
51pub fn polish_constants_scirs(
52 tree: &EmlTree,
53 ds: &DataSet,
54 budget: usize,
55 method: ScirsPolish,
56) -> (EmlTree, f64) {
57 let mut theta0 = Vec::new();
58 collect_consts(&tree.root, &mut theta0);
59
60 let base_mse = eval_tree(tree, &ds.x).map_or(f64::INFINITY, |p| mse(&p, &ds.y));
61 if theta0.is_empty() {
62 return (tree.clone(), base_mse);
63 }
64
65 let refined = match method {
66 ScirsPolish::Lm => refine_lm(tree, ds, &theta0, budget),
67 ScirsPolish::Lbfgs => refine_lbfgs(tree, ds, &theta0, budget),
68 };
69 let Some(consts) = refined else {
70 return (tree.clone(), base_mse);
71 };
72 if consts.iter().any(|c| !c.is_finite()) {
73 return (tree.clone(), base_mse);
74 }
75
76 let candidate = tree_with(tree, &consts);
77 match eval_tree(&candidate, &ds.x) {
78 Ok(pred) => {
79 let m = mse(&pred, &ds.y);
80 if m.is_finite() && m <= base_mse {
82 (candidate, m)
83 } else {
84 (tree.clone(), base_mse)
85 }
86 }
87 Err(_) => (tree.clone(), base_mse),
88 }
89}
90
91fn refine_lm(tree: &EmlTree, ds: &DataSet, theta0: &[f64], max_nfev: usize) -> Option<Vec<f64>> {
93 let x = &ds.x;
94 let residual = |params: &[f64], data: &[f64]| -> Array1<f64> {
95 let t = tree_with(tree, params);
96 match eval_tree(&t, x) {
97 Ok(pred) => Array1::from_iter(pred.iter().zip(data.iter()).map(|(p, d)| {
98 let v = p - d;
99 if v.is_finite() {
100 v
101 } else {
102 PENALTY
103 }
104 })),
105 Err(_) => Array1::from_elem(data.len(), PENALTY),
106 }
107 };
108
109 let x0 = Array1::from_vec(theta0.to_vec());
110 let data = Array1::from_vec(ds.y.to_vec());
111 let opts = LsqOptions {
112 max_nfev: Some(max_nfev),
113 xtol: Some(1e-12),
114 ftol: Some(1e-12),
115 gtol: Some(1e-12),
116 ..Default::default()
117 };
118 let res = least_squares(
120 residual,
121 &x0,
122 LsqMethod::LevenbergMarquardt,
123 None::<fn(&[f64], &[f64]) -> Array2<f64>>,
124 &data,
125 Some(opts),
126 )
127 .ok()?;
128 Some(res.x.to_vec())
129}
130
131fn ssr(tree: &EmlTree, ds: &DataSet, theta: &[f64]) -> f64 {
134 let t = tree_with(tree, theta);
135 match eval_tree(&t, &ds.x) {
136 Ok(pred) => {
137 let s: f64 = pred
138 .iter()
139 .zip(ds.y.iter())
140 .map(|(p, d)| (p - d) * (p - d))
141 .sum();
142 if s.is_finite() {
143 s
144 } else {
145 PENALTY
146 }
147 }
148 Err(_) => PENALTY,
149 }
150}
151
152fn refine_lbfgs(tree: &EmlTree, ds: &DataSet, theta0: &[f64], max_iter: usize) -> Option<Vec<f64>> {
158 let objective = |params: &ArrayView1<f64>| -> f64 {
159 let theta: Vec<f64> = params.iter().copied().collect();
160 ssr(tree, ds, &theta)
161 };
162 let gradient = |params: &ArrayView1<f64>| -> Array1<f64> {
163 let theta: Vec<f64> = params.iter().copied().collect();
164 let f0 = ssr(tree, ds, &theta);
165 let mut g = Array1::zeros(theta.len());
166 for (k, gk) in g.iter_mut().enumerate() {
167 let h = 1e-6 * (theta[k].abs() + 1.0);
168 let mut tp = theta.clone();
169 tp[k] += h;
170 *gk = (ssr(tree, ds, &tp) - f0) / h;
171 }
172 g
173 };
174
175 let opts = UncOptions {
176 max_iter,
177 gtol: 1e-12,
178 ftol: 1e-14,
179 use_gpu: false,
180 ..Default::default()
181 };
182 let res = minimize_lbfgs(
183 objective,
184 Some(gradient),
185 Array1::from_vec(theta0.to_vec()),
186 &opts,
187 )
188 .ok()?;
189 Some(res.x.to_vec())
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195 use crate::fit::collect_consts;
196 use scirs2_core::ndarray::Array2;
197
198 fn ds_from(xs: &[f64], ys: &[f64]) -> DataSet {
199 let x = Array2::from_shape_vec((xs.len(), 1), xs.to_vec()).expect("shape");
200 DataSet::from_arrays(x, Array1::from(ys.to_vec())).expect("dataset")
201 }
202
203 fn recovered_c(method: ScirsPolish, start_c: f64) -> (f64, f64) {
204 let true_c: f64 = 3.0;
206 let xs: Vec<f64> = (1..=20).map(|i| f64::from(i) * 0.1).collect();
207 let ys: Vec<f64> = xs.iter().map(|&x| x.exp() - true_c.ln()).collect();
208 let ds = ds_from(&xs, &ys);
209 let start = EmlTree::eml(&EmlTree::var(0), &EmlTree::const_val(start_c));
210 let (refined, m) = polish_constants_scirs(&start, &ds, 200, method);
211 let mut consts = Vec::new();
212 collect_consts(&refined.root, &mut consts);
213 (consts[0], m)
214 }
215
216 #[test]
217 fn scirs_lm_recovers_constant() {
218 let (c, m) = recovered_c(ScirsPolish::Lm, 1.0);
220 assert!((c - 3.0).abs() < 1e-3, "LM c = {c}, mse = {m}");
221 assert!(m < 1e-6, "LM mse not tight: {m}");
222 }
223
224 #[test]
225 fn scirs_lbfgs_polishes_a_coarse_fit() {
226 let (c, m) = recovered_c(ScirsPolish::Lbfgs, 2.7);
229 assert!((c - 3.0).abs() < 1e-2, "L-BFGS c = {c}, mse = {m}");
230 assert!(m < 1e-4, "L-BFGS mse not tight: {m}");
231 }
232
233 #[test]
234 fn never_worsens_the_fit() {
235 let xs: Vec<f64> = (1..=10).map(|i| f64::from(i) * 0.1).collect();
237 let ys: Vec<f64> = xs.iter().map(|&x| x.exp()).collect();
238 let ds = ds_from(&xs, &ys);
239 let tree = EmlTree::eml(&EmlTree::var(0), &EmlTree::one());
240 let (_, m) = polish_constants_scirs(&tree, &ds, 50, ScirsPolish::Lm);
241 assert!(m < 1e-9, "constant-free exp recovered exactly: mse = {m}");
242 }
243}