1use crate::error::FdarError;
15use crate::iter_maybe_parallel;
16use crate::matrix::FdMatrix;
17use crate::regression::fdata_to_pc_1d;
18#[cfg(feature = "parallel")]
19use rayon::iter::ParallelIterator;
20
21fn cholesky_factor(a: &[f64], p: usize) -> Option<Vec<f64>> {
27 let mut l = vec![0.0; p * p];
28 for j in 0..p {
29 let mut diag = a[j * p + j];
30 for k in 0..j {
31 diag -= l[j * p + k] * l[j * p + k];
32 }
33 if diag <= 1e-12 {
34 return None;
35 }
36 l[j * p + j] = diag.sqrt();
37 for i in (j + 1)..p {
38 let mut s = a[i * p + j];
39 for k in 0..j {
40 s -= l[i * p + k] * l[j * p + k];
41 }
42 l[i * p + j] = s / l[j * p + j];
43 }
44 }
45 Some(l)
46}
47
48fn cholesky_forward_back(l: &[f64], b: &[f64], p: usize) -> Vec<f64> {
50 let mut z = b.to_vec();
51 for j in 0..p {
52 for k in 0..j {
53 z[j] -= l[j * p + k] * z[k];
54 }
55 z[j] /= l[j * p + j];
56 }
57 for j in (0..p).rev() {
58 for k in (j + 1)..p {
59 z[j] -= l[k * p + j] * z[k];
60 }
61 z[j] /= l[j * p + j];
62 }
63 z
64}
65
66pub(crate) fn compute_xtx(x: &FdMatrix) -> Vec<f64> {
68 let (n, p) = x.shape();
69 let mut xtx = vec![0.0; p * p];
70 for k in 0..p {
71 for j in k..p {
72 let mut s = 0.0;
73 for i in 0..n {
74 s += x[(i, k)] * x[(i, j)];
75 }
76 xtx[k * p + j] = s;
77 xtx[j * p + k] = s;
78 }
79 }
80 xtx
81}
82
83#[derive(Debug, Clone, PartialEq)]
89#[non_exhaustive]
90pub struct FosrResult {
91 pub intercept: Vec<f64>,
93 pub beta: FdMatrix,
95 pub fitted: FdMatrix,
97 pub residuals: FdMatrix,
99 pub r_squared_t: Vec<f64>,
101 pub r_squared: f64,
103 pub beta_se: FdMatrix,
105 pub lambda: f64,
107 pub gcv: f64,
109}
110
111#[derive(Debug, Clone, PartialEq)]
113#[non_exhaustive]
114pub struct FosrFpcResult {
115 pub intercept: Vec<f64>,
117 pub beta: FdMatrix,
119 pub fitted: FdMatrix,
121 pub residuals: FdMatrix,
123 pub r_squared_t: Vec<f64>,
125 pub r_squared: f64,
127 pub beta_scores: Vec<Vec<f64>>,
129 pub ncomp: usize,
131}
132
133#[derive(Debug, Clone, PartialEq)]
135#[non_exhaustive]
136pub struct FanovaResult {
137 pub group_means: FdMatrix,
139 pub overall_mean: Vec<f64>,
141 pub f_statistic_t: Vec<f64>,
143 pub global_statistic: f64,
145 pub p_value: f64,
147 pub n_perm: usize,
149 pub n_groups: usize,
151 pub group_labels: Vec<usize>,
153}
154
155pub(crate) fn penalty_matrix(m: usize) -> Vec<f64> {
161 if m < 3 {
162 return vec![0.0; m * m];
163 }
164 let mut dtd = vec![0.0; m * m];
167 for i in 0..m - 2 {
168 let coeffs = [(i, 1.0), (i + 1, -2.0), (i + 2, 1.0)];
170 for &(r, cr) in &coeffs {
171 for &(c, cc) in &coeffs {
172 dtd[r * m + c] += cr * cc;
173 }
174 }
175 }
176 dtd
177}
178
179fn penalized_solve(
183 xtx: &[f64],
184 xty: &FdMatrix,
185 penalty: &[f64],
186 lambda: f64,
187) -> Result<FdMatrix, FdarError> {
188 let p = xty.nrows();
189 let m = xty.ncols();
190
191 let mut a = vec![0.0; p * p];
193 for i in 0..p * p {
194 a[i] = xtx[i] + lambda * penalty[i];
195 }
196
197 let l = cholesky_factor(&a, p).ok_or_else(|| FdarError::ComputationFailed {
199 operation: "penalized_solve",
200 detail: format!(
201 "Cholesky factorization of (X'X + {lambda:.4}*P) failed; matrix is singular — try increasing lambda or removing collinear basis columns"
202 ),
203 })?;
204
205 let mut beta = FdMatrix::zeros(p, m);
207 for t in 0..m {
208 let b: Vec<f64> = (0..p).map(|j| xty[(j, t)]).collect();
209 let x = cholesky_forward_back(&l, &b, p);
210 for j in 0..p {
211 beta[(j, t)] = x[j];
212 }
213 }
214 Ok(beta)
215}
216
217pub(crate) fn pointwise_r_squared(data: &FdMatrix, fitted: &FdMatrix) -> Vec<f64> {
219 let (n, m) = data.shape();
220 (0..m)
221 .map(|t| {
222 let mean_t: f64 = (0..n).map(|i| data[(i, t)]).sum::<f64>() / n as f64;
223 let ss_tot: f64 = (0..n).map(|i| (data[(i, t)] - mean_t).powi(2)).sum();
224 let ss_res: f64 = (0..n)
225 .map(|i| (data[(i, t)] - fitted[(i, t)]).powi(2))
226 .sum();
227 if ss_tot > 1e-15 {
228 1.0 - ss_res / ss_tot
229 } else {
230 0.0
231 }
232 })
233 .collect()
234}
235
236fn compute_fosr_gcv(residuals: &FdMatrix, trace_h: f64) -> f64 {
238 let (n, m) = residuals.shape();
239 let denom = (1.0 - trace_h / n as f64).max(1e-10);
240 let ss_res: f64 = (0..n)
241 .flat_map(|i| (0..m).map(move |t| residuals[(i, t)].powi(2)))
242 .sum();
243 ss_res / (n as f64 * m as f64 * denom * denom)
244}
245
246pub(crate) fn build_fosr_design(predictors: &FdMatrix, n: usize) -> FdMatrix {
271 let p = predictors.ncols();
272 let p_total = p + 1;
273 let mut design = FdMatrix::zeros(n, p_total);
274 for i in 0..n {
275 design[(i, 0)] = 1.0;
276 for j in 0..p {
277 design[(i, 1 + j)] = predictors[(i, j)];
278 }
279 }
280 design
281}
282
283pub(crate) fn compute_xty_matrix(design: &FdMatrix, data: &FdMatrix) -> FdMatrix {
285 let (n, m) = data.shape();
286 let p_total = design.ncols();
287 let mut xty = FdMatrix::zeros(p_total, m);
288 for j in 0..p_total {
289 for t in 0..m {
290 let mut s = 0.0;
291 for i in 0..n {
292 s += design[(i, j)] * data[(i, t)];
293 }
294 xty[(j, t)] = s;
295 }
296 }
297 xty
298}
299
300fn drop_intercept_rows(full: &FdMatrix, p: usize, m: usize) -> FdMatrix {
302 let mut out = FdMatrix::zeros(p, m);
303 for j in 0..p {
304 for t in 0..m {
305 out[(j, t)] = full[(j + 1, t)];
306 }
307 }
308 out
309}
310
311#[must_use = "expensive computation whose result should not be discarded"]
341pub fn fosr(data: &FdMatrix, predictors: &FdMatrix, lambda: f64) -> Result<FosrResult, FdarError> {
342 let (n, m) = data.shape();
343 let p = predictors.ncols();
344 if m == 0 {
345 return Err(FdarError::InvalidDimension {
346 parameter: "data",
347 expected: "at least 1 column (grid points)".to_string(),
348 actual: "0 columns".to_string(),
349 });
350 }
351 if predictors.nrows() != n {
352 return Err(FdarError::InvalidDimension {
353 parameter: "predictors",
354 expected: format!("{n} rows (matching data)"),
355 actual: format!("{} rows", predictors.nrows()),
356 });
357 }
358 if n < p + 2 {
359 return Err(FdarError::InvalidDimension {
360 parameter: "data",
361 expected: format!("at least {} observations (p + 2)", p + 2),
362 actual: format!("{n} observations"),
363 });
364 }
365
366 let design = build_fosr_design(predictors, n);
367 let p_total = design.ncols();
368 let xtx = compute_xtx(&design);
369 let xty = compute_xty_matrix(&design, data);
370 let penalty = penalty_matrix(p_total);
371
372 let lambda = if lambda < 0.0 {
373 select_lambda_gcv(&xtx, &xty, &penalty, data, &design)
374 } else {
375 lambda
376 };
377
378 let beta = penalized_solve(&xtx, &xty, &penalty, lambda)?;
379 let (fitted, residuals) = compute_fosr_fitted(&design, &beta, data);
380
381 let r_squared_t = pointwise_r_squared(data, &fitted);
382 let r_squared = r_squared_t.iter().sum::<f64>() / m as f64;
383 let beta_se = compute_beta_se(&xtx, &penalty, lambda, &residuals, p_total, n);
384 let trace_h = compute_trace_hat(&xtx, &penalty, lambda, p_total, n);
385 let gcv = compute_fosr_gcv(&residuals, trace_h);
386
387 let intercept: Vec<f64> = (0..m).map(|t| beta[(0, t)]).collect();
388
389 Ok(FosrResult {
390 intercept,
391 beta: drop_intercept_rows(&beta, p, m),
392 fitted,
393 residuals,
394 r_squared_t,
395 r_squared,
396 beta_se: drop_intercept_rows(&beta_se, p, m),
397 lambda,
398 gcv,
399 })
400}
401
402fn compute_fosr_fitted(
404 design: &FdMatrix,
405 beta: &FdMatrix,
406 data: &FdMatrix,
407) -> (FdMatrix, FdMatrix) {
408 let (n, m) = data.shape();
409 let p_total = design.ncols();
410 let rows: Vec<(Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
411 .map(|i| {
412 let mut fitted_row = vec![0.0; m];
413 let mut resid_row = vec![0.0; m];
414 for t in 0..m {
415 let mut yhat = 0.0;
416 for j in 0..p_total {
417 yhat += design[(i, j)] * beta[(j, t)];
418 }
419 fitted_row[t] = yhat;
420 resid_row[t] = data[(i, t)] - yhat;
421 }
422 (fitted_row, resid_row)
423 })
424 .collect();
425 let mut fitted = FdMatrix::zeros(n, m);
426 let mut residuals = FdMatrix::zeros(n, m);
427 for (i, (fr, rr)) in rows.into_iter().enumerate() {
428 for t in 0..m {
429 fitted[(i, t)] = fr[t];
430 residuals[(i, t)] = rr[t];
431 }
432 }
433 (fitted, residuals)
434}
435
436fn select_lambda_gcv(
438 xtx: &[f64],
439 xty: &FdMatrix,
440 penalty: &[f64],
441 data: &FdMatrix,
442 design: &FdMatrix,
443) -> f64 {
444 let lambdas = [0.0, 1e-6, 1e-4, 1e-2, 0.1, 1.0, 10.0, 100.0, 1000.0];
445 let p_total = design.ncols();
446 let n = design.nrows();
447
448 let mut best_lambda = 0.0;
449 let mut best_gcv = f64::INFINITY;
450
451 for &lam in &lambdas {
452 let Ok(beta) = penalized_solve(xtx, xty, penalty, lam) else {
453 continue;
454 };
455 let (_, residuals) = compute_fosr_fitted(design, &beta, data);
456 let trace_h = compute_trace_hat(xtx, penalty, lam, p_total, n);
457 let gcv = compute_fosr_gcv(&residuals, trace_h);
458 if gcv < best_gcv {
459 best_gcv = gcv;
460 best_lambda = lam;
461 }
462 }
463 best_lambda
464}
465
466fn compute_trace_hat(xtx: &[f64], penalty: &[f64], lambda: f64, p: usize, n: usize) -> f64 {
468 let mut a = vec![0.0; p * p];
469 for i in 0..p * p {
470 a[i] = xtx[i] + lambda * penalty[i];
471 }
472 let Some(l) = cholesky_factor(&a, p) else {
475 return p as f64; };
477
478 let mut trace = 0.0;
480 for j in 0..p {
481 let col: Vec<f64> = (0..p).map(|i| xtx[i * p + j]).collect();
482 let z = cholesky_forward_back(&l, &col, p);
483 trace += z[j]; }
485 trace.min(n as f64)
486}
487
488fn compute_beta_se(
490 xtx: &[f64],
491 penalty: &[f64],
492 lambda: f64,
493 residuals: &FdMatrix,
494 p: usize,
495 n: usize,
496) -> FdMatrix {
497 let m = residuals.ncols();
498 let mut a = vec![0.0; p * p];
499 for i in 0..p * p {
500 a[i] = xtx[i] + lambda * penalty[i];
501 }
502 let Some(l) = cholesky_factor(&a, p) else {
503 return FdMatrix::zeros(p, m);
504 };
505
506 let a_inv_diag: Vec<f64> = (0..p)
508 .map(|j| {
509 let mut ej = vec![0.0; p];
510 ej[j] = 1.0;
511 let v = cholesky_forward_back(&l, &ej, p);
512 v[j]
513 })
514 .collect();
515
516 let df = (n - p).max(1) as f64;
517 let mut se = FdMatrix::zeros(p, m);
518 for t in 0..m {
519 let sigma2_t: f64 = (0..n).map(|i| residuals[(i, t)].powi(2)).sum::<f64>() / df;
520 for j in 0..p {
521 se[(j, t)] = (sigma2_t * a_inv_diag[j]).max(0.0).sqrt();
522 }
523 }
524 se
525}
526
527fn regress_scores_on_design(
535 design: &FdMatrix,
536 scores: &FdMatrix,
537 n: usize,
538 k: usize,
539 p_total: usize,
540) -> Result<Vec<Vec<f64>>, FdarError> {
541 let xtx = compute_xtx(design);
542 let l = cholesky_factor(&xtx, p_total).ok_or_else(|| FdarError::ComputationFailed {
543 operation: "regress_scores_on_design",
544 detail: "Cholesky factorization of X'X failed; design matrix is rank-deficient — remove constant or collinear predictors, or add regularization".to_string(),
545 })?;
546
547 let gamma_all: Vec<Vec<f64>> = (0..k)
548 .map(|comp| {
549 let mut xts = vec![0.0; p_total];
550 for j in 0..p_total {
551 for i in 0..n {
552 xts[j] += design[(i, j)] * scores[(i, comp)];
553 }
554 }
555 cholesky_forward_back(&l, &xts, p_total)
556 })
557 .collect();
558 Ok(gamma_all)
559}
560
561fn reconstruct_beta_fpc(
563 gamma_all: &[Vec<f64>],
564 rotation: &FdMatrix,
565 p: usize,
566 k: usize,
567 m: usize,
568) -> FdMatrix {
569 let mut beta = FdMatrix::zeros(p, m);
570 for j in 0..p {
571 for t in 0..m {
572 let mut val = 0.0;
573 for comp in 0..k {
574 val += gamma_all[comp][1 + j] * rotation[(t, comp)];
575 }
576 beta[(j, t)] = val;
577 }
578 }
579 beta
580}
581
582fn compute_intercept_fpc(
584 mean: &[f64],
585 gamma_all: &[Vec<f64>],
586 rotation: &FdMatrix,
587 k: usize,
588 m: usize,
589) -> Vec<f64> {
590 let mut intercept = mean.to_vec();
591 for t in 0..m {
592 for comp in 0..k {
593 intercept[t] += gamma_all[comp][0] * rotation[(t, comp)];
594 }
595 }
596 intercept
597}
598
599fn extract_beta_scores(gamma_all: &[Vec<f64>], p: usize, k: usize, m: usize) -> Vec<Vec<f64>> {
601 let h = if m > 1 { 1.0 / (m - 1) as f64 } else { 1.0 };
602 let score_scale = h.sqrt();
603 (0..p)
604 .map(|j| {
605 (0..k)
606 .map(|comp| gamma_all[comp][1 + j] * score_scale)
607 .collect()
608 })
609 .collect()
610}
611
612#[must_use = "expensive computation whose result should not be discarded"]
631pub fn fosr_fpc(
632 data: &FdMatrix,
633 predictors: &FdMatrix,
634 ncomp: usize,
635) -> Result<FosrFpcResult, FdarError> {
636 let (n, m) = data.shape();
637 let p = predictors.ncols();
638 if m == 0 {
639 return Err(FdarError::InvalidDimension {
640 parameter: "data",
641 expected: "at least 1 column (grid points)".to_string(),
642 actual: "0 columns".to_string(),
643 });
644 }
645 if predictors.nrows() != n {
646 return Err(FdarError::InvalidDimension {
647 parameter: "predictors",
648 expected: format!("{n} rows (matching data)"),
649 actual: format!("{} rows", predictors.nrows()),
650 });
651 }
652 if n < p + 2 {
653 return Err(FdarError::InvalidDimension {
654 parameter: "data",
655 expected: format!("at least {} observations (p + 2)", p + 2),
656 actual: format!("{n} observations"),
657 });
658 }
659 if ncomp == 0 {
660 return Err(FdarError::InvalidParameter {
661 parameter: "ncomp",
662 message: "number of FPC components must be at least 1".to_string(),
663 });
664 }
665
666 let argvals: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1).max(1) as f64).collect();
667 let fpca = fdata_to_pc_1d(data, ncomp, &argvals)?;
668 let k = fpca.scores.ncols();
669 let p_total = p + 1;
670 let design = build_fosr_design(predictors, n);
671
672 let gamma_all = regress_scores_on_design(&design, &fpca.scores, n, k, p_total)?;
673 let beta = reconstruct_beta_fpc(&gamma_all, &fpca.rotation, p, k, m);
674 let intercept = compute_intercept_fpc(&fpca.mean, &gamma_all, &fpca.rotation, k, m);
675
676 let (fitted, residuals) = compute_fosr_fpc_fitted(data, &intercept, &beta, predictors);
677 let r_squared_t = pointwise_r_squared(data, &fitted);
678 let r_squared = r_squared_t.iter().sum::<f64>() / m as f64;
679 let beta_scores = extract_beta_scores(&gamma_all, p, k, m);
680
681 Ok(FosrFpcResult {
682 intercept,
683 beta,
684 fitted,
685 residuals,
686 r_squared_t,
687 r_squared,
688 beta_scores,
689 ncomp: k,
690 })
691}
692
693fn compute_fosr_fpc_fitted(
695 data: &FdMatrix,
696 intercept: &[f64],
697 beta: &FdMatrix,
698 predictors: &FdMatrix,
699) -> (FdMatrix, FdMatrix) {
700 let (n, m) = data.shape();
701 let p = predictors.ncols();
702 let mut fitted = FdMatrix::zeros(n, m);
703 let mut residuals = FdMatrix::zeros(n, m);
704 for i in 0..n {
705 for t in 0..m {
706 let mut yhat = intercept[t];
707 for j in 0..p {
708 yhat += predictors[(i, j)] * beta[(j, t)];
709 }
710 fitted[(i, t)] = yhat;
711 residuals[(i, t)] = data[(i, t)] - yhat;
712 }
713 }
714 (fitted, residuals)
715}
716
717#[must_use = "prediction result should not be discarded"]
723pub fn predict_fosr(result: &FosrResult, new_predictors: &FdMatrix) -> FdMatrix {
724 let n_new = new_predictors.nrows();
725 let m = result.intercept.len();
726 let p = result.beta.nrows();
727
728 let mut predicted = FdMatrix::zeros(n_new, m);
729 for i in 0..n_new {
730 for t in 0..m {
731 let mut yhat = result.intercept[t];
732 for j in 0..p {
733 yhat += new_predictors[(i, j)] * result.beta[(j, t)];
734 }
735 predicted[(i, t)] = yhat;
736 }
737 }
738 predicted
739}
740
741fn compute_group_means(
747 data: &FdMatrix,
748 groups: &[usize],
749 labels: &[usize],
750) -> (FdMatrix, Vec<f64>) {
751 let (n, m) = data.shape();
752 let k = labels.len();
753 let mut group_means = FdMatrix::zeros(k, m);
754 let mut counts = vec![0usize; k];
755
756 for i in 0..n {
757 let g = labels.iter().position(|&l| l == groups[i]).unwrap_or(0);
758 counts[g] += 1;
759 for t in 0..m {
760 group_means[(g, t)] += data[(i, t)];
761 }
762 }
763 for g in 0..k {
764 if counts[g] > 0 {
765 for t in 0..m {
766 group_means[(g, t)] /= counts[g] as f64;
767 }
768 }
769 }
770
771 let overall_mean: Vec<f64> = (0..m)
772 .map(|t| (0..n).map(|i| data[(i, t)]).sum::<f64>() / n as f64)
773 .collect();
774
775 (group_means, overall_mean)
776}
777
778fn pointwise_f_statistic(
780 data: &FdMatrix,
781 groups: &[usize],
782 labels: &[usize],
783 group_means: &FdMatrix,
784 overall_mean: &[f64],
785) -> Vec<f64> {
786 let (n, m) = data.shape();
787 let k = labels.len();
788 let mut counts = vec![0usize; k];
789 for &g in groups {
790 let idx = labels.iter().position(|&l| l == g).unwrap_or(0);
791 counts[idx] += 1;
792 }
793
794 (0..m)
795 .map(|t| {
796 let ss_between: f64 = (0..k)
797 .map(|g| counts[g] as f64 * (group_means[(g, t)] - overall_mean[t]).powi(2))
798 .sum();
799 let ss_within: f64 = (0..n)
800 .map(|i| {
801 let g = labels.iter().position(|&l| l == groups[i]).unwrap_or(0);
802 (data[(i, t)] - group_means[(g, t)]).powi(2)
803 })
804 .sum();
805 let ms_between = ss_between / (k as f64 - 1.0).max(1.0);
806 let ms_within = ss_within / (n as f64 - k as f64).max(1.0);
807 if ms_within > 1e-15 {
808 ms_between / ms_within
809 } else {
810 0.0
811 }
812 })
813 .collect()
814}
815
816fn global_f_statistic(f_t: &[f64]) -> f64 {
818 f_t.iter().sum::<f64>() / f_t.len() as f64
819}
820
821#[must_use = "expensive computation whose result should not be discarded"]
840pub fn fanova(data: &FdMatrix, groups: &[usize], n_perm: usize) -> Result<FanovaResult, FdarError> {
841 let (n, m) = data.shape();
842 if m == 0 {
843 return Err(FdarError::InvalidDimension {
844 parameter: "data",
845 expected: "at least 1 column (grid points)".to_string(),
846 actual: "0 columns".to_string(),
847 });
848 }
849 if groups.len() != n {
850 return Err(FdarError::InvalidDimension {
851 parameter: "groups",
852 expected: format!("{n} elements (matching data rows)"),
853 actual: format!("{} elements", groups.len()),
854 });
855 }
856 if n < 3 {
857 return Err(FdarError::InvalidDimension {
858 parameter: "data",
859 expected: "at least 3 observations".to_string(),
860 actual: format!("{n} observations"),
861 });
862 }
863
864 let mut labels: Vec<usize> = groups.to_vec();
865 labels.sort_unstable();
866 labels.dedup();
867 let n_groups = labels.len();
868 if n_groups < 2 {
869 return Err(FdarError::InvalidParameter {
870 parameter: "groups",
871 message: format!("at least 2 distinct groups required, but only {n_groups} found"),
872 });
873 }
874
875 let (group_means, overall_mean) = compute_group_means(data, groups, &labels);
876 let f_t = pointwise_f_statistic(data, groups, &labels, &group_means, &overall_mean);
877 let observed_stat = global_f_statistic(&f_t);
878
879 let n_perm = n_perm.max(1);
881 let mut n_ge = 0usize;
882 let mut perm_groups = groups.to_vec();
883
884 let mut rng_state: u64 = 42;
886 for _ in 0..n_perm {
887 for i in (1..n).rev() {
889 rng_state = rng_state
890 .wrapping_mul(6_364_136_223_846_793_005)
891 .wrapping_add(1);
892 let j = (rng_state >> 33) as usize % (i + 1);
893 perm_groups.swap(i, j);
894 }
895
896 let (perm_means, perm_overall) = compute_group_means(data, &perm_groups, &labels);
897 let perm_f = pointwise_f_statistic(data, &perm_groups, &labels, &perm_means, &perm_overall);
898 let perm_stat = global_f_statistic(&perm_f);
899 if perm_stat >= observed_stat {
900 n_ge += 1;
901 }
902 }
903
904 let p_value = (n_ge as f64 + 1.0) / (n_perm as f64 + 1.0);
905
906 Ok(FanovaResult {
907 group_means,
908 overall_mean,
909 f_statistic_t: f_t,
910 global_statistic: observed_stat,
911 p_value,
912 n_perm,
913 n_groups,
914 group_labels: labels,
915 })
916}
917
918impl FosrResult {
919 pub fn predict(&self, new_predictors: &FdMatrix) -> FdMatrix {
921 predict_fosr(self, new_predictors)
922 }
923}
924
925#[cfg(test)]
930mod tests {
931 use super::*;
932 use crate::test_helpers::uniform_grid;
933 use std::f64::consts::PI;
934
935 fn generate_fosr_data(n: usize, m: usize) -> (FdMatrix, FdMatrix) {
936 let t = uniform_grid(m);
937 let mut y = FdMatrix::zeros(n, m);
938 let mut z = FdMatrix::zeros(n, 2);
939
940 for i in 0..n {
941 let age = (i as f64) / (n as f64);
942 let group = if i % 2 == 0 { 1.0 } else { 0.0 };
943 z[(i, 0)] = age;
944 z[(i, 1)] = group;
945 for j in 0..m {
946 let mu = (2.0 * PI * t[j]).sin();
948 let beta1 = t[j]; let beta2 = (4.0 * PI * t[j]).cos(); y[(i, j)] = mu
951 + age * beta1
952 + group * beta2
953 + 0.05 * ((i * 13 + j * 7) % 100) as f64 / 100.0;
954 }
955 }
956 (y, z)
957 }
958
959 #[test]
962 fn test_fosr_basic() {
963 let (y, z) = generate_fosr_data(30, 50);
964 let result = fosr(&y, &z, 0.0);
965 assert!(result.is_ok());
966 let fit = result.unwrap();
967 assert_eq!(fit.intercept.len(), 50);
968 assert_eq!(fit.beta.shape(), (2, 50));
969 assert_eq!(fit.fitted.shape(), (30, 50));
970 assert_eq!(fit.residuals.shape(), (30, 50));
971 assert!(fit.r_squared >= 0.0);
972 }
973
974 #[test]
975 fn test_fosr_with_penalty() {
976 let (y, z) = generate_fosr_data(30, 50);
977 let fit0 = fosr(&y, &z, 0.0).unwrap();
978 let fit1 = fosr(&y, &z, 1.0).unwrap();
979 assert_eq!(fit0.beta.shape(), (2, 50));
981 assert_eq!(fit1.beta.shape(), (2, 50));
982 }
983
984 #[test]
985 fn test_fosr_auto_lambda() {
986 let (y, z) = generate_fosr_data(30, 50);
987 let fit = fosr(&y, &z, -1.0).unwrap();
988 assert!(fit.lambda >= 0.0);
989 }
990
991 #[test]
992 fn test_fosr_fitted_plus_residuals_equals_y() {
993 let (y, z) = generate_fosr_data(30, 50);
994 let fit = fosr(&y, &z, 0.0).unwrap();
995 for i in 0..30 {
996 for t in 0..50 {
997 let reconstructed = fit.fitted[(i, t)] + fit.residuals[(i, t)];
998 assert!(
999 (reconstructed - y[(i, t)]).abs() < 1e-10,
1000 "ŷ + r should equal y at ({}, {})",
1001 i,
1002 t
1003 );
1004 }
1005 }
1006 }
1007
1008 #[test]
1009 fn test_fosr_pointwise_r_squared_valid() {
1010 let (y, z) = generate_fosr_data(30, 50);
1011 let fit = fosr(&y, &z, 0.0).unwrap();
1012 for &r2 in &fit.r_squared_t {
1013 assert!(
1014 (-0.01..=1.0 + 1e-10).contains(&r2),
1015 "R²(t) out of range: {}",
1016 r2
1017 );
1018 }
1019 }
1020
1021 #[test]
1022 fn test_fosr_se_positive() {
1023 let (y, z) = generate_fosr_data(30, 50);
1024 let fit = fosr(&y, &z, 0.0).unwrap();
1025 for j in 0..2 {
1026 for t in 0..50 {
1027 assert!(
1028 fit.beta_se[(j, t)] >= 0.0 && fit.beta_se[(j, t)].is_finite(),
1029 "SE should be non-negative finite"
1030 );
1031 }
1032 }
1033 }
1034
1035 #[test]
1036 fn test_fosr_invalid_input() {
1037 let y = FdMatrix::zeros(2, 50);
1038 let z = FdMatrix::zeros(2, 1);
1039 assert!(fosr(&y, &z, 0.0).is_err());
1040 }
1041
1042 #[test]
1045 fn test_predict_fosr_on_training_data() {
1046 let (y, z) = generate_fosr_data(30, 50);
1047 let fit = fosr(&y, &z, 0.0).unwrap();
1048 let preds = predict_fosr(&fit, &z);
1049 assert_eq!(preds.shape(), (30, 50));
1050 for i in 0..30 {
1051 for t in 0..50 {
1052 assert!(
1053 (preds[(i, t)] - fit.fitted[(i, t)]).abs() < 1e-8,
1054 "Prediction on training data should match fitted"
1055 );
1056 }
1057 }
1058 }
1059
1060 #[test]
1063 fn test_fanova_two_groups() {
1064 let n = 40;
1065 let m = 50;
1066 let t = uniform_grid(m);
1067
1068 let mut data = FdMatrix::zeros(n, m);
1069 let mut groups = vec![0usize; n];
1070 for i in 0..n {
1071 groups[i] = if i < n / 2 { 0 } else { 1 };
1072 for j in 0..m {
1073 let base = (2.0 * PI * t[j]).sin();
1074 let effect = if groups[i] == 1 { 0.5 * t[j] } else { 0.0 };
1075 data[(i, j)] = base + effect + 0.01 * (i as f64 * 0.1).sin();
1076 }
1077 }
1078
1079 let result = fanova(&data, &groups, 200);
1080 assert!(result.is_ok());
1081 let res = result.unwrap();
1082 assert_eq!(res.n_groups, 2);
1083 assert_eq!(res.group_means.shape(), (2, m));
1084 assert_eq!(res.f_statistic_t.len(), m);
1085 assert!(res.p_value >= 0.0 && res.p_value <= 1.0);
1086 assert!(
1088 res.p_value < 0.1,
1089 "Should detect group effect, got p={}",
1090 res.p_value
1091 );
1092 }
1093
1094 #[test]
1095 fn test_fanova_no_effect() {
1096 let n = 40;
1097 let m = 50;
1098 let t = uniform_grid(m);
1099
1100 let mut data = FdMatrix::zeros(n, m);
1101 let mut groups = vec![0usize; n];
1102 for i in 0..n {
1103 groups[i] = if i < n / 2 { 0 } else { 1 };
1104 for j in 0..m {
1105 data[(i, j)] =
1107 (2.0 * PI * t[j]).sin() + 0.1 * ((i * 7 + j * 3) % 100) as f64 / 100.0;
1108 }
1109 }
1110
1111 let result = fanova(&data, &groups, 200);
1112 assert!(result.is_ok());
1113 let res = result.unwrap();
1114 assert!(
1116 res.p_value > 0.05,
1117 "Should not detect effect, got p={}",
1118 res.p_value
1119 );
1120 }
1121
1122 #[test]
1123 fn test_fanova_three_groups() {
1124 let n = 30;
1125 let m = 50;
1126 let t = uniform_grid(m);
1127
1128 let mut data = FdMatrix::zeros(n, m);
1129 let mut groups = vec![0usize; n];
1130 for i in 0..n {
1131 groups[i] = i % 3;
1132 for j in 0..m {
1133 let effect = match groups[i] {
1134 0 => 0.0,
1135 1 => 0.5 * t[j],
1136 _ => -0.3 * (2.0 * PI * t[j]).cos(),
1137 };
1138 data[(i, j)] = (2.0 * PI * t[j]).sin() + effect + 0.01 * (i as f64 * 0.1).sin();
1139 }
1140 }
1141
1142 let result = fanova(&data, &groups, 200);
1143 assert!(result.is_ok());
1144 let res = result.unwrap();
1145 assert_eq!(res.n_groups, 3);
1146 }
1147
1148 #[test]
1149 fn test_fanova_invalid_input() {
1150 let data = FdMatrix::zeros(10, 50);
1151 let groups = vec![0; 10]; assert!(fanova(&data, &groups, 100).is_err());
1153
1154 let groups = vec![0; 5]; assert!(fanova(&data, &groups, 100).is_err());
1156 }
1157}