use oxieml::compile::compile_to_rust_batch;
use oxieml::{EvalCtx, SymRegConfig, SymRegEngine};
const G: f64 = 9.81;
fn truth(length: f64) -> f64 {
2.0 * std::f64::consts::PI * (length / G).sqrt()
}
fn generate_dataset(n: usize, seed: u64) -> (Vec<Vec<f64>>, Vec<f64>) {
let mut state = seed ^ 0x9E37_79B9_7F4A_7C15;
let mut next = move || {
state = state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
let bits = (state >> 32) as u32;
f64::from(bits) / f64::from(u32::MAX)
};
let mut inputs = Vec::with_capacity(n);
let mut targets = Vec::with_capacity(n);
for _ in 0..n {
let length = 0.1 + 1.9 * next();
inputs.push(vec![length]);
targets.push(truth(length));
}
(inputs, targets)
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("=== OxiEML pendulum-period demo ===\n");
let (inputs, targets) = generate_dataset(150, 0xBADC_0FFE);
let min_t = targets.iter().copied().fold(f64::INFINITY, f64::min);
let max_t = targets.iter().copied().fold(f64::NEG_INFINITY, f64::max);
println!(
"Generated {} training samples: target range [{:.3}, {:.3}] s",
inputs.len(),
min_t,
max_t
);
let config = SymRegConfig {
max_depth: 4,
..Default::default()
};
let engine = SymRegEngine::new(config);
println!("Running symbolic regression (this may take a few seconds)...");
let formulas = engine.discover(&inputs, &targets, 1)?;
let best = match formulas.first() {
Some(f) => f,
None => {
eprintln!("no formulas discovered");
return Ok(());
}
};
println!("\nBest discovered formula (rank 1 of {}):", formulas.len());
println!(" score: {:.6}", best.score);
println!(" MSE: {:.6}", best.mse);
println!(" complexity: {}", best.complexity);
println!(" pretty: {}", best.pretty);
println!(" params: {:?}", best.params);
let lowered = best.eml_tree.lower();
let simplified = lowered.simplify();
println!("\nLowered IR: {lowered}");
println!("Lowered + simplified: {simplified}");
let rust_src = compile_to_rust_batch(&best.eml_tree, "pendulum_period");
println!("\nCompiled Rust (single-point + batch):\n----------------------------");
print!("{rust_src}");
println!("----------------------------");
let (test_inputs, test_targets) = generate_dataset(50, 0xC0DE_F00D);
let preds = match best.eml_tree.eval_batch(&test_inputs) {
Ok(preds) => {
println!("\nUsing EmlTree::eval_batch for held-out evaluation.");
preds
}
Err(e) => {
println!(
"\nEmlTree::eval_batch produced {e:?}; falling back to LoweredOp::eval_batch."
);
simplified.eval_batch(&test_inputs)
}
};
let pairs: Vec<(f64, f64)> = preds
.iter()
.zip(test_targets.iter())
.filter_map(|(&p, &y)| if p.is_finite() { Some((p, y)) } else { None })
.collect();
let valid = pairs.len();
let total = preds.len();
let mut mse = 0.0;
let mut max_abs_err = 0.0_f64;
for (p, y) in &pairs {
let d = p - y;
mse += d * d;
max_abs_err = max_abs_err.max(d.abs());
}
if valid > 0 {
mse /= valid as f64;
} else {
mse = f64::NAN;
}
println!("\nHeld-out test set ({total} points, {valid} finite):");
println!(" MSE: {mse:.6}");
println!(" max |error|: {max_abs_err:.6}");
println!("\n idx L truth predicted |error|");
println!(" ---- ------ --------- ----------- --------");
for (i, (input, (pred, truth_val))) in test_inputs
.iter()
.zip(preds.iter().zip(test_targets.iter()))
.take(5)
.enumerate()
{
let err = (pred - truth_val).abs();
println!(
" {:>3} {:>5.3} {:>8.4} {:>10.4} {:>8.5}",
i, input[0], truth_val, pred, err
);
}
let probe_length = 1.0_f64;
let probe_ctx = EvalCtx::new(&[probe_length]);
match best.eml_tree.eval_real_lowered(&probe_ctx) {
Ok(value) => {
let reference = truth(probe_length);
println!(
"\nSingle-point probe (eval_real_lowered): L = {:.3} m -> T = {:.6} s \
(truth {:.6} s, |delta| = {:.2e})",
probe_length,
value,
reference,
(value - reference).abs()
);
}
Err(e) => {
println!("\nSingle-point probe via eval_real_lowered failed: {e:?}");
}
}
let lowered_preds = simplified.eval_batch(&test_inputs);
let mut max_delta = 0.0_f64;
let mut compared = 0usize;
for (a, b) in preds.iter().zip(lowered_preds.iter()) {
if a.is_finite() && b.is_finite() {
max_delta = max_delta.max((a - b).abs());
compared += 1;
}
}
println!(
"\nEmlTree vs LoweredOp batch evaluator agree to within {max_delta:.2e} \
(max abs delta over {compared} finite points)."
);
println!("\nDone.");
Ok(())
}