Skip to main content

fdars_core/
function_on_scalar.rs

1//! Function-on-scalar regression and functional ANOVA.
2//!
3//! Predicts a **functional response** from scalar/categorical predictors:
4//! ```text
5//! X_i(t) = μ(t) + Σⱼ βⱼ(t) · z_ij + ε_i(t)
6//! ```
7//!
8//! # Methods
9//!
10//! - [`fosr`]: Penalized function-on-scalar regression (pointwise OLS + smoothing)
11//! - [`fanova`]: Functional ANOVA with permutation-based global test
12//! - [`predict_fosr`]: Predict new curves from fitted model
13
14use crate::error::FdarError;
15use crate::iter_maybe_parallel;
16use crate::linalg::{cholesky_factor, cholesky_forward_back, compute_xtx};
17use crate::matrix::FdMatrix;
18use crate::regression::fdata_to_pc_1d;
19#[cfg(feature = "parallel")]
20use rayon::iter::ParallelIterator;
21
22// ---------------------------------------------------------------------------
23// Result types
24// ---------------------------------------------------------------------------
25
26/// Result of function-on-scalar regression.
27#[derive(Debug, Clone, PartialEq)]
28#[non_exhaustive]
29pub struct FosrResult {
30    /// Intercept function μ(t) (length m)
31    pub intercept: Vec<f64>,
32    /// Coefficient functions β_j(t), one per predictor (p × m matrix, row j = βⱼ(t))
33    pub beta: FdMatrix,
34    /// Fitted functional values (n × m matrix)
35    pub fitted: FdMatrix,
36    /// Residual functions (n × m matrix)
37    pub residuals: FdMatrix,
38    /// Pointwise R² across the domain (length m)
39    pub r_squared_t: Vec<f64>,
40    /// Global R² (integrated)
41    pub r_squared: f64,
42    /// Pointwise standard errors for each βⱼ(t) (p × m matrix)
43    pub beta_se: FdMatrix,
44    /// Smoothing parameter λ used
45    pub lambda: f64,
46    /// GCV value
47    pub gcv: f64,
48}
49
50/// Result of FPC-based function-on-scalar regression.
51#[derive(Debug, Clone, PartialEq)]
52#[non_exhaustive]
53pub struct FosrFpcResult {
54    /// Intercept function μ(t) (length m)
55    pub intercept: Vec<f64>,
56    /// Coefficient functions β_j(t), one per predictor (p × m matrix, row j = βⱼ(t))
57    pub beta: FdMatrix,
58    /// Fitted functional values (n × m matrix)
59    pub fitted: FdMatrix,
60    /// Residual functions (n × m matrix)
61    pub residuals: FdMatrix,
62    /// Pointwise R² across the domain (length m)
63    pub r_squared_t: Vec<f64>,
64    /// Global R² (integrated)
65    pub r_squared: f64,
66    /// FPC-space regression coefficients gamma\[j\]\[k\] (one `Vec<f64>` per predictor)
67    pub beta_scores: Vec<Vec<f64>>,
68    /// Number of FPC components used
69    pub ncomp: usize,
70}
71
72/// Result of functional ANOVA.
73#[derive(Debug, Clone, PartialEq)]
74#[non_exhaustive]
75pub struct FanovaResult {
76    /// Group mean functions (k × m matrix, row g = mean curve of group g)
77    pub group_means: FdMatrix,
78    /// Overall mean function (length m)
79    pub overall_mean: Vec<f64>,
80    /// Pointwise F-statistic across the domain (length m)
81    pub f_statistic_t: Vec<f64>,
82    /// Global test statistic (integrated F)
83    pub global_statistic: f64,
84    /// P-value from permutation test
85    pub p_value: f64,
86    /// Number of permutations performed
87    pub n_perm: usize,
88    /// Number of groups
89    pub n_groups: usize,
90    /// Group labels (sorted unique values)
91    pub group_labels: Vec<usize>,
92}
93
94// ---------------------------------------------------------------------------
95// Shared helpers
96// ---------------------------------------------------------------------------
97
98/// Build second-order difference penalty matrix D'D (p×p, flat row-major).
99pub(crate) fn penalty_matrix(m: usize) -> Vec<f64> {
100    if m < 3 {
101        return vec![0.0; m * m];
102    }
103    // D is (m-2)×m second-difference operator
104    // D'D is m×m symmetric banded matrix
105    let mut dtd = vec![0.0; m * m];
106    for i in 0..m - 2 {
107        // D[i,:] = [0..0, 1, -2, 1, 0..0] at positions i, i+1, i+2
108        let coeffs = [(i, 1.0), (i + 1, -2.0), (i + 2, 1.0)];
109        for &(r, cr) in &coeffs {
110            for &(c, cc) in &coeffs {
111                dtd[r * m + c] += cr * cc;
112            }
113        }
114    }
115    dtd
116}
117
118/// Solve (A + λP)x = b for each column of B (pointwise regression at each t).
119/// A is X'X (p×p), B is X'Y (p×m), P is penalty matrix (p×p).
120/// Returns coefficient matrix (p×m).
121fn penalized_solve(
122    xtx: &[f64],
123    xty: &FdMatrix,
124    penalty: &[f64],
125    lambda: f64,
126) -> Result<FdMatrix, FdarError> {
127    let p = xty.nrows();
128    let m = xty.ncols();
129
130    // Build (X'X + λP)
131    let mut a = vec![0.0; p * p];
132    for i in 0..p * p {
133        a[i] = xtx[i] + lambda * penalty[i];
134    }
135
136    // Cholesky factor
137    let l = cholesky_factor(&a, p)?;
138
139    // Solve for each grid point
140    let mut beta = FdMatrix::zeros(p, m);
141    for t in 0..m {
142        let b: Vec<f64> = (0..p).map(|j| xty[(j, t)]).collect();
143        let x = cholesky_forward_back(&l, &b, p);
144        for j in 0..p {
145            beta[(j, t)] = x[j];
146        }
147    }
148    Ok(beta)
149}
150
151/// Compute pointwise R² at each grid point.
152pub(crate) fn pointwise_r_squared(data: &FdMatrix, fitted: &FdMatrix) -> Vec<f64> {
153    let (n, m) = data.shape();
154    (0..m)
155        .map(|t| {
156            let mean_t: f64 = (0..n).map(|i| data[(i, t)]).sum::<f64>() / n as f64;
157            let ss_tot: f64 = (0..n).map(|i| (data[(i, t)] - mean_t).powi(2)).sum();
158            let ss_res: f64 = (0..n)
159                .map(|i| (data[(i, t)] - fitted[(i, t)]).powi(2))
160                .sum();
161            if ss_tot > 1e-15 {
162                1.0 - ss_res / ss_tot
163            } else {
164                0.0
165            }
166        })
167        .collect()
168}
169
170/// GCV for penalized regression: (1/nm) Σ_{i,t} (r_{it} / (1 - tr(H)/n))²
171fn compute_fosr_gcv(residuals: &FdMatrix, trace_h: f64) -> f64 {
172    let (n, m) = residuals.shape();
173    let denom = (1.0 - trace_h / n as f64).max(1e-10);
174    let ss_res: f64 = (0..n)
175        .flat_map(|i| (0..m).map(move |t| residuals[(i, t)].powi(2)))
176        .sum();
177    ss_res / (n as f64 * m as f64 * denom * denom)
178}
179
180// ---------------------------------------------------------------------------
181// fosr: Function-on-scalar regression
182// ---------------------------------------------------------------------------
183
184/// Penalized function-on-scalar regression.
185///
186/// Fits pointwise OLS at each grid point t, then smooths the coefficient
187/// functions β_j(t) using a second-order roughness penalty.
188///
189/// # Arguments
190/// * `data` - Functional response matrix (n × m)
191/// * `predictors` - Scalar predictor matrix (n × p)
192/// * `lambda` - Smoothing parameter (0 for no smoothing, negative for GCV selection)
193///
194/// # Returns
195/// [`FosrResult`] with coefficient functions, fitted values, and diagnostics
196///
197/// # Errors
198///
199/// Returns [`FdarError::InvalidDimension`] if `data` has zero columns,
200/// `predictors` row count does not match `data`, or `n < p + 2`.
201/// Returns [`FdarError::ComputationFailed`] if the penalized Cholesky solve
202/// is singular.
203/// Build design matrix with intercept: \[1, z_1, ..., z_p\].
204pub(crate) fn build_fosr_design(predictors: &FdMatrix, n: usize) -> FdMatrix {
205    let p = predictors.ncols();
206    let p_total = p + 1;
207    let mut design = FdMatrix::zeros(n, p_total);
208    for i in 0..n {
209        design[(i, 0)] = 1.0;
210        for j in 0..p {
211            design[(i, 1 + j)] = predictors[(i, j)];
212        }
213    }
214    design
215}
216
217/// Compute X'Y (p_total × m).
218pub(crate) fn compute_xty_matrix(design: &FdMatrix, data: &FdMatrix) -> FdMatrix {
219    let (n, m) = data.shape();
220    let p_total = design.ncols();
221    let mut xty = FdMatrix::zeros(p_total, m);
222    for j in 0..p_total {
223        for t in 0..m {
224            let mut s = 0.0;
225            for i in 0..n {
226                s += design[(i, j)] * data[(i, t)];
227            }
228            xty[(j, t)] = s;
229        }
230    }
231    xty
232}
233
234/// Extract rows 1..p+1 from a (p+1)×m matrix, dropping the intercept row.
235fn drop_intercept_rows(full: &FdMatrix, p: usize, m: usize) -> FdMatrix {
236    let mut out = FdMatrix::zeros(p, m);
237    for j in 0..p {
238        for t in 0..m {
239            out[(j, t)] = full[(j + 1, t)];
240        }
241    }
242    out
243}
244
245/// Penalized function-on-scalar regression.
246///
247/// Fits the model `X_i(t) = mu(t) + sum_j beta_j(t) * z_ij + epsilon_i(t)`
248/// using pointwise OLS with second-order roughness penalty smoothing.
249///
250/// # Arguments
251/// * `data` - Functional response matrix (n x m)
252/// * `predictors` - Scalar predictor matrix (n x p)
253/// * `lambda` - Roughness penalty (negative for automatic GCV selection)
254///
255/// # Examples
256///
257/// ```
258/// use fdars_core::matrix::FdMatrix;
259/// use fdars_core::function_on_scalar::fosr;
260///
261/// // 10 curves at 15 grid points, 2 scalar predictors
262/// let data = FdMatrix::from_column_major(
263///     (0..150).map(|i| (i as f64 * 0.1).sin()).collect(),
264///     10, 15,
265/// ).unwrap();
266/// let predictors = FdMatrix::from_column_major(
267///     (0..20).map(|i| i as f64 / 19.0).collect(),
268///     10, 2,
269/// ).unwrap();
270/// let result = fosr(&data, &predictors, 0.1).unwrap();
271/// assert_eq!(result.fitted.shape(), (10, 15));
272/// assert_eq!(result.intercept.len(), 15);
273/// ```
274#[must_use = "expensive computation whose result should not be discarded"]
275pub fn fosr(data: &FdMatrix, predictors: &FdMatrix, lambda: f64) -> Result<FosrResult, FdarError> {
276    let (n, m) = data.shape();
277    let p = predictors.ncols();
278    if m == 0 {
279        return Err(FdarError::InvalidDimension {
280            parameter: "data",
281            expected: "at least 1 column (grid points)".to_string(),
282            actual: "0 columns".to_string(),
283        });
284    }
285    if predictors.nrows() != n {
286        return Err(FdarError::InvalidDimension {
287            parameter: "predictors",
288            expected: format!("{n} rows (matching data)"),
289            actual: format!("{} rows", predictors.nrows()),
290        });
291    }
292    if n < p + 2 {
293        return Err(FdarError::InvalidDimension {
294            parameter: "data",
295            expected: format!("at least {} observations (p + 2)", p + 2),
296            actual: format!("{n} observations"),
297        });
298    }
299
300    let design = build_fosr_design(predictors, n);
301    let p_total = design.ncols();
302    let xtx = compute_xtx(&design);
303    let xty = compute_xty_matrix(&design, data);
304    let penalty = penalty_matrix(p_total);
305
306    let lambda = if lambda < 0.0 {
307        select_lambda_gcv(&xtx, &xty, &penalty, data, &design)
308    } else {
309        lambda
310    };
311
312    let beta = penalized_solve(&xtx, &xty, &penalty, lambda)?;
313    let (fitted, residuals) = compute_fosr_fitted(&design, &beta, data);
314
315    let r_squared_t = pointwise_r_squared(data, &fitted);
316    let r_squared = r_squared_t.iter().sum::<f64>() / m as f64;
317    let beta_se = compute_beta_se(&xtx, &penalty, lambda, &residuals, p_total, n);
318    let trace_h = compute_trace_hat(&xtx, &penalty, lambda, p_total, n);
319    let gcv = compute_fosr_gcv(&residuals, trace_h);
320
321    let intercept: Vec<f64> = (0..m).map(|t| beta[(0, t)]).collect();
322
323    Ok(FosrResult {
324        intercept,
325        beta: drop_intercept_rows(&beta, p, m),
326        fitted,
327        residuals,
328        r_squared_t,
329        r_squared,
330        beta_se: drop_intercept_rows(&beta_se, p, m),
331        lambda,
332        gcv,
333    })
334}
335
336/// Compute fitted values Ŷ = X β and residuals.
337fn compute_fosr_fitted(
338    design: &FdMatrix,
339    beta: &FdMatrix,
340    data: &FdMatrix,
341) -> (FdMatrix, FdMatrix) {
342    let (n, m) = data.shape();
343    let p_total = design.ncols();
344    let rows: Vec<(Vec<f64>, Vec<f64>)> = iter_maybe_parallel!(0..n)
345        .map(|i| {
346            let mut fitted_row = vec![0.0; m];
347            let mut resid_row = vec![0.0; m];
348            for t in 0..m {
349                let mut yhat = 0.0;
350                for j in 0..p_total {
351                    yhat += design[(i, j)] * beta[(j, t)];
352                }
353                fitted_row[t] = yhat;
354                resid_row[t] = data[(i, t)] - yhat;
355            }
356            (fitted_row, resid_row)
357        })
358        .collect();
359    let mut fitted = FdMatrix::zeros(n, m);
360    let mut residuals = FdMatrix::zeros(n, m);
361    for (i, (fr, rr)) in rows.into_iter().enumerate() {
362        for t in 0..m {
363            fitted[(i, t)] = fr[t];
364            residuals[(i, t)] = rr[t];
365        }
366    }
367    (fitted, residuals)
368}
369
370/// Select smoothing parameter λ via GCV on a grid.
371fn select_lambda_gcv(
372    xtx: &[f64],
373    xty: &FdMatrix,
374    penalty: &[f64],
375    data: &FdMatrix,
376    design: &FdMatrix,
377) -> f64 {
378    let lambdas = [0.0, 1e-6, 1e-4, 1e-2, 0.1, 1.0, 10.0, 100.0, 1000.0];
379    let p_total = design.ncols();
380    let n = design.nrows();
381
382    let mut best_lambda = 0.0;
383    let mut best_gcv = f64::INFINITY;
384
385    for &lam in &lambdas {
386        let Ok(beta) = penalized_solve(xtx, xty, penalty, lam) else {
387            continue;
388        };
389        let (_, residuals) = compute_fosr_fitted(design, &beta, data);
390        let trace_h = compute_trace_hat(xtx, penalty, lam, p_total, n);
391        let gcv = compute_fosr_gcv(&residuals, trace_h);
392        if gcv < best_gcv {
393            best_gcv = gcv;
394            best_lambda = lam;
395        }
396    }
397    best_lambda
398}
399
400/// Compute trace of hat matrix: tr(H) = tr(X (X'X + λP)^{-1} X') = Σ_j h_jj.
401fn compute_trace_hat(xtx: &[f64], penalty: &[f64], lambda: f64, p: usize, n: usize) -> f64 {
402    let mut a = vec![0.0; p * p];
403    for i in 0..p * p {
404        a[i] = xtx[i] + lambda * penalty[i];
405    }
406    // tr(H) = tr(X A^{-1} X') = Σ_{j=0..p} a^{-1}_{jj} * xtx_{jj}
407    // More precisely: tr(X (X'X+λP)^{-1} X') = tr((X'X+λP)^{-1} X'X)
408    let Ok(l) = cholesky_factor(&a, p) else {
409        return p as f64; // fallback
410    };
411
412    // Compute A^{-1} X'X via solving A Z = X'X column by column, then trace
413    let mut trace = 0.0;
414    for j in 0..p {
415        let col: Vec<f64> = (0..p).map(|i| xtx[i * p + j]).collect();
416        let z = cholesky_forward_back(&l, &col, p);
417        trace += z[j]; // diagonal element of A^{-1} X'X
418    }
419    trace.min(n as f64)
420}
421
422/// Compute pointwise standard errors for β(t).
423fn compute_beta_se(
424    xtx: &[f64],
425    penalty: &[f64],
426    lambda: f64,
427    residuals: &FdMatrix,
428    p: usize,
429    n: usize,
430) -> FdMatrix {
431    let m = residuals.ncols();
432    let mut a = vec![0.0; p * p];
433    for i in 0..p * p {
434        a[i] = xtx[i] + lambda * penalty[i];
435    }
436    let Ok(l) = cholesky_factor(&a, p) else {
437        return FdMatrix::zeros(p, m);
438    };
439
440    // Diagonal of A^{-1}
441    let a_inv_diag: Vec<f64> = (0..p)
442        .map(|j| {
443            let mut ej = vec![0.0; p];
444            ej[j] = 1.0;
445            let v = cholesky_forward_back(&l, &ej, p);
446            v[j]
447        })
448        .collect();
449
450    let df = (n - p).max(1) as f64;
451    let mut se = FdMatrix::zeros(p, m);
452    for t in 0..m {
453        let sigma2_t: f64 = (0..n).map(|i| residuals[(i, t)].powi(2)).sum::<f64>() / df;
454        for j in 0..p {
455            se[(j, t)] = (sigma2_t * a_inv_diag[j]).max(0.0).sqrt();
456        }
457    }
458    se
459}
460
461// ---------------------------------------------------------------------------
462// fosr_fpc: FPC-based function-on-scalar regression (matches R's fda.usc approach)
463// ---------------------------------------------------------------------------
464
465/// OLS regression of each FPC score on the design matrix.
466///
467/// Returns gamma_all\[comp\]\[coef\] = (X'X)^{-1} X' scores\[:,comp\].
468fn regress_scores_on_design(
469    design: &FdMatrix,
470    scores: &FdMatrix,
471    n: usize,
472    k: usize,
473    p_total: usize,
474) -> Result<Vec<Vec<f64>>, FdarError> {
475    let xtx = compute_xtx(design);
476    let l = cholesky_factor(&xtx, p_total)?;
477
478    let gamma_all: Vec<Vec<f64>> = (0..k)
479        .map(|comp| {
480            let mut xts = vec![0.0; p_total];
481            for j in 0..p_total {
482                for i in 0..n {
483                    xts[j] += design[(i, j)] * scores[(i, comp)];
484                }
485            }
486            cholesky_forward_back(&l, &xts, p_total)
487        })
488        .collect();
489    Ok(gamma_all)
490}
491
492/// Reconstruct β_j(t) = Σ_k gamma\[comp\]\[1+j\] · φ_k(t) for each predictor j.
493fn reconstruct_beta_fpc(
494    gamma_all: &[Vec<f64>],
495    rotation: &FdMatrix,
496    p: usize,
497    k: usize,
498    m: usize,
499) -> FdMatrix {
500    let mut beta = FdMatrix::zeros(p, m);
501    for j in 0..p {
502        for t in 0..m {
503            let mut val = 0.0;
504            for comp in 0..k {
505                val += gamma_all[comp][1 + j] * rotation[(t, comp)];
506            }
507            beta[(j, t)] = val;
508        }
509    }
510    beta
511}
512
513/// Compute intercept function: μ(t) + Σ_k γ_intercept\[k\] · φ_k(t).
514fn compute_intercept_fpc(
515    mean: &[f64],
516    gamma_all: &[Vec<f64>],
517    rotation: &FdMatrix,
518    k: usize,
519    m: usize,
520) -> Vec<f64> {
521    let mut intercept = mean.to_vec();
522    for t in 0..m {
523        for comp in 0..k {
524            intercept[t] += gamma_all[comp][0] * rotation[(t, comp)];
525        }
526    }
527    intercept
528}
529
530/// Extract L²-normalized beta_scores from regression coefficients.
531fn extract_beta_scores(gamma_all: &[Vec<f64>], p: usize, k: usize, m: usize) -> Vec<Vec<f64>> {
532    let h = if m > 1 { 1.0 / (m - 1) as f64 } else { 1.0 };
533    let score_scale = h.sqrt();
534    (0..p)
535        .map(|j| {
536            (0..k)
537                .map(|comp| gamma_all[comp][1 + j] * score_scale)
538                .collect()
539        })
540        .collect()
541}
542
543/// FPC-based function-on-scalar regression.
544///
545/// Reduces the functional response to FPC scores, regresses each score on the
546/// scalar predictors via OLS, then reconstructs β(t) from the loadings.
547/// This matches R's `fdata2pc` + `lm(scores ~ x)` approach.
548///
549/// # Arguments
550/// * `data` - Functional response matrix (n × m)
551/// * `predictors` - Scalar predictor matrix (n × p)
552/// * `ncomp` - Number of FPC components to use
553///
554/// # Errors
555///
556/// Returns [`FdarError::InvalidDimension`] if `data` has zero columns,
557/// `predictors` row count does not match `data`, or `n < p + 2`.
558/// Returns [`FdarError::InvalidParameter`] if `ncomp` is zero.
559/// Returns [`FdarError::ComputationFailed`] if FPCA fails or the OLS
560/// Cholesky factorization of X'X is singular.
561#[must_use = "expensive computation whose result should not be discarded"]
562pub fn fosr_fpc(
563    data: &FdMatrix,
564    predictors: &FdMatrix,
565    ncomp: usize,
566) -> Result<FosrFpcResult, FdarError> {
567    let (n, m) = data.shape();
568    let p = predictors.ncols();
569    if m == 0 {
570        return Err(FdarError::InvalidDimension {
571            parameter: "data",
572            expected: "at least 1 column (grid points)".to_string(),
573            actual: "0 columns".to_string(),
574        });
575    }
576    if predictors.nrows() != n {
577        return Err(FdarError::InvalidDimension {
578            parameter: "predictors",
579            expected: format!("{n} rows (matching data)"),
580            actual: format!("{} rows", predictors.nrows()),
581        });
582    }
583    if n < p + 2 {
584        return Err(FdarError::InvalidDimension {
585            parameter: "data",
586            expected: format!("at least {} observations (p + 2)", p + 2),
587            actual: format!("{n} observations"),
588        });
589    }
590    if ncomp == 0 {
591        return Err(FdarError::InvalidParameter {
592            parameter: "ncomp",
593            message: "number of FPC components must be at least 1".to_string(),
594        });
595    }
596
597    let argvals: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1).max(1) as f64).collect();
598    let fpca = fdata_to_pc_1d(data, ncomp, &argvals)?;
599    let k = fpca.scores.ncols();
600    let p_total = p + 1;
601    let design = build_fosr_design(predictors, n);
602
603    let gamma_all = regress_scores_on_design(&design, &fpca.scores, n, k, p_total)?;
604    let beta = reconstruct_beta_fpc(&gamma_all, &fpca.rotation, p, k, m);
605    let intercept = compute_intercept_fpc(&fpca.mean, &gamma_all, &fpca.rotation, k, m);
606
607    let (fitted, residuals) = compute_fosr_fpc_fitted(data, &intercept, &beta, predictors);
608    let r_squared_t = pointwise_r_squared(data, &fitted);
609    let r_squared = r_squared_t.iter().sum::<f64>() / m as f64;
610    let beta_scores = extract_beta_scores(&gamma_all, p, k, m);
611
612    Ok(FosrFpcResult {
613        intercept,
614        beta,
615        fitted,
616        residuals,
617        r_squared_t,
618        r_squared,
619        beta_scores,
620        ncomp: k,
621    })
622}
623
624/// Compute fitted values and residuals for FPC-based FOSR.
625fn compute_fosr_fpc_fitted(
626    data: &FdMatrix,
627    intercept: &[f64],
628    beta: &FdMatrix,
629    predictors: &FdMatrix,
630) -> (FdMatrix, FdMatrix) {
631    let (n, m) = data.shape();
632    let p = predictors.ncols();
633    let mut fitted = FdMatrix::zeros(n, m);
634    let mut residuals = FdMatrix::zeros(n, m);
635    for i in 0..n {
636        for t in 0..m {
637            let mut yhat = intercept[t];
638            for j in 0..p {
639                yhat += predictors[(i, j)] * beta[(j, t)];
640            }
641            fitted[(i, t)] = yhat;
642            residuals[(i, t)] = data[(i, t)] - yhat;
643        }
644    }
645    (fitted, residuals)
646}
647
648/// Predict new functional responses from a fitted FOSR model.
649///
650/// # Arguments
651/// * `result` - Fitted [`FosrResult`]
652/// * `new_predictors` - New scalar predictors (n_new × p)
653#[must_use = "prediction result should not be discarded"]
654pub fn predict_fosr(result: &FosrResult, new_predictors: &FdMatrix) -> FdMatrix {
655    let n_new = new_predictors.nrows();
656    let m = result.intercept.len();
657    let p = result.beta.nrows();
658
659    let mut predicted = FdMatrix::zeros(n_new, m);
660    for i in 0..n_new {
661        for t in 0..m {
662            let mut yhat = result.intercept[t];
663            for j in 0..p {
664                yhat += new_predictors[(i, j)] * result.beta[(j, t)];
665            }
666            predicted[(i, t)] = yhat;
667        }
668    }
669    predicted
670}
671
672// ---------------------------------------------------------------------------
673// fanova: Functional ANOVA
674// ---------------------------------------------------------------------------
675
676/// Compute group means and overall mean.
677fn compute_group_means(
678    data: &FdMatrix,
679    groups: &[usize],
680    labels: &[usize],
681) -> (FdMatrix, Vec<f64>) {
682    let (n, m) = data.shape();
683    let k = labels.len();
684    let mut group_means = FdMatrix::zeros(k, m);
685    let mut counts = vec![0usize; k];
686
687    for i in 0..n {
688        let g = labels.iter().position(|&l| l == groups[i]).unwrap_or(0);
689        counts[g] += 1;
690        for t in 0..m {
691            group_means[(g, t)] += data[(i, t)];
692        }
693    }
694    for g in 0..k {
695        if counts[g] > 0 {
696            for t in 0..m {
697                group_means[(g, t)] /= counts[g] as f64;
698            }
699        }
700    }
701
702    let overall_mean: Vec<f64> = (0..m)
703        .map(|t| (0..n).map(|i| data[(i, t)]).sum::<f64>() / n as f64)
704        .collect();
705
706    (group_means, overall_mean)
707}
708
709/// Compute pointwise F-statistic.
710fn pointwise_f_statistic(
711    data: &FdMatrix,
712    groups: &[usize],
713    labels: &[usize],
714    group_means: &FdMatrix,
715    overall_mean: &[f64],
716) -> Vec<f64> {
717    let (n, m) = data.shape();
718    let k = labels.len();
719    let mut counts = vec![0usize; k];
720    for &g in groups {
721        let idx = labels.iter().position(|&l| l == g).unwrap_or(0);
722        counts[idx] += 1;
723    }
724
725    (0..m)
726        .map(|t| {
727            let ss_between: f64 = (0..k)
728                .map(|g| counts[g] as f64 * (group_means[(g, t)] - overall_mean[t]).powi(2))
729                .sum();
730            let ss_within: f64 = (0..n)
731                .map(|i| {
732                    let g = labels.iter().position(|&l| l == groups[i]).unwrap_or(0);
733                    (data[(i, t)] - group_means[(g, t)]).powi(2)
734                })
735                .sum();
736            let ms_between = ss_between / (k as f64 - 1.0).max(1.0);
737            let ms_within = ss_within / (n as f64 - k as f64).max(1.0);
738            if ms_within > 1e-15 {
739                ms_between / ms_within
740            } else {
741                0.0
742            }
743        })
744        .collect()
745}
746
747/// Compute global test statistic (integrated F).
748fn global_f_statistic(f_t: &[f64]) -> f64 {
749    f_t.iter().sum::<f64>() / f_t.len() as f64
750}
751
752/// Functional ANOVA: test whether groups have different mean curves.
753///
754/// Uses a permutation-based global test with the integrated F-statistic.
755///
756/// # Arguments
757/// * `data` - Functional response matrix (n × m)
758/// * `groups` - Group labels for each observation (length n, integer-coded)
759/// * `n_perm` - Number of permutations for the global test
760///
761/// # Returns
762/// [`FanovaResult`] with group means, F-statistics, and permutation p-value
763///
764/// # Errors
765///
766/// Returns [`FdarError::InvalidDimension`] if `data` has zero columns,
767/// `groups.len()` does not match the number of rows in `data`, or `n < 3`.
768/// Returns [`FdarError::InvalidParameter`] if fewer than 2 distinct groups
769/// are present.
770#[must_use = "expensive computation whose result should not be discarded"]
771pub fn fanova(data: &FdMatrix, groups: &[usize], n_perm: usize) -> Result<FanovaResult, FdarError> {
772    let (n, m) = data.shape();
773    if m == 0 {
774        return Err(FdarError::InvalidDimension {
775            parameter: "data",
776            expected: "at least 1 column (grid points)".to_string(),
777            actual: "0 columns".to_string(),
778        });
779    }
780    if groups.len() != n {
781        return Err(FdarError::InvalidDimension {
782            parameter: "groups",
783            expected: format!("{n} elements (matching data rows)"),
784            actual: format!("{} elements", groups.len()),
785        });
786    }
787    if n < 3 {
788        return Err(FdarError::InvalidDimension {
789            parameter: "data",
790            expected: "at least 3 observations".to_string(),
791            actual: format!("{n} observations"),
792        });
793    }
794
795    let mut labels: Vec<usize> = groups.to_vec();
796    labels.sort_unstable();
797    labels.dedup();
798    let n_groups = labels.len();
799    if n_groups < 2 {
800        return Err(FdarError::InvalidParameter {
801            parameter: "groups",
802            message: format!("at least 2 distinct groups required, but only {n_groups} found"),
803        });
804    }
805
806    let (group_means, overall_mean) = compute_group_means(data, groups, &labels);
807    let f_t = pointwise_f_statistic(data, groups, &labels, &group_means, &overall_mean);
808    let observed_stat = global_f_statistic(&f_t);
809
810    // Permutation test
811    let n_perm = n_perm.max(1);
812    let mut n_ge = 0usize;
813    let mut perm_groups = groups.to_vec();
814
815    // Simple LCG for reproducibility without requiring rand
816    let mut rng_state: u64 = 42;
817    for _ in 0..n_perm {
818        // Fisher-Yates shuffle with LCG
819        for i in (1..n).rev() {
820            rng_state = rng_state
821                .wrapping_mul(6_364_136_223_846_793_005)
822                .wrapping_add(1);
823            let j = (rng_state >> 33) as usize % (i + 1);
824            perm_groups.swap(i, j);
825        }
826
827        let (perm_means, perm_overall) = compute_group_means(data, &perm_groups, &labels);
828        let perm_f = pointwise_f_statistic(data, &perm_groups, &labels, &perm_means, &perm_overall);
829        let perm_stat = global_f_statistic(&perm_f);
830        if perm_stat >= observed_stat {
831            n_ge += 1;
832        }
833    }
834
835    let p_value = (n_ge as f64 + 1.0) / (n_perm as f64 + 1.0);
836
837    Ok(FanovaResult {
838        group_means,
839        overall_mean,
840        f_statistic_t: f_t,
841        global_statistic: observed_stat,
842        p_value,
843        n_perm,
844        n_groups,
845        group_labels: labels,
846    })
847}
848
849impl FosrResult {
850    /// Predict functional responses for new predictors. Delegates to [`predict_fosr`].
851    pub fn predict(&self, new_predictors: &FdMatrix) -> FdMatrix {
852        predict_fosr(self, new_predictors)
853    }
854}
855
856// ---------------------------------------------------------------------------
857// Tests
858// ---------------------------------------------------------------------------
859
860#[cfg(test)]
861mod tests {
862    use super::*;
863    use crate::test_helpers::uniform_grid;
864    use std::f64::consts::PI;
865
866    fn generate_fosr_data(n: usize, m: usize) -> (FdMatrix, FdMatrix) {
867        let t = uniform_grid(m);
868        let mut y = FdMatrix::zeros(n, m);
869        let mut z = FdMatrix::zeros(n, 2);
870
871        for i in 0..n {
872            let age = (i as f64) / (n as f64);
873            let group = if i % 2 == 0 { 1.0 } else { 0.0 };
874            z[(i, 0)] = age;
875            z[(i, 1)] = group;
876            for j in 0..m {
877                // True model: μ(t) + age * β₁(t) + group * β₂(t)
878                let mu = (2.0 * PI * t[j]).sin();
879                let beta1 = t[j]; // Linear coefficient for age
880                let beta2 = (4.0 * PI * t[j]).cos(); // Oscillating for group
881                y[(i, j)] = mu
882                    + age * beta1
883                    + group * beta2
884                    + 0.05 * ((i * 13 + j * 7) % 100) as f64 / 100.0;
885            }
886        }
887        (y, z)
888    }
889
890    // ----- FOSR tests -----
891
892    #[test]
893    fn test_fosr_basic() {
894        let (y, z) = generate_fosr_data(30, 50);
895        let result = fosr(&y, &z, 0.0);
896        assert!(result.is_ok());
897        let fit = result.unwrap();
898        assert_eq!(fit.intercept.len(), 50);
899        assert_eq!(fit.beta.shape(), (2, 50));
900        assert_eq!(fit.fitted.shape(), (30, 50));
901        assert_eq!(fit.residuals.shape(), (30, 50));
902        assert!(fit.r_squared >= 0.0);
903    }
904
905    #[test]
906    fn test_fosr_with_penalty() {
907        let (y, z) = generate_fosr_data(30, 50);
908        let fit0 = fosr(&y, &z, 0.0).unwrap();
909        let fit1 = fosr(&y, &z, 1.0).unwrap();
910        // Both should produce valid results
911        assert_eq!(fit0.beta.shape(), (2, 50));
912        assert_eq!(fit1.beta.shape(), (2, 50));
913    }
914
915    #[test]
916    fn test_fosr_auto_lambda() {
917        let (y, z) = generate_fosr_data(30, 50);
918        let fit = fosr(&y, &z, -1.0).unwrap();
919        assert!(fit.lambda >= 0.0);
920    }
921
922    #[test]
923    fn test_fosr_fitted_plus_residuals_equals_y() {
924        let (y, z) = generate_fosr_data(30, 50);
925        let fit = fosr(&y, &z, 0.0).unwrap();
926        for i in 0..30 {
927            for t in 0..50 {
928                let reconstructed = fit.fitted[(i, t)] + fit.residuals[(i, t)];
929                assert!(
930                    (reconstructed - y[(i, t)]).abs() < 1e-10,
931                    "ŷ + r should equal y at ({}, {})",
932                    i,
933                    t
934                );
935            }
936        }
937    }
938
939    #[test]
940    fn test_fosr_pointwise_r_squared_valid() {
941        let (y, z) = generate_fosr_data(30, 50);
942        let fit = fosr(&y, &z, 0.0).unwrap();
943        for &r2 in &fit.r_squared_t {
944            assert!(
945                (-0.01..=1.0 + 1e-10).contains(&r2),
946                "R²(t) out of range: {}",
947                r2
948            );
949        }
950    }
951
952    #[test]
953    fn test_fosr_se_positive() {
954        let (y, z) = generate_fosr_data(30, 50);
955        let fit = fosr(&y, &z, 0.0).unwrap();
956        for j in 0..2 {
957            for t in 0..50 {
958                assert!(
959                    fit.beta_se[(j, t)] >= 0.0 && fit.beta_se[(j, t)].is_finite(),
960                    "SE should be non-negative finite"
961                );
962            }
963        }
964    }
965
966    #[test]
967    fn test_fosr_invalid_input() {
968        let y = FdMatrix::zeros(2, 50);
969        let z = FdMatrix::zeros(2, 1);
970        assert!(fosr(&y, &z, 0.0).is_err());
971    }
972
973    // ----- predict_fosr tests -----
974
975    #[test]
976    fn test_predict_fosr_on_training_data() {
977        let (y, z) = generate_fosr_data(30, 50);
978        let fit = fosr(&y, &z, 0.0).unwrap();
979        let preds = predict_fosr(&fit, &z);
980        assert_eq!(preds.shape(), (30, 50));
981        for i in 0..30 {
982            for t in 0..50 {
983                assert!(
984                    (preds[(i, t)] - fit.fitted[(i, t)]).abs() < 1e-8,
985                    "Prediction on training data should match fitted"
986                );
987            }
988        }
989    }
990
991    // ----- FANOVA tests -----
992
993    #[test]
994    fn test_fanova_two_groups() {
995        let n = 40;
996        let m = 50;
997        let t = uniform_grid(m);
998
999        let mut data = FdMatrix::zeros(n, m);
1000        let mut groups = vec![0usize; n];
1001        for i in 0..n {
1002            groups[i] = if i < n / 2 { 0 } else { 1 };
1003            for j in 0..m {
1004                let base = (2.0 * PI * t[j]).sin();
1005                let effect = if groups[i] == 1 { 0.5 * t[j] } else { 0.0 };
1006                data[(i, j)] = base + effect + 0.01 * (i as f64 * 0.1).sin();
1007            }
1008        }
1009
1010        let result = fanova(&data, &groups, 200);
1011        assert!(result.is_ok());
1012        let res = result.unwrap();
1013        assert_eq!(res.n_groups, 2);
1014        assert_eq!(res.group_means.shape(), (2, m));
1015        assert_eq!(res.f_statistic_t.len(), m);
1016        assert!(res.p_value >= 0.0 && res.p_value <= 1.0);
1017        // With a real group effect, p should be small
1018        assert!(
1019            res.p_value < 0.1,
1020            "Should detect group effect, got p={}",
1021            res.p_value
1022        );
1023    }
1024
1025    #[test]
1026    fn test_fanova_no_effect() {
1027        let n = 40;
1028        let m = 50;
1029        let t = uniform_grid(m);
1030
1031        let mut data = FdMatrix::zeros(n, m);
1032        let mut groups = vec![0usize; n];
1033        for i in 0..n {
1034            groups[i] = if i < n / 2 { 0 } else { 1 };
1035            for j in 0..m {
1036                // Same distribution for both groups
1037                data[(i, j)] =
1038                    (2.0 * PI * t[j]).sin() + 0.1 * ((i * 7 + j * 3) % 100) as f64 / 100.0;
1039            }
1040        }
1041
1042        let result = fanova(&data, &groups, 200);
1043        assert!(result.is_ok());
1044        let res = result.unwrap();
1045        // Without group effect, p should be large
1046        assert!(
1047            res.p_value > 0.05,
1048            "Should not detect effect, got p={}",
1049            res.p_value
1050        );
1051    }
1052
1053    #[test]
1054    fn test_fanova_three_groups() {
1055        let n = 30;
1056        let m = 50;
1057        let t = uniform_grid(m);
1058
1059        let mut data = FdMatrix::zeros(n, m);
1060        let mut groups = vec![0usize; n];
1061        for i in 0..n {
1062            groups[i] = i % 3;
1063            for j in 0..m {
1064                let effect = match groups[i] {
1065                    0 => 0.0,
1066                    1 => 0.5 * t[j],
1067                    _ => -0.3 * (2.0 * PI * t[j]).cos(),
1068                };
1069                data[(i, j)] = (2.0 * PI * t[j]).sin() + effect + 0.01 * (i as f64 * 0.1).sin();
1070            }
1071        }
1072
1073        let result = fanova(&data, &groups, 200);
1074        assert!(result.is_ok());
1075        let res = result.unwrap();
1076        assert_eq!(res.n_groups, 3);
1077    }
1078
1079    #[test]
1080    fn test_fanova_invalid_input() {
1081        let data = FdMatrix::zeros(10, 50);
1082        let groups = vec![0; 10]; // Only one group
1083        assert!(fanova(&data, &groups, 100).is_err());
1084
1085        let groups = vec![0; 5]; // Wrong length
1086        assert!(fanova(&data, &groups, 100).is_err());
1087    }
1088}