use super::*;
use crate::matrix::FdMatrix;
use crate::test_helpers::uniform_grid;
use nalgebra::DVector;
use std::f64::consts::PI;
fn sine_wave(t: &[f64], freq: f64) -> Vec<f64> {
t.iter().map(|&ti| (2.0 * PI * freq * ti).sin()).collect()
}
#[test]
fn test_bspline_basis_dimensions() {
let t = uniform_grid(50);
let nknots = 10;
let order = 4;
let basis = bspline_basis(&t, nknots, order);
let expected_nbasis = nknots + order;
assert_eq!(basis.len(), t.len() * expected_nbasis);
}
#[test]
fn test_bspline_basis_partition_of_unity() {
let t = uniform_grid(50);
let nknots = 8;
let order = 4;
let basis = bspline_basis(&t, nknots, order);
let nbasis = nknots + order;
for i in 0..t.len() {
let sum: f64 = (0..nbasis).map(|j| basis[i + j * t.len()]).sum();
assert!(
(sum - 1.0).abs() < 1e-10,
"B-spline partition of unity failed at point {}: sum = {}",
i,
sum
);
}
}
#[test]
fn test_bspline_basis_non_negative() {
let t = uniform_grid(50);
let basis = bspline_basis(&t, 8, 4);
for &val in &basis {
assert!(val >= -1e-10, "B-spline values should be non-negative");
}
}
#[test]
fn test_bspline_basis_boundary() {
let t = vec![0.0, 0.5, 1.0];
let basis = bspline_basis(&t, 5, 4);
for &val in &basis {
assert!(val.is_finite(), "B-spline should produce finite values");
}
}
#[test]
fn test_fourier_basis_dimensions() {
let t = uniform_grid(50);
let nbasis = 7;
let basis = fourier_basis(&t, nbasis);
assert_eq!(basis.len(), t.len() * nbasis);
}
#[test]
fn test_fourier_basis_constant_first_column() {
let t = uniform_grid(50);
let nbasis = 7;
let basis = fourier_basis(&t, nbasis);
let first_val = basis[0];
for i in 0..t.len() {
assert!(
(basis[i] - first_val).abs() < 1e-10,
"First Fourier column should be constant"
);
}
}
#[test]
fn test_fourier_basis_sin_cos_range() {
let t = uniform_grid(100);
let nbasis = 11;
let basis = fourier_basis(&t, nbasis);
for &val in &basis {
assert!((-1.0 - 1e-10..=1.0 + 1e-10).contains(&val));
}
}
#[test]
fn test_fourier_basis_with_period() {
let t = uniform_grid(100);
let nbasis = 5;
let period = 0.5;
let basis = fourier_basis_with_period(&t, nbasis, period);
assert_eq!(basis.len(), t.len() * nbasis);
let first_val = basis[0];
for i in 0..t.len() {
assert!((basis[i] - first_val).abs() < 1e-10);
}
}
#[test]
fn test_fourier_basis_period_affects_frequency() {
let t = uniform_grid(100);
let nbasis = 5;
let basis1 = fourier_basis_with_period(&t, nbasis, 1.0);
let basis2 = fourier_basis_with_period(&t, nbasis, 0.5);
let n = t.len();
let mut any_different = false;
for i in 0..n {
if (basis1[i + n] - basis2[i + n]).abs() > 1e-10 {
any_different = true;
break;
}
}
assert!(
any_different,
"Different periods should produce different bases"
);
}
#[test]
fn test_difference_matrix_order_zero() {
let d = difference_matrix(5, 0);
assert_eq!(d.nrows(), 5);
assert_eq!(d.ncols(), 5);
for i in 0..5 {
for j in 0..5 {
let expected = if i == j { 1.0 } else { 0.0 };
assert!((d[(i, j)] - expected).abs() < 1e-10);
}
}
}
#[test]
fn test_difference_matrix_first_order() {
let d = difference_matrix(5, 1);
assert_eq!(d.nrows(), 4);
assert_eq!(d.ncols(), 5);
let x = DVector::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
let dx = &d * x;
for i in 0..4 {
assert!((dx[i] - 1.0).abs() < 1e-10);
}
}
#[test]
fn test_difference_matrix_second_order() {
let d = difference_matrix(5, 2);
assert_eq!(d.nrows(), 3);
assert_eq!(d.ncols(), 5);
let x = DVector::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
let dx = &d * x;
for i in 0..3 {
assert!(dx[i].abs() < 1e-10, "Second diff of linear should be zero");
}
}
#[test]
fn test_difference_matrix_quadratic() {
let d = difference_matrix(5, 2);
let x = DVector::from_vec(vec![1.0, 4.0, 9.0, 16.0, 25.0]);
let dx = &d * x;
for i in 0..3 {
assert!(
(dx[i] - 2.0).abs() < 1e-10,
"Second diff of squares should be 2"
);
}
}
fn make_matrix(flat_row_major: &[f64], n: usize, m: usize) -> FdMatrix {
let mut col_major = vec![0.0; n * m];
for i in 0..n {
for j in 0..m {
col_major[i + j * n] = flat_row_major[i * m + j];
}
}
FdMatrix::from_column_major(col_major, n, m).unwrap()
}
#[test]
fn test_fdata_to_basis_1d_bspline() {
let t = uniform_grid(50);
let n = 5;
let m = t.len();
let flat: Vec<f64> = (0..n)
.flat_map(|i| t.iter().map(move |&ti| ti + i as f64 * 0.1))
.collect();
let data = make_matrix(&flat, n, m);
let result = fdata_to_basis_1d(&data, &t, 10, 0);
assert!(result.is_some());
let res = result.unwrap();
assert!(res.n_basis > 0);
assert_eq!(res.coefficients.nrows(), n);
assert_eq!(res.coefficients.ncols(), res.n_basis);
}
#[test]
fn test_fdata_to_basis_1d_fourier() {
let t = uniform_grid(50);
let n = 5;
let m = t.len();
let flat: Vec<f64> = (0..n).flat_map(|_| sine_wave(&t, 2.0)).collect();
let data = make_matrix(&flat, n, m);
let result = fdata_to_basis_1d(&data, &t, 7, 1);
assert!(result.is_some());
let res = result.unwrap();
assert_eq!(res.n_basis, 7);
}
#[test]
fn test_fdata_to_basis_1d_invalid_input() {
let t = uniform_grid(50);
let empty = FdMatrix::zeros(0, 50);
let result = fdata_to_basis_1d(&empty, &t, 10, 0);
assert!(result.is_none());
let data = FdMatrix::zeros(1, 50);
let result = fdata_to_basis_1d(&data, &t, 1, 0);
assert!(result.is_none());
}
#[test]
fn test_basis_roundtrip() {
let t = uniform_grid(100);
let n = 1;
let m = t.len();
let raw = sine_wave(&t, 1.0);
let data = FdMatrix::from_column_major(raw.clone(), n, m).unwrap();
let proj = fdata_to_basis_1d(&data, &t, 5, 1).unwrap();
let reconstructed = basis_to_fdata_1d(&proj.coefficients, &t, proj.n_basis, 1);
let mut max_error = 0.0;
for j in 0..m {
let err = (raw[j] - reconstructed[(0, j)]).abs();
if err > max_error {
max_error = err;
}
}
assert!(max_error < 0.5, "Roundtrip error too large: {}", max_error);
}
#[test]
fn test_basis_to_fdata_empty_input() {
let empty = FdMatrix::zeros(0, 0);
let result = basis_to_fdata_1d(&empty, &[], 5, 0);
assert!(result.is_empty());
}
#[test]
fn test_pspline_fit_1d_basic() {
let t = uniform_grid(50);
let n = 3;
let m = t.len();
let flat: Vec<f64> = (0..n)
.flat_map(|i| {
t.iter()
.enumerate()
.map(move |(j, &ti)| (2.0 * PI * ti).sin() + 0.1 * (i * j) as f64 % 1.0)
})
.collect();
let data = make_matrix(&flat, n, m);
let result = pspline_fit_1d(&data, &t, 15, 1.0, 2);
assert!(result.is_some());
let res = result.unwrap();
assert!(res.n_basis > 0);
assert_eq!(res.fitted.nrows(), n);
assert_eq!(res.fitted.ncols(), m);
assert!(res.rss >= 0.0);
assert!(res.edf > 0.0);
assert!(res.gcv.is_finite());
}
#[test]
fn test_pspline_fit_1d_smoothness() {
let t = uniform_grid(50);
let n = 1;
let m = t.len();
let raw: Vec<f64> = t
.iter()
.enumerate()
.map(|(i, &ti)| (2.0 * PI * ti).sin() + 0.3 * ((i * 17) % 100) as f64 / 100.0)
.collect();
let data = FdMatrix::from_column_major(raw, n, m).unwrap();
let low_lambda = pspline_fit_1d(&data, &t, 15, 0.01, 2).unwrap();
let high_lambda = pspline_fit_1d(&data, &t, 15, 100.0, 2).unwrap();
assert!(high_lambda.edf < low_lambda.edf);
}
#[test]
fn test_pspline_fit_1d_invalid_input() {
let t = uniform_grid(50);
let empty = FdMatrix::zeros(0, 50);
let result = pspline_fit_1d(&empty, &t, 15, 1.0, 2);
assert!(result.is_none());
}
#[test]
fn test_fourier_fit_1d_sine_wave() {
let t = uniform_grid(100);
let n = 1;
let m = t.len();
let raw = sine_wave(&t, 2.0);
let data = FdMatrix::from_column_major(raw, n, m).unwrap();
let result = fourier_fit_1d(&data, &t, 11);
assert!(result.is_ok());
let res = result.unwrap();
assert!(res.rss < 1e-6, "Pure sine should have near-zero RSS");
}
#[test]
fn test_fourier_fit_1d_makes_nbasis_odd() {
let t = uniform_grid(50);
let raw = sine_wave(&t, 1.0);
let data = FdMatrix::from_column_major(raw, 1, t.len()).unwrap();
let result = fourier_fit_1d(&data, &t, 6);
assert!(result.is_ok());
let res = result.unwrap();
assert!(res.n_basis % 2 == 1);
}
#[test]
fn test_fourier_fit_1d_criteria() {
let t = uniform_grid(50);
let raw = sine_wave(&t, 2.0);
let data = FdMatrix::from_column_major(raw, 1, t.len()).unwrap();
let result = fourier_fit_1d(&data, &t, 9).unwrap();
assert!(result.gcv.is_finite());
assert!(result.aic.is_finite());
assert!(result.bic.is_finite());
}
#[test]
fn test_fourier_fit_1d_invalid_nbasis() {
let t = uniform_grid(50);
let raw = sine_wave(&t, 1.0);
let data = FdMatrix::from_column_major(raw, 1, t.len()).unwrap();
let result = fourier_fit_1d(&data, &t, 2);
assert!(result.is_err());
}
#[test]
fn test_select_fourier_nbasis_gcv_range() {
let t = uniform_grid(100);
let raw = sine_wave(&t, 3.0);
let data = FdMatrix::from_column_major(raw, 1, t.len()).unwrap();
let best = select_fourier_nbasis_gcv(&data, &t, 3, 15);
assert!((3..=15).contains(&best));
assert!(best % 2 == 1, "Selected nbasis should be odd");
}
#[test]
fn test_select_fourier_nbasis_gcv_respects_min() {
let t = uniform_grid(50);
let raw = sine_wave(&t, 1.0);
let data = FdMatrix::from_column_major(raw, 1, t.len()).unwrap();
let best = select_fourier_nbasis_gcv(&data, &t, 7, 15);
assert!(best >= 7);
}
#[test]
fn test_select_basis_auto_1d_returns_results() {
let t = uniform_grid(50);
let n = 3;
let m = t.len();
let flat: Vec<f64> = (0..n).flat_map(|i| sine_wave(&t, 1.0 + i as f64)).collect();
let data = make_matrix(&flat, n, m);
let result = select_basis_auto_1d(&data, &t, 0, 5, 15, 1.0, false);
assert_eq!(result.selections.len(), n);
for sel in &result.selections {
assert!(sel.nbasis >= 3);
assert!(!sel.coefficients.is_empty());
assert_eq!(sel.fitted.len(), m);
}
}
#[test]
fn test_select_basis_auto_1d_seasonal_hint() {
let t = uniform_grid(100);
let n = 1;
let m = t.len();
let raw = sine_wave(&t, 5.0);
let data = FdMatrix::from_column_major(raw, n, m).unwrap();
let result = select_basis_auto_1d(&data, &t, 0, 0, 0, -1.0, true);
assert_eq!(result.selections.len(), 1);
assert!(result.selections[0].seasonal_detected);
}
#[test]
fn test_select_basis_auto_1d_non_seasonal() {
let t = uniform_grid(50);
let n = 1;
let m = t.len();
let raw: Vec<f64> = vec![1.0; m];
let data = FdMatrix::from_column_major(raw, n, m).unwrap();
let result = select_basis_auto_1d(&data, &t, 0, 0, 0, -1.0, true);
assert!(!result.selections[0].seasonal_detected);
}
#[test]
fn test_select_basis_auto_1d_criterion_options() {
let t = uniform_grid(50);
let raw = sine_wave(&t, 2.0);
let data = FdMatrix::from_column_major(raw, 1, t.len()).unwrap();
let gcv_result = select_basis_auto_1d(&data, &t, 0, 0, 0, 1.0, false);
let aic_result = select_basis_auto_1d(&data, &t, 1, 0, 0, 1.0, false);
let bic_result = select_basis_auto_1d(&data, &t, 2, 0, 0, 1.0, false);
assert_eq!(gcv_result.criterion, 0);
assert_eq!(aic_result.criterion, 1);
assert_eq!(bic_result.criterion, 2);
}
#[test]
fn test_nan_pspline_no_panic() {
let t = uniform_grid(50);
let mut y = sine_wave(&t, 2.0);
y[10] = f64::NAN;
let data = FdMatrix::from_column_major(y, 1, t.len()).unwrap();
let result = pspline_fit_1d(&data, &t, 10, 1.0, 2);
assert!(result.is_some() || result.is_none());
}
#[test]
fn test_n1_fit() {
let t = uniform_grid(50);
let y = sine_wave(&t, 1.0);
let data = FdMatrix::from_column_major(y, 1, t.len()).unwrap();
let result = fdata_to_basis_1d(&data, &t, 7, 1);
assert!(result.is_some());
let res = result.unwrap();
assert_eq!(res.coefficients.nrows(), 1);
}
#[test]
fn test_single_point_basis() {
let t = vec![0.5];
let basis = fourier_basis(&t, 3);
assert_eq!(basis.len(), 3);
assert!(
(basis[0] - 1.0).abs() < 1e-12,
"constant basis should be 1.0"
);
}
fn sine_data(t: &[f64]) -> FdMatrix {
let n = 3;
let m = t.len();
let flat: Vec<f64> = (0..n)
.flat_map(|i| {
t.iter()
.map(move |&ti| (2.0 * PI * (1.0 + i as f64) * ti).sin())
})
.collect();
make_matrix(&flat, n, m)
}
#[test]
fn pspline_fit_gcv_selects_lambda() {
let t: Vec<f64> = (0..100).map(|i| i as f64 / 99.0).collect();
let data = sine_data(&t);
let result = pspline_fit_gcv(&data, &t, 15, 2).unwrap();
assert!(result.gcv.is_finite());
assert!(result.gcv > 0.0);
assert_eq!(result.fitted.shape(), (3, 100));
}
#[test]
fn pspline_fit_gcv_beats_extremes() {
let t: Vec<f64> = (0..60).map(|i| i as f64 / 59.0).collect();
let data = sine_data(&t);
let gcv_result = pspline_fit_gcv(&data, &t, 12, 2).unwrap();
let low_lambda = pspline_fit_1d(&data, &t, 12, 1e-8, 2).unwrap();
let high_lambda = pspline_fit_1d(&data, &t, 12, 1e8, 2).unwrap();
assert!(gcv_result.gcv <= low_lambda.gcv + 1e-10);
assert!(gcv_result.gcv <= high_lambda.gcv + 1e-10);
}