1use crate::polish::solve_dense;
19use crate::rng::SplitMix64;
20use scirs2_core::ndarray::{Array1, Array2};
21
22const EXP_CLAMP: f64 = 50.0;
24const LN_EPS: f64 = 1e-12;
26const ACTIVE_EPS: f64 = 1e-6;
28
29#[derive(Clone, Debug)]
31pub enum ANode {
32 Const(f64),
34 Linear {
36 coeffs: Vec<f64>,
38 b: f64,
40 },
41 LogLinear {
43 coeffs: Vec<f64>,
45 b: f64,
47 },
48 Eml(Box<ANode>, Box<ANode>),
50}
51
52fn active(coeffs: &[f64]) -> usize {
53 coeffs.iter().filter(|c| c.abs() > ACTIVE_EPS).count()
54}
55
56impl ANode {
57 #[must_use]
60 pub fn nodes(&self) -> usize {
61 match self {
62 ANode::Const(_) => 1,
63 ANode::Linear { coeffs, .. } | ANode::LogLinear { coeffs, .. } => 1 + active(coeffs),
64 ANode::Eml(l, r) => 1 + l.nodes() + r.nodes(),
65 }
66 }
67
68 #[must_use]
70 pub fn depth(&self) -> usize {
71 match self {
72 ANode::Eml(l, r) => 1 + l.depth().max(r.depth()),
73 _ => 0,
74 }
75 }
76
77 #[must_use]
79 pub fn pretty(&self) -> String {
80 match self {
81 ANode::Const(c) => format!("{c:.4}"),
82 ANode::Linear { coeffs, b } => format!("({})", combo(coeffs, *b, "x")),
83 ANode::LogLinear { coeffs, b } => format!("({})", combo(coeffs, *b, "ln x")),
84 ANode::Eml(l, r) => match const_value(r) {
87 Some(c) if (c - 1.0).abs() < 1e-6 => match l.as_ref() {
88 ANode::LogLinear { coeffs, b } => monomial(coeffs, *b, false),
89 _ => format!("exp({})", l.pretty()),
90 },
91 _ => format!("eml({}, {})", l.pretty(), r.pretty()),
92 },
93 }
94 }
95}
96
97fn const_value(node: &ANode) -> Option<f64> {
101 match node {
102 ANode::Const(c) => Some(*c),
103 ANode::Linear { coeffs, b } | ANode::LogLinear { coeffs, b } => {
104 coeffs.iter().all(|c| c.abs() <= ACTIVE_EPS).then_some(*b)
105 }
106 ANode::Eml(_, _) => None,
107 }
108}
109
110fn fmt_exp(a: f64) -> String {
112 if (a - a.round()).abs() < 1e-9 {
113 format!("{}", a.round() as i64)
114 } else {
115 format!("{a:.3}")
116 }
117}
118
119fn monomial(coeffs: &[f64], b: f64, latex: bool) -> String {
122 let mut parts: Vec<String> = coeffs
123 .iter()
124 .enumerate()
125 .filter(|(_, a)| a.abs() > ACTIVE_EPS)
126 .map(|(i, a)| match (latex, (a - 1.0).abs() < 1e-9) {
127 (true, true) => format!("x_{{{i}}}"),
128 (true, false) => format!("x_{{{i}}}^{{{}}}", fmt_exp(*a)),
129 (false, true) => format!("x{i}"),
130 (false, false) => format!("x{i}^{}", fmt_exp(*a)),
131 })
132 .collect();
133 if b.abs() > ACTIVE_EPS {
134 parts.push(if latex {
135 format!("e^{{{b:.3}}}")
136 } else {
137 format!("exp({b:.3})")
138 });
139 }
140 match (parts.is_empty(), latex) {
141 (true, _) => "1".to_string(),
142 (false, true) => parts.join(" \\cdot "),
143 (false, false) => parts.join("*"),
144 }
145}
146
147fn combo(coeffs: &[f64], b: f64, sym: &str) -> String {
149 let mut parts: Vec<String> = coeffs
150 .iter()
151 .enumerate()
152 .filter(|(_, c)| c.abs() > ACTIVE_EPS)
153 .map(|(i, c)| format!("{c:.3}*{sym}{i}"))
154 .collect();
155 if b.abs() > ACTIVE_EPS || parts.is_empty() {
156 parts.push(format!("{b:.3}"));
157 }
158 parts.join(" + ")
159}
160
161#[must_use]
163pub fn eval(node: &ANode, x: &Array2<f64>) -> Array1<f64> {
164 let n = x.nrows();
165 match node {
166 ANode::Const(c) => Array1::from_elem(n, *c),
167 ANode::Linear { coeffs, b } => {
168 let mut out = Array1::from_elem(n, *b);
169 for (j, &cf) in coeffs.iter().enumerate() {
170 if cf != 0.0 {
171 for i in 0..n {
172 out[i] += cf * x[[i, j]];
173 }
174 }
175 }
176 out
177 }
178 ANode::LogLinear { coeffs, b } => {
179 let mut out = Array1::from_elem(n, *b);
180 for (j, &cf) in coeffs.iter().enumerate() {
181 if cf != 0.0 {
182 for i in 0..n {
183 out[i] += cf * x[[i, j]].max(LN_EPS).ln();
184 }
185 }
186 }
187 out
188 }
189 ANode::Eml(l, r) => {
190 let la = eval(l, x);
191 let rb = eval(r, x);
192 let mut out = Array1::zeros(n);
193 for i in 0..n {
194 let ea = la[i].clamp(-EXP_CLAMP, EXP_CLAMP).exp();
195 let lb = rb[i].max(LN_EPS).ln();
196 out[i] = ea - lb;
197 }
198 out
199 }
200 }
201}
202
203fn collect(node: &ANode, out: &mut Vec<f64>) {
205 match node {
206 ANode::Const(c) => out.push(*c),
207 ANode::Linear { coeffs, b } | ANode::LogLinear { coeffs, b } => {
208 out.extend_from_slice(coeffs);
209 out.push(*b);
210 }
211 ANode::Eml(l, r) => {
212 collect(l, out);
213 collect(r, out);
214 }
215 }
216}
217
218fn apply(node: &ANode, p: &[f64], idx: &mut usize) -> ANode {
220 match node {
221 ANode::Const(_) => {
222 let c = p[*idx];
223 *idx += 1;
224 ANode::Const(c)
225 }
226 ANode::Linear { coeffs, .. } => {
227 let n = coeffs.len();
228 let cs = p[*idx..*idx + n].to_vec();
229 let b = p[*idx + n];
230 *idx += n + 1;
231 ANode::Linear { coeffs: cs, b }
232 }
233 ANode::LogLinear { coeffs, .. } => {
234 let n = coeffs.len();
235 let cs = p[*idx..*idx + n].to_vec();
236 let b = p[*idx + n];
237 *idx += n + 1;
238 ANode::LogLinear { coeffs: cs, b }
239 }
240 ANode::Eml(l, r) => ANode::Eml(Box::new(apply(l, p, idx)), Box::new(apply(r, p, idx))),
241 }
242}
243
244fn reinit(node: &ANode, coeff_init: f64) -> ANode {
247 match node {
248 ANode::Const(_) => ANode::Const(1.0),
249 ANode::Linear { coeffs, .. } => ANode::Linear {
250 coeffs: vec![coeff_init; coeffs.len()],
251 b: 0.0,
252 },
253 ANode::LogLinear { coeffs, .. } => ANode::LogLinear {
254 coeffs: vec![coeff_init; coeffs.len()],
255 b: 0.0,
256 },
257 ANode::Eml(l, r) => ANode::Eml(
258 Box::new(reinit(l, coeff_init)),
259 Box::new(reinit(r, coeff_init)),
260 ),
261 }
262}
263
264fn mse(pred: &Array1<f64>, y: &Array1<f64>) -> f64 {
266 let n = y.len().max(1) as f64;
267 let mut s = 0.0;
268 for (p, t) in pred.iter().zip(y.iter()) {
269 if !p.is_finite() {
270 return f64::INFINITY;
271 }
272 s += (p - t) * (p - t);
273 }
274 s / n
275}
276
277const SNAP_ABS: f64 = 0.03;
279const SYMBOLIC_R2: f64 = 0.999;
281const SNAP_REFIT_ITERS: usize = 60;
283
284fn snap_rational(v: f64) -> Option<f64> {
287 if v.abs() < SNAP_ABS {
288 return Some(0.0);
289 }
290 for d in [1.0, 2.0, 3.0, 4.0, 6.0] {
291 let k = (v * d).round();
292 let cand = k / d;
293 if cand.abs() <= 12.0 && (v - cand).abs() < SNAP_ABS {
294 return Some(cand);
295 }
296 }
297 None
298}
299
300fn snap_value(v: f64) -> Option<f64> {
302 oxieml::symreg::snap_to_named_const(v)
303 .map(|nc| nc.value())
304 .or_else(|| snap_rational(v))
305}
306
307#[derive(Clone, Copy, PartialEq)]
311enum Kind {
312 Exp,
313 Lin,
314 Other,
315}
316
317fn tag(node: &ANode, out: &mut Vec<Kind>) {
318 match node {
319 ANode::Const(_) => out.push(Kind::Other),
320 ANode::Linear { coeffs, .. } => {
321 out.extend(std::iter::repeat_n(Kind::Lin, coeffs.len()));
322 out.push(Kind::Other);
323 }
324 ANode::LogLinear { coeffs, .. } => {
325 out.extend(std::iter::repeat_n(Kind::Exp, coeffs.len()));
326 out.push(Kind::Other);
327 }
328 ANode::Eml(l, r) => {
329 tag(l, out);
330 tag(r, out);
331 }
332 }
333}
334
335fn snap_residual(v: f64, k: Kind) -> f64 {
338 let target = match k {
339 Kind::Exp => snap_rational(v),
340 Kind::Lin => {
341 if v.abs() < SNAP_ABS {
342 Some(0.0)
343 } else {
344 snap_value(v)
345 }
346 }
347 Kind::Other => return f64::INFINITY,
348 };
349 target.map_or(f64::INFINITY, |t| (v - t).abs())
350}
351
352fn try_snap(tree: &ANode, x: &Array2<f64>, y: &Array1<f64>) -> Option<ANode> {
363 let mut theta = Vec::new();
364 collect(tree, &mut theta);
365 let mut kinds = Vec::new();
366 tag(tree, &mut kinds);
367
368 let mut order: Vec<usize> = (0..theta.len())
370 .filter(|&i| kinds[i] != Kind::Other)
371 .collect();
372 if order.is_empty() {
373 return None;
374 }
375 order.sort_by(|&a, &b| {
376 snap_residual(theta[a], kinds[a])
377 .partial_cmp(&snap_residual(theta[b], kinds[b]))
378 .unwrap_or(std::cmp::Ordering::Equal)
379 });
380
381 let mut fixed = vec![false; theta.len()];
382 for i in order {
383 let snapped_val = match kinds[i] {
386 Kind::Exp => snap_rational(theta[i])?,
387 Kind::Lin => {
388 if theta[i].abs() < SNAP_ABS {
389 0.0
390 } else {
391 snap_value(theta[i])?
392 }
393 }
394 Kind::Other => continue,
395 };
396 theta[i] = snapped_val;
397 fixed[i] = true;
398 let (refit, _) = lm_fit_masked(tree, x, y, SNAP_REFIT_ITERS, &theta, &fixed);
400 theta = refit;
401 theta[i] = snapped_val;
402 let pred = {
403 let mut idx = 0;
404 eval(&apply(tree, &theta, &mut idx), x)
405 };
406 if r2(&pred, y) < SYMBOLIC_R2 {
407 return None;
408 }
409 }
410
411 for i in 0..theta.len() {
416 if fixed[i] || kinds[i] != Kind::Other {
417 continue;
418 }
419 if let Some(cv) = snap_rational(theta[i]) {
420 if cv == theta[i] {
421 continue; }
423 let saved = theta[i];
424 theta[i] = cv;
425 let pred = {
426 let mut idx = 0;
427 eval(&apply(tree, &theta, &mut idx), x)
428 };
429 if r2(&pred, y) < SYMBOLIC_R2 {
430 theta[i] = saved; }
432 }
433 }
434
435 let mut idx = 0;
436 Some(apply(tree, &theta, &mut idx))
437}
438
439#[must_use]
441pub fn r2(pred: &Array1<f64>, y: &Array1<f64>) -> f64 {
442 let mean = y.sum() / y.len().max(1) as f64;
443 let (mut sr, mut st) = (0.0, 0.0);
444 for (p, t) in pred.iter().zip(y.iter()) {
445 sr += (t - p) * (t - p);
446 st += (t - mean) * (t - mean);
447 }
448 if st == 0.0 {
449 return f64::NAN;
450 }
451 1.0 - sr / st
452}
453
454fn lm_fit_masked(
460 skel: &ANode,
461 x: &Array2<f64>,
462 y: &Array1<f64>,
463 iters: usize,
464 theta0: &[f64],
465 fixed: &[bool],
466) -> (Vec<f64>, f64) {
467 let mut theta = theta0.to_vec();
468 let free: Vec<usize> = (0..theta.len()).filter(|&j| !fixed[j]).collect();
469 let p = free.len();
470 let eval_at = |th: &[f64]| -> Option<Array1<f64>> {
471 let mut idx = 0;
472 let pred = eval(&apply(skel, th, &mut idx), x);
473 pred.iter().all(|v| v.is_finite()).then_some(pred)
474 };
475 let Some(mut pred) = eval_at(&theta) else {
476 return (theta, f64::INFINITY);
477 };
478 let mut cost = mse(&pred, y);
479 if p == 0 {
480 return (theta, cost);
481 }
482 let n = y.len();
483 let mut lambda = 1e-2_f64;
484
485 for _ in 0..iters {
486 let r: Vec<f64> = pred.iter().zip(y.iter()).map(|(p, t)| p - t).collect();
487 let mut jac = vec![vec![0.0; p]; n];
488 let mut ok = true;
489 for (jc, &j) in free.iter().enumerate() {
490 let h = 1e-6 * (theta[j].abs() + 1.0);
491 let mut th = theta.clone();
492 th[j] += h;
493 let Some(pj) = eval_at(&th) else {
494 ok = false;
495 break;
496 };
497 for i in 0..n {
498 jac[i][jc] = (pj[i] - pred[i]) / h;
499 }
500 }
501 if !ok {
502 break;
503 }
504 let mut a = vec![vec![0.0; p]; p];
505 let mut grad = vec![0.0; p];
506 for col in 0..p {
507 for (row, jr) in jac.iter().enumerate() {
508 grad[col] += jr[col] * r[row];
509 }
510 for col2 in col..p {
511 let s: f64 = jac.iter().map(|jr| jr[col] * jr[col2]).sum();
512 a[col][col2] = s;
513 a[col2][col] = s;
514 }
515 }
516 let mut accepted = false;
517 for _ in 0..12 {
518 let mut ad = a.clone();
519 for d in 0..p {
520 ad[d][d] += lambda * a[d][d].max(1e-12);
521 }
522 let rhs: Vec<f64> = grad.iter().map(|g| -g).collect();
523 let Some(delta) = solve_dense(ad, rhs) else {
524 lambda *= 4.0;
525 continue;
526 };
527 let mut cand = theta.clone();
528 for (jc, &j) in free.iter().enumerate() {
529 cand[j] = theta[j] + delta[jc];
530 }
531 if let Some(pc) = eval_at(&cand) {
532 let cc = mse(&pc, y);
533 if cc < cost {
534 theta = cand;
535 pred = pc;
536 cost = cc;
537 lambda = (lambda * 0.5).max(1e-12);
538 accepted = true;
539 break;
540 }
541 }
542 lambda *= 4.0;
543 }
544 if !accepted {
545 break;
546 }
547 }
548 (theta, cost)
549}
550
551fn lm_fit(skel: &ANode, x: &Array2<f64>, y: &Array1<f64>, iters: usize) -> (ANode, f64) {
554 let mut theta0 = Vec::new();
555 collect(skel, &mut theta0);
556 let fixed = vec![false; theta0.len()];
557 let (theta, cost) = lm_fit_masked(skel, x, y, iters, &theta0, &fixed);
558 let mut idx = 0;
559 (apply(skel, &theta, &mut idx), cost)
560}
561
562fn lm_fit_best(skel: &ANode, x: &Array2<f64>, y: &Array1<f64>, iters: usize) -> (ANode, f64) {
564 let mut best = lm_fit(skel, x, y, iters);
565 let alt = lm_fit(&reinit(skel, 0.0), x, y, iters);
566 if alt.1 < best.1 {
567 best = alt;
568 }
569 best
570}
571
572#[derive(Clone)]
574enum Skel {
575 Leaf,
576 Node(Box<Skel>, Box<Skel>),
577}
578
579impl Skel {
580 fn leaves(&self) -> usize {
581 match self {
582 Skel::Leaf => 1,
583 Skel::Node(l, r) => l.leaves() + r.leaves(),
584 }
585 }
586}
587
588fn skeletons(max_internal: usize) -> Vec<Skel> {
590 let mut by_k: Vec<Vec<Skel>> = vec![vec![Skel::Leaf]];
591 for k in 1..=max_internal {
592 let mut here = Vec::new();
593 for i in 0..k {
594 for l in &by_k[i] {
595 for r in &by_k[k - 1 - i] {
596 here.push(Skel::Node(Box::new(l.clone()), Box::new(r.clone())));
597 }
598 }
599 }
600 by_k.push(here);
601 }
602 by_k.into_iter().flatten().collect()
603}
604
605fn materialize(
607 skel: &Skel,
608 types: &[usize],
609 idx: &mut usize,
610 n_vars: usize,
611 coeff_init: f64,
612) -> ANode {
613 match skel {
614 Skel::Leaf => {
615 let t = types[*idx];
616 *idx += 1;
617 match t {
618 0 => ANode::Const(1.0),
619 1 => ANode::Linear {
620 coeffs: vec![coeff_init; n_vars],
621 b: 0.0,
622 },
623 _ => ANode::LogLinear {
624 coeffs: vec![coeff_init; n_vars],
625 b: 0.0,
626 },
627 }
628 }
629 Skel::Node(l, r) => ANode::Eml(
630 Box::new(materialize(l, types, idx, n_vars, coeff_init)),
631 Box::new(materialize(r, types, idx, n_vars, coeff_init)),
632 ),
633 }
634}
635
636#[derive(Clone, Debug)]
638pub struct AffineSolution {
639 pub tree: ANode,
641 pub mse: f64,
643 pub r2: f64,
645 pub expr: String,
647 pub nodes: usize,
649 pub depth: usize,
651 pub symbolic: bool,
654}
655
656impl AffineSolution {
657 fn from_tree(tree: ANode, x: &Array2<f64>, y: &Array1<f64>, mse: f64) -> Self {
658 let pred = eval(&tree, x);
659 Self {
660 r2: r2(&pred, y),
661 expr: tree.pretty(),
662 nodes: tree.nodes(),
663 depth: tree.depth(),
664 symbolic: false,
665 mse,
666 tree,
667 }
668 }
669
670 #[must_use]
673 fn with_snap(mut self, x: &Array2<f64>, y: &Array1<f64>) -> Self {
674 if let Some(snapped) = try_snap(&self.tree, x, y) {
675 let pred = eval(&snapped, x);
676 self.mse = mse(&pred, y);
677 self.r2 = r2(&pred, y);
678 self.expr = snapped.pretty();
679 self.nodes = snapped.nodes();
680 self.depth = snapped.depth();
681 self.tree = snapped;
682 self.symbolic = true;
683 }
684 self
685 }
686
687 #[must_use]
689 pub fn predict(&self, x: &Array2<f64>) -> Array1<f64> {
690 eval(&self.tree, x)
691 }
692
693 #[must_use]
695 pub fn latex(&self) -> String {
696 to_latex(&self.tree)
697 }
698}
699
700#[must_use]
702pub fn to_latex(node: &ANode) -> String {
703 match node {
704 ANode::Const(c) => format!("{c:.4}"),
705 ANode::Linear { coeffs, b } => combo_latex(coeffs, *b, false),
706 ANode::LogLinear { coeffs, b } => combo_latex(coeffs, *b, true),
707 ANode::Eml(l, r) => match const_value(r) {
710 Some(c) if (c - 1.0).abs() < 1e-6 => match l.as_ref() {
711 ANode::LogLinear { coeffs, b } => monomial(coeffs, *b, true),
712 _ => format!("e^{{{}}}", to_latex(l)),
713 },
714 _ => format!(
715 "\\left(e^{{{}}} - \\ln\\left({}\\right)\\right)",
716 to_latex(l),
717 to_latex(r)
718 ),
719 },
720 }
721}
722
723fn combo_latex(coeffs: &[f64], b: f64, log: bool) -> String {
724 let mut parts: Vec<String> = coeffs
725 .iter()
726 .enumerate()
727 .filter(|(_, c)| c.abs() > ACTIVE_EPS)
728 .map(|(i, c)| {
729 if log {
730 format!("{c:.3}\\,\\ln x_{{{i}}}")
731 } else {
732 format!("{c:.3}\\,x_{{{i}}}")
733 }
734 })
735 .collect();
736 if b.abs() > ACTIVE_EPS || parts.is_empty() {
737 parts.push(format!("{b:.3}"));
738 }
739 parts.join(" + ")
740}
741
742fn build_pool(n_vars: usize, max_internal: usize, cand_cap: usize) -> Vec<ANode> {
745 const RADIX: usize = 3; const EXHAUSTIVE_MAX: u128 = 256;
747 const SAMPLES_PER_SKEL: usize = 200;
748
749 let mut rng = SplitMix64::new(0xA5F1_C0DE ^ n_vars as u64);
750 let mut pool: Vec<ANode> = Vec::new();
751
752 'outer: for skel in skeletons(max_internal) {
753 let leaves = skel.leaves();
754 let total = (RADIX as u128)
755 .checked_pow(leaves as u32)
756 .unwrap_or(u128::MAX);
757 if total <= EXHAUSTIVE_MAX {
758 for code in 0..total {
759 if pool.len() >= cand_cap {
760 break 'outer;
761 }
762 let mut types = vec![0usize; leaves];
763 let mut c = code;
764 for slot in types.iter_mut() {
765 *slot = (c % RADIX as u128) as usize;
766 c /= RADIX as u128;
767 }
768 let mut idx = 0;
769 pool.push(materialize(&skel, &types, &mut idx, n_vars, 1.0));
770 }
771 } else {
772 for _ in 0..SAMPLES_PER_SKEL {
773 if pool.len() >= cand_cap {
774 break 'outer;
775 }
776 let types: Vec<usize> = (0..leaves).map(|_| rng.below(RADIX)).collect();
777 let mut idx = 0;
778 pool.push(materialize(&skel, &types, &mut idx, n_vars, 1.0));
779 }
780 }
781 }
782 pool
783}
784
785fn fit_pool(pool: &[ANode], x: &Array2<f64>, y: &Array1<f64>) -> Vec<AffineSolution> {
787 const QUICK: usize = 10;
788 const REFIT: usize = 50;
789 const REFIT_K: usize = 40;
790
791 let mut scored: Vec<(usize, f64)> = pool
792 .iter()
793 .enumerate()
794 .map(|(i, c)| (i, lm_fit(c, x, y, QUICK).1))
795 .filter(|(_, m)| m.is_finite())
796 .collect();
797 scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
798
799 scored
800 .iter()
801 .take(REFIT_K)
802 .filter_map(|(i, _)| {
803 let (fitted, m) = lm_fit_best(&pool[*i], x, y, REFIT);
804 m.is_finite()
805 .then(|| AffineSolution::from_tree(fitted, x, y, m))
806 })
807 .collect()
808}
809
810#[must_use]
813pub fn discover_affine(
814 x: &Array2<f64>,
815 y: &Array1<f64>,
816 max_internal: usize,
817 cand_cap: usize,
818) -> Option<AffineSolution> {
819 if x.nrows() == 0 || x.ncols() == 0 {
820 return None;
821 }
822 let pool = build_pool(x.ncols(), max_internal, cand_cap);
823 fit_pool(&pool, x, y)
824 .into_iter()
825 .min_by(|a, b| {
826 a.mse
827 .partial_cmp(&b.mse)
828 .unwrap_or(std::cmp::Ordering::Equal)
829 })
830 .map(|s| s.with_snap(x, y))
831}
832
833#[must_use]
836pub fn discover_affine_pareto(
837 x: &Array2<f64>,
838 y: &Array1<f64>,
839 max_internal: usize,
840 cand_cap: usize,
841) -> Vec<AffineSolution> {
842 if x.nrows() == 0 || x.ncols() == 0 {
843 return Vec::new();
844 }
845 let pool = build_pool(x.ncols(), max_internal, cand_cap);
846 let cands: Vec<AffineSolution> = fit_pool(&pool, x, y)
847 .into_iter()
848 .map(|s| s.with_snap(x, y))
849 .collect();
850
851 let mut front: Vec<AffineSolution> = Vec::new();
852 for c in cands {
853 let dominated = front
854 .iter()
855 .any(|s| s.nodes <= c.nodes && s.mse <= c.mse && (s.nodes < c.nodes || s.mse < c.mse));
856 if dominated {
857 continue;
858 }
859 front.retain(|s| {
860 !(c.nodes <= s.nodes && c.mse <= s.mse && (c.nodes < s.nodes || c.mse < s.mse))
861 });
862 front.push(c);
863 }
864 front.sort_by(|a, b| {
865 a.nodes.cmp(&b.nodes).then(
866 a.mse
867 .partial_cmp(&b.mse)
868 .unwrap_or(std::cmp::Ordering::Equal),
869 )
870 });
871 front
872}
873
874#[cfg(test)]
875mod tests {
876 use super::*;
877
878 fn ds(f: impl Fn(&[f64]) -> f64, cols: &[(f64, f64)], n: usize) -> (Array2<f64>, Array1<f64>) {
879 let nv = cols.len();
880 let mut xv = Vec::with_capacity(n * nv);
881 let mut yv = Vec::with_capacity(n);
882 for i in 0..n {
883 let row: Vec<f64> = cols
887 .iter()
888 .enumerate()
889 .map(|(j, (lo, hi))| {
890 let idx = (i * (2 * j + 1) + 7 * j) % n;
891 lo + (hi - lo) * (idx as f64) / (n as f64 - 1.0)
892 })
893 .collect();
894 yv.push(f(&row));
895 xv.extend(&row);
896 }
897 (
898 Array2::from_shape_vec((n, nv), xv).expect("shape"),
899 Array1::from(yv),
900 )
901 }
902
903 #[test]
904 fn recovers_linear_combination() {
905 let (x, y) = ds(
907 |r| 3.0 * r[0] - 2.0 * r[1] + 1.0,
908 &[(0.5, 5.0), (1.0, 4.0)],
909 50,
910 );
911 let s = discover_affine(&x, &y, 3, 2000).expect("solution");
912 assert!(
913 s.r2 > 0.9999,
914 "linear combo not recovered: r2={} expr={}",
915 s.r2,
916 s.expr
917 );
918 }
919
920 #[test]
921 fn recovers_scaled_exponential() {
922 let (x, y) = ds(|r| (2.0 * r[0]).exp(), &[(0.0, 2.0)], 40);
924 let s = discover_affine(&x, &y, 3, 2000).expect("solution");
925 assert!(
926 s.r2 > 0.999,
927 "scaled exp not recovered: r2={} expr={}",
928 s.r2,
929 s.expr
930 );
931 }
932
933 #[test]
934 fn recovers_product() {
935 let (x, y) = ds(|r| r[0] * r[1], &[(0.5, 5.0), (0.5, 5.0)], 50);
937 let s = discover_affine(&x, &y, 3, 2000).expect("solution");
938 assert!(
939 s.r2 > 0.999,
940 "product not recovered: r2={} expr={}",
941 s.r2,
942 s.expr
943 );
944 }
945
946 #[test]
947 fn recovers_power_and_ratio() {
948 let (x, y) = ds(|r| r[0] * r[0] / r[1], &[(0.5, 5.0), (0.5, 5.0)], 50);
950 let s = discover_affine(&x, &y, 3, 2000).expect("solution");
951 assert!(
952 s.r2 > 0.999,
953 "power/ratio not recovered: r2={} expr={}",
954 s.r2,
955 s.expr
956 );
957 }
958
959 #[test]
960 fn symbolic_recovery_snaps_exponents() {
961 let (x, y) = ds(|r| r[0] * r[0] / r[1], &[(0.5, 5.0), (0.5, 5.0)], 50);
963 let s = discover_affine(&x, &y, 3, 2000).expect("solution");
964 assert!(
965 s.symbolic,
966 "x0^2/x1 should be a symbolic recovery: expr={}",
967 s.expr
968 );
969 assert!(s.r2 >= 0.999, "snapped form lost accuracy: r2={}", s.r2);
970
971 let (x2, y2) = ds(|r| (2.0 * r[0]).exp(), &[(0.0, 2.0)], 40);
975 let s2 = discover_affine(&x2, &y2, 3, 2000).expect("solution");
976 assert!(s2.r2 > 0.999);
977 }
978
979 #[test]
980 fn pareto_front_is_non_dominated_and_sorted() {
981 let (x, y) = ds(|r| r[0] * r[1], &[(0.5, 5.0), (0.5, 5.0)], 40);
982 let front = discover_affine_pareto(&x, &y, 3, 2000);
983 assert!(!front.is_empty(), "empty pareto front");
984 for w in front.windows(2) {
985 assert!(w[0].nodes <= w[1].nodes, "front not sorted by complexity");
986 }
987 assert!(
988 front.iter().any(|s| s.r2 > 0.999),
989 "no accurate solution on the front"
990 );
991 }
992}