use crate::error::EmlError;
use crate::eval::EvalCtx;
use crate::grad::ParameterizedEmlTree;
use crate::tree::EmlTree;
use crate::units::Units;
use rand::RngExt;
use rand::SeedableRng;
mod constants;
mod mcts;
mod numerics;
mod topology;
use constants::{bake_params_into_lowered, extract_named_constants};
use topology::{
compute_mse_direct, compute_mse_parameterized, topology_interval_feasible, try_integer_rounding,
};
pub use topology::{dedupe_by_semantics, enumerate_topologies};
type Rng = rand::rngs::StdRng;
fn make_rng(seed: Option<u64>, salt: u64) -> Rng {
match seed {
Some(s) => Rng::seed_from_u64(derive_seed(s, salt)),
None => rand::make_rng::<Rng>(),
}
}
fn huber_loss(residuals: &[f64], delta: f64) -> f64 {
if residuals.is_empty() {
return 0.0;
}
let sum: f64 = residuals
.iter()
.map(|&r| {
let ar = r.abs();
if ar <= delta {
0.5 * r * r
} else {
delta * (ar - 0.5 * delta)
}
})
.sum();
sum / residuals.len() as f64
}
fn huber_grad_factor(r: f64, delta: f64) -> f64 {
if r.abs() <= delta {
r
} else {
delta * r.signum()
}
}
fn trimmed_mse(residuals: &[f64], alpha: f64) -> f64 {
if residuals.is_empty() {
return 0.0;
}
let mut sorted: Vec<f64> = residuals.iter().map(|r| r * r).collect();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let keep = ((1.0 - alpha) * sorted.len() as f64).ceil() as usize;
let keep = keep.max(1).min(sorted.len());
sorted[..keep].iter().sum::<f64>() / keep as f64
}
fn trimmed_mse_grad_factor(r: f64, residuals: &[f64], alpha: f64) -> f64 {
if residuals.is_empty() || alpha <= 0.0 {
return r;
}
let mut abs_res: Vec<f64> = residuals.iter().map(|x| x.abs()).collect();
abs_res.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let q_idx = ((1.0 - alpha) * (abs_res.len() - 1) as f64).round() as usize;
let q = abs_res[q_idx.min(abs_res.len() - 1)].max(1e-12);
let sharpness = 3.0_f64;
let w = 1.0 / (1.0 + (r.abs() / q - (1.0 - alpha)).exp() * sharpness.exp());
w.clamp(0.0, 1.0) * r
}
fn derive_seed(master: u64, topology_idx: u64) -> u64 {
let mix = |mut z: u64| -> u64 {
z = z.wrapping_add(0x9e37_79b9_7f4a_7c15);
z = (z ^ (z >> 30)).wrapping_mul(0xbf58_476d_1ce4_e5b9);
z = (z ^ (z >> 27)).wrapping_mul(0x94d0_49bb_1331_11eb);
z ^ (z >> 31)
};
mix(master).wrapping_add(mix(topology_idx))
}
#[cfg(feature = "parallel")]
use rayon::prelude::*;
#[derive(Clone, Debug, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum SymRegLoss {
#[default]
Mse,
Huber {
delta: f64,
},
TrimmedMse {
alpha: f64,
},
}
#[derive(Debug, Clone, PartialEq, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum MultiOutputStrategy {
#[default]
Independent,
}
#[derive(Debug, Clone, PartialEq, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum SymRegStrategy {
#[default]
Exhaustive,
Beam {
width: usize,
},
Mcts {
iterations: usize,
exploration: f64,
},
}
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "serde", serde(default))]
pub struct SymRegConfig {
pub max_depth: usize,
pub learning_rate: f64,
pub tolerance: f64,
pub max_iter: usize,
pub complexity_penalty: f64,
pub num_restarts: usize,
pub integer_rounding: bool,
pub cv_folds: Option<usize>,
pub seed: Option<u64>,
pub loss: SymRegLoss,
pub constant_extraction: Option<f64>,
pub interval_pruning: bool,
pub interval_pruning_depth_threshold: usize,
pub multi_output_strategy: MultiOutputStrategy,
pub strategy: SymRegStrategy,
pub ode_sg_window: Option<usize>,
pub unit_filter: Option<(Vec<Units>, Units)>,
}
impl Default for SymRegConfig {
fn default() -> Self {
Self {
max_depth: 4,
learning_rate: 1e-3,
tolerance: 1e-10,
max_iter: 10_000,
complexity_penalty: 1e-4,
num_restarts: 3,
integer_rounding: true,
cv_folds: None,
seed: None,
loss: SymRegLoss::default(),
constant_extraction: None,
interval_pruning: false,
interval_pruning_depth_threshold: 2,
multi_output_strategy: MultiOutputStrategy::Independent,
strategy: SymRegStrategy::Exhaustive,
ode_sg_window: None,
unit_filter: None,
}
}
}
impl SymRegConfig {
pub fn quick() -> Self {
Self {
max_depth: 2,
max_iter: 200,
num_restarts: 2,
..Self::default()
}
}
pub fn balanced() -> Self {
Self::default()
}
pub fn exhaustive() -> Self {
Self {
max_depth: 4,
max_iter: 20_000,
num_restarts: 8,
cv_folds: Some(5),
..Self::default()
}
}
}
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct DiscoveredFormula {
pub eml_tree: EmlTree,
pub mse: f64,
pub complexity: usize,
pub score: f64,
pub pretty: String,
pub params: Vec<f64>,
pub cv_mse: Option<f64>,
}
impl DiscoveredFormula {
pub fn dominates(&self, other: &DiscoveredFormula) -> bool {
self.mse <= other.mse
&& self.complexity <= other.complexity
&& (self.mse < other.mse || self.complexity < other.complexity)
}
pub fn to_latex(&self) -> String {
self.eml_tree.lower().simplify().to_latex()
}
}
#[cfg(feature = "serde")]
impl DiscoveredFormula {
pub fn to_json(&self) -> Result<String, serde_json::Error> {
serde_json::to_string(self)
}
pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
serde_json::from_str(json)
}
pub fn to_binary(&self) -> Result<Vec<u8>, oxicode::Error> {
oxicode::serde::encode_serde(self)
}
pub fn from_binary(bytes: &[u8]) -> Result<Self, oxicode::Error> {
oxicode::serde::decode_serde(bytes)
}
}
#[cfg(feature = "tensorlogic")]
impl DiscoveredFormula {
pub fn to_tlexpr(&self) -> tensorlogic_ir::TLExpr {
crate::tensorlogic::to_tlexpr(&self.eml_tree.lower().simplify())
}
pub fn to_tl_weighted_rule(&self, weight: f64) -> tensorlogic_ir::TLExpr {
tensorlogic_ir::TLExpr::WeightedRule {
weight,
rule: Box::new(self.to_tlexpr()),
}
}
pub fn to_tl_weighted_equation(&self, target_var: &str, weight: f64) -> tensorlogic_ir::TLExpr {
let lhs = tensorlogic_ir::TLExpr::Pred {
name: target_var.to_string(),
args: vec![tensorlogic_ir::Term::var(target_var)],
};
let eq = tensorlogic_ir::TLExpr::Eq(Box::new(lhs), Box::new(self.to_tlexpr()));
tensorlogic_ir::TLExpr::WeightedRule {
weight,
rule: Box::new(eq),
}
}
}
#[cfg(test)]
#[cfg(feature = "tensorlogic")]
mod tl_adapter_tests {
use super::*;
use crate::canonical::Canonical;
use crate::tensorlogic;
use tensorlogic_ir::{TLExpr, Term};
fn make_formula() -> DiscoveredFormula {
let tree = Canonical::nat(1);
DiscoveredFormula {
eml_tree: tree,
mse: 0.0,
complexity: 1,
score: 0.0,
pretty: "1".to_string(),
params: vec![],
cv_mse: None,
}
}
#[test]
fn discoveredformula_to_tlexpr_matches_lowered_simplified_path() {
let f = make_formula();
let expected = tensorlogic::to_tlexpr(&f.eml_tree.lower().simplify());
assert_eq!(f.to_tlexpr(), expected);
}
#[test]
fn to_tl_weighted_rule_shape_carries_weight_verbatim() {
let f = make_formula();
let tl = f.to_tl_weighted_rule(0.42);
match tl {
TLExpr::WeightedRule { weight, .. } => {
assert!((weight - 0.42).abs() < f64::EPSILON);
}
other => panic!("expected WeightedRule, got {other:?}"),
}
}
#[test]
fn to_tl_weighted_equation_shape_lhs_pred_eq_rhs_formula() {
let f = make_formula();
let tl = f.to_tl_weighted_equation("y", 1.0);
match tl {
TLExpr::WeightedRule { weight, rule } => {
assert!((weight - 1.0).abs() < f64::EPSILON);
match *rule {
TLExpr::Eq(lhs, rhs) => {
match *lhs {
TLExpr::Pred { name, ref args } => {
assert_eq!(name, "y");
assert_eq!(args.len(), 1);
assert_eq!(args[0], Term::var("y"));
}
other => panic!("expected Pred on lhs, got {other:?}"),
}
assert_eq!(*rhs, f.to_tlexpr());
}
other => panic!("expected Eq inside WeightedRule, got {other:?}"),
}
}
other => panic!("expected WeightedRule, got {other:?}"),
}
}
}
pub fn pareto_front(formulas: &[DiscoveredFormula]) -> Vec<DiscoveredFormula> {
let mut front: Vec<DiscoveredFormula> = formulas
.iter()
.filter(|candidate| !formulas.iter().any(|other| other.dominates(candidate)))
.cloned()
.collect();
front.sort_by(|a, b| {
a.complexity.cmp(&b.complexity).then_with(|| {
a.mse
.partial_cmp(&b.mse)
.unwrap_or(std::cmp::Ordering::Equal)
})
});
front
}
pub struct SymRegEngine {
config: SymRegConfig,
}
impl SymRegEngine {
pub fn new(config: SymRegConfig) -> Self {
Self { config }
}
pub fn discover_pareto(
&self,
inputs: &[Vec<f64>],
targets: &[f64],
num_vars: usize,
) -> Result<Vec<DiscoveredFormula>, EmlError> {
let formulas = self.discover(inputs, targets, num_vars)?;
Ok(pareto_front(&formulas))
}
pub fn discover(
&self,
inputs: &[Vec<f64>],
targets: &[f64],
num_vars: usize,
) -> Result<Vec<DiscoveredFormula>, EmlError> {
match self.config.strategy {
SymRegStrategy::Exhaustive => self.discover_exhaustive(inputs, targets, num_vars),
SymRegStrategy::Beam { width } => self.discover_beam(inputs, targets, num_vars, width),
SymRegStrategy::Mcts {
iterations,
exploration,
} => self.discover_mcts(inputs, targets, num_vars, iterations, exploration),
}
}
pub fn discover_exhaustive(
&self,
inputs: &[Vec<f64>],
targets: &[f64],
num_vars: usize,
) -> Result<Vec<DiscoveredFormula>, EmlError> {
if inputs.is_empty() || targets.is_empty() {
return Err(EmlError::EmptyData);
}
if inputs.len() != targets.len() {
return Err(EmlError::DimensionMismatch(inputs.len(), targets.len()));
}
let topologies = enumerate_topologies(self.config.max_depth, num_vars);
let topologies = dedupe_by_semantics(topologies);
let topologies = if self.config.interval_pruning {
use crate::lower_interval::IntervalLO;
let input_intervals: Vec<IntervalLO> = (0..num_vars)
.map(|j| {
let mut lo = f64::INFINITY;
let mut hi = f64::NEG_INFINITY;
for row in inputs.iter() {
if let Some(&v) = row.get(j) {
if v < lo {
lo = v;
}
if v > hi {
hi = v;
}
}
}
if lo.is_finite() && hi.is_finite() {
IntervalLO::new(lo, hi)
} else {
IntervalLO::full()
}
})
.collect();
let target_lo = targets.iter().copied().fold(f64::INFINITY, f64::min);
let target_hi = targets.iter().copied().fold(f64::NEG_INFINITY, f64::max);
let threshold = self.config.interval_pruning_depth_threshold;
topologies
.into_iter()
.filter(|topo| {
if topo.depth() < threshold {
true } else {
topology_interval_feasible(topo, &input_intervals, target_lo, target_hi)
}
})
.collect()
} else {
topologies
};
let topologies = if let Some((ref var_units, target_units)) = self.config.unit_filter {
topologies
.into_iter()
.filter(|topo| {
let lowered = topo.lower().simplify();
matches!(lowered.check_units(var_units), Ok(u) if u == target_units)
})
.collect()
} else {
topologies
};
self.optimize_and_finalize(topologies, inputs, targets)
}
pub fn discover_beam(
&self,
inputs: &[Vec<f64>],
targets: &[f64],
num_vars: usize,
width: usize,
) -> Result<Vec<DiscoveredFormula>, EmlError> {
if inputs.is_empty() || targets.is_empty() {
return Err(EmlError::EmptyData);
}
if inputs.len() != targets.len() {
return Err(EmlError::DimensionMismatch(inputs.len(), targets.len()));
}
let topologies = enumerate_topologies(self.config.max_depth, num_vars);
let topologies = dedupe_by_semantics(topologies);
let topologies = if let Some((ref var_units, target_units)) = self.config.unit_filter {
topologies
.into_iter()
.filter(|topo| {
let lowered = topo.lower().simplify();
matches!(lowered.check_units(var_units), Ok(u) if u == target_units)
})
.collect()
} else {
topologies
};
let surrogate_iters = self.config.max_iter.clamp(10, 50);
let surrogate_config = SymRegConfig {
max_iter: surrogate_iters,
num_restarts: 1,
cv_folds: None,
..self.config.clone()
};
let surrogate_engine = SymRegEngine::new(surrogate_config);
#[cfg(feature = "parallel")]
let mut surrogate_scores: Vec<(usize, f64)> = topologies
.par_iter()
.enumerate()
.filter_map(|(i, topo)| {
surrogate_engine
.optimize_topology(topo, inputs, targets, i)
.map(|f| (i, f.mse))
})
.collect();
#[cfg(not(feature = "parallel"))]
let mut surrogate_scores: Vec<(usize, f64)> = topologies
.iter()
.enumerate()
.filter_map(|(i, topo)| {
surrogate_engine
.optimize_topology(topo, inputs, targets, i)
.map(|f| (i, f.mse))
})
.collect();
surrogate_scores.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let effective_width = width.max(1);
surrogate_scores.truncate(effective_width);
let mut keep_indices: Vec<usize> = surrogate_scores.iter().map(|&(i, _)| i).collect();
keep_indices.sort_unstable();
let beam_topologies: Vec<EmlTree> = keep_indices
.iter()
.filter_map(|&i| topologies.get(i).cloned())
.collect();
self.optimize_and_finalize(beam_topologies, inputs, targets)
}
fn discover_mcts(
&self,
inputs: &[Vec<f64>],
targets: &[f64],
num_vars: usize,
iterations: usize,
exploration: f64,
) -> Result<Vec<DiscoveredFormula>, EmlError> {
mcts::run_mcts(self, inputs, targets, num_vars, iterations, exploration)
}
pub(super) fn config(&self) -> &SymRegConfig {
&self.config
}
pub(super) fn optimize_topology_pub(
&self,
topology: &EmlTree,
inputs: &[Vec<f64>],
targets: &[f64],
topology_idx: usize,
) -> Option<DiscoveredFormula> {
self.optimize_topology(topology, inputs, targets, topology_idx)
}
pub(super) fn optimize_and_finalize_pub(
&self,
topologies: Vec<EmlTree>,
inputs: &[Vec<f64>],
targets: &[f64],
) -> Result<Vec<DiscoveredFormula>, EmlError> {
self.optimize_and_finalize(topologies, inputs, targets)
}
pub fn discover_multi(
&self,
inputs: &[Vec<f64>],
targets: &[Vec<f64>],
num_vars: usize,
) -> Result<Vec<Vec<DiscoveredFormula>>, EmlError> {
if inputs.is_empty() {
return Err(EmlError::EmptyData);
}
if targets.is_empty() {
return Err(EmlError::EmptyData);
}
for col in targets.iter() {
if col.len() != inputs.len() {
return Err(EmlError::DimensionMismatch(inputs.len(), col.len()));
}
}
match &self.config.multi_output_strategy {
MultiOutputStrategy::Independent => targets
.iter()
.map(|col| self.discover(inputs, col, num_vars))
.collect(),
}
}
pub fn discover_ode(
&self,
trajectory: &[Vec<f64>],
dt: f64,
) -> Result<Vec<Vec<DiscoveredFormula>>, EmlError> {
if trajectory.is_empty() {
return Err(EmlError::EmptyData);
}
let n_timesteps = trajectory[0].len();
for var in trajectory.iter() {
if var.len() != n_timesteps {
return Err(EmlError::DimensionMismatch(n_timesteps, var.len()));
}
}
if n_timesteps < 3 {
return Err(EmlError::DimensionMismatch(3, n_timesteps));
}
let n_vars = trajectory.len();
let derivatives: Vec<Vec<f64>> = trajectory
.iter()
.map(|x| match self.config.ode_sg_window {
Some(w) if w >= 5 => numerics::savitzky_golay_derivative(x, dt),
_ => numerics::central_differences(x, dt),
})
.collect();
let n_interior = n_timesteps - 2;
let mut features: Vec<Vec<f64>> = Vec::with_capacity(n_interior);
for t in 1..n_timesteps - 1 {
features.push(trajectory.iter().map(|x| x[t]).collect());
}
let targets: Vec<Vec<f64>> = derivatives
.iter()
.map(|dx| dx[1..n_timesteps - 1].to_vec())
.collect();
self.discover_multi(&features, &targets, n_vars)
}
fn optimize_and_finalize(
&self,
topologies: Vec<EmlTree>,
inputs: &[Vec<f64>],
targets: &[f64],
) -> Result<Vec<DiscoveredFormula>, EmlError> {
#[cfg(feature = "parallel")]
let mut formulas: Vec<DiscoveredFormula> = topologies
.par_iter()
.enumerate()
.filter_map(|(i, topology)| self.optimize_topology(topology, inputs, targets, i))
.collect();
#[cfg(not(feature = "parallel"))]
let mut formulas: Vec<DiscoveredFormula> = topologies
.iter()
.enumerate()
.filter_map(|(i, topology)| self.optimize_topology(topology, inputs, targets, i))
.collect();
formulas.sort_by(|a, b| {
use std::collections::hash_map::DefaultHasher;
use std::hash::Hasher;
a.score
.partial_cmp(&b.score)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.complexity.cmp(&b.complexity))
.then_with(|| {
let hash_of = |f: &DiscoveredFormula| {
let mut h = DefaultHasher::new();
f.eml_tree.lower().simplify().structural_hash(&mut h);
h.finish()
};
hash_of(a).cmp(&hash_of(b))
})
});
if let Some(k) = self.config.cv_folds {
let k = k.clamp(2, inputs.len());
for formula in &mut formulas {
formula.cv_mse =
Some(self.k_fold_cv(&formula.eml_tree, &formula.params, inputs, targets, k));
}
formulas.sort_by(|a, b| {
let a_score = a.cv_mse.unwrap_or(a.score);
let b_score = b.cv_mse.unwrap_or(b.score);
a_score
.partial_cmp(&b_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
}
Ok(formulas)
}
fn k_fold_cv(
&self,
topology: &EmlTree,
params: &[f64],
inputs: &[Vec<f64>],
targets: &[f64],
k: usize,
) -> f64 {
let n = inputs.len();
if n < 2 || k <= 1 {
return compute_mse_direct(topology, inputs, targets).unwrap_or(f64::INFINITY);
}
let fold_iters = (self.config.max_iter / k).clamp(1, 200);
let lr = self.config.learning_rate;
let beta1 = 0.9_f64;
let beta2 = 0.999_f64;
let epsilon = 1e-8_f64;
let mut total_cv_mse = 0.0;
let mut valid_folds = 0usize;
for fold in 0..k {
let fold_start = (fold * n) / k;
let fold_end = ((fold + 1) * n) / k;
if fold_start >= fold_end {
continue;
}
let train_inputs: Vec<&Vec<f64>> = inputs[..fold_start]
.iter()
.chain(inputs[fold_end..].iter())
.collect();
let train_targets: Vec<f64> = targets[..fold_start]
.iter()
.chain(targets[fold_end..].iter())
.copied()
.collect();
let test_inputs = &inputs[fold_start..fold_end];
let test_targets = &targets[fold_start..fold_end];
if train_inputs.is_empty() || test_inputs.is_empty() {
continue;
}
let mut ptree = ParameterizedEmlTree::from_topology(topology, 1.0);
if ptree.params.len() == params.len() {
ptree.params.clone_from_slice(params);
}
let n_params = ptree.num_params();
if n_params > 0 {
let mut m = vec![0.0_f64; n_params];
let mut v = vec![0.0_f64; n_params];
for t in 1..=fold_iters {
let mut total_grads = vec![0.0_f64; n_params];
let mut valid_count = 0usize;
for (input, &target) in train_inputs.iter().zip(train_targets.iter()) {
let ctx = EvalCtx::new(input);
match ptree.forward_backward(&ctx, target) {
Ok((loss, grads)) if loss.is_finite() => {
for (tg, g) in total_grads.iter_mut().zip(&grads) {
if g.is_finite() {
*tg += g;
}
}
valid_count += 1;
}
_ => {}
}
}
if valid_count == 0 {
break;
}
let n_f = valid_count as f64;
for i in 0..n_params {
let g = total_grads[i] / n_f;
m[i] = beta1 * m[i] + (1.0 - beta1) * g;
v[i] = beta2 * v[i] + (1.0 - beta2) * g * g;
let m_hat = m[i] / (1.0 - beta1.powi(t as i32));
let v_hat = v[i] / (1.0 - beta2.powi(t as i32));
ptree.params[i] -= lr * m_hat / (v_hat.sqrt() + epsilon);
}
}
}
let held_out_mse = if n_params == 0 {
let test_slices: Vec<&Vec<f64>> = test_inputs.iter().collect();
let mut total = 0.0;
let mut cnt = 0usize;
for (input, &target) in test_slices.iter().zip(test_targets) {
let ctx = EvalCtx::new(input);
if let Ok(val) = topology.eval_real(&ctx) {
if val.is_finite() {
total += (val - target).powi(2);
cnt += 1;
}
}
}
if cnt == 0 {
None
} else {
Some(total / cnt as f64)
}
} else {
let test_input_vecs: Vec<&Vec<f64>> = test_inputs.iter().collect();
let mut total = 0.0;
let mut cnt = 0usize;
for (input, &target) in test_input_vecs.iter().zip(test_targets) {
let ctx = EvalCtx::new(input);
if let Ok(val) = ptree.forward(&ctx) {
if val.is_finite() {
total += (val - target).powi(2);
cnt += 1;
}
}
}
if cnt == 0 {
None
} else {
Some(total / cnt as f64)
}
};
if let Some(mse) = held_out_mse {
total_cv_mse += mse;
valid_folds += 1;
}
}
if valid_folds == 0 {
compute_mse_direct(topology, inputs, targets).unwrap_or(f64::INFINITY)
} else {
total_cv_mse / valid_folds as f64
}
}
fn optimize_topology(
&self,
topology: &EmlTree,
inputs: &[Vec<f64>],
targets: &[f64],
topology_idx: usize,
) -> Option<DiscoveredFormula> {
let mut best_mse = f64::INFINITY;
let mut best_params = Vec::new();
let mut rng = make_rng(self.config.seed, topology_idx as u64);
for _ in 0..self.config.num_restarts {
let mut ptree = ParameterizedEmlTree::from_topology(topology, 1.0);
for p in &mut ptree.params {
*p = 1.0 + rng.random_range(-0.5..0.5);
}
let n_params = ptree.num_params();
if n_params == 0 {
let mse = compute_mse_direct(topology, inputs, targets);
if let Some(mse) = mse {
if mse < best_mse {
best_mse = mse;
best_params = vec![];
}
}
break;
}
let mut m = vec![0.0_f64; n_params]; let mut v = vec![0.0_f64; n_params]; let beta1 = 0.9;
let beta2 = 0.999;
let epsilon = 1e-8;
let lr = self.config.learning_rate;
let mut converged = false;
for t in 1..=self.config.max_iter {
let mut outputs_and_jacs: Vec<(f64, Vec<f64>)> = Vec::with_capacity(inputs.len());
let mut residuals: Vec<f64> = Vec::with_capacity(inputs.len());
for (input, &target) in inputs.iter().zip(targets) {
let ctx = EvalCtx::new(input);
match ptree.forward_with_jacobian(&ctx) {
Ok((out, jac)) if out.is_finite() => {
residuals.push(out - target);
outputs_and_jacs.push((out, jac));
}
_ => {}
}
}
let valid_count = residuals.len();
if valid_count == 0 {
break;
}
let (total_loss, total_grads) = match &self.config.loss {
SymRegLoss::Mse => {
let tloss: f64 = residuals.iter().map(|r| r * r).sum();
let mut tg = vec![0.0_f64; n_params];
for (r, (_, jac)) in residuals.iter().zip(&outputs_and_jacs) {
for (tg_i, &j) in tg.iter_mut().zip(jac.iter()) {
if j.is_finite() {
*tg_i += 2.0 * r * j;
}
}
}
(tloss, tg)
}
SymRegLoss::Huber { delta } => {
let d = *delta;
let tloss = huber_loss(&residuals, d) * valid_count as f64;
let mut tg = vec![0.0_f64; n_params];
for (r, (_, jac)) in residuals.iter().zip(&outputs_and_jacs) {
let gf = 2.0 * huber_grad_factor(*r, d);
for (tg_i, &j) in tg.iter_mut().zip(jac.iter()) {
if j.is_finite() {
*tg_i += gf * j;
}
}
}
(tloss, tg)
}
SymRegLoss::TrimmedMse { alpha } => {
let a = *alpha;
let tloss = trimmed_mse(&residuals, a) * valid_count as f64;
let mut tg = vec![0.0_f64; n_params];
for (r, (_, jac)) in residuals.iter().zip(&outputs_and_jacs) {
let gf = 2.0 * trimmed_mse_grad_factor(*r, &residuals, a);
for (tg_i, &j) in tg.iter_mut().zip(jac.iter()) {
if j.is_finite() {
*tg_i += gf * j;
}
}
}
(tloss, tg)
}
};
let n_f = valid_count as f64;
let mse = total_loss / n_f;
if mse < self.config.tolerance {
best_mse = mse;
best_params = ptree.params.clone();
converged = true;
break;
}
for i in 0..n_params {
let g = total_grads[i] / n_f;
m[i] = beta1 * m[i] + (1.0 - beta1) * g;
v[i] = beta2 * v[i] + (1.0 - beta2) * g * g;
let m_hat = m[i] / (1.0 - beta1.powi(t as i32));
let v_hat = v[i] / (1.0 - beta2.powi(t as i32));
ptree.params[i] -= lr * m_hat / (v_hat.sqrt() + epsilon);
}
if mse < best_mse {
best_mse = mse;
best_params = ptree.params.clone();
}
}
if converged {
break;
}
}
if self.config.integer_rounding && !best_params.is_empty() {
let rounded = try_integer_rounding(&best_params);
let mut ptree_rounded = ParameterizedEmlTree::from_topology(topology, 1.0);
ptree_rounded.params = rounded;
let rounded_mse = compute_mse_parameterized(&ptree_rounded, inputs, targets);
if let Some(rmse) = rounded_mse {
if rmse <= best_mse * 1.01 {
best_mse = rmse;
best_params = ptree_rounded.params;
}
}
}
if !best_mse.is_finite() || best_mse > 1e10 {
return None;
}
let complexity = topology.size();
let baked = bake_params_into_lowered(topology, &best_params);
let baked_simplified = baked.simplify();
let (final_op, final_mse) = if let Some(eps) = self.config.constant_extraction {
extract_named_constants(baked_simplified, best_mse, eps, inputs, targets)
} else {
(baked_simplified, best_mse)
};
let pretty = final_op.to_pretty();
Some(DiscoveredFormula {
eml_tree: topology.clone(),
mse: final_mse,
complexity,
score: final_mse + self.config.complexity_penalty * complexity as f64,
pretty,
params: best_params,
cv_mse: None,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_enumerate_depth0() {
let topos = enumerate_topologies(0, 1);
assert_eq!(topos.len(), 2);
}
#[test]
fn test_enumerate_depth1() {
let topos = enumerate_topologies(1, 1);
assert!(topos.len() >= 6);
}
#[test]
fn test_symreg_exp() {
let inputs: Vec<Vec<f64>> = (0..20).map(|i| vec![i as f64 * 0.25]).collect();
let targets: Vec<f64> = inputs.iter().map(|x| x[0].exp()).collect();
let config = SymRegConfig {
max_depth: 1,
learning_rate: 1e-2,
tolerance: 1e-6,
max_iter: 1000,
complexity_penalty: 1e-4,
num_restarts: 2,
integer_rounding: true,
..SymRegConfig::default()
};
let engine = SymRegEngine::new(config);
let formulas = engine
.discover(&inputs, &targets, 1)
.expect("symreg discover exp should succeed");
assert!(!formulas.is_empty());
assert!(formulas[0].mse < 1.0);
}
#[test]
fn test_integer_rounding() {
let params = vec![0.98, 2.03, 1.51, -0.99];
let rounded = try_integer_rounding(¶ms);
assert!((rounded[0] - 1.0).abs() < 1e-15);
assert!((rounded[1] - 2.0).abs() < 1e-15);
assert!((rounded[2] - 1.51).abs() < 1e-15); assert!((rounded[3] - (-1.0)).abs() < 1e-15);
}
#[test]
fn test_symreg_parallel_matches_sequential() {
let inputs: Vec<Vec<f64>> = (0..20).map(|i| vec![i as f64 * 0.25]).collect();
let targets: Vec<f64> = inputs.iter().map(|x| x[0].exp()).collect();
let config = SymRegConfig {
max_depth: 1,
learning_rate: 1e-2,
tolerance: 1e-6,
max_iter: 1000,
complexity_penalty: 1e-4,
num_restarts: 2,
integer_rounding: true,
..SymRegConfig::default()
};
let engine = SymRegEngine::new(config);
let formulas = engine
.discover(&inputs, &targets, 1)
.expect("parallel symreg discover should succeed");
assert!(!formulas.is_empty());
assert!(formulas[0].mse < 1.0);
}
#[test]
fn test_empty_data() {
let engine = SymRegEngine::new(SymRegConfig::default());
let result = engine.discover(&[], &[], 1);
assert!(matches!(result, Err(EmlError::EmptyData)));
}
#[test]
fn test_dimension_mismatch() {
let engine = SymRegEngine::new(SymRegConfig::default());
let result = engine.discover(&[vec![1.0]], &[1.0, 2.0], 1);
assert!(matches!(result, Err(EmlError::DimensionMismatch(1, 2))));
}
#[test]
fn test_dedupe_reduces_topology_count() {
let topologies = enumerate_topologies(2, 1);
let before = topologies.len();
let after = dedupe_by_semantics(topologies).len();
assert!(
after <= before,
"dedup must not grow the set: before={before}, after={after}"
);
}
#[test]
#[ignore = "slow: depth-4 enumerates 2M topologies, ~38s wall-clock"]
fn test_dedupe_depth_four_stress() {
let topologies = enumerate_topologies(4, 1);
let before = topologies.len();
let after = dedupe_by_semantics(topologies).len();
assert!(after <= before);
}
#[test]
fn test_dedupe_preserves_uniqueness() {
use std::collections::HashSet;
use std::collections::hash_map::DefaultHasher;
use std::hash::Hasher;
let topologies = enumerate_topologies(3, 1);
let deduped = dedupe_by_semantics(topologies);
let mut hashes: HashSet<u64> = HashSet::new();
for tree in &deduped {
let eml_simplified = crate::simplify::simplify(tree);
let simplified = eml_simplified.lower().simplify();
let mut h = DefaultHasher::new();
simplified.structural_hash(&mut h);
let inserted = hashes.insert(h.finish());
assert!(inserted, "duplicate structural hash found in deduped set");
}
assert_eq!(hashes.len(), deduped.len());
}
#[test]
fn test_dedupe_preserves_discovery_exp() {
let inputs: Vec<Vec<f64>> = (0..30).map(|i| vec![i as f64 * 0.2]).collect();
let targets: Vec<f64> = inputs.iter().map(|x| x[0].exp()).collect();
let config = SymRegConfig {
max_depth: 2,
learning_rate: 1e-2,
tolerance: 1e-5,
max_iter: 1000,
complexity_penalty: 1e-4,
num_restarts: 2,
integer_rounding: false,
..SymRegConfig::default()
};
let engine = SymRegEngine::new(config);
let formulas = engine
.discover(&inputs, &targets, 1)
.expect("discover should succeed");
assert!(!formulas.is_empty(), "should discover at least one formula");
let best = &formulas[0];
assert!(
best.mse < 0.1,
"best formula MSE too high after dedup: {} (pretty={})",
best.mse,
best.pretty
);
}
}
#[cfg(test)]
mod preset_tests {
use super::*;
#[test]
fn balanced_equals_default() {
let bal = SymRegConfig::balanced();
let def = SymRegConfig::default();
assert_eq!(bal.max_depth, def.max_depth);
assert_eq!(bal.max_iter, def.max_iter);
assert_eq!(bal.num_restarts, def.num_restarts);
assert_eq!(bal.integer_rounding, def.integer_rounding);
assert_eq!(bal.seed, def.seed);
assert_eq!(bal.constant_extraction, def.constant_extraction);
assert_eq!(bal.learning_rate.to_bits(), def.learning_rate.to_bits());
assert_eq!(bal.tolerance.to_bits(), def.tolerance.to_bits());
assert_eq!(
bal.complexity_penalty.to_bits(),
def.complexity_penalty.to_bits()
);
}
#[test]
fn quick_is_faster_than_balanced() {
let q = SymRegConfig::quick();
let b = SymRegConfig::balanced();
assert!(q.max_iter <= b.max_iter);
assert!(q.num_restarts <= b.num_restarts);
assert!(q.max_depth <= b.max_depth);
}
#[test]
fn exhaustive_is_slower_than_balanced() {
let e = SymRegConfig::exhaustive();
let b = SymRegConfig::balanced();
assert!(e.max_iter >= b.max_iter);
assert!(e.num_restarts >= b.num_restarts);
assert!(e.max_depth >= b.max_depth);
}
#[test]
fn engine_constructs_from_preset() {
let _ = SymRegEngine::new(SymRegConfig::quick());
let _ = SymRegEngine::new(SymRegConfig::balanced());
let _ = SymRegEngine::new(SymRegConfig::exhaustive());
}
}