1use nalgebra::{DMatrix, DVector};
2
3fn build_basis(
8 x: &[f64],
9 b_list: &[f64],
10 poly_deg: usize,
11 log_lambdas: &[f64],
12) -> DMatrix<f64> {
13 let n = x.len();
14 let n_cols = b_list.len() + (poly_deg + 1) + log_lambdas.len();
15 let mut mat = DMatrix::zeros(n, n_cols);
16
17 for (i, &xi) in x.iter().enumerate() {
18 let mut col = 0;
19
20 for &b in b_list {
22 mat[(i, col)] = (-((b * xi).sqrt())).exp();
23 col += 1;
24 }
25
26 for j in 0..=poly_deg {
28 mat[(i, col)] = xi.powi(j as i32);
29 col += 1;
30 }
31
32 for &lambda in log_lambdas {
34 mat[(i, col)] = (1.0 + lambda * xi).ln();
35 col += 1;
36 }
37 }
38
39 mat
40}
41
42pub fn approx_exp_root_log(
45 x: &[f64],
46 y: &[f64],
47 b_list: &[f64],
48 poly_deg: usize,
49 log_lambdas: &[f64],
50) -> impl Fn(f64) -> f64 {
51 let design = build_basis(x, b_list, poly_deg, log_lambdas);
52 let y_vec = DVector::from_column_slice(y);
53
54 let coeffs = design
56 .svd(true, true)
57 .solve(&y_vec, 1e-10)
58 .expect("SVD solve failed");
59
60
61 let coeffs = coeffs.data.as_vec().clone();
62 let b_list = b_list.to_vec();
63 let log_lambdas = log_lambdas.to_vec();
64
65 move |x: f64| {
66 let mut result = 0.0;
67 let mut idx = 0;
68
69 for &b in &b_list {
70 result += coeffs[idx] * (-((b * x).sqrt())).exp();
71 idx += 1;
72 }
73
74 for j in 0..=poly_deg {
75 result += coeffs[idx] * x.powi(j as i32);
76 idx += 1;
77 }
78
79 for &lambda in &log_lambdas {
80 result += coeffs[idx] * (1.0 + lambda * x).ln();
81 idx += 1;
82 }
83
84 result
85 }
86}
87
88#[cfg(test)]
89mod tests {
90 use super::*;
91
92 fn mse(y_true: &[f64], y_pred: &[f64]) -> f64 {
93 y_true
94 .iter()
95 .zip(y_pred)
96 .map(|(a, b)| (a - b).powi(2))
97 .sum::<f64>()
98 / y_true.len() as f64
99 }
100
101 #[test]
102 fn test_sin_approximation() {
103 let x: Vec<f64> = (0..100).map(|i| i as f64 / 100.0).collect();
104 let y: Vec<f64> = x
105 .iter()
106 .map(|&x| (2.0 * std::f64::consts::PI * x).sin())
107 .collect();
108
109 let approx_fn = approx_exp_root_log(
110 &x,
111 &y,
112 &[0.5, 2.0, 5.0], 3, &[1.0, 5.0, 10.0], );
116
117 let y_pred: Vec<f64> = x.iter().map(|&xi| approx_fn(xi)).collect();
118 let error = mse(&y, &y_pred);
119
120 println!("MSE for sin approximation: {:.2e}", error);
121 assert!(error < 1e-3, "MSE too high: {}", error);
122 }
123}