exp_root_log/
lib.rs

1use nalgebra::{DMatrix, DVector};
2
3/// Constructs the design matrix using basis functions:
4/// - exp(-sqrt(b * x))
5/// - x^j (polynomials)
6/// - log(1 + λ * x)
7fn build_basis(
8    x: &[f64],
9    b_list: &[f64],
10    poly_deg: usize,
11    log_lambdas: &[f64],
12) -> DMatrix<f64> {
13    let n = x.len();
14    let n_cols = b_list.len() + (poly_deg + 1) + log_lambdas.len();
15    let mut mat = DMatrix::zeros(n, n_cols);
16
17    for (i, &xi) in x.iter().enumerate() {
18        let mut col = 0;
19
20        // Exponential-root terms
21        for &b in b_list {
22            mat[(i, col)] = (-((b * xi).sqrt())).exp();
23            col += 1;
24        }
25
26        // Polynomial terms
27        for j in 0..=poly_deg {
28            mat[(i, col)] = xi.powi(j as i32);
29            col += 1;
30        }
31
32        // Logarithmic terms
33        for &lambda in log_lambdas {
34            mat[(i, col)] = (1.0 + lambda * xi).ln();
35            col += 1;
36        }
37    }
38
39    mat
40}
41
42/// Fits an ExpRoot+Log approximation to the (x, y) data,
43/// and returns a closure representing the approximate function.
44pub fn approx_exp_root_log(
45    x: &[f64],
46    y: &[f64],
47    b_list: &[f64],
48    poly_deg: usize,
49    log_lambdas: &[f64],
50) -> impl Fn(f64) -> f64 {
51    let design = build_basis(x, b_list, poly_deg, log_lambdas);
52    let y_vec = DVector::from_column_slice(y);
53
54    // Solve least squares using QR decomposition
55    let coeffs = design
56    .svd(true, true)
57    .solve(&y_vec, 1e-10)
58    .expect("SVD solve failed");
59
60
61    let coeffs = coeffs.data.as_vec().clone();
62    let b_list = b_list.to_vec();
63    let log_lambdas = log_lambdas.to_vec();
64
65    move |x: f64| {
66        let mut result = 0.0;
67        let mut idx = 0;
68
69        for &b in &b_list {
70            result += coeffs[idx] * (-((b * x).sqrt())).exp();
71            idx += 1;
72        }
73
74        for j in 0..=poly_deg {
75            result += coeffs[idx] * x.powi(j as i32);
76            idx += 1;
77        }
78
79        for &lambda in &log_lambdas {
80            result += coeffs[idx] * (1.0 + lambda * x).ln();
81            idx += 1;
82        }
83
84        result
85    }
86}
87
88#[cfg(test)]
89mod tests {
90    use super::*;
91
92    fn mse(y_true: &[f64], y_pred: &[f64]) -> f64 {
93        y_true
94            .iter()
95            .zip(y_pred)
96            .map(|(a, b)| (a - b).powi(2))
97            .sum::<f64>()
98            / y_true.len() as f64
99    }
100
101    #[test]
102    fn test_sin_approximation() {
103        let x: Vec<f64> = (0..100).map(|i| i as f64 / 100.0).collect();
104        let y: Vec<f64> = x
105            .iter()
106            .map(|&x| (2.0 * std::f64::consts::PI * x).sin())
107            .collect();
108
109        let approx_fn = approx_exp_root_log(
110            &x,
111            &y,
112            &[0.5, 2.0, 5.0],   // b_i
113            3,                  // polynomial degree
114            &[1.0, 5.0, 10.0],  // λ_k
115        );
116
117        let y_pred: Vec<f64> = x.iter().map(|&xi| approx_fn(xi)).collect();
118        let error = mse(&y, &y_pred);
119
120        println!("MSE for sin approximation: {:.2e}", error);
121        assert!(error < 1e-3, "MSE too high: {}", error);
122    }
123}