Skip to main content

fdars_core/
function_on_scalar.rs

1//! Function-on-scalar regression and functional ANOVA.
2//!
3//! Predicts a **functional response** from scalar/categorical predictors:
4//! ```text
5//! X_i(t) = μ(t) + Σⱼ βⱼ(t) · z_ij + ε_i(t)
6//! ```
7//!
8//! # Methods
9//!
10//! - [`fosr`]: Penalized function-on-scalar regression (pointwise OLS + smoothing)
11//! - [`fanova`]: Functional ANOVA with permutation-based global test
12//! - [`predict_fosr`]: Predict new curves from fitted model
13
14use crate::error::FdarError;
15use crate::iter_maybe_parallel;
16use crate::matrix::FdMatrix;
17use crate::regression::fdata_to_pc_1d;
18#[cfg(feature = "parallel")]
19use rayon::iter::ParallelIterator;
20
21// ---------------------------------------------------------------------------
22// Linear algebra helpers (self-contained)
23// ---------------------------------------------------------------------------
24
25/// Cholesky factorization: A = LL'. Returns L (p×p flat row-major) or None if singular.
26fn cholesky_factor(a: &[f64], p: usize) -> Option<Vec<f64>> {
27    let mut l = vec![0.0; p * p];
28    for j in 0..p {
29        let mut diag = a[j * p + j];
30        for k in 0..j {
31            diag -= l[j * p + k] * l[j * p + k];
32        }
33        if diag <= 1e-12 {
34            return None;
35        }
36        l[j * p + j] = diag.sqrt();
37        for i in (j + 1)..p {
38            let mut s = a[i * p + j];
39            for k in 0..j {
40                s -= l[i * p + k] * l[j * p + k];
41            }
42            l[i * p + j] = s / l[j * p + j];
43        }
44    }
45    Some(l)
46}
47
48/// Solve Lz = b (forward) then L'x = z (back).
49fn cholesky_forward_back(l: &[f64], b: &[f64], p: usize) -> Vec<f64> {
50    let mut z = b.to_vec();
51    for j in 0..p {
52        for k in 0..j {
53            z[j] -= l[j * p + k] * z[k];
54        }
55        z[j] /= l[j * p + j];
56    }
57    for j in (0..p).rev() {
58        for k in (j + 1)..p {
59            z[j] -= l[k * p + j] * z[k];
60        }
61        z[j] /= l[j * p + j];
62    }
63    z
64}
65
66/// Compute X'X (symmetric, p×p stored flat row-major).
67pub(crate) fn compute_xtx(x: &FdMatrix) -> Vec<f64> {
68    let (n, p) = x.shape();
69    let mut xtx = vec![0.0; p * p];
70    for k in 0..p {
71        for j in k..p {
72            let mut s = 0.0;
73            for i in 0..n {
74                s += x[(i, k)] * x[(i, j)];
75            }
76            xtx[k * p + j] = s;
77            xtx[j * p + k] = s;
78        }
79    }
80    xtx
81}
82
83// ---------------------------------------------------------------------------
84// Result types
85// ---------------------------------------------------------------------------
86
87/// Result of function-on-scalar regression.
88#[derive(Debug, Clone, PartialEq)]
89pub struct FosrResult {
90    /// Intercept function μ(t) (length m)
91    pub intercept: Vec<f64>,
92    /// Coefficient functions β_j(t), one per predictor (p × m matrix, row j = βⱼ(t))
93    pub beta: FdMatrix,
94    /// Fitted functional values (n × m matrix)
95    pub fitted: FdMatrix,
96    /// Residual functions (n × m matrix)
97    pub residuals: FdMatrix,
98    /// Pointwise R² across the domain (length m)
99    pub r_squared_t: Vec<f64>,
100    /// Global R² (integrated)
101    pub r_squared: f64,
102    /// Pointwise standard errors for each βⱼ(t) (p × m matrix)
103    pub beta_se: FdMatrix,
104    /// Smoothing parameter λ used
105    pub lambda: f64,
106    /// GCV value
107    pub gcv: f64,
108}
109
110/// Result of FPC-based function-on-scalar regression.
111#[derive(Debug, Clone, PartialEq)]
112pub struct FosrFpcResult {
113    /// Intercept function μ(t) (length m)
114    pub intercept: Vec<f64>,
115    /// Coefficient functions β_j(t), one per predictor (p × m matrix, row j = βⱼ(t))
116    pub beta: FdMatrix,
117    /// Fitted functional values (n × m matrix)
118    pub fitted: FdMatrix,
119    /// Residual functions (n × m matrix)
120    pub residuals: FdMatrix,
121    /// Pointwise R² across the domain (length m)
122    pub r_squared_t: Vec<f64>,
123    /// Global R² (integrated)
124    pub r_squared: f64,
125    /// FPC-space regression coefficients gamma\[j\]\[k\] (one `Vec<f64>` per predictor)
126    pub beta_scores: Vec<Vec<f64>>,
127    /// Number of FPC components used
128    pub ncomp: usize,
129}
130
131/// Result of functional ANOVA.
132#[derive(Debug, Clone, PartialEq)]
133pub struct FanovaResult {
134    /// Group mean functions (k × m matrix, row g = mean curve of group g)
135    pub group_means: FdMatrix,
136    /// Overall mean function (length m)
137    pub overall_mean: Vec<f64>,
138    /// Pointwise F-statistic across the domain (length m)
139    pub f_statistic_t: Vec<f64>,
140    /// Global test statistic (integrated F)
141    pub global_statistic: f64,
142    /// P-value from permutation test
143    pub p_value: f64,
144    /// Number of permutations performed
145    pub n_perm: usize,
146    /// Number of groups
147    pub n_groups: usize,
148    /// Group labels (sorted unique values)
149    pub group_labels: Vec<usize>,
150}
151
152// ---------------------------------------------------------------------------
153// Shared helpers
154// ---------------------------------------------------------------------------
155
156/// Build second-order difference penalty matrix D'D (p×p, flat row-major).
157pub(crate) fn penalty_matrix(m: usize) -> Vec<f64> {
158    if m < 3 {
159        return vec![0.0; m * m];
160    }
161    // D is (m-2)×m second-difference operator
162    // D'D is m×m symmetric banded matrix
163    let mut dtd = vec![0.0; m * m];
164    for i in 0..m - 2 {
165        // D[i,:] = [0..0, 1, -2, 1, 0..0] at positions i, i+1, i+2
166        let coeffs = [(i, 1.0), (i + 1, -2.0), (i + 2, 1.0)];
167        for &(r, cr) in &coeffs {
168            for &(c, cc) in &coeffs {
169                dtd[r * m + c] += cr * cc;
170            }
171        }
172    }
173    dtd
174}
175
176/// Solve (A + λP)x = b for each column of B (pointwise regression at each t).
177/// A is X'X (p×p), B is X'Y (p×m), P is penalty matrix (p×p).
178/// Returns coefficient matrix (p×m).
179fn penalized_solve(
180    xtx: &[f64],
181    xty: &FdMatrix,
182    penalty: &[f64],
183    lambda: f64,
184) -> Result<FdMatrix, FdarError> {
185    let p = xty.nrows();
186    let m = xty.ncols();
187
188    // Build (X'X + λP)
189    let mut a = vec![0.0; p * p];
190    for i in 0..p * p {
191        a[i] = xtx[i] + lambda * penalty[i];
192    }
193
194    // Cholesky factor
195    let l = cholesky_factor(&a, p).ok_or_else(|| FdarError::ComputationFailed {
196        operation: "penalized_solve",
197        detail: format!(
198            "Cholesky factorization of (X'X + {lambda:.4}*P) failed; matrix is singular or near-singular"
199        ),
200    })?;
201
202    // Solve for each grid point
203    let mut beta = FdMatrix::zeros(p, m);
204    for t in 0..m {
205        let b: Vec<f64> = (0..p).map(|j| xty[(j, t)]).collect();
206        let x = cholesky_forward_back(&l, &b, p);
207        for j in 0..p {
208            beta[(j, t)] = x[j];
209        }
210    }
211    Ok(beta)
212}
213
214/// Compute pointwise R² at each grid point.
215pub(crate) fn pointwise_r_squared(data: &FdMatrix, fitted: &FdMatrix) -> Vec<f64> {
216    let (n, m) = data.shape();
217    (0..m)
218        .map(|t| {
219            let mean_t: f64 = (0..n).map(|i| data[(i, t)]).sum::<f64>() / n as f64;
220            let ss_tot: f64 = (0..n).map(|i| (data[(i, t)] - mean_t).powi(2)).sum();
221            let ss_res: f64 = (0..n)
222                .map(|i| (data[(i, t)] - fitted[(i, t)]).powi(2))
223                .sum();
224            if ss_tot > 1e-15 {
225                1.0 - ss_res / ss_tot
226            } else {
227                0.0
228            }
229        })
230        .collect()
231}
232
233/// GCV for penalized regression: (1/nm) Σ_{i,t} (r_{it} / (1 - tr(H)/n))²
234fn compute_fosr_gcv(residuals: &FdMatrix, trace_h: f64) -> f64 {
235    let (n, m) = residuals.shape();
236    let denom = (1.0 - trace_h / n as f64).max(1e-10);
237    let ss_res: f64 = (0..n)
238        .flat_map(|i| (0..m).map(move |t| residuals[(i, t)].powi(2)))
239        .sum();
240    ss_res / (n as f64 * m as f64 * denom * denom)
241}
242
243// ---------------------------------------------------------------------------
244// fosr: Function-on-scalar regression
245// ---------------------------------------------------------------------------
246
247/// Penalized function-on-scalar regression.
248///
249/// Fits pointwise OLS at each grid point t, then smooths the coefficient
250/// functions β_j(t) using a second-order roughness penalty.
251///
252/// # Arguments
253/// * `data` - Functional response matrix (n × m)
254/// * `predictors` - Scalar predictor matrix (n × p)
255/// * `lambda` - Smoothing parameter (0 for no smoothing, negative for GCV selection)
256///
257/// # Returns
258/// [`FosrResult`] with coefficient functions, fitted values, and diagnostics
259///
260/// # Errors
261///
262/// Returns [`FdarError::InvalidDimension`] if `data` has zero columns,
263/// `predictors` row count does not match `data`, or `n < p + 2`.
264/// Returns [`FdarError::ComputationFailed`] if the penalized Cholesky solve
265/// is singular.
266/// Build design matrix with intercept: \[1, z_1, ..., z_p\].
267pub(crate) fn build_fosr_design(predictors: &FdMatrix, n: usize) -> FdMatrix {
268    let p = predictors.ncols();
269    let p_total = p + 1;
270    let mut design = FdMatrix::zeros(n, p_total);
271    for i in 0..n {
272        design[(i, 0)] = 1.0;
273        for j in 0..p {
274            design[(i, 1 + j)] = predictors[(i, j)];
275        }
276    }
277    design
278}
279
280/// Compute X'Y (p_total × m).
281pub(crate) fn compute_xty_matrix(design: &FdMatrix, data: &FdMatrix) -> FdMatrix {
282    let (n, m) = data.shape();
283    let p_total = design.ncols();
284    let mut xty = FdMatrix::zeros(p_total, m);
285    for j in 0..p_total {
286        for t in 0..m {
287            let mut s = 0.0;
288            for i in 0..n {
289                s += design[(i, j)] * data[(i, t)];
290            }
291            xty[(j, t)] = s;
292        }
293    }
294    xty
295}
296
297/// Extract rows 1..p+1 from a (p+1)×m matrix, dropping the intercept row.
298fn drop_intercept_rows(full: &FdMatrix, p: usize, m: usize) -> FdMatrix {
299    let mut out = FdMatrix::zeros(p, m);
300    for j in 0..p {
301        for t in 0..m {
302            out[(j, t)] = full[(j + 1, t)];
303        }
304    }
305    out
306}
307
308#[must_use = "expensive computation whose result should not be discarded"]
309pub fn fosr(data: &FdMatrix, predictors: &FdMatrix, lambda: f64) -> Result<FosrResult, FdarError> {
310    let (n, m) = data.shape();
311    let p = predictors.ncols();
312    if m == 0 {
313        return Err(FdarError::InvalidDimension {
314            parameter: "data",
315            expected: "at least 1 column (grid points)".to_string(),
316            actual: "0 columns".to_string(),
317        });
318    }
319    if predictors.nrows() != n {
320        return Err(FdarError::InvalidDimension {
321            parameter: "predictors",
322            expected: format!("{n} rows (matching data)"),
323            actual: format!("{} rows", predictors.nrows()),
324        });
325    }
326    if n < p + 2 {
327        return Err(FdarError::InvalidDimension {
328            parameter: "data",
329            expected: format!("at least {} observations (p + 2)", p + 2),
330            actual: format!("{n} observations"),
331        });
332    }
333
334    let design = build_fosr_design(predictors, n);
335    let p_total = design.ncols();
336    let xtx = compute_xtx(&design);
337    let xty = compute_xty_matrix(&design, data);
338    let penalty = penalty_matrix(p_total);
339
340    let lambda = if lambda < 0.0 {
341        select_lambda_gcv(&xtx, &xty, &penalty, data, &design)
342    } else {
343        lambda
344    };
345
346    let beta = penalized_solve(&xtx, &xty, &penalty, lambda)?;
347    let (fitted, residuals) = compute_fosr_fitted(&design, &beta, data);
348
349    let r_squared_t = pointwise_r_squared(data, &fitted);
350    let r_squared = r_squared_t.iter().sum::<f64>() / m as f64;
351    let beta_se = compute_beta_se(&xtx, &penalty, lambda, &residuals, p_total, n);
352    let trace_h = compute_trace_hat(&xtx, &penalty, lambda, p_total, n);
353    let gcv = compute_fosr_gcv(&residuals, trace_h);
354
355    let intercept: Vec<f64> = (0..m).map(|t| beta[(0, t)]).collect();
356
357    Ok(FosrResult {
358        intercept,
359        beta: drop_intercept_rows(&beta, p, m),
360        fitted,
361        residuals,
362        r_squared_t,
363        r_squared,
364        beta_se: drop_intercept_rows(&beta_se, p, m),
365        lambda,
366        gcv,
367    })
368}
369
370/// Compute fitted values Ŷ = X β and residuals.
371fn compute_fosr_fitted(
372    design: &FdMatrix,
373    beta: &FdMatrix,
374    data: &FdMatrix,
375) -> (FdMatrix, FdMatrix) {
376    let (n, m) = data.shape();
377    let p_total = design.ncols();
378    let rows: Vec<(Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
379        .map(|i| {
380            let mut fitted_row = vec![0.0; m];
381            let mut resid_row = vec![0.0; m];
382            for t in 0..m {
383                let mut yhat = 0.0;
384                for j in 0..p_total {
385                    yhat += design[(i, j)] * beta[(j, t)];
386                }
387                fitted_row[t] = yhat;
388                resid_row[t] = data[(i, t)] - yhat;
389            }
390            (fitted_row, resid_row)
391        })
392        .collect();
393    let mut fitted = FdMatrix::zeros(n, m);
394    let mut residuals = FdMatrix::zeros(n, m);
395    for (i, (fr, rr)) in rows.into_iter().enumerate() {
396        for t in 0..m {
397            fitted[(i, t)] = fr[t];
398            residuals[(i, t)] = rr[t];
399        }
400    }
401    (fitted, residuals)
402}
403
404/// Select smoothing parameter λ via GCV on a grid.
405fn select_lambda_gcv(
406    xtx: &[f64],
407    xty: &FdMatrix,
408    penalty: &[f64],
409    data: &FdMatrix,
410    design: &FdMatrix,
411) -> f64 {
412    let lambdas = [0.0, 1e-6, 1e-4, 1e-2, 0.1, 1.0, 10.0, 100.0, 1000.0];
413    let p_total = design.ncols();
414    let n = design.nrows();
415
416    let mut best_lambda = 0.0;
417    let mut best_gcv = f64::INFINITY;
418
419    for &lam in &lambdas {
420        let beta = match penalized_solve(xtx, xty, penalty, lam) {
421            Ok(b) => b,
422            Err(_) => continue,
423        };
424        let (_, residuals) = compute_fosr_fitted(design, &beta, data);
425        let trace_h = compute_trace_hat(xtx, penalty, lam, p_total, n);
426        let gcv = compute_fosr_gcv(&residuals, trace_h);
427        if gcv < best_gcv {
428            best_gcv = gcv;
429            best_lambda = lam;
430        }
431    }
432    best_lambda
433}
434
435/// Compute trace of hat matrix: tr(H) = tr(X (X'X + λP)^{-1} X') = Σ_j h_jj.
436fn compute_trace_hat(xtx: &[f64], penalty: &[f64], lambda: f64, p: usize, n: usize) -> f64 {
437    let mut a = vec![0.0; p * p];
438    for i in 0..p * p {
439        a[i] = xtx[i] + lambda * penalty[i];
440    }
441    // tr(H) = tr(X A^{-1} X') = Σ_{j=0..p} a^{-1}_{jj} * xtx_{jj}
442    // More precisely: tr(X (X'X+λP)^{-1} X') = tr((X'X+λP)^{-1} X'X)
443    let l = match cholesky_factor(&a, p) {
444        Some(l) => l,
445        None => return p as f64, // fallback
446    };
447
448    // Compute A^{-1} X'X via solving A Z = X'X column by column, then trace
449    let mut trace = 0.0;
450    for j in 0..p {
451        let col: Vec<f64> = (0..p).map(|i| xtx[i * p + j]).collect();
452        let z = cholesky_forward_back(&l, &col, p);
453        trace += z[j]; // diagonal element of A^{-1} X'X
454    }
455    trace.min(n as f64)
456}
457
458/// Compute pointwise standard errors for β(t).
459fn compute_beta_se(
460    xtx: &[f64],
461    penalty: &[f64],
462    lambda: f64,
463    residuals: &FdMatrix,
464    p: usize,
465    n: usize,
466) -> FdMatrix {
467    let m = residuals.ncols();
468    let mut a = vec![0.0; p * p];
469    for i in 0..p * p {
470        a[i] = xtx[i] + lambda * penalty[i];
471    }
472    let l = match cholesky_factor(&a, p) {
473        Some(l) => l,
474        None => return FdMatrix::zeros(p, m),
475    };
476
477    // Diagonal of A^{-1}
478    let a_inv_diag: Vec<f64> = (0..p)
479        .map(|j| {
480            let mut ej = vec![0.0; p];
481            ej[j] = 1.0;
482            let v = cholesky_forward_back(&l, &ej, p);
483            v[j]
484        })
485        .collect();
486
487    let df = (n - p).max(1) as f64;
488    let mut se = FdMatrix::zeros(p, m);
489    for t in 0..m {
490        let sigma2_t: f64 = (0..n).map(|i| residuals[(i, t)].powi(2)).sum::<f64>() / df;
491        for j in 0..p {
492            se[(j, t)] = (sigma2_t * a_inv_diag[j]).max(0.0).sqrt();
493        }
494    }
495    se
496}
497
498// ---------------------------------------------------------------------------
499// fosr_fpc: FPC-based function-on-scalar regression (matches R's fda.usc approach)
500// ---------------------------------------------------------------------------
501
502/// OLS regression of each FPC score on the design matrix.
503///
504/// Returns gamma_all\[comp\]\[coef\] = (X'X)^{-1} X' scores\[:,comp\].
505fn regress_scores_on_design(
506    design: &FdMatrix,
507    scores: &FdMatrix,
508    n: usize,
509    k: usize,
510    p_total: usize,
511) -> Result<Vec<Vec<f64>>, FdarError> {
512    let xtx = compute_xtx(design);
513    let l = cholesky_factor(&xtx, p_total).ok_or_else(|| FdarError::ComputationFailed {
514        operation: "regress_scores_on_design",
515        detail: "Cholesky factorization of X'X failed; design matrix is rank-deficient".to_string(),
516    })?;
517
518    let gamma_all: Vec<Vec<f64>> = (0..k)
519        .map(|comp| {
520            let mut xts = vec![0.0; p_total];
521            for j in 0..p_total {
522                for i in 0..n {
523                    xts[j] += design[(i, j)] * scores[(i, comp)];
524                }
525            }
526            cholesky_forward_back(&l, &xts, p_total)
527        })
528        .collect();
529    Ok(gamma_all)
530}
531
532/// Reconstruct β_j(t) = Σ_k gamma\[comp\]\[1+j\] · φ_k(t) for each predictor j.
533fn reconstruct_beta_fpc(
534    gamma_all: &[Vec<f64>],
535    rotation: &FdMatrix,
536    p: usize,
537    k: usize,
538    m: usize,
539) -> FdMatrix {
540    let mut beta = FdMatrix::zeros(p, m);
541    for j in 0..p {
542        for t in 0..m {
543            let mut val = 0.0;
544            for comp in 0..k {
545                val += gamma_all[comp][1 + j] * rotation[(t, comp)];
546            }
547            beta[(j, t)] = val;
548        }
549    }
550    beta
551}
552
553/// Compute intercept function: μ(t) + Σ_k γ_intercept\[k\] · φ_k(t).
554fn compute_intercept_fpc(
555    mean: &[f64],
556    gamma_all: &[Vec<f64>],
557    rotation: &FdMatrix,
558    k: usize,
559    m: usize,
560) -> Vec<f64> {
561    let mut intercept = mean.to_vec();
562    for t in 0..m {
563        for comp in 0..k {
564            intercept[t] += gamma_all[comp][0] * rotation[(t, comp)];
565        }
566    }
567    intercept
568}
569
570/// Extract L²-normalized beta_scores from regression coefficients.
571fn extract_beta_scores(gamma_all: &[Vec<f64>], p: usize, k: usize, m: usize) -> Vec<Vec<f64>> {
572    let h = if m > 1 { 1.0 / (m - 1) as f64 } else { 1.0 };
573    let score_scale = h.sqrt();
574    (0..p)
575        .map(|j| {
576            (0..k)
577                .map(|comp| gamma_all[comp][1 + j] * score_scale)
578                .collect()
579        })
580        .collect()
581}
582
583/// FPC-based function-on-scalar regression.
584///
585/// Reduces the functional response to FPC scores, regresses each score on the
586/// scalar predictors via OLS, then reconstructs β(t) from the loadings.
587/// This matches R's `fdata2pc` + `lm(scores ~ x)` approach.
588///
589/// # Arguments
590/// * `data` - Functional response matrix (n × m)
591/// * `predictors` - Scalar predictor matrix (n × p)
592/// * `ncomp` - Number of FPC components to use
593///
594/// # Errors
595///
596/// Returns [`FdarError::InvalidDimension`] if `data` has zero columns,
597/// `predictors` row count does not match `data`, or `n < p + 2`.
598/// Returns [`FdarError::InvalidParameter`] if `ncomp` is zero.
599/// Returns [`FdarError::ComputationFailed`] if FPCA fails or the OLS
600/// Cholesky factorization of X'X is singular.
601#[must_use = "expensive computation whose result should not be discarded"]
602pub fn fosr_fpc(
603    data: &FdMatrix,
604    predictors: &FdMatrix,
605    ncomp: usize,
606) -> Result<FosrFpcResult, FdarError> {
607    let (n, m) = data.shape();
608    let p = predictors.ncols();
609    if m == 0 {
610        return Err(FdarError::InvalidDimension {
611            parameter: "data",
612            expected: "at least 1 column (grid points)".to_string(),
613            actual: "0 columns".to_string(),
614        });
615    }
616    if predictors.nrows() != n {
617        return Err(FdarError::InvalidDimension {
618            parameter: "predictors",
619            expected: format!("{n} rows (matching data)"),
620            actual: format!("{} rows", predictors.nrows()),
621        });
622    }
623    if n < p + 2 {
624        return Err(FdarError::InvalidDimension {
625            parameter: "data",
626            expected: format!("at least {} observations (p + 2)", p + 2),
627            actual: format!("{n} observations"),
628        });
629    }
630    if ncomp == 0 {
631        return Err(FdarError::InvalidParameter {
632            parameter: "ncomp",
633            message: "number of FPC components must be at least 1".to_string(),
634        });
635    }
636
637    let fpca = fdata_to_pc_1d(data, ncomp)?;
638    let k = fpca.scores.ncols();
639    let p_total = p + 1;
640    let design = build_fosr_design(predictors, n);
641
642    let gamma_all = regress_scores_on_design(&design, &fpca.scores, n, k, p_total)?;
643    let beta = reconstruct_beta_fpc(&gamma_all, &fpca.rotation, p, k, m);
644    let intercept = compute_intercept_fpc(&fpca.mean, &gamma_all, &fpca.rotation, k, m);
645
646    let (fitted, residuals) = compute_fosr_fpc_fitted(data, &intercept, &beta, predictors);
647    let r_squared_t = pointwise_r_squared(data, &fitted);
648    let r_squared = r_squared_t.iter().sum::<f64>() / m as f64;
649    let beta_scores = extract_beta_scores(&gamma_all, p, k, m);
650
651    Ok(FosrFpcResult {
652        intercept,
653        beta,
654        fitted,
655        residuals,
656        r_squared_t,
657        r_squared,
658        beta_scores,
659        ncomp: k,
660    })
661}
662
663/// Compute fitted values and residuals for FPC-based FOSR.
664fn compute_fosr_fpc_fitted(
665    data: &FdMatrix,
666    intercept: &[f64],
667    beta: &FdMatrix,
668    predictors: &FdMatrix,
669) -> (FdMatrix, FdMatrix) {
670    let (n, m) = data.shape();
671    let p = predictors.ncols();
672    let mut fitted = FdMatrix::zeros(n, m);
673    let mut residuals = FdMatrix::zeros(n, m);
674    for i in 0..n {
675        for t in 0..m {
676            let mut yhat = intercept[t];
677            for j in 0..p {
678                yhat += predictors[(i, j)] * beta[(j, t)];
679            }
680            fitted[(i, t)] = yhat;
681            residuals[(i, t)] = data[(i, t)] - yhat;
682        }
683    }
684    (fitted, residuals)
685}
686
687/// Predict new functional responses from a fitted FOSR model.
688///
689/// # Arguments
690/// * `result` - Fitted [`FosrResult`]
691/// * `new_predictors` - New scalar predictors (n_new × p)
692#[must_use = "prediction result should not be discarded"]
693pub fn predict_fosr(result: &FosrResult, new_predictors: &FdMatrix) -> FdMatrix {
694    let n_new = new_predictors.nrows();
695    let m = result.intercept.len();
696    let p = result.beta.nrows();
697
698    let mut predicted = FdMatrix::zeros(n_new, m);
699    for i in 0..n_new {
700        for t in 0..m {
701            let mut yhat = result.intercept[t];
702            for j in 0..p {
703                yhat += new_predictors[(i, j)] * result.beta[(j, t)];
704            }
705            predicted[(i, t)] = yhat;
706        }
707    }
708    predicted
709}
710
711// ---------------------------------------------------------------------------
712// fanova: Functional ANOVA
713// ---------------------------------------------------------------------------
714
715/// Compute group means and overall mean.
716fn compute_group_means(
717    data: &FdMatrix,
718    groups: &[usize],
719    labels: &[usize],
720) -> (FdMatrix, Vec<f64>) {
721    let (n, m) = data.shape();
722    let k = labels.len();
723    let mut group_means = FdMatrix::zeros(k, m);
724    let mut counts = vec![0usize; k];
725
726    for i in 0..n {
727        let g = labels.iter().position(|&l| l == groups[i]).unwrap_or(0);
728        counts[g] += 1;
729        for t in 0..m {
730            group_means[(g, t)] += data[(i, t)];
731        }
732    }
733    for g in 0..k {
734        if counts[g] > 0 {
735            for t in 0..m {
736                group_means[(g, t)] /= counts[g] as f64;
737            }
738        }
739    }
740
741    let overall_mean: Vec<f64> = (0..m)
742        .map(|t| (0..n).map(|i| data[(i, t)]).sum::<f64>() / n as f64)
743        .collect();
744
745    (group_means, overall_mean)
746}
747
748/// Compute pointwise F-statistic.
749fn pointwise_f_statistic(
750    data: &FdMatrix,
751    groups: &[usize],
752    labels: &[usize],
753    group_means: &FdMatrix,
754    overall_mean: &[f64],
755) -> Vec<f64> {
756    let (n, m) = data.shape();
757    let k = labels.len();
758    let mut counts = vec![0usize; k];
759    for &g in groups {
760        let idx = labels.iter().position(|&l| l == g).unwrap_or(0);
761        counts[idx] += 1;
762    }
763
764    (0..m)
765        .map(|t| {
766            let ss_between: f64 = (0..k)
767                .map(|g| counts[g] as f64 * (group_means[(g, t)] - overall_mean[t]).powi(2))
768                .sum();
769            let ss_within: f64 = (0..n)
770                .map(|i| {
771                    let g = labels.iter().position(|&l| l == groups[i]).unwrap_or(0);
772                    (data[(i, t)] - group_means[(g, t)]).powi(2)
773                })
774                .sum();
775            let ms_between = ss_between / (k as f64 - 1.0).max(1.0);
776            let ms_within = ss_within / (n as f64 - k as f64).max(1.0);
777            if ms_within > 1e-15 {
778                ms_between / ms_within
779            } else {
780                0.0
781            }
782        })
783        .collect()
784}
785
786/// Compute global test statistic (integrated F).
787fn global_f_statistic(f_t: &[f64]) -> f64 {
788    f_t.iter().sum::<f64>() / f_t.len() as f64
789}
790
791/// Functional ANOVA: test whether groups have different mean curves.
792///
793/// Uses a permutation-based global test with the integrated F-statistic.
794///
795/// # Arguments
796/// * `data` - Functional response matrix (n × m)
797/// * `groups` - Group labels for each observation (length n, integer-coded)
798/// * `n_perm` - Number of permutations for the global test
799///
800/// # Returns
801/// [`FanovaResult`] with group means, F-statistics, and permutation p-value
802///
803/// # Errors
804///
805/// Returns [`FdarError::InvalidDimension`] if `data` has zero columns,
806/// `groups.len()` does not match the number of rows in `data`, or `n < 3`.
807/// Returns [`FdarError::InvalidParameter`] if fewer than 2 distinct groups
808/// are present.
809#[must_use = "expensive computation whose result should not be discarded"]
810pub fn fanova(data: &FdMatrix, groups: &[usize], n_perm: usize) -> Result<FanovaResult, FdarError> {
811    let (n, m) = data.shape();
812    if m == 0 {
813        return Err(FdarError::InvalidDimension {
814            parameter: "data",
815            expected: "at least 1 column (grid points)".to_string(),
816            actual: "0 columns".to_string(),
817        });
818    }
819    if groups.len() != n {
820        return Err(FdarError::InvalidDimension {
821            parameter: "groups",
822            expected: format!("{n} elements (matching data rows)"),
823            actual: format!("{} elements", groups.len()),
824        });
825    }
826    if n < 3 {
827        return Err(FdarError::InvalidDimension {
828            parameter: "data",
829            expected: "at least 3 observations".to_string(),
830            actual: format!("{n} observations"),
831        });
832    }
833
834    let mut labels: Vec<usize> = groups.to_vec();
835    labels.sort();
836    labels.dedup();
837    let n_groups = labels.len();
838    if n_groups < 2 {
839        return Err(FdarError::InvalidParameter {
840            parameter: "groups",
841            message: format!("at least 2 distinct groups required, but only {n_groups} found"),
842        });
843    }
844
845    let (group_means, overall_mean) = compute_group_means(data, groups, &labels);
846    let f_t = pointwise_f_statistic(data, groups, &labels, &group_means, &overall_mean);
847    let observed_stat = global_f_statistic(&f_t);
848
849    // Permutation test
850    let n_perm = n_perm.max(1);
851    let mut n_ge = 0usize;
852    let mut perm_groups = groups.to_vec();
853
854    // Simple LCG for reproducibility without requiring rand
855    let mut rng_state: u64 = 42;
856    for _ in 0..n_perm {
857        // Fisher-Yates shuffle with LCG
858        for i in (1..n).rev() {
859            rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
860            let j = (rng_state >> 33) as usize % (i + 1);
861            perm_groups.swap(i, j);
862        }
863
864        let (perm_means, perm_overall) = compute_group_means(data, &perm_groups, &labels);
865        let perm_f = pointwise_f_statistic(data, &perm_groups, &labels, &perm_means, &perm_overall);
866        let perm_stat = global_f_statistic(&perm_f);
867        if perm_stat >= observed_stat {
868            n_ge += 1;
869        }
870    }
871
872    let p_value = (n_ge as f64 + 1.0) / (n_perm as f64 + 1.0);
873
874    Ok(FanovaResult {
875        group_means,
876        overall_mean,
877        f_statistic_t: f_t,
878        global_statistic: observed_stat,
879        p_value,
880        n_perm,
881        n_groups,
882        group_labels: labels,
883    })
884}
885
886impl FosrResult {
887    /// Predict functional responses for new predictors. Delegates to [`predict_fosr`].
888    pub fn predict(&self, new_predictors: &FdMatrix) -> FdMatrix {
889        predict_fosr(self, new_predictors)
890    }
891}
892
893// ---------------------------------------------------------------------------
894// Tests
895// ---------------------------------------------------------------------------
896
897#[cfg(test)]
898mod tests {
899    use super::*;
900    use std::f64::consts::PI;
901
902    fn uniform_grid(m: usize) -> Vec<f64> {
903        (0..m).map(|j| j as f64 / (m - 1) as f64).collect()
904    }
905
906    fn generate_fosr_data(n: usize, m: usize) -> (FdMatrix, FdMatrix) {
907        let t = uniform_grid(m);
908        let mut y = FdMatrix::zeros(n, m);
909        let mut z = FdMatrix::zeros(n, 2);
910
911        for i in 0..n {
912            let age = (i as f64) / (n as f64);
913            let group = if i % 2 == 0 { 1.0 } else { 0.0 };
914            z[(i, 0)] = age;
915            z[(i, 1)] = group;
916            for j in 0..m {
917                // True model: μ(t) + age * β₁(t) + group * β₂(t)
918                let mu = (2.0 * PI * t[j]).sin();
919                let beta1 = t[j]; // Linear coefficient for age
920                let beta2 = (4.0 * PI * t[j]).cos(); // Oscillating for group
921                y[(i, j)] = mu
922                    + age * beta1
923                    + group * beta2
924                    + 0.05 * ((i * 13 + j * 7) % 100) as f64 / 100.0;
925            }
926        }
927        (y, z)
928    }
929
930    // ----- FOSR tests -----
931
932    #[test]
933    fn test_fosr_basic() {
934        let (y, z) = generate_fosr_data(30, 50);
935        let result = fosr(&y, &z, 0.0);
936        assert!(result.is_ok());
937        let fit = result.unwrap();
938        assert_eq!(fit.intercept.len(), 50);
939        assert_eq!(fit.beta.shape(), (2, 50));
940        assert_eq!(fit.fitted.shape(), (30, 50));
941        assert_eq!(fit.residuals.shape(), (30, 50));
942        assert!(fit.r_squared >= 0.0);
943    }
944
945    #[test]
946    fn test_fosr_with_penalty() {
947        let (y, z) = generate_fosr_data(30, 50);
948        let fit0 = fosr(&y, &z, 0.0).unwrap();
949        let fit1 = fosr(&y, &z, 1.0).unwrap();
950        // Both should produce valid results
951        assert_eq!(fit0.beta.shape(), (2, 50));
952        assert_eq!(fit1.beta.shape(), (2, 50));
953    }
954
955    #[test]
956    fn test_fosr_auto_lambda() {
957        let (y, z) = generate_fosr_data(30, 50);
958        let fit = fosr(&y, &z, -1.0).unwrap();
959        assert!(fit.lambda >= 0.0);
960    }
961
962    #[test]
963    fn test_fosr_fitted_plus_residuals_equals_y() {
964        let (y, z) = generate_fosr_data(30, 50);
965        let fit = fosr(&y, &z, 0.0).unwrap();
966        for i in 0..30 {
967            for t in 0..50 {
968                let reconstructed = fit.fitted[(i, t)] + fit.residuals[(i, t)];
969                assert!(
970                    (reconstructed - y[(i, t)]).abs() < 1e-10,
971                    "ŷ + r should equal y at ({}, {})",
972                    i,
973                    t
974                );
975            }
976        }
977    }
978
979    #[test]
980    fn test_fosr_pointwise_r_squared_valid() {
981        let (y, z) = generate_fosr_data(30, 50);
982        let fit = fosr(&y, &z, 0.0).unwrap();
983        for &r2 in &fit.r_squared_t {
984            assert!(
985                (-0.01..=1.0 + 1e-10).contains(&r2),
986                "R²(t) out of range: {}",
987                r2
988            );
989        }
990    }
991
992    #[test]
993    fn test_fosr_se_positive() {
994        let (y, z) = generate_fosr_data(30, 50);
995        let fit = fosr(&y, &z, 0.0).unwrap();
996        for j in 0..2 {
997            for t in 0..50 {
998                assert!(
999                    fit.beta_se[(j, t)] >= 0.0 && fit.beta_se[(j, t)].is_finite(),
1000                    "SE should be non-negative finite"
1001                );
1002            }
1003        }
1004    }
1005
1006    #[test]
1007    fn test_fosr_invalid_input() {
1008        let y = FdMatrix::zeros(2, 50);
1009        let z = FdMatrix::zeros(2, 1);
1010        assert!(fosr(&y, &z, 0.0).is_err());
1011    }
1012
1013    // ----- predict_fosr tests -----
1014
1015    #[test]
1016    fn test_predict_fosr_on_training_data() {
1017        let (y, z) = generate_fosr_data(30, 50);
1018        let fit = fosr(&y, &z, 0.0).unwrap();
1019        let preds = predict_fosr(&fit, &z);
1020        assert_eq!(preds.shape(), (30, 50));
1021        for i in 0..30 {
1022            for t in 0..50 {
1023                assert!(
1024                    (preds[(i, t)] - fit.fitted[(i, t)]).abs() < 1e-8,
1025                    "Prediction on training data should match fitted"
1026                );
1027            }
1028        }
1029    }
1030
1031    // ----- FANOVA tests -----
1032
1033    #[test]
1034    fn test_fanova_two_groups() {
1035        let n = 40;
1036        let m = 50;
1037        let t = uniform_grid(m);
1038
1039        let mut data = FdMatrix::zeros(n, m);
1040        let mut groups = vec![0usize; n];
1041        for i in 0..n {
1042            groups[i] = if i < n / 2 { 0 } else { 1 };
1043            for j in 0..m {
1044                let base = (2.0 * PI * t[j]).sin();
1045                let effect = if groups[i] == 1 { 0.5 * t[j] } else { 0.0 };
1046                data[(i, j)] = base + effect + 0.01 * (i as f64 * 0.1).sin();
1047            }
1048        }
1049
1050        let result = fanova(&data, &groups, 200);
1051        assert!(result.is_ok());
1052        let res = result.unwrap();
1053        assert_eq!(res.n_groups, 2);
1054        assert_eq!(res.group_means.shape(), (2, m));
1055        assert_eq!(res.f_statistic_t.len(), m);
1056        assert!(res.p_value >= 0.0 && res.p_value <= 1.0);
1057        // With a real group effect, p should be small
1058        assert!(
1059            res.p_value < 0.1,
1060            "Should detect group effect, got p={}",
1061            res.p_value
1062        );
1063    }
1064
1065    #[test]
1066    fn test_fanova_no_effect() {
1067        let n = 40;
1068        let m = 50;
1069        let t = uniform_grid(m);
1070
1071        let mut data = FdMatrix::zeros(n, m);
1072        let mut groups = vec![0usize; n];
1073        for i in 0..n {
1074            groups[i] = if i < n / 2 { 0 } else { 1 };
1075            for j in 0..m {
1076                // Same distribution for both groups
1077                data[(i, j)] =
1078                    (2.0 * PI * t[j]).sin() + 0.1 * ((i * 7 + j * 3) % 100) as f64 / 100.0;
1079            }
1080        }
1081
1082        let result = fanova(&data, &groups, 200);
1083        assert!(result.is_ok());
1084        let res = result.unwrap();
1085        // Without group effect, p should be large
1086        assert!(
1087            res.p_value > 0.05,
1088            "Should not detect effect, got p={}",
1089            res.p_value
1090        );
1091    }
1092
1093    #[test]
1094    fn test_fanova_three_groups() {
1095        let n = 30;
1096        let m = 50;
1097        let t = uniform_grid(m);
1098
1099        let mut data = FdMatrix::zeros(n, m);
1100        let mut groups = vec![0usize; n];
1101        for i in 0..n {
1102            groups[i] = i % 3;
1103            for j in 0..m {
1104                let effect = match groups[i] {
1105                    0 => 0.0,
1106                    1 => 0.5 * t[j],
1107                    _ => -0.3 * (2.0 * PI * t[j]).cos(),
1108                };
1109                data[(i, j)] = (2.0 * PI * t[j]).sin() + effect + 0.01 * (i as f64 * 0.1).sin();
1110            }
1111        }
1112
1113        let result = fanova(&data, &groups, 200);
1114        assert!(result.is_ok());
1115        let res = result.unwrap();
1116        assert_eq!(res.n_groups, 3);
1117    }
1118
1119    #[test]
1120    fn test_fanova_invalid_input() {
1121        let data = FdMatrix::zeros(10, 50);
1122        let groups = vec![0; 10]; // Only one group
1123        assert!(fanova(&data, &groups, 100).is_err());
1124
1125        let groups = vec![0; 5]; // Wrong length
1126        assert!(fanova(&data, &groups, 100).is_err());
1127    }
1128}