use std::f64::consts::PI;
pub fn polynomial(x: &[f64], degree: usize) -> (Vec<f64>, usize) {
let n = x.len();
let ncols = degree + 1;
let mut mat = vec![0.0_f64; n * ncols];
for i in 0..n {
for p in 0..ncols {
mat[i * ncols + p] = x[i].powi(p as i32);
}
}
(mat, ncols)
}
pub fn rbf(x: &[f64], centers: &[f64], width: f64) -> (Vec<f64>, usize) {
let n = x.len();
let ncols = centers.len();
let mut mat = vec![0.0_f64; n * ncols];
for i in 0..n {
for j in 0..ncols {
let diff = (x[i] - centers[j]) / width;
mat[i * ncols + j] = (-0.5 * diff * diff).exp();
}
}
(mat, ncols)
}
pub fn trig(x: &[f64], n_freq: usize) -> (Vec<f64>, usize) {
let n = x.len();
let ncols = 2 * n_freq + 1;
let mut mat = vec![0.0_f64; n * ncols];
for i in 0..n {
mat[i * ncols] = 1.0; for freq in 1..=n_freq {
let arg = (freq as f64) * PI * x[i];
mat[i * ncols + 2 * freq - 1] = arg.sin();
mat[i * ncols + 2 * freq] = arg.cos();
}
}
(mat, ncols)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_polynomial_parity() {
let (mat, ncols) = polynomial(&[0.5, 1.0, 1.5], 3);
assert_eq!(ncols, 4);
let row0 = &mat[..4];
let expected = [1.0_f64, 0.5, 0.25, 0.125];
for (a, b) in row0.iter().zip(expected.iter()) {
assert!((a - b).abs() < 1e-10, "poly row0 mismatch: {a} vs {b}");
}
let row1 = &mat[4..8];
for v in row1 {
assert!((v - 1.0).abs() < 1e-10);
}
let row2 = &mat[8..12];
let expected2 = [1.0_f64, 1.5, 2.25, 3.375];
for (a, b) in row2.iter().zip(expected2.iter()) {
assert!((a - b).abs() < 1e-10, "poly row2 mismatch: {a} vs {b}");
}
}
#[test]
fn test_rbf_parity() {
let (mat, ncols) = rbf(&[0.0], &[0.0, 1.0], 1.0);
assert_eq!(ncols, 2);
assert!((mat[0] - 1.0).abs() < 1e-10);
assert!((mat[1] - (-0.5_f64).exp()).abs() < 1e-10);
}
#[test]
fn test_trig_parity() {
let (mat, ncols) = trig(&[1.0], 2);
assert_eq!(ncols, 5);
let expected = [
1.0_f64,
(PI).sin(), (PI).cos(), (2.0 * PI).sin(), (2.0 * PI).cos(), ];
for (i, (a, b)) in mat.iter().zip(expected.iter()).enumerate() {
assert!((a - b).abs() < 1e-10, "trig col {i}: {a} vs {b}");
}
}
#[test]
fn test_trig_ncols() {
let (_, ncols) = trig(&[0.0; 10], 5);
assert_eq!(ncols, 11);
}
}