1use crate::error::FdarError;
15use crate::iter_maybe_parallel;
16use crate::matrix::FdMatrix;
17use crate::regression::fdata_to_pc_1d;
18#[cfg(feature = "parallel")]
19use rayon::iter::ParallelIterator;
20
21fn cholesky_factor(a: &[f64], p: usize) -> Option<Vec<f64>> {
27 let mut l = vec![0.0; p * p];
28 for j in 0..p {
29 let mut diag = a[j * p + j];
30 for k in 0..j {
31 diag -= l[j * p + k] * l[j * p + k];
32 }
33 if diag <= 1e-12 {
34 return None;
35 }
36 l[j * p + j] = diag.sqrt();
37 for i in (j + 1)..p {
38 let mut s = a[i * p + j];
39 for k in 0..j {
40 s -= l[i * p + k] * l[j * p + k];
41 }
42 l[i * p + j] = s / l[j * p + j];
43 }
44 }
45 Some(l)
46}
47
48fn cholesky_forward_back(l: &[f64], b: &[f64], p: usize) -> Vec<f64> {
50 let mut z = b.to_vec();
51 for j in 0..p {
52 for k in 0..j {
53 z[j] -= l[j * p + k] * z[k];
54 }
55 z[j] /= l[j * p + j];
56 }
57 for j in (0..p).rev() {
58 for k in (j + 1)..p {
59 z[j] -= l[k * p + j] * z[k];
60 }
61 z[j] /= l[j * p + j];
62 }
63 z
64}
65
66pub(crate) fn compute_xtx(x: &FdMatrix) -> Vec<f64> {
68 let (n, p) = x.shape();
69 let mut xtx = vec![0.0; p * p];
70 for k in 0..p {
71 for j in k..p {
72 let mut s = 0.0;
73 for i in 0..n {
74 s += x[(i, k)] * x[(i, j)];
75 }
76 xtx[k * p + j] = s;
77 xtx[j * p + k] = s;
78 }
79 }
80 xtx
81}
82
83#[derive(Debug, Clone, PartialEq)]
89pub struct FosrResult {
90 pub intercept: Vec<f64>,
92 pub beta: FdMatrix,
94 pub fitted: FdMatrix,
96 pub residuals: FdMatrix,
98 pub r_squared_t: Vec<f64>,
100 pub r_squared: f64,
102 pub beta_se: FdMatrix,
104 pub lambda: f64,
106 pub gcv: f64,
108}
109
110#[derive(Debug, Clone, PartialEq)]
112pub struct FosrFpcResult {
113 pub intercept: Vec<f64>,
115 pub beta: FdMatrix,
117 pub fitted: FdMatrix,
119 pub residuals: FdMatrix,
121 pub r_squared_t: Vec<f64>,
123 pub r_squared: f64,
125 pub beta_scores: Vec<Vec<f64>>,
127 pub ncomp: usize,
129}
130
131#[derive(Debug, Clone, PartialEq)]
133pub struct FanovaResult {
134 pub group_means: FdMatrix,
136 pub overall_mean: Vec<f64>,
138 pub f_statistic_t: Vec<f64>,
140 pub global_statistic: f64,
142 pub p_value: f64,
144 pub n_perm: usize,
146 pub n_groups: usize,
148 pub group_labels: Vec<usize>,
150}
151
152pub(crate) fn penalty_matrix(m: usize) -> Vec<f64> {
158 if m < 3 {
159 return vec![0.0; m * m];
160 }
161 let mut dtd = vec![0.0; m * m];
164 for i in 0..m - 2 {
165 let coeffs = [(i, 1.0), (i + 1, -2.0), (i + 2, 1.0)];
167 for &(r, cr) in &coeffs {
168 for &(c, cc) in &coeffs {
169 dtd[r * m + c] += cr * cc;
170 }
171 }
172 }
173 dtd
174}
175
176fn penalized_solve(
180 xtx: &[f64],
181 xty: &FdMatrix,
182 penalty: &[f64],
183 lambda: f64,
184) -> Result<FdMatrix, FdarError> {
185 let p = xty.nrows();
186 let m = xty.ncols();
187
188 let mut a = vec![0.0; p * p];
190 for i in 0..p * p {
191 a[i] = xtx[i] + lambda * penalty[i];
192 }
193
194 let l = cholesky_factor(&a, p).ok_or_else(|| FdarError::ComputationFailed {
196 operation: "penalized_solve",
197 detail: format!(
198 "Cholesky factorization of (X'X + {lambda:.4}*P) failed; matrix is singular or near-singular"
199 ),
200 })?;
201
202 let mut beta = FdMatrix::zeros(p, m);
204 for t in 0..m {
205 let b: Vec<f64> = (0..p).map(|j| xty[(j, t)]).collect();
206 let x = cholesky_forward_back(&l, &b, p);
207 for j in 0..p {
208 beta[(j, t)] = x[j];
209 }
210 }
211 Ok(beta)
212}
213
214pub(crate) fn pointwise_r_squared(data: &FdMatrix, fitted: &FdMatrix) -> Vec<f64> {
216 let (n, m) = data.shape();
217 (0..m)
218 .map(|t| {
219 let mean_t: f64 = (0..n).map(|i| data[(i, t)]).sum::<f64>() / n as f64;
220 let ss_tot: f64 = (0..n).map(|i| (data[(i, t)] - mean_t).powi(2)).sum();
221 let ss_res: f64 = (0..n)
222 .map(|i| (data[(i, t)] - fitted[(i, t)]).powi(2))
223 .sum();
224 if ss_tot > 1e-15 {
225 1.0 - ss_res / ss_tot
226 } else {
227 0.0
228 }
229 })
230 .collect()
231}
232
233fn compute_fosr_gcv(residuals: &FdMatrix, trace_h: f64) -> f64 {
235 let (n, m) = residuals.shape();
236 let denom = (1.0 - trace_h / n as f64).max(1e-10);
237 let ss_res: f64 = (0..n)
238 .flat_map(|i| (0..m).map(move |t| residuals[(i, t)].powi(2)))
239 .sum();
240 ss_res / (n as f64 * m as f64 * denom * denom)
241}
242
243pub(crate) fn build_fosr_design(predictors: &FdMatrix, n: usize) -> FdMatrix {
268 let p = predictors.ncols();
269 let p_total = p + 1;
270 let mut design = FdMatrix::zeros(n, p_total);
271 for i in 0..n {
272 design[(i, 0)] = 1.0;
273 for j in 0..p {
274 design[(i, 1 + j)] = predictors[(i, j)];
275 }
276 }
277 design
278}
279
280pub(crate) fn compute_xty_matrix(design: &FdMatrix, data: &FdMatrix) -> FdMatrix {
282 let (n, m) = data.shape();
283 let p_total = design.ncols();
284 let mut xty = FdMatrix::zeros(p_total, m);
285 for j in 0..p_total {
286 for t in 0..m {
287 let mut s = 0.0;
288 for i in 0..n {
289 s += design[(i, j)] * data[(i, t)];
290 }
291 xty[(j, t)] = s;
292 }
293 }
294 xty
295}
296
297fn drop_intercept_rows(full: &FdMatrix, p: usize, m: usize) -> FdMatrix {
299 let mut out = FdMatrix::zeros(p, m);
300 for j in 0..p {
301 for t in 0..m {
302 out[(j, t)] = full[(j + 1, t)];
303 }
304 }
305 out
306}
307
308#[must_use = "expensive computation whose result should not be discarded"]
309pub fn fosr(data: &FdMatrix, predictors: &FdMatrix, lambda: f64) -> Result<FosrResult, FdarError> {
310 let (n, m) = data.shape();
311 let p = predictors.ncols();
312 if m == 0 {
313 return Err(FdarError::InvalidDimension {
314 parameter: "data",
315 expected: "at least 1 column (grid points)".to_string(),
316 actual: "0 columns".to_string(),
317 });
318 }
319 if predictors.nrows() != n {
320 return Err(FdarError::InvalidDimension {
321 parameter: "predictors",
322 expected: format!("{n} rows (matching data)"),
323 actual: format!("{} rows", predictors.nrows()),
324 });
325 }
326 if n < p + 2 {
327 return Err(FdarError::InvalidDimension {
328 parameter: "data",
329 expected: format!("at least {} observations (p + 2)", p + 2),
330 actual: format!("{n} observations"),
331 });
332 }
333
334 let design = build_fosr_design(predictors, n);
335 let p_total = design.ncols();
336 let xtx = compute_xtx(&design);
337 let xty = compute_xty_matrix(&design, data);
338 let penalty = penalty_matrix(p_total);
339
340 let lambda = if lambda < 0.0 {
341 select_lambda_gcv(&xtx, &xty, &penalty, data, &design)
342 } else {
343 lambda
344 };
345
346 let beta = penalized_solve(&xtx, &xty, &penalty, lambda)?;
347 let (fitted, residuals) = compute_fosr_fitted(&design, &beta, data);
348
349 let r_squared_t = pointwise_r_squared(data, &fitted);
350 let r_squared = r_squared_t.iter().sum::<f64>() / m as f64;
351 let beta_se = compute_beta_se(&xtx, &penalty, lambda, &residuals, p_total, n);
352 let trace_h = compute_trace_hat(&xtx, &penalty, lambda, p_total, n);
353 let gcv = compute_fosr_gcv(&residuals, trace_h);
354
355 let intercept: Vec<f64> = (0..m).map(|t| beta[(0, t)]).collect();
356
357 Ok(FosrResult {
358 intercept,
359 beta: drop_intercept_rows(&beta, p, m),
360 fitted,
361 residuals,
362 r_squared_t,
363 r_squared,
364 beta_se: drop_intercept_rows(&beta_se, p, m),
365 lambda,
366 gcv,
367 })
368}
369
370fn compute_fosr_fitted(
372 design: &FdMatrix,
373 beta: &FdMatrix,
374 data: &FdMatrix,
375) -> (FdMatrix, FdMatrix) {
376 let (n, m) = data.shape();
377 let p_total = design.ncols();
378 let rows: Vec<(Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
379 .map(|i| {
380 let mut fitted_row = vec![0.0; m];
381 let mut resid_row = vec![0.0; m];
382 for t in 0..m {
383 let mut yhat = 0.0;
384 for j in 0..p_total {
385 yhat += design[(i, j)] * beta[(j, t)];
386 }
387 fitted_row[t] = yhat;
388 resid_row[t] = data[(i, t)] - yhat;
389 }
390 (fitted_row, resid_row)
391 })
392 .collect();
393 let mut fitted = FdMatrix::zeros(n, m);
394 let mut residuals = FdMatrix::zeros(n, m);
395 for (i, (fr, rr)) in rows.into_iter().enumerate() {
396 for t in 0..m {
397 fitted[(i, t)] = fr[t];
398 residuals[(i, t)] = rr[t];
399 }
400 }
401 (fitted, residuals)
402}
403
404fn select_lambda_gcv(
406 xtx: &[f64],
407 xty: &FdMatrix,
408 penalty: &[f64],
409 data: &FdMatrix,
410 design: &FdMatrix,
411) -> f64 {
412 let lambdas = [0.0, 1e-6, 1e-4, 1e-2, 0.1, 1.0, 10.0, 100.0, 1000.0];
413 let p_total = design.ncols();
414 let n = design.nrows();
415
416 let mut best_lambda = 0.0;
417 let mut best_gcv = f64::INFINITY;
418
419 for &lam in &lambdas {
420 let beta = match penalized_solve(xtx, xty, penalty, lam) {
421 Ok(b) => b,
422 Err(_) => continue,
423 };
424 let (_, residuals) = compute_fosr_fitted(design, &beta, data);
425 let trace_h = compute_trace_hat(xtx, penalty, lam, p_total, n);
426 let gcv = compute_fosr_gcv(&residuals, trace_h);
427 if gcv < best_gcv {
428 best_gcv = gcv;
429 best_lambda = lam;
430 }
431 }
432 best_lambda
433}
434
435fn compute_trace_hat(xtx: &[f64], penalty: &[f64], lambda: f64, p: usize, n: usize) -> f64 {
437 let mut a = vec![0.0; p * p];
438 for i in 0..p * p {
439 a[i] = xtx[i] + lambda * penalty[i];
440 }
441 let l = match cholesky_factor(&a, p) {
444 Some(l) => l,
445 None => return p as f64, };
447
448 let mut trace = 0.0;
450 for j in 0..p {
451 let col: Vec<f64> = (0..p).map(|i| xtx[i * p + j]).collect();
452 let z = cholesky_forward_back(&l, &col, p);
453 trace += z[j]; }
455 trace.min(n as f64)
456}
457
458fn compute_beta_se(
460 xtx: &[f64],
461 penalty: &[f64],
462 lambda: f64,
463 residuals: &FdMatrix,
464 p: usize,
465 n: usize,
466) -> FdMatrix {
467 let m = residuals.ncols();
468 let mut a = vec![0.0; p * p];
469 for i in 0..p * p {
470 a[i] = xtx[i] + lambda * penalty[i];
471 }
472 let l = match cholesky_factor(&a, p) {
473 Some(l) => l,
474 None => return FdMatrix::zeros(p, m),
475 };
476
477 let a_inv_diag: Vec<f64> = (0..p)
479 .map(|j| {
480 let mut ej = vec![0.0; p];
481 ej[j] = 1.0;
482 let v = cholesky_forward_back(&l, &ej, p);
483 v[j]
484 })
485 .collect();
486
487 let df = (n - p).max(1) as f64;
488 let mut se = FdMatrix::zeros(p, m);
489 for t in 0..m {
490 let sigma2_t: f64 = (0..n).map(|i| residuals[(i, t)].powi(2)).sum::<f64>() / df;
491 for j in 0..p {
492 se[(j, t)] = (sigma2_t * a_inv_diag[j]).max(0.0).sqrt();
493 }
494 }
495 se
496}
497
498fn regress_scores_on_design(
506 design: &FdMatrix,
507 scores: &FdMatrix,
508 n: usize,
509 k: usize,
510 p_total: usize,
511) -> Result<Vec<Vec<f64>>, FdarError> {
512 let xtx = compute_xtx(design);
513 let l = cholesky_factor(&xtx, p_total).ok_or_else(|| FdarError::ComputationFailed {
514 operation: "regress_scores_on_design",
515 detail: "Cholesky factorization of X'X failed; design matrix is rank-deficient".to_string(),
516 })?;
517
518 let gamma_all: Vec<Vec<f64>> = (0..k)
519 .map(|comp| {
520 let mut xts = vec![0.0; p_total];
521 for j in 0..p_total {
522 for i in 0..n {
523 xts[j] += design[(i, j)] * scores[(i, comp)];
524 }
525 }
526 cholesky_forward_back(&l, &xts, p_total)
527 })
528 .collect();
529 Ok(gamma_all)
530}
531
532fn reconstruct_beta_fpc(
534 gamma_all: &[Vec<f64>],
535 rotation: &FdMatrix,
536 p: usize,
537 k: usize,
538 m: usize,
539) -> FdMatrix {
540 let mut beta = FdMatrix::zeros(p, m);
541 for j in 0..p {
542 for t in 0..m {
543 let mut val = 0.0;
544 for comp in 0..k {
545 val += gamma_all[comp][1 + j] * rotation[(t, comp)];
546 }
547 beta[(j, t)] = val;
548 }
549 }
550 beta
551}
552
553fn compute_intercept_fpc(
555 mean: &[f64],
556 gamma_all: &[Vec<f64>],
557 rotation: &FdMatrix,
558 k: usize,
559 m: usize,
560) -> Vec<f64> {
561 let mut intercept = mean.to_vec();
562 for t in 0..m {
563 for comp in 0..k {
564 intercept[t] += gamma_all[comp][0] * rotation[(t, comp)];
565 }
566 }
567 intercept
568}
569
570fn extract_beta_scores(gamma_all: &[Vec<f64>], p: usize, k: usize, m: usize) -> Vec<Vec<f64>> {
572 let h = if m > 1 { 1.0 / (m - 1) as f64 } else { 1.0 };
573 let score_scale = h.sqrt();
574 (0..p)
575 .map(|j| {
576 (0..k)
577 .map(|comp| gamma_all[comp][1 + j] * score_scale)
578 .collect()
579 })
580 .collect()
581}
582
583#[must_use = "expensive computation whose result should not be discarded"]
602pub fn fosr_fpc(
603 data: &FdMatrix,
604 predictors: &FdMatrix,
605 ncomp: usize,
606) -> Result<FosrFpcResult, FdarError> {
607 let (n, m) = data.shape();
608 let p = predictors.ncols();
609 if m == 0 {
610 return Err(FdarError::InvalidDimension {
611 parameter: "data",
612 expected: "at least 1 column (grid points)".to_string(),
613 actual: "0 columns".to_string(),
614 });
615 }
616 if predictors.nrows() != n {
617 return Err(FdarError::InvalidDimension {
618 parameter: "predictors",
619 expected: format!("{n} rows (matching data)"),
620 actual: format!("{} rows", predictors.nrows()),
621 });
622 }
623 if n < p + 2 {
624 return Err(FdarError::InvalidDimension {
625 parameter: "data",
626 expected: format!("at least {} observations (p + 2)", p + 2),
627 actual: format!("{n} observations"),
628 });
629 }
630 if ncomp == 0 {
631 return Err(FdarError::InvalidParameter {
632 parameter: "ncomp",
633 message: "number of FPC components must be at least 1".to_string(),
634 });
635 }
636
637 let fpca = fdata_to_pc_1d(data, ncomp)?;
638 let k = fpca.scores.ncols();
639 let p_total = p + 1;
640 let design = build_fosr_design(predictors, n);
641
642 let gamma_all = regress_scores_on_design(&design, &fpca.scores, n, k, p_total)?;
643 let beta = reconstruct_beta_fpc(&gamma_all, &fpca.rotation, p, k, m);
644 let intercept = compute_intercept_fpc(&fpca.mean, &gamma_all, &fpca.rotation, k, m);
645
646 let (fitted, residuals) = compute_fosr_fpc_fitted(data, &intercept, &beta, predictors);
647 let r_squared_t = pointwise_r_squared(data, &fitted);
648 let r_squared = r_squared_t.iter().sum::<f64>() / m as f64;
649 let beta_scores = extract_beta_scores(&gamma_all, p, k, m);
650
651 Ok(FosrFpcResult {
652 intercept,
653 beta,
654 fitted,
655 residuals,
656 r_squared_t,
657 r_squared,
658 beta_scores,
659 ncomp: k,
660 })
661}
662
663fn compute_fosr_fpc_fitted(
665 data: &FdMatrix,
666 intercept: &[f64],
667 beta: &FdMatrix,
668 predictors: &FdMatrix,
669) -> (FdMatrix, FdMatrix) {
670 let (n, m) = data.shape();
671 let p = predictors.ncols();
672 let mut fitted = FdMatrix::zeros(n, m);
673 let mut residuals = FdMatrix::zeros(n, m);
674 for i in 0..n {
675 for t in 0..m {
676 let mut yhat = intercept[t];
677 for j in 0..p {
678 yhat += predictors[(i, j)] * beta[(j, t)];
679 }
680 fitted[(i, t)] = yhat;
681 residuals[(i, t)] = data[(i, t)] - yhat;
682 }
683 }
684 (fitted, residuals)
685}
686
687#[must_use = "prediction result should not be discarded"]
693pub fn predict_fosr(result: &FosrResult, new_predictors: &FdMatrix) -> FdMatrix {
694 let n_new = new_predictors.nrows();
695 let m = result.intercept.len();
696 let p = result.beta.nrows();
697
698 let mut predicted = FdMatrix::zeros(n_new, m);
699 for i in 0..n_new {
700 for t in 0..m {
701 let mut yhat = result.intercept[t];
702 for j in 0..p {
703 yhat += new_predictors[(i, j)] * result.beta[(j, t)];
704 }
705 predicted[(i, t)] = yhat;
706 }
707 }
708 predicted
709}
710
711fn compute_group_means(
717 data: &FdMatrix,
718 groups: &[usize],
719 labels: &[usize],
720) -> (FdMatrix, Vec<f64>) {
721 let (n, m) = data.shape();
722 let k = labels.len();
723 let mut group_means = FdMatrix::zeros(k, m);
724 let mut counts = vec![0usize; k];
725
726 for i in 0..n {
727 let g = labels.iter().position(|&l| l == groups[i]).unwrap_or(0);
728 counts[g] += 1;
729 for t in 0..m {
730 group_means[(g, t)] += data[(i, t)];
731 }
732 }
733 for g in 0..k {
734 if counts[g] > 0 {
735 for t in 0..m {
736 group_means[(g, t)] /= counts[g] as f64;
737 }
738 }
739 }
740
741 let overall_mean: Vec<f64> = (0..m)
742 .map(|t| (0..n).map(|i| data[(i, t)]).sum::<f64>() / n as f64)
743 .collect();
744
745 (group_means, overall_mean)
746}
747
748fn pointwise_f_statistic(
750 data: &FdMatrix,
751 groups: &[usize],
752 labels: &[usize],
753 group_means: &FdMatrix,
754 overall_mean: &[f64],
755) -> Vec<f64> {
756 let (n, m) = data.shape();
757 let k = labels.len();
758 let mut counts = vec![0usize; k];
759 for &g in groups {
760 let idx = labels.iter().position(|&l| l == g).unwrap_or(0);
761 counts[idx] += 1;
762 }
763
764 (0..m)
765 .map(|t| {
766 let ss_between: f64 = (0..k)
767 .map(|g| counts[g] as f64 * (group_means[(g, t)] - overall_mean[t]).powi(2))
768 .sum();
769 let ss_within: f64 = (0..n)
770 .map(|i| {
771 let g = labels.iter().position(|&l| l == groups[i]).unwrap_or(0);
772 (data[(i, t)] - group_means[(g, t)]).powi(2)
773 })
774 .sum();
775 let ms_between = ss_between / (k as f64 - 1.0).max(1.0);
776 let ms_within = ss_within / (n as f64 - k as f64).max(1.0);
777 if ms_within > 1e-15 {
778 ms_between / ms_within
779 } else {
780 0.0
781 }
782 })
783 .collect()
784}
785
786fn global_f_statistic(f_t: &[f64]) -> f64 {
788 f_t.iter().sum::<f64>() / f_t.len() as f64
789}
790
791#[must_use = "expensive computation whose result should not be discarded"]
810pub fn fanova(data: &FdMatrix, groups: &[usize], n_perm: usize) -> Result<FanovaResult, FdarError> {
811 let (n, m) = data.shape();
812 if m == 0 {
813 return Err(FdarError::InvalidDimension {
814 parameter: "data",
815 expected: "at least 1 column (grid points)".to_string(),
816 actual: "0 columns".to_string(),
817 });
818 }
819 if groups.len() != n {
820 return Err(FdarError::InvalidDimension {
821 parameter: "groups",
822 expected: format!("{n} elements (matching data rows)"),
823 actual: format!("{} elements", groups.len()),
824 });
825 }
826 if n < 3 {
827 return Err(FdarError::InvalidDimension {
828 parameter: "data",
829 expected: "at least 3 observations".to_string(),
830 actual: format!("{n} observations"),
831 });
832 }
833
834 let mut labels: Vec<usize> = groups.to_vec();
835 labels.sort();
836 labels.dedup();
837 let n_groups = labels.len();
838 if n_groups < 2 {
839 return Err(FdarError::InvalidParameter {
840 parameter: "groups",
841 message: format!("at least 2 distinct groups required, but only {n_groups} found"),
842 });
843 }
844
845 let (group_means, overall_mean) = compute_group_means(data, groups, &labels);
846 let f_t = pointwise_f_statistic(data, groups, &labels, &group_means, &overall_mean);
847 let observed_stat = global_f_statistic(&f_t);
848
849 let n_perm = n_perm.max(1);
851 let mut n_ge = 0usize;
852 let mut perm_groups = groups.to_vec();
853
854 let mut rng_state: u64 = 42;
856 for _ in 0..n_perm {
857 for i in (1..n).rev() {
859 rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
860 let j = (rng_state >> 33) as usize % (i + 1);
861 perm_groups.swap(i, j);
862 }
863
864 let (perm_means, perm_overall) = compute_group_means(data, &perm_groups, &labels);
865 let perm_f = pointwise_f_statistic(data, &perm_groups, &labels, &perm_means, &perm_overall);
866 let perm_stat = global_f_statistic(&perm_f);
867 if perm_stat >= observed_stat {
868 n_ge += 1;
869 }
870 }
871
872 let p_value = (n_ge as f64 + 1.0) / (n_perm as f64 + 1.0);
873
874 Ok(FanovaResult {
875 group_means,
876 overall_mean,
877 f_statistic_t: f_t,
878 global_statistic: observed_stat,
879 p_value,
880 n_perm,
881 n_groups,
882 group_labels: labels,
883 })
884}
885
886impl FosrResult {
887 pub fn predict(&self, new_predictors: &FdMatrix) -> FdMatrix {
889 predict_fosr(self, new_predictors)
890 }
891}
892
893#[cfg(test)]
898mod tests {
899 use super::*;
900 use std::f64::consts::PI;
901
902 fn uniform_grid(m: usize) -> Vec<f64> {
903 (0..m).map(|j| j as f64 / (m - 1) as f64).collect()
904 }
905
906 fn generate_fosr_data(n: usize, m: usize) -> (FdMatrix, FdMatrix) {
907 let t = uniform_grid(m);
908 let mut y = FdMatrix::zeros(n, m);
909 let mut z = FdMatrix::zeros(n, 2);
910
911 for i in 0..n {
912 let age = (i as f64) / (n as f64);
913 let group = if i % 2 == 0 { 1.0 } else { 0.0 };
914 z[(i, 0)] = age;
915 z[(i, 1)] = group;
916 for j in 0..m {
917 let mu = (2.0 * PI * t[j]).sin();
919 let beta1 = t[j]; let beta2 = (4.0 * PI * t[j]).cos(); y[(i, j)] = mu
922 + age * beta1
923 + group * beta2
924 + 0.05 * ((i * 13 + j * 7) % 100) as f64 / 100.0;
925 }
926 }
927 (y, z)
928 }
929
930 #[test]
933 fn test_fosr_basic() {
934 let (y, z) = generate_fosr_data(30, 50);
935 let result = fosr(&y, &z, 0.0);
936 assert!(result.is_ok());
937 let fit = result.unwrap();
938 assert_eq!(fit.intercept.len(), 50);
939 assert_eq!(fit.beta.shape(), (2, 50));
940 assert_eq!(fit.fitted.shape(), (30, 50));
941 assert_eq!(fit.residuals.shape(), (30, 50));
942 assert!(fit.r_squared >= 0.0);
943 }
944
945 #[test]
946 fn test_fosr_with_penalty() {
947 let (y, z) = generate_fosr_data(30, 50);
948 let fit0 = fosr(&y, &z, 0.0).unwrap();
949 let fit1 = fosr(&y, &z, 1.0).unwrap();
950 assert_eq!(fit0.beta.shape(), (2, 50));
952 assert_eq!(fit1.beta.shape(), (2, 50));
953 }
954
955 #[test]
956 fn test_fosr_auto_lambda() {
957 let (y, z) = generate_fosr_data(30, 50);
958 let fit = fosr(&y, &z, -1.0).unwrap();
959 assert!(fit.lambda >= 0.0);
960 }
961
962 #[test]
963 fn test_fosr_fitted_plus_residuals_equals_y() {
964 let (y, z) = generate_fosr_data(30, 50);
965 let fit = fosr(&y, &z, 0.0).unwrap();
966 for i in 0..30 {
967 for t in 0..50 {
968 let reconstructed = fit.fitted[(i, t)] + fit.residuals[(i, t)];
969 assert!(
970 (reconstructed - y[(i, t)]).abs() < 1e-10,
971 "ŷ + r should equal y at ({}, {})",
972 i,
973 t
974 );
975 }
976 }
977 }
978
979 #[test]
980 fn test_fosr_pointwise_r_squared_valid() {
981 let (y, z) = generate_fosr_data(30, 50);
982 let fit = fosr(&y, &z, 0.0).unwrap();
983 for &r2 in &fit.r_squared_t {
984 assert!(
985 (-0.01..=1.0 + 1e-10).contains(&r2),
986 "R²(t) out of range: {}",
987 r2
988 );
989 }
990 }
991
992 #[test]
993 fn test_fosr_se_positive() {
994 let (y, z) = generate_fosr_data(30, 50);
995 let fit = fosr(&y, &z, 0.0).unwrap();
996 for j in 0..2 {
997 for t in 0..50 {
998 assert!(
999 fit.beta_se[(j, t)] >= 0.0 && fit.beta_se[(j, t)].is_finite(),
1000 "SE should be non-negative finite"
1001 );
1002 }
1003 }
1004 }
1005
1006 #[test]
1007 fn test_fosr_invalid_input() {
1008 let y = FdMatrix::zeros(2, 50);
1009 let z = FdMatrix::zeros(2, 1);
1010 assert!(fosr(&y, &z, 0.0).is_err());
1011 }
1012
1013 #[test]
1016 fn test_predict_fosr_on_training_data() {
1017 let (y, z) = generate_fosr_data(30, 50);
1018 let fit = fosr(&y, &z, 0.0).unwrap();
1019 let preds = predict_fosr(&fit, &z);
1020 assert_eq!(preds.shape(), (30, 50));
1021 for i in 0..30 {
1022 for t in 0..50 {
1023 assert!(
1024 (preds[(i, t)] - fit.fitted[(i, t)]).abs() < 1e-8,
1025 "Prediction on training data should match fitted"
1026 );
1027 }
1028 }
1029 }
1030
1031 #[test]
1034 fn test_fanova_two_groups() {
1035 let n = 40;
1036 let m = 50;
1037 let t = uniform_grid(m);
1038
1039 let mut data = FdMatrix::zeros(n, m);
1040 let mut groups = vec![0usize; n];
1041 for i in 0..n {
1042 groups[i] = if i < n / 2 { 0 } else { 1 };
1043 for j in 0..m {
1044 let base = (2.0 * PI * t[j]).sin();
1045 let effect = if groups[i] == 1 { 0.5 * t[j] } else { 0.0 };
1046 data[(i, j)] = base + effect + 0.01 * (i as f64 * 0.1).sin();
1047 }
1048 }
1049
1050 let result = fanova(&data, &groups, 200);
1051 assert!(result.is_ok());
1052 let res = result.unwrap();
1053 assert_eq!(res.n_groups, 2);
1054 assert_eq!(res.group_means.shape(), (2, m));
1055 assert_eq!(res.f_statistic_t.len(), m);
1056 assert!(res.p_value >= 0.0 && res.p_value <= 1.0);
1057 assert!(
1059 res.p_value < 0.1,
1060 "Should detect group effect, got p={}",
1061 res.p_value
1062 );
1063 }
1064
1065 #[test]
1066 fn test_fanova_no_effect() {
1067 let n = 40;
1068 let m = 50;
1069 let t = uniform_grid(m);
1070
1071 let mut data = FdMatrix::zeros(n, m);
1072 let mut groups = vec![0usize; n];
1073 for i in 0..n {
1074 groups[i] = if i < n / 2 { 0 } else { 1 };
1075 for j in 0..m {
1076 data[(i, j)] =
1078 (2.0 * PI * t[j]).sin() + 0.1 * ((i * 7 + j * 3) % 100) as f64 / 100.0;
1079 }
1080 }
1081
1082 let result = fanova(&data, &groups, 200);
1083 assert!(result.is_ok());
1084 let res = result.unwrap();
1085 assert!(
1087 res.p_value > 0.05,
1088 "Should not detect effect, got p={}",
1089 res.p_value
1090 );
1091 }
1092
1093 #[test]
1094 fn test_fanova_three_groups() {
1095 let n = 30;
1096 let m = 50;
1097 let t = uniform_grid(m);
1098
1099 let mut data = FdMatrix::zeros(n, m);
1100 let mut groups = vec![0usize; n];
1101 for i in 0..n {
1102 groups[i] = i % 3;
1103 for j in 0..m {
1104 let effect = match groups[i] {
1105 0 => 0.0,
1106 1 => 0.5 * t[j],
1107 _ => -0.3 * (2.0 * PI * t[j]).cos(),
1108 };
1109 data[(i, j)] = (2.0 * PI * t[j]).sin() + effect + 0.01 * (i as f64 * 0.1).sin();
1110 }
1111 }
1112
1113 let result = fanova(&data, &groups, 200);
1114 assert!(result.is_ok());
1115 let res = result.unwrap();
1116 assert_eq!(res.n_groups, 3);
1117 }
1118
1119 #[test]
1120 fn test_fanova_invalid_input() {
1121 let data = FdMatrix::zeros(10, 50);
1122 let groups = vec![0; 10]; assert!(fanova(&data, &groups, 100).is_err());
1124
1125 let groups = vec![0; 5]; assert!(fanova(&data, &groups, 100).is_err());
1127 }
1128}