Skip to main content

phop_core/
affine.rs

1//! Affine- and log-affine-leaf EML discovery (M6 root-cause fix, steps 1 + 2).
2//!
3//! The Feynman post-mortem showed baseline phop (depth-3, bare var/const leaves) recovers ~nothing:
4//! physics laws need scalings, linear combinations, products and powers, which in pure EML cost depth
5//! the bounded search can't afford. This engine makes those forms *primitive*:
6//!
7//! - [`ANode::Linear`] `= Σ aᵢ·xᵢ + b` — a full affine combination of the inputs (step 1).
8//! - [`ANode::LogLinear`] `= Σ aᵢ·ln xᵢ + b` — whose `exp` is a **monomial** `e^b·∏ xᵢ^{aᵢ}`, so any
9//!   product / ratio / power-law is `eml(LogLinear, 1)` at **depth 1** (step 2 — the multiplicative
10//!   laws the affine-only leaf couldn't reach).
11//!
12//! Because each leaf is a full-width combination fitted by Levenberg–Marquardt, the structure search
13//! is just over the `eml` skeleton + a *leaf type* per slot (`Const`/`Linear`/`LogLinear`) — no
14//! per-variable enumerate explosion. It is its own engine (real-affine/-log leaves + guarded `eml`),
15//! deliberately NOT folded into phop's EmlTree/autograd path: `a·x+b` (or `Σaᵢ ln xᵢ`) with negative
16//! coefficients is not guarded real EML (it would need the complex branch).
17
18use crate::polish::solve_dense;
19use crate::rng::SplitMix64;
20use scirs2_core::ndarray::{Array1, Array2};
21
22/// Symmetric clamp on `exp` arguments (matches [`crate::forest`]).
23const EXP_CLAMP: f64 = 50.0;
24/// Lower clamp on `ln` arguments (matches [`crate::forest`]).
25const LN_EPS: f64 = 1e-12;
26/// A coefficient is "active" (counts toward complexity) above this magnitude.
27const ACTIVE_EPS: f64 = 1e-6;
28
29/// An EML tree with rich leaves: constants, affine combinations, and log-affine (monomial) combinations.
30#[derive(Clone, Debug)]
31pub enum ANode {
32    /// A free constant.
33    Const(f64),
34    /// Affine combination `Σ coeffs[i]·x_i + b`.
35    Linear {
36        /// Per-variable slopes.
37        coeffs: Vec<f64>,
38        /// Intercept.
39        b: f64,
40    },
41    /// Log-affine combination `Σ coeffs[i]·ln(x_i) + b`; `exp` of it is a monomial `e^b·∏ x_i^{coeffs[i]}`.
42    LogLinear {
43        /// Per-variable log-exponents.
44        coeffs: Vec<f64>,
45        /// Log-intercept.
46        b: f64,
47    },
48    /// The EML primitive `exp(left) − ln(right)`.
49    Eml(Box<ANode>, Box<ANode>),
50}
51
52fn active(coeffs: &[f64]) -> usize {
53    coeffs.iter().filter(|c| c.abs() > ACTIVE_EPS).count()
54}
55
56impl ANode {
57    /// Complexity: structural nodes plus the number of *active* coefficients (so a monomial in 3
58    /// variables is "bigger" than a constant).
59    #[must_use]
60    pub fn nodes(&self) -> usize {
61        match self {
62            ANode::Const(_) => 1,
63            ANode::Linear { coeffs, .. } | ANode::LogLinear { coeffs, .. } => 1 + active(coeffs),
64            ANode::Eml(l, r) => 1 + l.nodes() + r.nodes(),
65        }
66    }
67
68    /// Maximum `eml` depth (a leaf has depth 0).
69    #[must_use]
70    pub fn depth(&self) -> usize {
71        match self {
72            ANode::Eml(l, r) => 1 + l.depth().max(r.depth()),
73            _ => 0,
74        }
75    }
76
77    /// A readable rendering.
78    #[must_use]
79    pub fn pretty(&self) -> String {
80        match self {
81            ANode::Const(c) => format!("{c:.4}"),
82            ANode::Linear { coeffs, b } => format!("({})", combo(coeffs, *b, "x")),
83            ANode::LogLinear { coeffs, b } => format!("({})", combo(coeffs, *b, "ln x")),
84            // eml(z, 1) = exp(z) − ln(1) = exp(z): fold the trivial denominator, and render a
85            // LogLinear numerator as the monomial it represents (e.g. x0·x1, x0^1.5).
86            ANode::Eml(l, r) => match const_value(r) {
87                Some(c) if (c - 1.0).abs() < 1e-6 => match l.as_ref() {
88                    ANode::LogLinear { coeffs, b } => monomial(coeffs, *b, false),
89                    _ => format!("exp({})", l.pretty()),
90                },
91                _ => format!("eml({}, {})", l.pretty(), r.pretty()),
92            },
93        }
94    }
95}
96
97/// If `node` is constant-valued — a `Const`, or a combination leaf whose coefficients are all
98/// inactive (so it reduces to its intercept) — return that constant; else `None`. Used to fold a
99/// trivial `eml(z, 1)` denominator for display.
100fn const_value(node: &ANode) -> Option<f64> {
101    match node {
102        ANode::Const(c) => Some(*c),
103        ANode::Linear { coeffs, b } | ANode::LogLinear { coeffs, b } => {
104            coeffs.iter().all(|c| c.abs() <= ACTIVE_EPS).then_some(*b)
105        }
106        ANode::Eml(_, _) => None,
107    }
108}
109
110/// Render an exponent: an integer when it snapped to one, else a short decimal.
111fn fmt_exp(a: f64) -> String {
112    if (a - a.round()).abs() < 1e-9 {
113        format!("{}", a.round() as i64)
114    } else {
115        format!("{a:.3}")
116    }
117}
118
119/// `exp(Σ aᵢ ln xᵢ + b) = exp(b)·∏ xᵢ^{aᵢ}` — the clean monomial form of `eml(LogLinear, 1)`. With
120/// `latex`, variables/exponents use subscript/superscript markup; otherwise an ASCII `x{i}^{a}` form.
121fn monomial(coeffs: &[f64], b: f64, latex: bool) -> String {
122    let mut parts: Vec<String> = coeffs
123        .iter()
124        .enumerate()
125        .filter(|(_, a)| a.abs() > ACTIVE_EPS)
126        .map(|(i, a)| match (latex, (a - 1.0).abs() < 1e-9) {
127            (true, true) => format!("x_{{{i}}}"),
128            (true, false) => format!("x_{{{i}}}^{{{}}}", fmt_exp(*a)),
129            (false, true) => format!("x{i}"),
130            (false, false) => format!("x{i}^{}", fmt_exp(*a)),
131        })
132        .collect();
133    if b.abs() > ACTIVE_EPS {
134        parts.push(if latex {
135            format!("e^{{{b:.3}}}")
136        } else {
137            format!("exp({b:.3})")
138        });
139    }
140    match (parts.is_empty(), latex) {
141        (true, _) => "1".to_string(),
142        (false, true) => parts.join(" \\cdot "),
143        (false, false) => parts.join("*"),
144    }
145}
146
147/// Format `Σ coeffs[i]·<sym>i + b`, keeping only active terms.
148fn combo(coeffs: &[f64], b: f64, sym: &str) -> String {
149    let mut parts: Vec<String> = coeffs
150        .iter()
151        .enumerate()
152        .filter(|(_, c)| c.abs() > ACTIVE_EPS)
153        .map(|(i, c)| format!("{c:.3}*{sym}{i}"))
154        .collect();
155    if b.abs() > ACTIVE_EPS || parts.is_empty() {
156        parts.push(format!("{b:.3}"));
157    }
158    parts.join(" + ")
159}
160
161/// Guarded forward evaluation over `x` (`[n_rows, n_vars]`).
162#[must_use]
163pub fn eval(node: &ANode, x: &Array2<f64>) -> Array1<f64> {
164    let n = x.nrows();
165    match node {
166        ANode::Const(c) => Array1::from_elem(n, *c),
167        ANode::Linear { coeffs, b } => {
168            let mut out = Array1::from_elem(n, *b);
169            for (j, &cf) in coeffs.iter().enumerate() {
170                if cf != 0.0 {
171                    for i in 0..n {
172                        out[i] += cf * x[[i, j]];
173                    }
174                }
175            }
176            out
177        }
178        ANode::LogLinear { coeffs, b } => {
179            let mut out = Array1::from_elem(n, *b);
180            for (j, &cf) in coeffs.iter().enumerate() {
181                if cf != 0.0 {
182                    for i in 0..n {
183                        out[i] += cf * x[[i, j]].max(LN_EPS).ln();
184                    }
185                }
186            }
187            out
188        }
189        ANode::Eml(l, r) => {
190            let la = eval(l, x);
191            let rb = eval(r, x);
192            let mut out = Array1::zeros(n);
193            for i in 0..n {
194                let ea = la[i].clamp(-EXP_CLAMP, EXP_CLAMP).exp();
195                let lb = rb[i].max(LN_EPS).ln();
196                out[i] = ea - lb;
197            }
198            out
199        }
200    }
201}
202
203/// Collect free parameters in pre-order (`coeffs…, b` per combination leaf; `c` per constant).
204fn collect(node: &ANode, out: &mut Vec<f64>) {
205    match node {
206        ANode::Const(c) => out.push(*c),
207        ANode::Linear { coeffs, b } | ANode::LogLinear { coeffs, b } => {
208            out.extend_from_slice(coeffs);
209            out.push(*b);
210        }
211        ANode::Eml(l, r) => {
212            collect(l, out);
213            collect(r, out);
214        }
215    }
216}
217
218/// Rebuild a tree with parameters taken from `p` in pre-order.
219fn apply(node: &ANode, p: &[f64], idx: &mut usize) -> ANode {
220    match node {
221        ANode::Const(_) => {
222            let c = p[*idx];
223            *idx += 1;
224            ANode::Const(c)
225        }
226        ANode::Linear { coeffs, .. } => {
227            let n = coeffs.len();
228            let cs = p[*idx..*idx + n].to_vec();
229            let b = p[*idx + n];
230            *idx += n + 1;
231            ANode::Linear { coeffs: cs, b }
232        }
233        ANode::LogLinear { coeffs, .. } => {
234            let n = coeffs.len();
235            let cs = p[*idx..*idx + n].to_vec();
236            let b = p[*idx + n];
237            *idx += n + 1;
238            ANode::LogLinear { coeffs: cs, b }
239        }
240        ANode::Eml(l, r) => ANode::Eml(Box::new(apply(l, p, idx)), Box::new(apply(r, p, idx))),
241    }
242}
243
244/// Reset every leaf's coefficients to `coeff_init` (and intercept/const to a neutral start) for a
245/// fresh Levenberg–Marquardt start.
246fn reinit(node: &ANode, coeff_init: f64) -> ANode {
247    match node {
248        ANode::Const(_) => ANode::Const(1.0),
249        ANode::Linear { coeffs, .. } => ANode::Linear {
250            coeffs: vec![coeff_init; coeffs.len()],
251            b: 0.0,
252        },
253        ANode::LogLinear { coeffs, .. } => ANode::LogLinear {
254            coeffs: vec![coeff_init; coeffs.len()],
255            b: 0.0,
256        },
257        ANode::Eml(l, r) => ANode::Eml(
258            Box::new(reinit(l, coeff_init)),
259            Box::new(reinit(r, coeff_init)),
260        ),
261    }
262}
263
264/// Mean-squared error, or `INFINITY` on any non-finite prediction.
265fn mse(pred: &Array1<f64>, y: &Array1<f64>) -> f64 {
266    let n = y.len().max(1) as f64;
267    let mut s = 0.0;
268    for (p, t) in pred.iter().zip(y.iter()) {
269        if !p.is_finite() {
270            return f64::INFINITY;
271        }
272        s += (p - t) * (p - t);
273    }
274    s / n
275}
276
277/// Max absolute error for snapping a fitted coefficient to a small rational.
278const SNAP_ABS: f64 = 0.03;
279/// R² a snapped form must retain to count as a *symbolic* recovery.
280const SYMBOLIC_R2: f64 = 0.999;
281/// LM iterations used to re-fit the remaining free parameters after each coefficient is snapped.
282const SNAP_REFIT_ITERS: usize = 60;
283
284/// Snap `v` to the nearest small rational `k/d` (`d ∈ {1,2,3,4,6}`, `|k/d| ≤ 12`) within [`SNAP_ABS`],
285/// preferring the smallest denominator. Near-zero snaps to `0`. Returns `None` if nothing is close.
286fn snap_rational(v: f64) -> Option<f64> {
287    if v.abs() < SNAP_ABS {
288        return Some(0.0);
289    }
290    for d in [1.0, 2.0, 3.0, 4.0, 6.0] {
291        let k = (v * d).round();
292        let cand = k / d;
293        if cand.abs() <= 12.0 && (v - cand).abs() < SNAP_ABS {
294            return Some(cand);
295        }
296    }
297    None
298}
299
300/// Snap a linear slope: a recognisable named constant (π, e, √2, …) or a small rational.
301fn snap_value(v: f64) -> Option<f64> {
302    oxieml::symreg::snap_to_named_const(v)
303        .map(|nc| nc.value())
304        .or_else(|| snap_rational(v))
305}
306
307/// Per-parameter kind, parallel to [`collect`]'s pre-order: `Exp` = a log-linear exponent, `Lin` = a
308/// linear slope (both define the *symbolic structure*); `Other` = an intercept/constant (a fitted
309/// scale/offset, left untouched).
310#[derive(Clone, Copy, PartialEq)]
311enum Kind {
312    Exp,
313    Lin,
314    Other,
315}
316
317fn tag(node: &ANode, out: &mut Vec<Kind>) {
318    match node {
319        ANode::Const(_) => out.push(Kind::Other),
320        ANode::Linear { coeffs, .. } => {
321            out.extend(std::iter::repeat_n(Kind::Lin, coeffs.len()));
322            out.push(Kind::Other);
323        }
324        ANode::LogLinear { coeffs, .. } => {
325            out.extend(std::iter::repeat_n(Kind::Exp, coeffs.len()));
326            out.push(Kind::Other);
327        }
328        ANode::Eml(l, r) => {
329            tag(l, out);
330            tag(r, out);
331        }
332    }
333}
334
335/// Snap residual: how far `v` is from its nearest clean target (smaller ⇒ snap with more
336/// confidence). `INFINITY` if no clean target is within tolerance, or the parameter is not structural.
337fn snap_residual(v: f64, k: Kind) -> f64 {
338    let target = match k {
339        Kind::Exp => snap_rational(v),
340        Kind::Lin => {
341            if v.abs() < SNAP_ABS {
342                Some(0.0)
343            } else {
344                snap_value(v)
345            }
346        }
347        Kind::Other => return f64::INFINITY,
348    };
349    target.map_or(f64::INFINITY, |t| (v - t).abs())
350}
351
352/// Iterative rational-rounding (AI-Feynman style): snap structural exponents/slopes to small
353/// rationals / named constants **one at a time** — most-confident first — re-fitting the remaining
354/// free parameters (intercepts, scales, not-yet-snapped coefficients) after each snap so they can
355/// absorb the perturbation. Returns the snapped tree only if **every** structural coefficient snaps
356/// while the form retains R² ≥ [`SYMBOLIC_R2`].
357///
358/// This is strictly stronger than snapping all coefficients at once with the intercepts frozen: the
359/// per-snap refit both *rescues* fits that landed just outside the snap window (coupling holds a clean
360/// exponent off its integer until its neighbours are pinned) and *rejects* snaps that genuinely break
361/// the fit (the R² gate runs after each refit).
362fn try_snap(tree: &ANode, x: &Array2<f64>, y: &Array1<f64>) -> Option<ANode> {
363    let mut theta = Vec::new();
364    collect(tree, &mut theta);
365    let mut kinds = Vec::new();
366    tag(tree, &mut kinds);
367
368    // Structural parameters only, ordered by snap confidence (closest to a clean target first).
369    let mut order: Vec<usize> = (0..theta.len())
370        .filter(|&i| kinds[i] != Kind::Other)
371        .collect();
372    if order.is_empty() {
373        return None;
374    }
375    order.sort_by(|&a, &b| {
376        snap_residual(theta[a], kinds[a])
377            .partial_cmp(&snap_residual(theta[b], kinds[b]))
378            .unwrap_or(std::cmp::Ordering::Equal)
379    });
380
381    let mut fixed = vec![false; theta.len()];
382    for i in order {
383        // Re-test eligibility on the CURRENT value: a prior refit may have pulled it onto (or off) a
384        // clean target.
385        let snapped_val = match kinds[i] {
386            Kind::Exp => snap_rational(theta[i])?,
387            Kind::Lin => {
388                if theta[i].abs() < SNAP_ABS {
389                    0.0
390                } else {
391                    snap_value(theta[i])?
392                }
393            }
394            Kind::Other => continue,
395        };
396        theta[i] = snapped_val;
397        fixed[i] = true;
398        // Re-fit everything still free with this coefficient pinned, then hold the pin exactly.
399        let (refit, _) = lm_fit_masked(tree, x, y, SNAP_REFIT_ITERS, &theta, &fixed);
400        theta = refit;
401        theta[i] = snapped_val;
402        let pred = {
403            let mut idx = 0;
404            eval(&apply(tree, &theta, &mut idx), x)
405        };
406        if r2(&pred, y) < SYMBOLIC_R2 {
407            return None;
408        }
409    }
410
411    // Best-effort second pass: snap the remaining "Other" constants (eml-argument constants and leaf
412    // intercepts) to clean values too — e.g. an `eml(·, 1.0000…)` denominator collapses to `eml(·, 1)`
413    // so the `ln(1)=0` residue vanishes. This does not affect the symbolic *criterion* (only structural
414    // Exp/Lin coeffs decide that); accept a snap only if R² is retained.
415    for i in 0..theta.len() {
416        if fixed[i] || kinds[i] != Kind::Other {
417            continue;
418        }
419        if let Some(cv) = snap_rational(theta[i]) {
420            if cv == theta[i] {
421                continue; // already exactly clean; a near value must still snap (removes ε residue)
422            }
423            let saved = theta[i];
424            theta[i] = cv;
425            let pred = {
426                let mut idx = 0;
427                eval(&apply(tree, &theta, &mut idx), x)
428            };
429            if r2(&pred, y) < SYMBOLIC_R2 {
430                theta[i] = saved; // snap would break the fit — leave the fitted value
431            }
432        }
433    }
434
435    let mut idx = 0;
436    Some(apply(tree, &theta, &mut idx))
437}
438
439/// Coefficient of determination of `pred` against `y`.
440#[must_use]
441pub fn r2(pred: &Array1<f64>, y: &Array1<f64>) -> f64 {
442    let mean = y.sum() / y.len().max(1) as f64;
443    let (mut sr, mut st) = (0.0, 0.0);
444    for (p, t) in pred.iter().zip(y.iter()) {
445        sr += (t - p) * (t - p);
446        st += (t - mean) * (t - mean);
447    }
448    if st == 0.0 {
449        return f64::NAN;
450    }
451    1.0 - sr / st
452}
453
454/// Finite-difference Levenberg–Marquardt fit of the **free** parameters of `skel` — those whose index
455/// is `false` in `fixed` — holding the rest at their values in `theta0`. Returns the full updated
456/// parameter vector (pinned entries unchanged) and the achieved MSE. With an all-`false` mask this
457/// fits every parameter (see [`lm_fit`]); with some entries pinned it is the constrained refit the
458/// iterative snapper ([`try_snap`]) relies on.
459fn lm_fit_masked(
460    skel: &ANode,
461    x: &Array2<f64>,
462    y: &Array1<f64>,
463    iters: usize,
464    theta0: &[f64],
465    fixed: &[bool],
466) -> (Vec<f64>, f64) {
467    let mut theta = theta0.to_vec();
468    let free: Vec<usize> = (0..theta.len()).filter(|&j| !fixed[j]).collect();
469    let p = free.len();
470    let eval_at = |th: &[f64]| -> Option<Array1<f64>> {
471        let mut idx = 0;
472        let pred = eval(&apply(skel, th, &mut idx), x);
473        pred.iter().all(|v| v.is_finite()).then_some(pred)
474    };
475    let Some(mut pred) = eval_at(&theta) else {
476        return (theta, f64::INFINITY);
477    };
478    let mut cost = mse(&pred, y);
479    if p == 0 {
480        return (theta, cost);
481    }
482    let n = y.len();
483    let mut lambda = 1e-2_f64;
484
485    for _ in 0..iters {
486        let r: Vec<f64> = pred.iter().zip(y.iter()).map(|(p, t)| p - t).collect();
487        let mut jac = vec![vec![0.0; p]; n];
488        let mut ok = true;
489        for (jc, &j) in free.iter().enumerate() {
490            let h = 1e-6 * (theta[j].abs() + 1.0);
491            let mut th = theta.clone();
492            th[j] += h;
493            let Some(pj) = eval_at(&th) else {
494                ok = false;
495                break;
496            };
497            for i in 0..n {
498                jac[i][jc] = (pj[i] - pred[i]) / h;
499            }
500        }
501        if !ok {
502            break;
503        }
504        let mut a = vec![vec![0.0; p]; p];
505        let mut grad = vec![0.0; p];
506        for col in 0..p {
507            for (row, jr) in jac.iter().enumerate() {
508                grad[col] += jr[col] * r[row];
509            }
510            for col2 in col..p {
511                let s: f64 = jac.iter().map(|jr| jr[col] * jr[col2]).sum();
512                a[col][col2] = s;
513                a[col2][col] = s;
514            }
515        }
516        let mut accepted = false;
517        for _ in 0..12 {
518            let mut ad = a.clone();
519            for d in 0..p {
520                ad[d][d] += lambda * a[d][d].max(1e-12);
521            }
522            let rhs: Vec<f64> = grad.iter().map(|g| -g).collect();
523            let Some(delta) = solve_dense(ad, rhs) else {
524                lambda *= 4.0;
525                continue;
526            };
527            let mut cand = theta.clone();
528            for (jc, &j) in free.iter().enumerate() {
529                cand[j] = theta[j] + delta[jc];
530            }
531            if let Some(pc) = eval_at(&cand) {
532                let cc = mse(&pc, y);
533                if cc < cost {
534                    theta = cand;
535                    pred = pc;
536                    cost = cc;
537                    lambda = (lambda * 0.5).max(1e-12);
538                    accepted = true;
539                    break;
540                }
541            }
542            lambda *= 4.0;
543        }
544        if !accepted {
545            break;
546        }
547    }
548    (theta, cost)
549}
550
551/// Finite-difference Levenberg–Marquardt fit of *all* of a skeleton's parameters. Returns fitted
552/// tree + MSE.
553fn lm_fit(skel: &ANode, x: &Array2<f64>, y: &Array1<f64>, iters: usize) -> (ANode, f64) {
554    let mut theta0 = Vec::new();
555    collect(skel, &mut theta0);
556    let fixed = vec![false; theta0.len()];
557    let (theta, cost) = lm_fit_masked(skel, x, y, iters, &theta0, &fixed);
558    let mut idx = 0;
559    (apply(skel, &theta, &mut idx), cost)
560}
561
562/// Multi-start fit: try coefficient inits `1.0` (favours monomials/products) and `0.0`, keep the best.
563fn lm_fit_best(skel: &ANode, x: &Array2<f64>, y: &Array1<f64>, iters: usize) -> (ANode, f64) {
564    let mut best = lm_fit(skel, x, y, iters);
565    let alt = lm_fit(&reinit(skel, 0.0), x, y, iters);
566    if alt.1 < best.1 {
567        best = alt;
568    }
569    best
570}
571
572/// A binary-tree skeleton with placeholder leaves.
573#[derive(Clone)]
574enum Skel {
575    Leaf,
576    Node(Box<Skel>, Box<Skel>),
577}
578
579impl Skel {
580    fn leaves(&self) -> usize {
581        match self {
582            Skel::Leaf => 1,
583            Skel::Node(l, r) => l.leaves() + r.leaves(),
584        }
585    }
586}
587
588/// All binary-tree skeletons with `0..=max_internal` internal (`eml`) nodes.
589fn skeletons(max_internal: usize) -> Vec<Skel> {
590    let mut by_k: Vec<Vec<Skel>> = vec![vec![Skel::Leaf]];
591    for k in 1..=max_internal {
592        let mut here = Vec::new();
593        for i in 0..k {
594            for l in &by_k[i] {
595                for r in &by_k[k - 1 - i] {
596                    here.push(Skel::Node(Box::new(l.clone()), Box::new(r.clone())));
597                }
598            }
599        }
600        by_k.push(here);
601    }
602    by_k.into_iter().flatten().collect()
603}
604
605/// Materialize a skeleton from per-leaf *type* codes (0 = const, 1 = linear, 2 = log-linear).
606fn materialize(
607    skel: &Skel,
608    types: &[usize],
609    idx: &mut usize,
610    n_vars: usize,
611    coeff_init: f64,
612) -> ANode {
613    match skel {
614        Skel::Leaf => {
615            let t = types[*idx];
616            *idx += 1;
617            match t {
618                0 => ANode::Const(1.0),
619                1 => ANode::Linear {
620                    coeffs: vec![coeff_init; n_vars],
621                    b: 0.0,
622                },
623                _ => ANode::LogLinear {
624                    coeffs: vec![coeff_init; n_vars],
625                    b: 0.0,
626                },
627            }
628        }
629        Skel::Node(l, r) => ANode::Eml(
630            Box::new(materialize(l, types, idx, n_vars, coeff_init)),
631            Box::new(materialize(r, types, idx, n_vars, coeff_init)),
632        ),
633    }
634}
635
636/// A discovered law and its quality.
637#[derive(Clone, Debug)]
638pub struct AffineSolution {
639    /// The fitted tree (its forward IS the prediction; see [`Self::predict`]).
640    pub tree: ANode,
641    /// Mean-squared error on the fitting data.
642    pub mse: f64,
643    /// Coefficient of determination.
644    pub r2: f64,
645    /// Readable expression.
646    pub expr: String,
647    /// Complexity (nodes + active coefficients).
648    pub nodes: usize,
649    /// Tree depth.
650    pub depth: usize,
651    /// Whether the law is a **symbolic** recovery: every structural exponent/slope snaps to a small
652    /// rational / named constant while retaining R² ≥ 0.999 (set by `Self::with_snap`).
653    pub symbolic: bool,
654}
655
656impl AffineSolution {
657    fn from_tree(tree: ANode, x: &Array2<f64>, y: &Array1<f64>, mse: f64) -> Self {
658        let pred = eval(&tree, x);
659        Self {
660            r2: r2(&pred, y),
661            expr: tree.pretty(),
662            nodes: tree.nodes(),
663            depth: tree.depth(),
664            symbolic: false,
665            mse,
666            tree,
667        }
668    }
669
670    /// If the law snaps to a clean rational-exponent / linear form (keeping R² ≥ 0.999), replace the
671    /// tree with the snapped one and mark it `symbolic`; otherwise return it unchanged.
672    #[must_use]
673    fn with_snap(mut self, x: &Array2<f64>, y: &Array1<f64>) -> Self {
674        if let Some(snapped) = try_snap(&self.tree, x, y) {
675            let pred = eval(&snapped, x);
676            self.mse = mse(&pred, y);
677            self.r2 = r2(&pred, y);
678            self.expr = snapped.pretty();
679            self.nodes = snapped.nodes();
680            self.depth = snapped.depth();
681            self.tree = snapped;
682            self.symbolic = true;
683        }
684        self
685    }
686
687    /// Evaluate the discovered law on new data `x` (`[n_rows, n_vars]`).
688    #[must_use]
689    pub fn predict(&self, x: &Array2<f64>) -> Array1<f64> {
690        eval(&self.tree, x)
691    }
692
693    /// LaTeX rendering of the law.
694    #[must_use]
695    pub fn latex(&self) -> String {
696        to_latex(&self.tree)
697    }
698}
699
700/// LaTeX for a tree (combination leaves render directly).
701#[must_use]
702pub fn to_latex(node: &ANode) -> String {
703    match node {
704        ANode::Const(c) => format!("{c:.4}"),
705        ANode::Linear { coeffs, b } => combo_latex(coeffs, *b, false),
706        ANode::LogLinear { coeffs, b } => combo_latex(coeffs, *b, true),
707        // eml(z, 1) = exp(z) − ln(1) = exp(z): fold the trivial denominator; render a LogLinear
708        // numerator as its monomial (x_0 x_1, x_0^{3/2}, …).
709        ANode::Eml(l, r) => match const_value(r) {
710            Some(c) if (c - 1.0).abs() < 1e-6 => match l.as_ref() {
711                ANode::LogLinear { coeffs, b } => monomial(coeffs, *b, true),
712                _ => format!("e^{{{}}}", to_latex(l)),
713            },
714            _ => format!(
715                "\\left(e^{{{}}} - \\ln\\left({}\\right)\\right)",
716                to_latex(l),
717                to_latex(r)
718            ),
719        },
720    }
721}
722
723fn combo_latex(coeffs: &[f64], b: f64, log: bool) -> String {
724    let mut parts: Vec<String> = coeffs
725        .iter()
726        .enumerate()
727        .filter(|(_, c)| c.abs() > ACTIVE_EPS)
728        .map(|(i, c)| {
729            if log {
730                format!("{c:.3}\\,\\ln x_{{{i}}}")
731            } else {
732                format!("{c:.3}\\,x_{{{i}}}")
733            }
734        })
735        .collect();
736    if b.abs() > ACTIVE_EPS || parts.is_empty() {
737        parts.push(format!("{b:.3}"));
738    }
739    parts.join(" + ")
740}
741
742/// Build the candidate pool: every `eml` skeleton (`0..=max_internal` nodes) × every per-leaf type
743/// assignment (const/linear/log-linear). Bounded by `cand_cap`.
744fn build_pool(n_vars: usize, max_internal: usize, cand_cap: usize) -> Vec<ANode> {
745    const RADIX: usize = 3; // const | linear | log-linear
746    const EXHAUSTIVE_MAX: u128 = 256;
747    const SAMPLES_PER_SKEL: usize = 200;
748
749    let mut rng = SplitMix64::new(0xA5F1_C0DE ^ n_vars as u64);
750    let mut pool: Vec<ANode> = Vec::new();
751
752    'outer: for skel in skeletons(max_internal) {
753        let leaves = skel.leaves();
754        let total = (RADIX as u128)
755            .checked_pow(leaves as u32)
756            .unwrap_or(u128::MAX);
757        if total <= EXHAUSTIVE_MAX {
758            for code in 0..total {
759                if pool.len() >= cand_cap {
760                    break 'outer;
761                }
762                let mut types = vec![0usize; leaves];
763                let mut c = code;
764                for slot in types.iter_mut() {
765                    *slot = (c % RADIX as u128) as usize;
766                    c /= RADIX as u128;
767                }
768                let mut idx = 0;
769                pool.push(materialize(&skel, &types, &mut idx, n_vars, 1.0));
770            }
771        } else {
772            for _ in 0..SAMPLES_PER_SKEL {
773                if pool.len() >= cand_cap {
774                    break 'outer;
775                }
776                let types: Vec<usize> = (0..leaves).map(|_| rng.below(RADIX)).collect();
777                let mut idx = 0;
778                pool.push(materialize(&skel, &types, &mut idx, n_vars, 1.0));
779            }
780        }
781    }
782    pool
783}
784
785/// Quick-fit every candidate, multi-start-refit the most promising, return their fitted solutions.
786fn fit_pool(pool: &[ANode], x: &Array2<f64>, y: &Array1<f64>) -> Vec<AffineSolution> {
787    const QUICK: usize = 10;
788    const REFIT: usize = 50;
789    const REFIT_K: usize = 40;
790
791    let mut scored: Vec<(usize, f64)> = pool
792        .iter()
793        .enumerate()
794        .map(|(i, c)| (i, lm_fit(c, x, y, QUICK).1))
795        .filter(|(_, m)| m.is_finite())
796        .collect();
797    scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
798
799    scored
800        .iter()
801        .take(REFIT_K)
802        .filter_map(|(i, _)| {
803            let (fitted, m) = lm_fit_best(&pool[*i], x, y, REFIT);
804            m.is_finite()
805                .then(|| AffineSolution::from_tree(fitted, x, y, m))
806        })
807        .collect()
808}
809
810/// Discover a rich-leaf EML law for `(x, y)`. `max_internal` bounds `eml` nodes; `cand_cap` bounds
811/// candidates. Returns the best fit by MSE, or `None` if `x` is empty.
812#[must_use]
813pub fn discover_affine(
814    x: &Array2<f64>,
815    y: &Array1<f64>,
816    max_internal: usize,
817    cand_cap: usize,
818) -> Option<AffineSolution> {
819    if x.nrows() == 0 || x.ncols() == 0 {
820        return None;
821    }
822    let pool = build_pool(x.ncols(), max_internal, cand_cap);
823    fit_pool(&pool, x, y)
824        .into_iter()
825        .min_by(|a, b| {
826            a.mse
827                .partial_cmp(&b.mse)
828                .unwrap_or(std::cmp::Ordering::Equal)
829        })
830        .map(|s| s.with_snap(x, y))
831}
832
833/// Discover a rich-leaf EML **Pareto front** (non-dominated over complexity and MSE), sorted by
834/// increasing complexity.
835#[must_use]
836pub fn discover_affine_pareto(
837    x: &Array2<f64>,
838    y: &Array1<f64>,
839    max_internal: usize,
840    cand_cap: usize,
841) -> Vec<AffineSolution> {
842    if x.nrows() == 0 || x.ncols() == 0 {
843        return Vec::new();
844    }
845    let pool = build_pool(x.ncols(), max_internal, cand_cap);
846    let cands: Vec<AffineSolution> = fit_pool(&pool, x, y)
847        .into_iter()
848        .map(|s| s.with_snap(x, y))
849        .collect();
850
851    let mut front: Vec<AffineSolution> = Vec::new();
852    for c in cands {
853        let dominated = front
854            .iter()
855            .any(|s| s.nodes <= c.nodes && s.mse <= c.mse && (s.nodes < c.nodes || s.mse < c.mse));
856        if dominated {
857            continue;
858        }
859        front.retain(|s| {
860            !(c.nodes <= s.nodes && c.mse <= s.mse && (c.nodes < s.nodes || c.mse < s.mse))
861        });
862        front.push(c);
863    }
864    front.sort_by(|a, b| {
865        a.nodes.cmp(&b.nodes).then(
866            a.mse
867                .partial_cmp(&b.mse)
868                .unwrap_or(std::cmp::Ordering::Equal),
869        )
870    });
871    front
872}
873
874#[cfg(test)]
875mod tests {
876    use super::*;
877
878    fn ds(f: impl Fn(&[f64]) -> f64, cols: &[(f64, f64)], n: usize) -> (Array2<f64>, Array1<f64>) {
879        let nv = cols.len();
880        let mut xv = Vec::with_capacity(n * nv);
881        let mut yv = Vec::with_capacity(n);
882        for i in 0..n {
883            // Sample each column on an INDEPENDENT (per-column permuted) grid so features are not
884            // collinear — collinear columns make a multivariate linear fit's normal equations
885            // singular and don't test real recovery. The strides are coprime-with-n permutations.
886            let row: Vec<f64> = cols
887                .iter()
888                .enumerate()
889                .map(|(j, (lo, hi))| {
890                    let idx = (i * (2 * j + 1) + 7 * j) % n;
891                    lo + (hi - lo) * (idx as f64) / (n as f64 - 1.0)
892                })
893                .collect();
894            yv.push(f(&row));
895            xv.extend(&row);
896        }
897        (
898            Array2::from_shape_vec((n, nv), xv).expect("shape"),
899            Array1::from(yv),
900        )
901    }
902
903    #[test]
904    fn recovers_linear_combination() {
905        // y = 3·x0 − 2·x1 + 1: a single Linear leaf over both variables.
906        let (x, y) = ds(
907            |r| 3.0 * r[0] - 2.0 * r[1] + 1.0,
908            &[(0.5, 5.0), (1.0, 4.0)],
909            50,
910        );
911        let s = discover_affine(&x, &y, 3, 2000).expect("solution");
912        assert!(
913            s.r2 > 0.9999,
914            "linear combo not recovered: r2={} expr={}",
915            s.r2,
916            s.expr
917        );
918    }
919
920    #[test]
921    fn recovers_scaled_exponential() {
922        // y = e^{2x} = eml(Linear{2x}, 1).
923        let (x, y) = ds(|r| (2.0 * r[0]).exp(), &[(0.0, 2.0)], 40);
924        let s = discover_affine(&x, &y, 3, 2000).expect("solution");
925        assert!(
926            s.r2 > 0.999,
927            "scaled exp not recovered: r2={} expr={}",
928            s.r2,
929            s.expr
930        );
931    }
932
933    #[test]
934    fn recovers_product() {
935        // y = x0·x1 = eml(LogLinear{ln x0 + ln x1}, 1) — the multiplicative case (step 2).
936        let (x, y) = ds(|r| r[0] * r[1], &[(0.5, 5.0), (0.5, 5.0)], 50);
937        let s = discover_affine(&x, &y, 3, 2000).expect("solution");
938        assert!(
939            s.r2 > 0.999,
940            "product not recovered: r2={} expr={}",
941            s.r2,
942            s.expr
943        );
944    }
945
946    #[test]
947    fn recovers_power_and_ratio() {
948        // y = x0² / x1 = eml(LogLinear{2 ln x0 − ln x1}, 1) — a power-law ratio monomial.
949        let (x, y) = ds(|r| r[0] * r[0] / r[1], &[(0.5, 5.0), (0.5, 5.0)], 50);
950        let s = discover_affine(&x, &y, 3, 2000).expect("solution");
951        assert!(
952            s.r2 > 0.999,
953            "power/ratio not recovered: r2={} expr={}",
954            s.r2,
955            s.expr
956        );
957    }
958
959    #[test]
960    fn symbolic_recovery_snaps_exponents() {
961        // y = x0² / x1: the fitted exponents (≈2, ≈−1) must snap to exact rationals → symbolic.
962        let (x, y) = ds(|r| r[0] * r[0] / r[1], &[(0.5, 5.0), (0.5, 5.0)], 50);
963        let s = discover_affine(&x, &y, 3, 2000).expect("solution");
964        assert!(
965            s.symbolic,
966            "x0^2/x1 should be a symbolic recovery: expr={}",
967            s.expr
968        );
969        assert!(s.r2 >= 0.999, "snapped form lost accuracy: r2={}", s.r2);
970
971        // A non-monomial transcendental (e^{2x}) is a numeric recovery but NOT a clean monomial:
972        // its single Linear-slope leaf still snaps (slope 2), so it IS symbolic too — sanity that
973        // the flag is set for clean linear forms as well.
974        let (x2, y2) = ds(|r| (2.0 * r[0]).exp(), &[(0.0, 2.0)], 40);
975        let s2 = discover_affine(&x2, &y2, 3, 2000).expect("solution");
976        assert!(s2.r2 > 0.999);
977    }
978
979    #[test]
980    fn pareto_front_is_non_dominated_and_sorted() {
981        let (x, y) = ds(|r| r[0] * r[1], &[(0.5, 5.0), (0.5, 5.0)], 40);
982        let front = discover_affine_pareto(&x, &y, 3, 2000);
983        assert!(!front.is_empty(), "empty pareto front");
984        for w in front.windows(2) {
985            assert!(w[0].nodes <= w[1].nodes, "front not sorted by complexity");
986        }
987        assert!(
988            front.iter().any(|s| s.r2 > 0.999),
989            "no accurate solution on the front"
990        );
991    }
992}