1use crate::error::{StatsError, StatsResult};
33
34#[non_exhaustive]
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum FunctionalBasis {
42 BSpline {
44 n_basis: usize,
46 degree: usize,
48 },
49 Fourier {
51 n_basis: usize,
53 },
54 Wavelet {
56 n_basis: usize,
58 },
59}
60
61#[non_exhaustive]
63#[derive(Debug, Clone)]
64pub struct FofConfig {
65 pub basis: FunctionalBasis,
67 pub lambda: f64,
69 pub n_grid: usize,
71}
72
73impl Default for FofConfig {
74 fn default() -> Self {
75 Self {
76 basis: FunctionalBasis::BSpline {
77 n_basis: 10,
78 degree: 3,
79 },
80 lambda: 0.01,
81 n_grid: 100,
82 }
83 }
84}
85
86#[derive(Debug, Clone)]
88pub struct FofResult {
89 pub beta_coefs: Vec<f64>,
91 pub beta_values: Vec<f64>,
93 pub grid: Vec<f64>,
95 pub intercept: f64,
97 pub r_squared: f64,
99 pub gcv_score: f64,
101}
102
103#[derive(Debug, Clone)]
105pub struct FunctionalRegression {
106 config: FofConfig,
107 beta_coefs: Option<Vec<f64>>,
109 intercept: Option<f64>,
111 fit_grid: Option<Vec<f64>>,
113}
114
115impl FunctionalRegression {
116 pub fn new(config: FofConfig) -> Self {
118 Self {
119 config,
120 beta_coefs: None,
121 intercept: None,
122 fit_grid: None,
123 }
124 }
125
126 pub fn fit(
140 &mut self,
141 data: &[Vec<f64>],
142 response: &[f64],
143 grid: &[f64],
144 ) -> StatsResult<FofResult> {
145 let n_obs = data.len();
146 if n_obs == 0 {
147 return Err(StatsError::InsufficientData(
148 "need at least one observation".to_owned(),
149 ));
150 }
151 let n_time = grid.len();
152 if n_time < 2 {
153 return Err(StatsError::InvalidArgument(
154 "grid must have at least 2 points".to_owned(),
155 ));
156 }
157 if response.len() != n_obs {
158 return Err(StatsError::DimensionMismatch(format!(
159 "response length {} != n_obs {}",
160 response.len(),
161 n_obs
162 )));
163 }
164 for (i, row) in data.iter().enumerate() {
165 if row.len() != n_time {
166 return Err(StatsError::DimensionMismatch(format!(
167 "data[{}] has {} time points, expected {}",
168 i,
169 row.len(),
170 n_time
171 )));
172 }
173 }
174
175 let n_basis = self.n_basis_fns();
176 if n_obs < n_basis + 1 {
177 return Err(StatsError::InsufficientData(format!(
178 "need n_obs >= n_basis+1 = {} but got {}",
179 n_basis + 1,
180 n_obs
181 )));
182 }
183
184 let phi = self.evaluate_basis(grid); let z = build_z_matrix(data, &phi, grid); let omega = self.roughness_penalty(n_basis);
193
194 let y_mean = response.iter().sum::<f64>() / n_obs as f64;
196 let y_centred: Vec<f64> = response.iter().map(|&y| y - y_mean).collect();
197
198 let z_col_means: Vec<f64> = (0..n_basis)
200 .map(|j| z.iter().map(|row| row[j]).sum::<f64>() / n_obs as f64)
201 .collect();
202 let z_centred: Vec<Vec<f64>> = z
203 .iter()
204 .map(|row| {
205 row.iter()
206 .enumerate()
207 .map(|(j, &v)| v - z_col_means[j])
208 .collect()
209 })
210 .collect();
211
212 let coefs = penalized_ls(&z_centred, &y_centred, &omega, self.config.lambda)?;
214
215 let intercept = y_mean
217 - z_col_means
218 .iter()
219 .zip(coefs.iter())
220 .map(|(&zm, &c)| zm * c)
221 .sum::<f64>();
222
223 let y_hat: Vec<f64> = z
225 .iter()
226 .map(|row| {
227 intercept
228 + row
229 .iter()
230 .zip(coefs.iter())
231 .map(|(&z_ij, &c)| z_ij * c)
232 .sum::<f64>()
233 })
234 .collect();
235
236 let ss_res: f64 = response
237 .iter()
238 .zip(y_hat.iter())
239 .map(|(&y, &yh)| (y - yh).powi(2))
240 .sum();
241 let ss_tot: f64 = response.iter().map(|&y| (y - y_mean).powi(2)).sum();
242 let r_squared = if ss_tot > 0.0 {
243 1.0 - ss_res / ss_tot
244 } else {
245 1.0
246 };
247
248 let gcv_score = compute_gcv(&z, &y_hat, response, &omega, self.config.lambda, n_obs);
252
253 let eval_grid = linspace(grid[0], *grid.last().unwrap_or(&1.0), self.config.n_grid);
255 let phi_eval = self.evaluate_basis(&eval_grid);
256 let beta_values: Vec<f64> = eval_grid
257 .iter()
258 .enumerate()
259 .map(|(t, _)| {
260 phi_eval[t]
261 .iter()
262 .zip(coefs.iter())
263 .map(|(&p, &c)| p * c)
264 .sum()
265 })
266 .collect();
267
268 self.beta_coefs = Some(coefs.clone());
270 self.intercept = Some(intercept);
271 self.fit_grid = Some(grid.to_vec());
272
273 Ok(FofResult {
274 beta_coefs: coefs,
275 beta_values,
276 grid: eval_grid,
277 intercept,
278 r_squared,
279 gcv_score,
280 })
281 }
282
283 pub fn predict(&self, new_data: &[Vec<f64>], grid: &[f64]) -> StatsResult<Vec<f64>> {
295 let coefs = self
296 .beta_coefs
297 .as_ref()
298 .ok_or_else(|| StatsError::ComputationError("model not fitted yet".to_owned()))?;
299 let intercept = self.intercept.unwrap_or(0.0);
300
301 let n_basis = self.n_basis_fns();
302 let phi = self.evaluate_basis(grid);
303 let z = build_z_matrix(new_data, &phi, grid);
304
305 let preds = z
306 .iter()
307 .map(|row| {
308 intercept
309 + row
310 .iter()
311 .zip(coefs.iter())
312 .take(n_basis)
313 .map(|(&z_ij, &c)| z_ij * c)
314 .sum::<f64>()
315 })
316 .collect();
317
318 Ok(preds)
319 }
320
321 fn n_basis_fns(&self) -> usize {
327 match self.config.basis {
328 FunctionalBasis::BSpline { n_basis, .. } => n_basis,
329 FunctionalBasis::Fourier { n_basis } => n_basis,
330 FunctionalBasis::Wavelet { n_basis } => n_basis,
331 }
332 }
333
334 fn evaluate_basis(&self, grid: &[f64]) -> Vec<Vec<f64>> {
338 match self.config.basis {
339 FunctionalBasis::BSpline { n_basis, degree } => bspline_basis(grid, n_basis, degree),
340 FunctionalBasis::Fourier { n_basis } => fourier_basis(grid, n_basis),
341 FunctionalBasis::Wavelet { n_basis } => wavelet_basis(grid, n_basis),
342 }
343 }
344
345 fn roughness_penalty(&self, n_basis: usize) -> Vec<Vec<f64>> {
349 roughness_penalty(n_basis)
350 }
351}
352
353pub fn bspline_basis(grid: &[f64], n_basis: usize, degree: usize) -> Vec<Vec<f64>> {
362 let n_grid = grid.len();
363 if n_basis == 0 || n_grid == 0 {
364 return vec![vec![]; n_grid];
365 }
366
367 let t_min = grid[0];
368 let t_max = *grid.last().unwrap_or(&1.0);
369
370 let n_knots = n_basis + degree + 1;
373 let knots = build_clamped_knots(t_min, t_max, n_knots, degree);
374
375 let mut phi = vec![vec![0.0f64; n_basis]; n_grid];
376 for (t_idx, &t) in grid.iter().enumerate() {
377 for k in 0..n_basis {
378 phi[t_idx][k] = de_boor_basis(t, k, degree, &knots);
379 }
380 }
381 phi
382}
383
384fn build_clamped_knots(t_min: f64, t_max: f64, n_knots: usize, degree: usize) -> Vec<f64> {
386 let mut knots = Vec::with_capacity(n_knots);
387 let n_interior = n_knots.saturating_sub(2 * (degree + 1));
389 for _ in 0..=degree {
390 knots.push(t_min);
391 }
392 for i in 1..=(n_interior) {
393 let t = t_min + (t_max - t_min) * (i as f64) / (n_interior + 1) as f64;
394 knots.push(t);
395 }
396 while knots.len() < n_knots - (degree) {
397 let t = t_max;
398 knots.push(t);
399 }
400 for _ in 0..=degree {
401 knots.push(t_max);
402 }
403 knots.truncate(n_knots);
404 while knots.len() < n_knots {
406 knots.push(t_max);
407 }
408 knots
409}
410
411fn de_boor_basis(t: f64, k: usize, p: usize, knots: &[f64]) -> f64 {
417 let n_knots = knots.len();
418 if k + p + 1 >= n_knots {
419 return 0.0;
420 }
421
422 if p == 0 {
423 let at_right_end = (t - knots[k + 1]).abs() < 1e-14
425 && knots[k + 1] >= knots.last().copied().unwrap_or(f64::INFINITY);
426 return if (t >= knots[k] && t < knots[k + 1]) || at_right_end {
427 1.0
428 } else {
429 0.0
430 };
431 }
432
433 let denom1 = knots[k + p] - knots[k];
434 let left = if denom1.abs() > 1e-14 {
435 (t - knots[k]) / denom1 * de_boor_basis(t, k, p - 1, knots)
436 } else {
437 0.0
438 };
439
440 let denom2 = knots[k + p + 1] - knots[k + 1];
441 let right = if denom2.abs() > 1e-14 {
442 (knots[k + p + 1] - t) / denom2 * de_boor_basis(t, k + 1, p - 1, knots)
443 } else {
444 0.0
445 };
446
447 left + right
448}
449
450fn fourier_basis(grid: &[f64], n_basis: usize) -> Vec<Vec<f64>> {
458 let n_grid = grid.len();
459 if n_basis == 0 || n_grid == 0 {
460 return vec![vec![]; n_grid];
461 }
462
463 let t_min = grid[0];
464 let t_max = *grid.last().unwrap_or(&1.0);
465 let span = (t_max - t_min).max(1e-12);
466
467 let mut phi = vec![vec![0.0f64; n_basis]; n_grid];
468 for (t_idx, &t) in grid.iter().enumerate() {
469 let s = (t - t_min) / span; phi[t_idx][0] = 1.0; let mut k = 1usize;
472 let mut freq = 1usize;
473 while k < n_basis {
474 let omega = 2.0 * std::f64::consts::PI * freq as f64 * s;
475 if k < n_basis {
476 phi[t_idx][k] = omega.cos();
477 k += 1;
478 }
479 if k < n_basis {
480 phi[t_idx][k] = omega.sin();
481 k += 1;
482 }
483 freq += 1;
484 }
485 }
486 phi
487}
488
489fn wavelet_basis(grid: &[f64], n_basis: usize) -> Vec<Vec<f64>> {
498 let n_grid = grid.len();
499 if n_basis == 0 || n_grid == 0 {
500 return vec![vec![]; n_grid];
501 }
502
503 let t_min = grid[0];
504 let t_max = *grid.last().unwrap_or(&1.0);
505 let span = (t_max - t_min).max(1e-12);
506
507 let mut phi = vec![vec![0.0f64; n_basis]; n_grid];
508 for (t_idx, &t) in grid.iter().enumerate() {
509 let s = (t - t_min) / span; phi[t_idx][0] = 1.0;
512
513 let mut k = 1usize;
515 let mut level = 0usize;
516 while k < n_basis {
517 let n_at_level = 1usize << level; let scale = (n_at_level as f64).sqrt(); for translate in 0..n_at_level {
520 if k >= n_basis {
521 break;
522 }
523 let t0 = translate as f64 / n_at_level as f64;
524 let tmid = (translate as f64 + 0.5) / n_at_level as f64;
525 let t1 = (translate + 1) as f64 / n_at_level as f64;
526 phi[t_idx][k] = if s >= t0 && s < tmid {
527 scale
528 } else if s >= tmid && s < t1 {
529 -scale
530 } else {
531 0.0
532 };
533 k += 1;
534 }
535 level += 1;
536 }
537 }
538 phi
539}
540
541fn build_z_matrix(data: &[Vec<f64>], phi: &[Vec<f64>], grid: &[f64]) -> Vec<Vec<f64>> {
549 let n_obs = data.len();
550 let n_time = grid.len();
551 let n_basis = phi.first().map(|r| r.len()).unwrap_or(0);
552
553 let mut z = vec![vec![0.0f64; n_basis]; n_obs];
554
555 for (i, xi) in data.iter().enumerate() {
556 for j in 0..n_basis {
557 let mut integral = 0.0f64;
559 for t in 0..(n_time.saturating_sub(1)) {
560 let h = grid[t + 1] - grid[t];
561 let f_t = xi[t] * phi[t][j];
562 let f_t1 = xi[t + 1] * phi[t + 1][j];
563 integral += 0.5 * h * (f_t + f_t1);
564 }
565 z[i][j] = integral;
566 }
567 }
568 z
569}
570
571pub fn roughness_penalty(n_basis: usize) -> Vec<Vec<f64>> {
579 if n_basis < 3 {
580 return vec![vec![0.0; n_basis]; n_basis];
581 }
582 let m = n_basis - 2; let mut omega = vec![vec![0.0f64; n_basis]; n_basis];
585 for row in 0..m {
586 let cols = [(row, 1.0f64), (row + 1, -2.0f64), (row + 2, 1.0f64)];
590 for &(c1, v1) in &cols {
591 for &(c2, v2) in &cols {
592 omega[c1][c2] += v1 * v2;
593 }
594 }
595 }
596 omega
597}
598
599pub fn penalized_ls(
608 z: &[Vec<f64>],
609 y: &[f64],
610 omega: &[Vec<f64>],
611 lambda: f64,
612) -> StatsResult<Vec<f64>> {
613 let n = z.first().map(|r| r.len()).unwrap_or(0);
614 if n == 0 {
615 return Ok(Vec::new());
616 }
617
618 let mut a = vec![vec![0.0f64; n]; n];
620 for obs in z {
621 for j in 0..n {
622 for k in 0..n {
623 a[j][k] += obs[j] * obs[k];
624 }
625 }
626 }
627 for j in 0..n {
628 for k in 0..n {
629 a[j][k] += lambda * omega[j][k];
630 }
631 }
632
633 let mut b = vec![0.0f64; n];
635 for (obs, &yi) in z.iter().zip(y) {
636 for j in 0..n {
637 b[j] += obs[j] * yi;
638 }
639 }
640
641 gauss_solve(&a, &b)
643}
644
645fn gauss_solve(a: &[Vec<f64>], b: &[f64]) -> StatsResult<Vec<f64>> {
647 let n = a.len();
648 if n == 0 {
649 return Ok(Vec::new());
650 }
651
652 let mut m: Vec<Vec<f64>> = a
654 .iter()
655 .zip(b.iter())
656 .map(|(row, &bi)| {
657 let mut r = row.clone();
658 r.push(bi);
659 r
660 })
661 .collect();
662
663 for col in 0..n {
664 let pivot_row = (col..n).max_by(|&r1, &r2| {
666 m[r1][col]
667 .abs()
668 .partial_cmp(&m[r2][col].abs())
669 .unwrap_or(std::cmp::Ordering::Equal)
670 });
671 let pivot_row = pivot_row
672 .ok_or_else(|| StatsError::ComputationError("singular penalised system".to_owned()))?;
673
674 m.swap(col, pivot_row);
675
676 let pivot = m[col][col];
677 if pivot.abs() < 1e-300 {
678 return Err(StatsError::ComputationError(
679 "near-singular penalised normal equations; increase lambda".to_owned(),
680 ));
681 }
682
683 for row in (col + 1)..n {
684 let factor = m[row][col] / pivot;
685 for k in col..=n {
686 let val = m[col][k];
687 m[row][k] -= factor * val;
688 }
689 }
690 }
691
692 let mut x = vec![0.0f64; n];
694 for i in (0..n).rev() {
695 let mut sum = m[i][n];
696 for j in (i + 1)..n {
697 sum -= m[i][j] * x[j];
698 }
699 x[i] = sum / m[i][i];
700 }
701 Ok(x)
702}
703
704fn compute_gcv(
715 z: &[Vec<f64>],
716 y_hat: &[f64],
717 response: &[f64],
718 omega: &[Vec<f64>],
719 lambda: f64,
720 n_obs: usize,
721) -> f64 {
722 let n = z.first().map(|r| r.len()).unwrap_or(0);
723 if n == 0 || n_obs == 0 {
724 return f64::INFINITY;
725 }
726
727 let mut a = vec![vec![0.0f64; n]; n];
729 for obs in z {
730 for j in 0..n {
731 for k in 0..n {
732 a[j][k] += obs[j] * obs[k];
733 }
734 }
735 }
736 for j in 0..n {
737 for k in 0..n {
738 a[j][k] += lambda * omega[j][k];
739 }
740 }
741
742 let a_inv = match invert_matrix(&a) {
744 Ok(inv) => inv,
745 Err(_) => return f64::INFINITY,
746 };
747
748 let tr_h: f64 = z
750 .iter()
751 .map(|zi| {
752 let az: Vec<f64> = (0..n)
754 .map(|j| (0..n).map(|k| a_inv[j][k] * zi[k]).sum::<f64>())
755 .collect();
756 zi.iter().zip(az.iter()).map(|(&v, &w)| v * w).sum::<f64>()
758 })
759 .sum();
760
761 let df_hat = tr_h / n_obs as f64;
762 if (1.0 - df_hat).abs() < 1e-10 {
763 return f64::INFINITY;
764 }
765
766 let ss_res: f64 = response
767 .iter()
768 .zip(y_hat.iter())
769 .map(|(&y, &yh)| (y - yh).powi(2))
770 .sum();
771
772 (ss_res / n_obs as f64) / (1.0 - df_hat).powi(2)
773}
774
775fn invert_matrix(a: &[Vec<f64>]) -> StatsResult<Vec<Vec<f64>>> {
777 let n = a.len();
778 let mut m: Vec<Vec<f64>> = a
780 .iter()
781 .enumerate()
782 .map(|(i, row)| {
783 let mut r = row.clone();
784 r.resize(2 * n, 0.0);
785 r[n + i] = 1.0;
786 r
787 })
788 .collect();
789
790 for col in 0..n {
791 let pivot_row = (col..n).max_by(|&r1, &r2| {
793 m[r1][col]
794 .abs()
795 .partial_cmp(&m[r2][col].abs())
796 .unwrap_or(std::cmp::Ordering::Equal)
797 });
798 let pivot_row =
799 pivot_row.ok_or_else(|| StatsError::ComputationError("singular matrix".to_owned()))?;
800 m.swap(col, pivot_row);
801
802 let pivot = m[col][col];
803 if pivot.abs() < 1e-300 {
804 return Err(StatsError::ComputationError("singular matrix".to_owned()));
805 }
806 let scale = 1.0 / pivot;
807 for k in 0..(2 * n) {
808 m[col][k] *= scale;
809 }
810 for row in 0..n {
811 if row != col {
812 let factor = m[row][col];
813 for k in 0..(2 * n) {
814 let val = m[col][k];
815 m[row][k] -= factor * val;
816 }
817 }
818 }
819 }
820
821 let inv: Vec<Vec<f64>> = m.iter().map(|row| row[n..].to_vec()).collect();
822 Ok(inv)
823}
824
825fn linspace(start: f64, end: f64, n: usize) -> Vec<f64> {
831 if n == 0 {
832 return Vec::new();
833 }
834 if n == 1 {
835 return vec![start];
836 }
837 (0..n)
838 .map(|i| start + (end - start) * i as f64 / (n - 1) as f64)
839 .collect()
840}
841
842#[cfg(test)]
847mod tests {
848 use super::*;
849
850 fn lcg(s: &mut u64) -> f64 {
856 *s = s
857 .wrapping_mul(6_364_136_223_846_793_005)
858 .wrapping_add(1_442_695_040_888_963_407);
859 ((*s >> 11) as f64) / ((1u64 << 53) as f64)
860 }
861
862 fn smooth_data(n_obs: usize, n_time: usize, seed: u64) -> (Vec<Vec<f64>>, Vec<f64>, Vec<f64>) {
864 let grid: Vec<f64> = (0..n_time)
865 .map(|i| i as f64 / (n_time - 1) as f64)
866 .collect();
867 let mut rng = seed;
868 let mut data = Vec::with_capacity(n_obs);
869 let mut response = Vec::with_capacity(n_obs);
870 for _ in 0..n_obs {
873 let a = lcg(&mut rng) * 2.0 - 1.0; let curve: Vec<f64> = grid
875 .iter()
876 .map(|&t| a * (2.0 * std::f64::consts::PI * t).sin())
877 .collect();
878 let integral: f64 = grid
881 .windows(2)
882 .map(|w| {
883 let t0 = w[0];
884 let t1 = w[1];
885 let dt = t1 - t0;
886 let f0 = t0 * a * (2.0 * std::f64::consts::PI * t0).sin();
887 let f1 = t1 * a * (2.0 * std::f64::consts::PI * t1).sin();
888 0.5 * dt * (f0 + f1)
889 })
890 .sum();
891 response.push(integral + (lcg(&mut rng) - 0.5) * 0.01); data.push(curve);
893 }
894 (data, response, grid)
895 }
896
897 #[test]
902 fn test_fof_config_default() {
903 let cfg = FofConfig::default();
904 assert_eq!(
905 cfg.basis,
906 FunctionalBasis::BSpline {
907 n_basis: 10,
908 degree: 3
909 }
910 );
911 assert!((cfg.lambda - 0.01).abs() < 1e-15);
912 assert_eq!(cfg.n_grid, 100);
913 }
914
915 #[test]
920 fn test_bspline_basis_partition_of_unity() {
921 let grid: Vec<f64> = (0..20).map(|i| i as f64 / 19.0).collect();
923 let phi = bspline_basis(&grid, 8, 3);
924 for (t_idx, _) in grid.iter().enumerate() {
925 let s: f64 = phi[t_idx].iter().sum();
926 assert!(
927 (s - 1.0).abs() < 1e-8,
928 "partition of unity at t={t_idx}: sum={s}"
929 );
930 }
931 }
932
933 #[test]
934 fn test_bspline_basis_non_negative() {
935 let grid: Vec<f64> = (0..50).map(|i| i as f64 / 49.0).collect();
936 let phi = bspline_basis(&grid, 10, 3);
937 for row in &phi {
938 for &v in row {
939 assert!(v >= -1e-10, "negative B-spline value: {v}");
940 }
941 }
942 }
943
944 #[test]
949 fn test_roughness_penalty_symmetry() {
950 let omega = roughness_penalty(8);
951 let n = omega.len();
952 for i in 0..n {
953 for j in 0..n {
954 assert!(
955 (omega[i][j] - omega[j][i]).abs() < 1e-14,
956 "Omega not symmetric at ({i},{j})"
957 );
958 }
959 }
960 }
961
962 #[test]
963 fn test_roughness_penalty_psd() {
964 let omega = roughness_penalty(6);
966 for (i, row) in omega.iter().enumerate() {
967 assert!(row[i] >= 0.0, "negative diagonal in Omega");
968 }
969 }
970
971 #[test]
976 fn test_penalized_ls_identity() {
977 let n = 4;
979 let z: Vec<Vec<f64>> = (0..n)
980 .map(|i| (0..n).map(|j| if i == j { 1.0 } else { 0.0 }).collect())
981 .collect();
982 let omega = vec![vec![0.0; n]; n];
983 let y = vec![1.0, 2.0, 3.0, 4.0];
984 let c = penalized_ls(&z, &y, &omega, 0.0).expect("penalized_ls failed");
985 for (ci, &yi) in c.iter().zip(y.iter()) {
986 assert!((ci - yi).abs() < 1e-10, "expected {yi}, got {ci}");
987 }
988 }
989
990 #[test]
995 fn test_fit_r_squared_high_on_clean_data() {
996 let (data, response, grid) = smooth_data(50, 40, 42);
997 let config = FofConfig {
998 basis: FunctionalBasis::BSpline {
999 n_basis: 8,
1000 degree: 3,
1001 },
1002 lambda: 1e-4,
1003 n_grid: 50,
1004 };
1005 let mut model = FunctionalRegression::new(config);
1006 let result = model.fit(&data, &response, &grid).expect("fit failed");
1007 assert!(result.r_squared > 0.9, "R² too low: {}", result.r_squared);
1009 }
1010
1011 #[test]
1012 fn test_predict_length() {
1013 let (data, response, grid) = smooth_data(30, 30, 7);
1014 let config = FofConfig {
1015 basis: FunctionalBasis::BSpline {
1016 n_basis: 6,
1017 degree: 3,
1018 },
1019 lambda: 0.01,
1020 n_grid: 50,
1021 };
1022 let mut model = FunctionalRegression::new(config);
1023 model.fit(&data, &response, &grid).expect("fit failed");
1024
1025 let (new_data, _, _) = smooth_data(10, 30, 99);
1026 let preds = model.predict(&new_data, &grid).expect("predict failed");
1027 assert_eq!(preds.len(), 10, "predict length mismatch");
1028 }
1029
1030 #[test]
1031 fn test_predict_before_fit_returns_error() {
1032 let config = FofConfig::default();
1033 let model = FunctionalRegression::new(config);
1034 let grid: Vec<f64> = (0..10).map(|i| i as f64 / 9.0).collect();
1035 let data = vec![vec![0.0; 10]];
1036 let res = model.predict(&data, &grid);
1037 assert!(res.is_err(), "predict before fit should return error");
1038 }
1039
1040 #[test]
1041 fn test_fit_with_fourier_basis() {
1042 let (data, response, grid) = smooth_data(40, 40, 123);
1043 let config = FofConfig {
1044 basis: FunctionalBasis::Fourier { n_basis: 9 },
1045 lambda: 0.01,
1046 n_grid: 50,
1047 };
1048 let mut model = FunctionalRegression::new(config);
1049 let result = model.fit(&data, &response, &grid).expect("fit failed");
1050 assert!(result.r_squared >= 0.0 && result.r_squared <= 1.0 + 1e-10);
1051 }
1052
1053 #[test]
1054 fn test_fit_with_wavelet_basis() {
1055 let (data, response, grid) = smooth_data(40, 40, 456);
1056 let config = FofConfig {
1057 basis: FunctionalBasis::Wavelet { n_basis: 8 },
1058 lambda: 0.01,
1059 n_grid: 50,
1060 };
1061 let mut model = FunctionalRegression::new(config);
1062 let result = model.fit(&data, &response, &grid).expect("fit failed");
1063 assert!(result.r_squared >= 0.0, "r_squared should be non-negative");
1064 }
1065
1066 #[test]
1067 fn test_gcv_score_finite() {
1068 let (data, response, grid) = smooth_data(30, 30, 11);
1069 let config = FofConfig {
1070 basis: FunctionalBasis::BSpline {
1071 n_basis: 6,
1072 degree: 3,
1073 },
1074 lambda: 0.1,
1075 n_grid: 40,
1076 };
1077 let mut model = FunctionalRegression::new(config);
1078 let result = model.fit(&data, &response, &grid).expect("fit failed");
1079 assert!(result.gcv_score.is_finite(), "GCV should be finite");
1080 assert!(result.gcv_score >= 0.0, "GCV should be non-negative");
1081 }
1082
1083 #[test]
1084 fn test_beta_values_length() {
1085 let (data, response, grid) = smooth_data(30, 30, 13);
1086 let n_grid = 60;
1087 let config = FofConfig {
1088 basis: FunctionalBasis::BSpline {
1089 n_basis: 6,
1090 degree: 3,
1091 },
1092 lambda: 0.01,
1093 n_grid,
1094 };
1095 let mut model = FunctionalRegression::new(config);
1096 let result = model.fit(&data, &response, &grid).expect("fit failed");
1097 assert_eq!(result.beta_values.len(), n_grid);
1098 assert_eq!(result.grid.len(), n_grid);
1099 }
1100
1101 #[test]
1102 fn test_dimension_mismatch_error() {
1103 let data = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
1104 let response = vec![1.0]; let grid = vec![0.0, 0.5, 1.0];
1106 let config = FofConfig::default();
1107 let mut model = FunctionalRegression::new(config);
1108 let res = model.fit(&data, &response, &grid);
1109 assert!(res.is_err(), "should return dimension mismatch error");
1110 }
1111
1112 #[test]
1113 fn test_gcv_varies_with_lambda() {
1114 let (data, response, grid) = smooth_data(40, 40, 77);
1116 let lambdas = [1e-4, 1e-2, 1.0];
1117 let mut gcv_scores = Vec::new();
1118 for &lam in &lambdas {
1119 let config = FofConfig {
1120 basis: FunctionalBasis::BSpline {
1121 n_basis: 6,
1122 degree: 3,
1123 },
1124 lambda: lam,
1125 n_grid: 40,
1126 };
1127 let mut model = FunctionalRegression::new(config);
1128 let result = model.fit(&data, &response, &grid).expect("fit failed");
1129 gcv_scores.push(result.gcv_score);
1130 }
1131 let all_same = gcv_scores.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-14);
1133 assert!(!all_same, "GCV should vary with lambda");
1134 }
1135}