Skip to main content

fdars_core/
function_on_scalar.rs

1//! Function-on-scalar regression and functional ANOVA.
2//!
3//! Predicts a **functional response** from scalar/categorical predictors:
4//! ```text
5//! X_i(t) = μ(t) + Σⱼ βⱼ(t) · z_ij + ε_i(t)
6//! ```
7//!
8//! # Methods
9//!
10//! - [`fosr`]: Penalized function-on-scalar regression (pointwise OLS + smoothing)
11//! - [`fanova`]: Functional ANOVA with permutation-based global test
12//! - [`predict_fosr`]: Predict new curves from fitted model
13
14use crate::matrix::FdMatrix;
15use crate::regression::fdata_to_pc_1d;
16
17// ---------------------------------------------------------------------------
18// Linear algebra helpers (self-contained)
19// ---------------------------------------------------------------------------
20
21/// Cholesky factorization: A = LL'. Returns L (p×p flat row-major) or None if singular.
22fn cholesky_factor(a: &[f64], p: usize) -> Option<Vec<f64>> {
23    let mut l = vec![0.0; p * p];
24    for j in 0..p {
25        let mut diag = a[j * p + j];
26        for k in 0..j {
27            diag -= l[j * p + k] * l[j * p + k];
28        }
29        if diag <= 1e-12 {
30            return None;
31        }
32        l[j * p + j] = diag.sqrt();
33        for i in (j + 1)..p {
34            let mut s = a[i * p + j];
35            for k in 0..j {
36                s -= l[i * p + k] * l[j * p + k];
37            }
38            l[i * p + j] = s / l[j * p + j];
39        }
40    }
41    Some(l)
42}
43
44/// Solve Lz = b (forward) then L'x = z (back).
45fn cholesky_forward_back(l: &[f64], b: &[f64], p: usize) -> Vec<f64> {
46    let mut z = b.to_vec();
47    for j in 0..p {
48        for k in 0..j {
49            z[j] -= l[j * p + k] * z[k];
50        }
51        z[j] /= l[j * p + j];
52    }
53    for j in (0..p).rev() {
54        for k in (j + 1)..p {
55            z[j] -= l[k * p + j] * z[k];
56        }
57        z[j] /= l[j * p + j];
58    }
59    z
60}
61
62/// Compute X'X (symmetric, p×p stored flat row-major).
63fn compute_xtx(x: &FdMatrix) -> Vec<f64> {
64    let (n, p) = x.shape();
65    let mut xtx = vec![0.0; p * p];
66    for k in 0..p {
67        for j in k..p {
68            let mut s = 0.0;
69            for i in 0..n {
70                s += x[(i, k)] * x[(i, j)];
71            }
72            xtx[k * p + j] = s;
73            xtx[j * p + k] = s;
74        }
75    }
76    xtx
77}
78
79// ---------------------------------------------------------------------------
80// Result types
81// ---------------------------------------------------------------------------
82
83/// Result of function-on-scalar regression.
84pub struct FosrResult {
85    /// Intercept function μ(t) (length m)
86    pub intercept: Vec<f64>,
87    /// Coefficient functions β_j(t), one per predictor (p × m matrix, row j = βⱼ(t))
88    pub beta: FdMatrix,
89    /// Fitted functional values (n × m matrix)
90    pub fitted: FdMatrix,
91    /// Residual functions (n × m matrix)
92    pub residuals: FdMatrix,
93    /// Pointwise R² across the domain (length m)
94    pub r_squared_t: Vec<f64>,
95    /// Global R² (integrated)
96    pub r_squared: f64,
97    /// Pointwise standard errors for each βⱼ(t) (p × m matrix)
98    pub beta_se: FdMatrix,
99    /// Smoothing parameter λ used
100    pub lambda: f64,
101    /// GCV value
102    pub gcv: f64,
103}
104
105/// Result of FPC-based function-on-scalar regression.
106pub struct FosrFpcResult {
107    /// Intercept function μ(t) (length m)
108    pub intercept: Vec<f64>,
109    /// Coefficient functions β_j(t), one per predictor (p × m matrix, row j = βⱼ(t))
110    pub beta: FdMatrix,
111    /// Fitted functional values (n × m matrix)
112    pub fitted: FdMatrix,
113    /// Residual functions (n × m matrix)
114    pub residuals: FdMatrix,
115    /// Pointwise R² across the domain (length m)
116    pub r_squared_t: Vec<f64>,
117    /// Global R² (integrated)
118    pub r_squared: f64,
119    /// FPC-space regression coefficients gamma\[j\]\[k\] (one `Vec<f64>` per predictor)
120    pub beta_scores: Vec<Vec<f64>>,
121    /// Number of FPC components used
122    pub ncomp: usize,
123}
124
125/// Result of functional ANOVA.
126pub struct FanovaResult {
127    /// Group mean functions (k × m matrix, row g = mean curve of group g)
128    pub group_means: FdMatrix,
129    /// Overall mean function (length m)
130    pub overall_mean: Vec<f64>,
131    /// Pointwise F-statistic across the domain (length m)
132    pub f_statistic_t: Vec<f64>,
133    /// Global test statistic (integrated F)
134    pub global_statistic: f64,
135    /// P-value from permutation test
136    pub p_value: f64,
137    /// Number of permutations performed
138    pub n_perm: usize,
139    /// Number of groups
140    pub n_groups: usize,
141    /// Group labels (sorted unique values)
142    pub group_labels: Vec<usize>,
143}
144
145// ---------------------------------------------------------------------------
146// Shared helpers
147// ---------------------------------------------------------------------------
148
149/// Build second-order difference penalty matrix D'D (p×p, flat row-major).
150fn penalty_matrix(m: usize) -> Vec<f64> {
151    if m < 3 {
152        return vec![0.0; m * m];
153    }
154    // D is (m-2)×m second-difference operator
155    // D'D is m×m symmetric banded matrix
156    let mut dtd = vec![0.0; m * m];
157    for i in 0..m - 2 {
158        // D[i,:] = [0..0, 1, -2, 1, 0..0] at positions i, i+1, i+2
159        let coeffs = [(i, 1.0), (i + 1, -2.0), (i + 2, 1.0)];
160        for &(r, cr) in &coeffs {
161            for &(c, cc) in &coeffs {
162                dtd[r * m + c] += cr * cc;
163            }
164        }
165    }
166    dtd
167}
168
169/// Solve (A + λP)x = b for each column of B (pointwise regression at each t).
170/// A is X'X (p×p), B is X'Y (p×m), P is penalty matrix (p×p).
171/// Returns coefficient matrix (p×m).
172fn penalized_solve(xtx: &[f64], xty: &FdMatrix, penalty: &[f64], lambda: f64) -> Option<FdMatrix> {
173    let p = xty.nrows();
174    let m = xty.ncols();
175
176    // Build (X'X + λP)
177    let mut a = vec![0.0; p * p];
178    for i in 0..p * p {
179        a[i] = xtx[i] + lambda * penalty[i];
180    }
181
182    // Cholesky factor
183    let l = cholesky_factor(&a, p)?;
184
185    // Solve for each grid point
186    let mut beta = FdMatrix::zeros(p, m);
187    for t in 0..m {
188        let b: Vec<f64> = (0..p).map(|j| xty[(j, t)]).collect();
189        let x = cholesky_forward_back(&l, &b, p);
190        for j in 0..p {
191            beta[(j, t)] = x[j];
192        }
193    }
194    Some(beta)
195}
196
197/// Compute pointwise R² at each grid point.
198fn pointwise_r_squared(data: &FdMatrix, fitted: &FdMatrix) -> Vec<f64> {
199    let (n, m) = data.shape();
200    (0..m)
201        .map(|t| {
202            let mean_t: f64 = (0..n).map(|i| data[(i, t)]).sum::<f64>() / n as f64;
203            let ss_tot: f64 = (0..n).map(|i| (data[(i, t)] - mean_t).powi(2)).sum();
204            let ss_res: f64 = (0..n)
205                .map(|i| (data[(i, t)] - fitted[(i, t)]).powi(2))
206                .sum();
207            if ss_tot > 1e-15 {
208                1.0 - ss_res / ss_tot
209            } else {
210                0.0
211            }
212        })
213        .collect()
214}
215
216/// GCV for penalized regression: (1/nm) Σ_{i,t} (r_{it} / (1 - tr(H)/n))²
217fn compute_fosr_gcv(residuals: &FdMatrix, trace_h: f64) -> f64 {
218    let (n, m) = residuals.shape();
219    let denom = (1.0 - trace_h / n as f64).max(1e-10);
220    let ss_res: f64 = (0..n)
221        .flat_map(|i| (0..m).map(move |t| residuals[(i, t)].powi(2)))
222        .sum();
223    ss_res / (n as f64 * m as f64 * denom * denom)
224}
225
226// ---------------------------------------------------------------------------
227// fosr: Function-on-scalar regression
228// ---------------------------------------------------------------------------
229
230/// Penalized function-on-scalar regression.
231///
232/// Fits pointwise OLS at each grid point t, then smooths the coefficient
233/// functions β_j(t) using a second-order roughness penalty.
234///
235/// # Arguments
236/// * `data` - Functional response matrix (n × m)
237/// * `predictors` - Scalar predictor matrix (n × p)
238/// * `lambda` - Smoothing parameter (0 for no smoothing, negative for GCV selection)
239///
240/// # Returns
241/// [`FosrResult`] with coefficient functions, fitted values, and diagnostics
242/// Build design matrix with intercept: \[1, z_1, ..., z_p\].
243fn build_fosr_design(predictors: &FdMatrix, n: usize) -> FdMatrix {
244    let p = predictors.ncols();
245    let p_total = p + 1;
246    let mut design = FdMatrix::zeros(n, p_total);
247    for i in 0..n {
248        design[(i, 0)] = 1.0;
249        for j in 0..p {
250            design[(i, 1 + j)] = predictors[(i, j)];
251        }
252    }
253    design
254}
255
256/// Compute X'Y (p_total × m).
257fn compute_xty_matrix(design: &FdMatrix, data: &FdMatrix) -> FdMatrix {
258    let (n, m) = data.shape();
259    let p_total = design.ncols();
260    let mut xty = FdMatrix::zeros(p_total, m);
261    for j in 0..p_total {
262        for t in 0..m {
263            let mut s = 0.0;
264            for i in 0..n {
265                s += design[(i, j)] * data[(i, t)];
266            }
267            xty[(j, t)] = s;
268        }
269    }
270    xty
271}
272
273/// Extract rows 1..p+1 from a (p+1)×m matrix, dropping the intercept row.
274fn drop_intercept_rows(full: &FdMatrix, p: usize, m: usize) -> FdMatrix {
275    let mut out = FdMatrix::zeros(p, m);
276    for j in 0..p {
277        for t in 0..m {
278            out[(j, t)] = full[(j + 1, t)];
279        }
280    }
281    out
282}
283
284pub fn fosr(data: &FdMatrix, predictors: &FdMatrix, lambda: f64) -> Option<FosrResult> {
285    let (n, m) = data.shape();
286    let p = predictors.ncols();
287    if n < p + 2 || m == 0 || predictors.nrows() != n {
288        return None;
289    }
290
291    let design = build_fosr_design(predictors, n);
292    let p_total = design.ncols();
293    let xtx = compute_xtx(&design);
294    let xty = compute_xty_matrix(&design, data);
295    let penalty = penalty_matrix(p_total);
296
297    let lambda = if lambda < 0.0 {
298        select_lambda_gcv(&xtx, &xty, &penalty, data, &design)
299    } else {
300        lambda
301    };
302
303    let beta = penalized_solve(&xtx, &xty, &penalty, lambda)?;
304    let (fitted, residuals) = compute_fosr_fitted(&design, &beta, data);
305
306    let r_squared_t = pointwise_r_squared(data, &fitted);
307    let r_squared = r_squared_t.iter().sum::<f64>() / m as f64;
308    let beta_se = compute_beta_se(&xtx, &penalty, lambda, &residuals, p_total, n);
309    let trace_h = compute_trace_hat(&xtx, &penalty, lambda, p_total, n);
310    let gcv = compute_fosr_gcv(&residuals, trace_h);
311
312    let intercept: Vec<f64> = (0..m).map(|t| beta[(0, t)]).collect();
313
314    Some(FosrResult {
315        intercept,
316        beta: drop_intercept_rows(&beta, p, m),
317        fitted,
318        residuals,
319        r_squared_t,
320        r_squared,
321        beta_se: drop_intercept_rows(&beta_se, p, m),
322        lambda,
323        gcv,
324    })
325}
326
327/// Compute fitted values Ŷ = X β and residuals.
328fn compute_fosr_fitted(
329    design: &FdMatrix,
330    beta: &FdMatrix,
331    data: &FdMatrix,
332) -> (FdMatrix, FdMatrix) {
333    let (n, m) = data.shape();
334    let p_total = design.ncols();
335    let mut fitted = FdMatrix::zeros(n, m);
336    let mut residuals = FdMatrix::zeros(n, m);
337    for i in 0..n {
338        for t in 0..m {
339            let mut yhat = 0.0;
340            for j in 0..p_total {
341                yhat += design[(i, j)] * beta[(j, t)];
342            }
343            fitted[(i, t)] = yhat;
344            residuals[(i, t)] = data[(i, t)] - yhat;
345        }
346    }
347    (fitted, residuals)
348}
349
350/// Select smoothing parameter λ via GCV on a grid.
351fn select_lambda_gcv(
352    xtx: &[f64],
353    xty: &FdMatrix,
354    penalty: &[f64],
355    data: &FdMatrix,
356    design: &FdMatrix,
357) -> f64 {
358    let lambdas = [0.0, 1e-6, 1e-4, 1e-2, 0.1, 1.0, 10.0, 100.0, 1000.0];
359    let p_total = design.ncols();
360    let n = design.nrows();
361
362    let mut best_lambda = 0.0;
363    let mut best_gcv = f64::INFINITY;
364
365    for &lam in &lambdas {
366        let beta = match penalized_solve(xtx, xty, penalty, lam) {
367            Some(b) => b,
368            None => continue,
369        };
370        let (_, residuals) = compute_fosr_fitted(design, &beta, data);
371        let trace_h = compute_trace_hat(xtx, penalty, lam, p_total, n);
372        let gcv = compute_fosr_gcv(&residuals, trace_h);
373        if gcv < best_gcv {
374            best_gcv = gcv;
375            best_lambda = lam;
376        }
377    }
378    best_lambda
379}
380
381/// Compute trace of hat matrix: tr(H) = tr(X (X'X + λP)^{-1} X') = Σ_j h_jj.
382fn compute_trace_hat(xtx: &[f64], penalty: &[f64], lambda: f64, p: usize, n: usize) -> f64 {
383    let mut a = vec![0.0; p * p];
384    for i in 0..p * p {
385        a[i] = xtx[i] + lambda * penalty[i];
386    }
387    // tr(H) = tr(X A^{-1} X') = Σ_{j=0..p} a^{-1}_{jj} * xtx_{jj}
388    // More precisely: tr(X (X'X+λP)^{-1} X') = tr((X'X+λP)^{-1} X'X)
389    let l = match cholesky_factor(&a, p) {
390        Some(l) => l,
391        None => return p as f64, // fallback
392    };
393
394    // Compute A^{-1} X'X via solving A Z = X'X column by column, then trace
395    let mut trace = 0.0;
396    for j in 0..p {
397        let col: Vec<f64> = (0..p).map(|i| xtx[i * p + j]).collect();
398        let z = cholesky_forward_back(&l, &col, p);
399        trace += z[j]; // diagonal element of A^{-1} X'X
400    }
401    trace.min(n as f64)
402}
403
404/// Compute pointwise standard errors for β(t).
405fn compute_beta_se(
406    xtx: &[f64],
407    penalty: &[f64],
408    lambda: f64,
409    residuals: &FdMatrix,
410    p: usize,
411    n: usize,
412) -> FdMatrix {
413    let m = residuals.ncols();
414    let mut a = vec![0.0; p * p];
415    for i in 0..p * p {
416        a[i] = xtx[i] + lambda * penalty[i];
417    }
418    let l = match cholesky_factor(&a, p) {
419        Some(l) => l,
420        None => return FdMatrix::zeros(p, m),
421    };
422
423    // Diagonal of A^{-1}
424    let a_inv_diag: Vec<f64> = (0..p)
425        .map(|j| {
426            let mut ej = vec![0.0; p];
427            ej[j] = 1.0;
428            let v = cholesky_forward_back(&l, &ej, p);
429            v[j]
430        })
431        .collect();
432
433    let df = (n - p).max(1) as f64;
434    let mut se = FdMatrix::zeros(p, m);
435    for t in 0..m {
436        let sigma2_t: f64 = (0..n).map(|i| residuals[(i, t)].powi(2)).sum::<f64>() / df;
437        for j in 0..p {
438            se[(j, t)] = (sigma2_t * a_inv_diag[j]).max(0.0).sqrt();
439        }
440    }
441    se
442}
443
444// ---------------------------------------------------------------------------
445// fosr_fpc: FPC-based function-on-scalar regression (matches R's fda.usc approach)
446// ---------------------------------------------------------------------------
447
448/// OLS regression of each FPC score on the design matrix.
449///
450/// Returns gamma_all\[comp\]\[coef\] = (X'X)^{-1} X' scores\[:,comp\].
451fn regress_scores_on_design(
452    design: &FdMatrix,
453    scores: &FdMatrix,
454    n: usize,
455    k: usize,
456    p_total: usize,
457) -> Option<Vec<Vec<f64>>> {
458    let xtx = compute_xtx(design);
459    let l = cholesky_factor(&xtx, p_total)?;
460
461    let gamma_all: Vec<Vec<f64>> = (0..k)
462        .map(|comp| {
463            let mut xts = vec![0.0; p_total];
464            for j in 0..p_total {
465                for i in 0..n {
466                    xts[j] += design[(i, j)] * scores[(i, comp)];
467                }
468            }
469            cholesky_forward_back(&l, &xts, p_total)
470        })
471        .collect();
472    Some(gamma_all)
473}
474
475/// Reconstruct β_j(t) = Σ_k gamma\[comp\]\[1+j\] · φ_k(t) for each predictor j.
476fn reconstruct_beta_fpc(
477    gamma_all: &[Vec<f64>],
478    rotation: &FdMatrix,
479    p: usize,
480    k: usize,
481    m: usize,
482) -> FdMatrix {
483    let mut beta = FdMatrix::zeros(p, m);
484    for j in 0..p {
485        for t in 0..m {
486            let mut val = 0.0;
487            for comp in 0..k {
488                val += gamma_all[comp][1 + j] * rotation[(t, comp)];
489            }
490            beta[(j, t)] = val;
491        }
492    }
493    beta
494}
495
496/// Compute intercept function: μ(t) + Σ_k γ_intercept\[k\] · φ_k(t).
497fn compute_intercept_fpc(
498    mean: &[f64],
499    gamma_all: &[Vec<f64>],
500    rotation: &FdMatrix,
501    k: usize,
502    m: usize,
503) -> Vec<f64> {
504    let mut intercept = mean.to_vec();
505    for t in 0..m {
506        for comp in 0..k {
507            intercept[t] += gamma_all[comp][0] * rotation[(t, comp)];
508        }
509    }
510    intercept
511}
512
513/// Extract L²-normalized beta_scores from regression coefficients.
514fn extract_beta_scores(gamma_all: &[Vec<f64>], p: usize, k: usize, m: usize) -> Vec<Vec<f64>> {
515    let h = if m > 1 { 1.0 / (m - 1) as f64 } else { 1.0 };
516    let score_scale = h.sqrt();
517    (0..p)
518        .map(|j| {
519            (0..k)
520                .map(|comp| gamma_all[comp][1 + j] * score_scale)
521                .collect()
522        })
523        .collect()
524}
525
526/// FPC-based function-on-scalar regression.
527///
528/// Reduces the functional response to FPC scores, regresses each score on the
529/// scalar predictors via OLS, then reconstructs β(t) from the loadings.
530/// This matches R's `fdata2pc` + `lm(scores ~ x)` approach.
531///
532/// # Arguments
533/// * `data` - Functional response matrix (n × m)
534/// * `predictors` - Scalar predictor matrix (n × p)
535/// * `ncomp` - Number of FPC components to use
536pub fn fosr_fpc(data: &FdMatrix, predictors: &FdMatrix, ncomp: usize) -> Option<FosrFpcResult> {
537    let (n, m) = data.shape();
538    let p = predictors.ncols();
539    if n < p + 2 || m == 0 || predictors.nrows() != n || ncomp == 0 {
540        return None;
541    }
542
543    let fpca = fdata_to_pc_1d(data, ncomp)?;
544    let k = fpca.scores.ncols();
545    let p_total = p + 1;
546    let design = build_fosr_design(predictors, n);
547
548    let gamma_all = regress_scores_on_design(&design, &fpca.scores, n, k, p_total)?;
549    let beta = reconstruct_beta_fpc(&gamma_all, &fpca.rotation, p, k, m);
550    let intercept = compute_intercept_fpc(&fpca.mean, &gamma_all, &fpca.rotation, k, m);
551
552    let (fitted, residuals) = compute_fosr_fpc_fitted(data, &intercept, &beta, predictors);
553    let r_squared_t = pointwise_r_squared(data, &fitted);
554    let r_squared = r_squared_t.iter().sum::<f64>() / m as f64;
555    let beta_scores = extract_beta_scores(&gamma_all, p, k, m);
556
557    Some(FosrFpcResult {
558        intercept,
559        beta,
560        fitted,
561        residuals,
562        r_squared_t,
563        r_squared,
564        beta_scores,
565        ncomp: k,
566    })
567}
568
569/// Compute fitted values and residuals for FPC-based FOSR.
570fn compute_fosr_fpc_fitted(
571    data: &FdMatrix,
572    intercept: &[f64],
573    beta: &FdMatrix,
574    predictors: &FdMatrix,
575) -> (FdMatrix, FdMatrix) {
576    let (n, m) = data.shape();
577    let p = predictors.ncols();
578    let mut fitted = FdMatrix::zeros(n, m);
579    let mut residuals = FdMatrix::zeros(n, m);
580    for i in 0..n {
581        for t in 0..m {
582            let mut yhat = intercept[t];
583            for j in 0..p {
584                yhat += predictors[(i, j)] * beta[(j, t)];
585            }
586            fitted[(i, t)] = yhat;
587            residuals[(i, t)] = data[(i, t)] - yhat;
588        }
589    }
590    (fitted, residuals)
591}
592
593/// Predict new functional responses from a fitted FOSR model.
594///
595/// # Arguments
596/// * `result` - Fitted [`FosrResult`]
597/// * `new_predictors` - New scalar predictors (n_new × p)
598pub fn predict_fosr(result: &FosrResult, new_predictors: &FdMatrix) -> FdMatrix {
599    let n_new = new_predictors.nrows();
600    let m = result.intercept.len();
601    let p = result.beta.nrows();
602
603    let mut predicted = FdMatrix::zeros(n_new, m);
604    for i in 0..n_new {
605        for t in 0..m {
606            let mut yhat = result.intercept[t];
607            for j in 0..p {
608                yhat += new_predictors[(i, j)] * result.beta[(j, t)];
609            }
610            predicted[(i, t)] = yhat;
611        }
612    }
613    predicted
614}
615
616// ---------------------------------------------------------------------------
617// fanova: Functional ANOVA
618// ---------------------------------------------------------------------------
619
620/// Compute group means and overall mean.
621fn compute_group_means(
622    data: &FdMatrix,
623    groups: &[usize],
624    labels: &[usize],
625) -> (FdMatrix, Vec<f64>) {
626    let (n, m) = data.shape();
627    let k = labels.len();
628    let mut group_means = FdMatrix::zeros(k, m);
629    let mut counts = vec![0usize; k];
630
631    for i in 0..n {
632        let g = labels.iter().position(|&l| l == groups[i]).unwrap_or(0);
633        counts[g] += 1;
634        for t in 0..m {
635            group_means[(g, t)] += data[(i, t)];
636        }
637    }
638    for g in 0..k {
639        if counts[g] > 0 {
640            for t in 0..m {
641                group_means[(g, t)] /= counts[g] as f64;
642            }
643        }
644    }
645
646    let overall_mean: Vec<f64> = (0..m)
647        .map(|t| (0..n).map(|i| data[(i, t)]).sum::<f64>() / n as f64)
648        .collect();
649
650    (group_means, overall_mean)
651}
652
653/// Compute pointwise F-statistic.
654fn pointwise_f_statistic(
655    data: &FdMatrix,
656    groups: &[usize],
657    labels: &[usize],
658    group_means: &FdMatrix,
659    overall_mean: &[f64],
660) -> Vec<f64> {
661    let (n, m) = data.shape();
662    let k = labels.len();
663    let mut counts = vec![0usize; k];
664    for &g in groups {
665        let idx = labels.iter().position(|&l| l == g).unwrap_or(0);
666        counts[idx] += 1;
667    }
668
669    (0..m)
670        .map(|t| {
671            let ss_between: f64 = (0..k)
672                .map(|g| counts[g] as f64 * (group_means[(g, t)] - overall_mean[t]).powi(2))
673                .sum();
674            let ss_within: f64 = (0..n)
675                .map(|i| {
676                    let g = labels.iter().position(|&l| l == groups[i]).unwrap_or(0);
677                    (data[(i, t)] - group_means[(g, t)]).powi(2)
678                })
679                .sum();
680            let ms_between = ss_between / (k as f64 - 1.0).max(1.0);
681            let ms_within = ss_within / (n as f64 - k as f64).max(1.0);
682            if ms_within > 1e-15 {
683                ms_between / ms_within
684            } else {
685                0.0
686            }
687        })
688        .collect()
689}
690
691/// Compute global test statistic (integrated F).
692fn global_f_statistic(f_t: &[f64]) -> f64 {
693    f_t.iter().sum::<f64>() / f_t.len() as f64
694}
695
696/// Functional ANOVA: test whether groups have different mean curves.
697///
698/// Uses a permutation-based global test with the integrated F-statistic.
699///
700/// # Arguments
701/// * `data` - Functional response matrix (n × m)
702/// * `groups` - Group labels for each observation (length n, integer-coded)
703/// * `n_perm` - Number of permutations for the global test
704///
705/// # Returns
706/// [`FanovaResult`] with group means, F-statistics, and permutation p-value
707pub fn fanova(data: &FdMatrix, groups: &[usize], n_perm: usize) -> Option<FanovaResult> {
708    let (n, m) = data.shape();
709    if n < 3 || m == 0 || groups.len() != n {
710        return None;
711    }
712
713    let mut labels: Vec<usize> = groups.to_vec();
714    labels.sort();
715    labels.dedup();
716    let n_groups = labels.len();
717    if n_groups < 2 {
718        return None;
719    }
720
721    let (group_means, overall_mean) = compute_group_means(data, groups, &labels);
722    let f_t = pointwise_f_statistic(data, groups, &labels, &group_means, &overall_mean);
723    let observed_stat = global_f_statistic(&f_t);
724
725    // Permutation test
726    let n_perm = n_perm.max(1);
727    let mut n_ge = 0usize;
728    let mut perm_groups = groups.to_vec();
729
730    // Simple LCG for reproducibility without requiring rand
731    let mut rng_state: u64 = 42;
732    for _ in 0..n_perm {
733        // Fisher-Yates shuffle with LCG
734        for i in (1..n).rev() {
735            rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
736            let j = (rng_state >> 33) as usize % (i + 1);
737            perm_groups.swap(i, j);
738        }
739
740        let (perm_means, perm_overall) = compute_group_means(data, &perm_groups, &labels);
741        let perm_f = pointwise_f_statistic(data, &perm_groups, &labels, &perm_means, &perm_overall);
742        let perm_stat = global_f_statistic(&perm_f);
743        if perm_stat >= observed_stat {
744            n_ge += 1;
745        }
746    }
747
748    let p_value = (n_ge as f64 + 1.0) / (n_perm as f64 + 1.0);
749
750    Some(FanovaResult {
751        group_means,
752        overall_mean,
753        f_statistic_t: f_t,
754        global_statistic: observed_stat,
755        p_value,
756        n_perm,
757        n_groups,
758        group_labels: labels,
759    })
760}
761
762// ---------------------------------------------------------------------------
763// Tests
764// ---------------------------------------------------------------------------
765
766#[cfg(test)]
767mod tests {
768    use super::*;
769    use std::f64::consts::PI;
770
771    fn uniform_grid(m: usize) -> Vec<f64> {
772        (0..m).map(|j| j as f64 / (m - 1) as f64).collect()
773    }
774
775    fn generate_fosr_data(n: usize, m: usize) -> (FdMatrix, FdMatrix) {
776        let t = uniform_grid(m);
777        let mut y = FdMatrix::zeros(n, m);
778        let mut z = FdMatrix::zeros(n, 2);
779
780        for i in 0..n {
781            let age = (i as f64) / (n as f64);
782            let group = if i % 2 == 0 { 1.0 } else { 0.0 };
783            z[(i, 0)] = age;
784            z[(i, 1)] = group;
785            for j in 0..m {
786                // True model: μ(t) + age * β₁(t) + group * β₂(t)
787                let mu = (2.0 * PI * t[j]).sin();
788                let beta1 = t[j]; // Linear coefficient for age
789                let beta2 = (4.0 * PI * t[j]).cos(); // Oscillating for group
790                y[(i, j)] = mu
791                    + age * beta1
792                    + group * beta2
793                    + 0.05 * ((i * 13 + j * 7) % 100) as f64 / 100.0;
794            }
795        }
796        (y, z)
797    }
798
799    // ----- FOSR tests -----
800
801    #[test]
802    fn test_fosr_basic() {
803        let (y, z) = generate_fosr_data(30, 50);
804        let result = fosr(&y, &z, 0.0);
805        assert!(result.is_some());
806        let fit = result.unwrap();
807        assert_eq!(fit.intercept.len(), 50);
808        assert_eq!(fit.beta.shape(), (2, 50));
809        assert_eq!(fit.fitted.shape(), (30, 50));
810        assert_eq!(fit.residuals.shape(), (30, 50));
811        assert!(fit.r_squared >= 0.0);
812    }
813
814    #[test]
815    fn test_fosr_with_penalty() {
816        let (y, z) = generate_fosr_data(30, 50);
817        let fit0 = fosr(&y, &z, 0.0).unwrap();
818        let fit1 = fosr(&y, &z, 1.0).unwrap();
819        // Both should produce valid results
820        assert_eq!(fit0.beta.shape(), (2, 50));
821        assert_eq!(fit1.beta.shape(), (2, 50));
822    }
823
824    #[test]
825    fn test_fosr_auto_lambda() {
826        let (y, z) = generate_fosr_data(30, 50);
827        let fit = fosr(&y, &z, -1.0).unwrap();
828        assert!(fit.lambda >= 0.0);
829    }
830
831    #[test]
832    fn test_fosr_fitted_plus_residuals_equals_y() {
833        let (y, z) = generate_fosr_data(30, 50);
834        let fit = fosr(&y, &z, 0.0).unwrap();
835        for i in 0..30 {
836            for t in 0..50 {
837                let reconstructed = fit.fitted[(i, t)] + fit.residuals[(i, t)];
838                assert!(
839                    (reconstructed - y[(i, t)]).abs() < 1e-10,
840                    "ŷ + r should equal y at ({}, {})",
841                    i,
842                    t
843                );
844            }
845        }
846    }
847
848    #[test]
849    fn test_fosr_pointwise_r_squared_valid() {
850        let (y, z) = generate_fosr_data(30, 50);
851        let fit = fosr(&y, &z, 0.0).unwrap();
852        for &r2 in &fit.r_squared_t {
853            assert!(
854                (-0.01..=1.0 + 1e-10).contains(&r2),
855                "R²(t) out of range: {}",
856                r2
857            );
858        }
859    }
860
861    #[test]
862    fn test_fosr_se_positive() {
863        let (y, z) = generate_fosr_data(30, 50);
864        let fit = fosr(&y, &z, 0.0).unwrap();
865        for j in 0..2 {
866            for t in 0..50 {
867                assert!(
868                    fit.beta_se[(j, t)] >= 0.0 && fit.beta_se[(j, t)].is_finite(),
869                    "SE should be non-negative finite"
870                );
871            }
872        }
873    }
874
875    #[test]
876    fn test_fosr_invalid_input() {
877        let y = FdMatrix::zeros(2, 50);
878        let z = FdMatrix::zeros(2, 1);
879        assert!(fosr(&y, &z, 0.0).is_none());
880    }
881
882    // ----- predict_fosr tests -----
883
884    #[test]
885    fn test_predict_fosr_on_training_data() {
886        let (y, z) = generate_fosr_data(30, 50);
887        let fit = fosr(&y, &z, 0.0).unwrap();
888        let preds = predict_fosr(&fit, &z);
889        assert_eq!(preds.shape(), (30, 50));
890        for i in 0..30 {
891            for t in 0..50 {
892                assert!(
893                    (preds[(i, t)] - fit.fitted[(i, t)]).abs() < 1e-8,
894                    "Prediction on training data should match fitted"
895                );
896            }
897        }
898    }
899
900    // ----- FANOVA tests -----
901
902    #[test]
903    fn test_fanova_two_groups() {
904        let n = 40;
905        let m = 50;
906        let t = uniform_grid(m);
907
908        let mut data = FdMatrix::zeros(n, m);
909        let mut groups = vec![0usize; n];
910        for i in 0..n {
911            groups[i] = if i < n / 2 { 0 } else { 1 };
912            for j in 0..m {
913                let base = (2.0 * PI * t[j]).sin();
914                let effect = if groups[i] == 1 { 0.5 * t[j] } else { 0.0 };
915                data[(i, j)] = base + effect + 0.01 * (i as f64 * 0.1).sin();
916            }
917        }
918
919        let result = fanova(&data, &groups, 200);
920        assert!(result.is_some());
921        let res = result.unwrap();
922        assert_eq!(res.n_groups, 2);
923        assert_eq!(res.group_means.shape(), (2, m));
924        assert_eq!(res.f_statistic_t.len(), m);
925        assert!(res.p_value >= 0.0 && res.p_value <= 1.0);
926        // With a real group effect, p should be small
927        assert!(
928            res.p_value < 0.1,
929            "Should detect group effect, got p={}",
930            res.p_value
931        );
932    }
933
934    #[test]
935    fn test_fanova_no_effect() {
936        let n = 40;
937        let m = 50;
938        let t = uniform_grid(m);
939
940        let mut data = FdMatrix::zeros(n, m);
941        let mut groups = vec![0usize; n];
942        for i in 0..n {
943            groups[i] = if i < n / 2 { 0 } else { 1 };
944            for j in 0..m {
945                // Same distribution for both groups
946                data[(i, j)] =
947                    (2.0 * PI * t[j]).sin() + 0.1 * ((i * 7 + j * 3) % 100) as f64 / 100.0;
948            }
949        }
950
951        let result = fanova(&data, &groups, 200);
952        assert!(result.is_some());
953        let res = result.unwrap();
954        // Without group effect, p should be large
955        assert!(
956            res.p_value > 0.05,
957            "Should not detect effect, got p={}",
958            res.p_value
959        );
960    }
961
962    #[test]
963    fn test_fanova_three_groups() {
964        let n = 30;
965        let m = 50;
966        let t = uniform_grid(m);
967
968        let mut data = FdMatrix::zeros(n, m);
969        let mut groups = vec![0usize; n];
970        for i in 0..n {
971            groups[i] = i % 3;
972            for j in 0..m {
973                let effect = match groups[i] {
974                    0 => 0.0,
975                    1 => 0.5 * t[j],
976                    _ => -0.3 * (2.0 * PI * t[j]).cos(),
977                };
978                data[(i, j)] = (2.0 * PI * t[j]).sin() + effect + 0.01 * (i as f64 * 0.1).sin();
979            }
980        }
981
982        let result = fanova(&data, &groups, 200);
983        assert!(result.is_some());
984        let res = result.unwrap();
985        assert_eq!(res.n_groups, 3);
986    }
987
988    #[test]
989    fn test_fanova_invalid_input() {
990        let data = FdMatrix::zeros(10, 50);
991        let groups = vec![0; 10]; // Only one group
992        assert!(fanova(&data, &groups, 100).is_none());
993
994        let groups = vec![0; 5]; // Wrong length
995        assert!(fanova(&data, &groups, 100).is_none());
996    }
997}