1use std::time::Instant;
2use exp_root_log::approx_exp_root_log;
3use nalgebra as na;
4
5fn sin_func(x: f64) -> f64 { (2.0 * std::f64::consts::PI * x).sin() }
7fn exp_decay(x: f64) -> f64 { (-5.0 * x).exp() }
8fn step(x: f64) -> f64 { if x < 0.5 { 1.0 } else { 0.0 } }
9fn spike(x: f64) -> f64 { (-100.0 * (x - 0.5).powi(2)).exp() }
10
11fn mse(a: &[f64], b: &[f64]) -> f64 {
13 a.iter().zip(b).map(|(u, v)| (u - v).powi(2)).sum::<f64>() / a.len() as f64
14}
15
16fn poly_ls(x: &[f64], y: &[f64], deg: usize) -> na::DVector<f64> {
17 let n = x.len();
18 let mut m = na::DMatrix::zeros(n, deg + 1);
19 for (i, &xi) in x.iter().enumerate() {
20 for j in 0..=deg {
21 m[(i, j)] = xi.powi(j as i32);
22 }
23 }
24 m.svd(true, true)
25 .solve(&na::DVector::from_column_slice(y), 1e-12)
26 .expect("poly solve")
27}
28
29fn poly_predict(coeffs: &[f64], x: f64) -> f64 {
30 coeffs
31 .iter()
32 .enumerate()
33 .map(|(j, &c)| c * x.powi(j as i32))
34 .sum()
35}
36
37fn bench_one<F>(name: &str, f: F)
39where
40 F: Fn(f64) -> f64,
41{
42 let x: Vec<f64> = (0..2000).map(|i| i as f64 / 2000.0).collect();
43 let y: Vec<f64> = x.iter().map(|&xi| f(xi)).collect();
44
45 let t0 = Instant::now();
47 let exp_fn = approx_exp_root_log(
48 &x,
49 &y,
50 &[0.5, 2.0, 5.0, 10.0],
51 5,
52 &[1.0, 5.0, 10.0, 20.0],
53 );
54 let y_pred_exp: Vec<f64> = x.iter().map(|&xi| exp_fn(xi)).collect();
55 let mse_exp = mse(&y, &y_pred_exp);
56 let dt_exp = t0.elapsed();
57
58 let t1 = Instant::now();
60 let coeffs = poly_ls(&x, &y, 10);
61 let y_pred_poly: Vec<f64> =
62 x.iter().map(|&xi| poly_predict(coeffs.as_slice(), xi)).collect();
63 let mse_poly = mse(&y, &y_pred_poly);
64 let dt_poly = t1.elapsed();
65
66 println!(
67 "{:<8} | ExpRoot MSE = {:8.2e} {:>6?} || Poly MSE = {:8.2e} {:>6?}",
68 name, mse_exp, dt_exp, mse_poly, dt_poly
69 );
70}
71
72fn main() {
73 println!("Function | ExpRoot+Log | Polynomial(deg=10)");
74 println!("--------------------------------------------------------------------------");
75 bench_one("Sin", sin_func);
76 bench_one("ExpDecay", exp_decay);
77 bench_one("Step", step);
78 bench_one("Spike", spike);
79}