Skip to main content

fdars_core/
famm.rs

1//! Functional Additive Mixed Models (FAMM).
2//!
3//! Implements functional mixed effects models for repeated functional
4//! measurements with subject-level covariates.
5//!
6//! Model: `Y_ij(t) = μ(t) + X_i'β(t) + b_i(t) + ε_ij(t)`
7//!
8//! Key functions:
9//! - [`fmm`] — Fit a functional mixed model via FPC decomposition
10//! - [`fmm_predict`] — Predict curves for new subjects
11//! - [`fmm_test_fixed`] — Hypothesis test on fixed effects
12
13use crate::error::FdarError;
14use crate::iter_maybe_parallel;
15use crate::matrix::FdMatrix;
16use crate::regression::fdata_to_pc_1d;
17#[cfg(feature = "parallel")]
18use rayon::iter::ParallelIterator;
19
20/// Result of a functional mixed model fit.
21#[derive(Debug, Clone, PartialEq)]
22pub struct FmmResult {
23    /// Overall mean function μ̂(t) (length m)
24    pub mean_function: Vec<f64>,
25    /// Fixed effect coefficient functions β̂_j(t) (p × m matrix, one row per covariate)
26    pub beta_functions: FdMatrix,
27    /// Random effect functions b̂_i(t) per subject (n_subjects × m)
28    pub random_effects: FdMatrix,
29    /// Fitted values for all observations (n_total × m)
30    pub fitted: FdMatrix,
31    /// Residuals (n_total × m)
32    pub residuals: FdMatrix,
33    /// Variance of random effects at each time point (length m)
34    pub random_variance: Vec<f64>,
35    /// Residual variance estimate
36    pub sigma2_eps: f64,
37    /// Random effect variance estimate (per-component)
38    pub sigma2_u: Vec<f64>,
39    /// Number of FPC components used
40    pub ncomp: usize,
41    /// Number of subjects
42    pub n_subjects: usize,
43    /// FPC eigenvalues (singular values squared / n)
44    pub eigenvalues: Vec<f64>,
45}
46
47/// Result of fixed effect hypothesis test.
48#[derive(Debug, Clone, PartialEq)]
49pub struct FmmTestResult {
50    /// F-statistic per covariate (length p)
51    pub f_statistics: Vec<f64>,
52    /// P-values per covariate (via permutation, length p)
53    pub p_values: Vec<f64>,
54}
55
56// ---------------------------------------------------------------------------
57// Core FMM algorithm
58// ---------------------------------------------------------------------------
59
60/// Fit a functional mixed model via FPC decomposition.
61///
62/// # Arguments
63/// * `data` — All observed curves (n_total × m), stacked across subjects and visits
64/// * `subject_ids` — Subject identifier for each curve (length n_total)
65/// * `covariates` — Subject-level covariates (n_total × p).
66///   Each row corresponds to the same curve in `data`.
67///   If a covariate is subject-level, its value should be repeated across visits.
68/// * `ncomp` — Number of FPC components
69///
70/// # Algorithm
71/// 1. Pool curves, compute FPCA
72/// 2. For each FPC score, fit scalar mixed model: ξ_ijk = x_i'γ_k + u_ik + e_ijk
73/// 3. Recover β̂(t) and b̂_i(t) from component coefficients
74///
75/// # Errors
76///
77/// Returns [`FdarError::InvalidDimension`] if `data` is empty (zero rows or
78/// columns), or if `subject_ids.len()` does not match the number of rows.
79/// Returns [`FdarError::InvalidParameter`] if `ncomp` is zero.
80/// Returns [`FdarError::ComputationFailed`] if the underlying FPCA fails.
81#[must_use = "expensive computation whose result should not be discarded"]
82pub fn fmm(
83    data: &FdMatrix,
84    subject_ids: &[usize],
85    covariates: Option<&FdMatrix>,
86    ncomp: usize,
87) -> Result<FmmResult, FdarError> {
88    let n_total = data.nrows();
89    let m = data.ncols();
90    if n_total == 0 || m == 0 {
91        return Err(FdarError::InvalidDimension {
92            parameter: "data",
93            expected: "non-empty matrix".to_string(),
94            actual: format!("{n_total} x {m}"),
95        });
96    }
97    if subject_ids.len() != n_total {
98        return Err(FdarError::InvalidDimension {
99            parameter: "subject_ids",
100            expected: format!("length {n_total}"),
101            actual: format!("length {}", subject_ids.len()),
102        });
103    }
104    if ncomp == 0 {
105        return Err(FdarError::InvalidParameter {
106            parameter: "ncomp",
107            message: "must be >= 1".to_string(),
108        });
109    }
110
111    // Determine unique subjects
112    let (subject_map, n_subjects) = build_subject_map(subject_ids);
113
114    // Step 1: FPCA on pooled data
115    let fpca = fdata_to_pc_1d(data, ncomp)?;
116    let k = fpca.scores.ncols(); // actual number of components
117
118    // Step 2: For each FPC score, fit scalar mixed model (parallelized)
119    let p = covariates.map_or(0, super::matrix::FdMatrix::ncols);
120    let ComponentResults {
121        gamma,
122        u_hat,
123        sigma2_u,
124        sigma2_eps,
125    } = fit_all_components(
126        &fpca.scores,
127        &subject_map,
128        n_subjects,
129        covariates,
130        p,
131        k,
132        n_total,
133        m,
134    );
135
136    // Step 3: Recover functional coefficients (using gamma in original scale)
137    let beta_functions = recover_beta_functions(&gamma, &fpca.rotation, p, m, k);
138    let random_effects = recover_random_effects(&u_hat, &fpca.rotation, n_subjects, m, k);
139
140    // Compute random variance function: Var(b_i(t)) across subjects
141    let random_variance = compute_random_variance(&random_effects, n_subjects, m);
142
143    // Compute fitted and residuals
144    let (fitted, residuals) = compute_fitted_residuals(
145        data,
146        &fpca.mean,
147        &beta_functions,
148        &random_effects,
149        covariates,
150        &subject_map,
151        n_total,
152        m,
153        p,
154    );
155
156    let eigenvalues: Vec<f64> = fpca
157        .singular_values
158        .iter()
159        .map(|&sv| sv * sv / n_total as f64)
160        .collect();
161
162    Ok(FmmResult {
163        mean_function: fpca.mean,
164        beta_functions,
165        random_effects,
166        fitted,
167        residuals,
168        random_variance,
169        sigma2_eps,
170        sigma2_u,
171        ncomp: k,
172        n_subjects,
173        eigenvalues,
174    })
175}
176
177/// Build mapping from observation index to subject index (0..n_subjects-1).
178fn build_subject_map(subject_ids: &[usize]) -> (Vec<usize>, usize) {
179    let mut unique_ids: Vec<usize> = subject_ids.to_vec();
180    unique_ids.sort_unstable();
181    unique_ids.dedup();
182    let n_subjects = unique_ids.len();
183
184    let map: Vec<usize> = subject_ids
185        .iter()
186        .map(|id| unique_ids.iter().position(|u| u == id).unwrap_or(0))
187        .collect();
188
189    (map, n_subjects)
190}
191
192/// Aggregated results from fitting all FPC components.
193struct ComponentResults {
194    gamma: Vec<Vec<f64>>, // gamma[j][k] = fixed effect coeff j for component k
195    u_hat: Vec<Vec<f64>>, // u_hat[i][k] = random effect for subject i, component k
196    sigma2_u: Vec<f64>,   // per-component random effect variance
197    sigma2_eps: f64,      // average residual variance across components
198}
199
200/// Fit scalar mixed models for all FPC components (parallelized across components).
201///
202/// For each component k, scales FPC scores to L²-normalized space, fits a scalar
203/// mixed model, then scales coefficients back to the original score space.
204#[allow(clippy::too_many_arguments)]
205fn fit_all_components(
206    scores: &FdMatrix,
207    subject_map: &[usize],
208    n_subjects: usize,
209    covariates: Option<&FdMatrix>,
210    p: usize,
211    k: usize,
212    n_total: usize,
213    m: usize,
214) -> ComponentResults {
215    // Normalize scores by sqrt(h) to match R's L²-weighted FPCA convention.
216    // This ensures variance components are on the same scale as R's lmer().
217    let h = if m > 1 { 1.0 / (m - 1) as f64 } else { 1.0 };
218    let score_scale = h.sqrt();
219
220    // Fit each component independently — parallelized when the feature is enabled
221    let per_comp: Vec<ScalarMixedResult> = iter_maybe_parallel!(0..k)
222        .map(|comp| {
223            let comp_scores: Vec<f64> = (0..n_total)
224                .map(|i| scores[(i, comp)] * score_scale)
225                .collect();
226            fit_scalar_mixed_model(&comp_scores, subject_map, n_subjects, covariates, p)
227        })
228        .collect();
229
230    // Unpack per-component results into the aggregate structure
231    let mut gamma = vec![vec![0.0; k]; p];
232    let mut u_hat = vec![vec![0.0; k]; n_subjects];
233    let mut sigma2_u = vec![0.0; k];
234    let mut sigma2_eps_total = 0.0;
235
236    for (comp, result) in per_comp.iter().enumerate() {
237        for j in 0..p {
238            gamma[j][comp] = result.gamma[j] / score_scale;
239        }
240        for s in 0..n_subjects {
241            u_hat[s][comp] = result.u_hat[s] / score_scale;
242        }
243        sigma2_u[comp] = result.sigma2_u;
244        sigma2_eps_total += result.sigma2_eps;
245    }
246    let sigma2_eps = sigma2_eps_total / k as f64;
247
248    ComponentResults {
249        gamma,
250        u_hat,
251        sigma2_u,
252        sigma2_eps,
253    }
254}
255
256/// Scalar mixed model result for one FPC component.
257struct ScalarMixedResult {
258    gamma: Vec<f64>, // fixed effects (length p)
259    u_hat: Vec<f64>, // random effects per subject (length n_subjects)
260    sigma2_u: f64,   // random effect variance
261    sigma2_eps: f64, // residual variance
262}
263
264/// Precomputed subject structure for the mixed model.
265struct SubjectStructure {
266    counts: Vec<usize>,
267    obs: Vec<Vec<usize>>,
268}
269
270impl SubjectStructure {
271    fn new(subject_map: &[usize], n_subjects: usize, n: usize) -> Self {
272        let mut counts = vec![0usize; n_subjects];
273        let mut obs: Vec<Vec<usize>> = vec![Vec::new(); n_subjects];
274        for i in 0..n {
275            let s = subject_map[i];
276            counts[s] += 1;
277            obs[s].push(i);
278        }
279        Self { counts, obs }
280    }
281}
282
283/// Compute shrinkage weights: w_s = σ²_u / (σ²_u + σ²_e / n_s).
284fn shrinkage_weights(ss: &SubjectStructure, sigma2_u: f64, sigma2_e: f64) -> Vec<f64> {
285    ss.counts
286        .iter()
287        .map(|&c| {
288            let ns = c as f64;
289            if ns < 1.0 {
290                0.0
291            } else {
292                sigma2_u / (sigma2_u + sigma2_e / ns)
293            }
294        })
295        .collect()
296}
297
298/// GLS fixed effect update using block-diagonal V^{-1}.
299///
300/// Computes γ = (X'V⁻¹X)⁻¹ X'V⁻¹y exploiting the balanced random intercept structure.
301fn gls_update_gamma(
302    cov: &FdMatrix,
303    p: usize,
304    ss: &SubjectStructure,
305    weights: &[f64],
306    y: &[f64],
307    sigma2_e: f64,
308) -> Option<Vec<f64>> {
309    let n_subjects = ss.counts.len();
310    let mut xtvinvx = vec![0.0; p * p];
311    let mut xtvinvy = vec![0.0; p];
312    let inv_e = 1.0 / sigma2_e;
313
314    for s in 0..n_subjects {
315        let ns = ss.counts[s] as f64;
316        if ns < 1.0 {
317            continue;
318        }
319        let (x_sum, y_sum) = subject_sums(cov, y, &ss.obs[s], p);
320        accumulate_gls_terms(
321            cov,
322            y,
323            &ss.obs[s],
324            &x_sum,
325            y_sum,
326            weights[s],
327            ns,
328            inv_e,
329            p,
330            &mut xtvinvx,
331            &mut xtvinvy,
332        );
333    }
334
335    for j in 0..p {
336        xtvinvx[j * p + j] += 1e-10;
337    }
338    cholesky_solve(&xtvinvx, &xtvinvy, p)
339}
340
341/// Compute subject-level covariate sums and response sum.
342fn subject_sums(cov: &FdMatrix, y: &[f64], obs: &[usize], p: usize) -> (Vec<f64>, f64) {
343    let mut x_sum = vec![0.0; p];
344    let mut y_sum = 0.0;
345    for &i in obs {
346        for r in 0..p {
347            x_sum[r] += cov[(i, r)];
348        }
349        y_sum += y[i];
350    }
351    (x_sum, y_sum)
352}
353
354/// Accumulate X'V^{-1}X and X'V^{-1}y for one subject.
355fn accumulate_gls_terms(
356    cov: &FdMatrix,
357    y: &[f64],
358    obs: &[usize],
359    x_sum: &[f64],
360    y_sum: f64,
361    w_s: f64,
362    ns: f64,
363    inv_e: f64,
364    p: usize,
365    xtvinvx: &mut [f64],
366    xtvinvy: &mut [f64],
367) {
368    for &i in obs {
369        let vinv_y = inv_e * (y[i] - w_s * y_sum / ns);
370        for r in 0..p {
371            xtvinvy[r] += cov[(i, r)] * vinv_y;
372            for c in r..p {
373                let vinv_xc = inv_e * (cov[(i, c)] - w_s * x_sum[c] / ns);
374                let val = cov[(i, r)] * vinv_xc;
375                xtvinvx[r * p + c] += val;
376                if r != c {
377                    xtvinvx[c * p + r] += val;
378                }
379            }
380        }
381    }
382}
383
384/// REML EM update for variance components.
385///
386/// Returns (σ²_u_new, σ²_e_new) from the conditional expectations.
387/// Uses n - p divisor for σ²_e (REML correction where p = number of fixed effects).
388fn reml_variance_update(
389    residuals: &[f64],
390    ss: &SubjectStructure,
391    weights: &[f64],
392    sigma2_u: f64,
393    p: usize,
394) -> (f64, f64) {
395    let n_subjects = ss.counts.len();
396    let n: usize = ss.counts.iter().sum();
397    let mut sigma2_u_new = 0.0;
398    let mut sigma2_e_new = 0.0;
399
400    for s in 0..n_subjects {
401        let ns = ss.counts[s] as f64;
402        if ns < 1.0 {
403            continue;
404        }
405        let w_s = weights[s];
406        let mean_r_s: f64 = ss.obs[s].iter().map(|&i| residuals[i]).sum::<f64>() / ns;
407        let u_hat_s = w_s * mean_r_s;
408        let cond_var_s = sigma2_u * (1.0 - w_s);
409
410        sigma2_u_new += u_hat_s * u_hat_s + cond_var_s;
411        for &i in &ss.obs[s] {
412            sigma2_e_new += (residuals[i] - u_hat_s).powi(2);
413        }
414        sigma2_e_new += ns * cond_var_s;
415    }
416
417    // REML divisor: n - p for residual variance (matches R's lmer)
418    let denom_e = (n.saturating_sub(p)).max(1) as f64;
419
420    (
421        (sigma2_u_new / n_subjects as f64).max(1e-15),
422        (sigma2_e_new / denom_e).max(1e-15),
423    )
424}
425
426/// Fit scalar mixed model: y_ij = x_i'γ + u_i + e_ij.
427///
428/// Uses iterative GLS for fixed effects + REML EM for variance components,
429/// matching R's lmer() behavior. Initializes from Henderson's ANOVA, then
430/// iterates until convergence.
431fn fit_scalar_mixed_model(
432    y: &[f64],
433    subject_map: &[usize],
434    n_subjects: usize,
435    covariates: Option<&FdMatrix>,
436    p: usize,
437) -> ScalarMixedResult {
438    let n = y.len();
439    let ss = SubjectStructure::new(subject_map, n_subjects, n);
440
441    // Initialize from OLS + Henderson's ANOVA
442    let gamma_init = estimate_fixed_effects(y, covariates, p, n);
443    let residuals_init = compute_ols_residuals(y, covariates, &gamma_init, p, n);
444    let (mut sigma2_u, mut sigma2_e) =
445        estimate_variance_components(&residuals_init, subject_map, n_subjects, n);
446
447    if sigma2_e < 1e-15 {
448        sigma2_e = 1e-6;
449    }
450    if sigma2_u < 1e-15 {
451        sigma2_u = sigma2_e * 0.1;
452    }
453
454    let mut gamma = gamma_init;
455
456    for _iter in 0..50 {
457        let sigma2_u_old = sigma2_u;
458        let sigma2_e_old = sigma2_e;
459
460        let weights = shrinkage_weights(&ss, sigma2_u, sigma2_e);
461
462        if let Some(cov) = covariates.filter(|_| p > 0) {
463            if let Some(g) = gls_update_gamma(cov, p, &ss, &weights, y, sigma2_e) {
464                gamma = g;
465            }
466        }
467
468        let r = compute_ols_residuals(y, covariates, &gamma, p, n);
469        (sigma2_u, sigma2_e) = reml_variance_update(&r, &ss, &weights, sigma2_u, p);
470
471        let delta = (sigma2_u - sigma2_u_old).abs() + (sigma2_e - sigma2_e_old).abs();
472        if delta < 1e-10 * (sigma2_u_old + sigma2_e_old) {
473            break;
474        }
475    }
476
477    let final_residuals = compute_ols_residuals(y, covariates, &gamma, p, n);
478    let u_hat = compute_blup(
479        &final_residuals,
480        subject_map,
481        n_subjects,
482        sigma2_u,
483        sigma2_e,
484    );
485
486    ScalarMixedResult {
487        gamma,
488        u_hat,
489        sigma2_u,
490        sigma2_eps: sigma2_e,
491    }
492}
493
494/// OLS estimation of fixed effects.
495fn estimate_fixed_effects(
496    y: &[f64],
497    covariates: Option<&FdMatrix>,
498    p: usize,
499    n: usize,
500) -> Vec<f64> {
501    if p == 0 || covariates.is_none() {
502        return Vec::new();
503    }
504    let cov = covariates.expect("checked: covariates is Some");
505
506    // Solve (X'X)γ = X'y via Cholesky
507    let mut xtx = vec![0.0; p * p];
508    let mut xty = vec![0.0; p];
509    for i in 0..n {
510        for r in 0..p {
511            xty[r] += cov[(i, r)] * y[i];
512            for s in r..p {
513                let val = cov[(i, r)] * cov[(i, s)];
514                xtx[r * p + s] += val;
515                if r != s {
516                    xtx[s * p + r] += val;
517                }
518            }
519        }
520    }
521    // Regularize
522    for j in 0..p {
523        xtx[j * p + j] += 1e-8;
524    }
525
526    cholesky_solve(&xtx, &xty, p).unwrap_or(vec![0.0; p])
527}
528
529/// Cholesky factorization: A = LL'. Returns L (p×p flat row-major) or None if singular.
530fn cholesky_factor_famm(a: &[f64], p: usize) -> Option<Vec<f64>> {
531    let mut l = vec![0.0; p * p];
532    for j in 0..p {
533        let mut sum = 0.0;
534        for k in 0..j {
535            sum += l[j * p + k] * l[j * p + k];
536        }
537        let diag = a[j * p + j] - sum;
538        if diag <= 0.0 {
539            return None;
540        }
541        l[j * p + j] = diag.sqrt();
542        for i in (j + 1)..p {
543            let mut s = 0.0;
544            for k in 0..j {
545                s += l[i * p + k] * l[j * p + k];
546            }
547            l[i * p + j] = (a[i * p + j] - s) / l[j * p + j];
548        }
549    }
550    Some(l)
551}
552
553/// Solve L z = b (forward) then L' x = z (back).
554fn cholesky_triangular_solve(l: &[f64], b: &[f64], p: usize) -> Vec<f64> {
555    let mut z = vec![0.0; p];
556    for i in 0..p {
557        let mut s = 0.0;
558        for j in 0..i {
559            s += l[i * p + j] * z[j];
560        }
561        z[i] = (b[i] - s) / l[i * p + i];
562    }
563    for i in (0..p).rev() {
564        let mut s = 0.0;
565        for j in (i + 1)..p {
566            s += l[j * p + i] * z[j];
567        }
568        z[i] = (z[i] - s) / l[i * p + i];
569    }
570    z
571}
572
573/// Cholesky solve: A x = b where A is p×p symmetric positive definite.
574fn cholesky_solve(a: &[f64], b: &[f64], p: usize) -> Option<Vec<f64>> {
575    let l = cholesky_factor_famm(a, p)?;
576    Some(cholesky_triangular_solve(&l, b, p))
577}
578
579/// Compute OLS residuals: r = y - X*gamma.
580fn compute_ols_residuals(
581    y: &[f64],
582    covariates: Option<&FdMatrix>,
583    gamma: &[f64],
584    p: usize,
585    n: usize,
586) -> Vec<f64> {
587    let mut residuals = y.to_vec();
588    if p > 0 {
589        if let Some(cov) = covariates {
590            for i in 0..n {
591                for j in 0..p {
592                    residuals[i] -= cov[(i, j)] * gamma[j];
593                }
594            }
595        }
596    }
597    residuals
598}
599
600/// Estimate variance components via method of moments.
601///
602/// σ²_u and σ²_ε from one-way random effects ANOVA.
603fn estimate_variance_components(
604    residuals: &[f64],
605    subject_map: &[usize],
606    n_subjects: usize,
607    n: usize,
608) -> (f64, f64) {
609    // Compute subject means and within-subject SS
610    let mut subject_sums = vec![0.0; n_subjects];
611    let mut subject_counts = vec![0usize; n_subjects];
612    for i in 0..n {
613        let s = subject_map[i];
614        subject_sums[s] += residuals[i];
615        subject_counts[s] += 1;
616    }
617    let subject_means: Vec<f64> = subject_sums
618        .iter()
619        .zip(&subject_counts)
620        .map(|(&s, &c)| if c > 0 { s / c as f64 } else { 0.0 })
621        .collect();
622
623    // Within-subject SS
624    let mut ss_within = 0.0;
625    for i in 0..n {
626        let s = subject_map[i];
627        ss_within += (residuals[i] - subject_means[s]).powi(2);
628    }
629    let df_within = n.saturating_sub(n_subjects);
630
631    // Between-subject SS
632    let grand_mean = residuals.iter().sum::<f64>() / n as f64;
633    let mut ss_between = 0.0;
634    for s in 0..n_subjects {
635        ss_between += subject_counts[s] as f64 * (subject_means[s] - grand_mean).powi(2);
636    }
637
638    let sigma2_eps = if df_within > 0 {
639        ss_within / df_within as f64
640    } else {
641        1e-6
642    };
643
644    // Mean number of observations per subject
645    let n_bar = n as f64 / n_subjects.max(1) as f64;
646    let df_between = n_subjects.saturating_sub(1).max(1);
647    let ms_between = ss_between / df_between as f64;
648    let sigma2_u = ((ms_between - sigma2_eps) / n_bar).max(0.0);
649
650    (sigma2_u, sigma2_eps)
651}
652
653/// Compute BLUP (Best Linear Unbiased Prediction) for random effects.
654///
655/// û_i = σ²_u / (σ²_u + σ²_ε/n_i) * (ȳ_i - x̄_i'γ)
656fn compute_blup(
657    residuals: &[f64],
658    subject_map: &[usize],
659    n_subjects: usize,
660    sigma2_u: f64,
661    sigma2_eps: f64,
662) -> Vec<f64> {
663    let mut subject_sums = vec![0.0; n_subjects];
664    let mut subject_counts = vec![0usize; n_subjects];
665    for (i, &r) in residuals.iter().enumerate() {
666        let s = subject_map[i];
667        subject_sums[s] += r;
668        subject_counts[s] += 1;
669    }
670
671    (0..n_subjects)
672        .map(|s| {
673            let ni = subject_counts[s] as f64;
674            if ni < 1.0 {
675                return 0.0;
676            }
677            let mean_r = subject_sums[s] / ni;
678            let shrinkage = sigma2_u / (sigma2_u + sigma2_eps / ni).max(1e-15);
679            shrinkage * mean_r
680        })
681        .collect()
682}
683
684// ---------------------------------------------------------------------------
685// Recovery of functional coefficients
686// ---------------------------------------------------------------------------
687
688/// Recover β̂(t) = Σ_k γ̂_jk φ_k(t) for each covariate j.
689fn recover_beta_functions(
690    gamma: &[Vec<f64>],
691    rotation: &FdMatrix,
692    p: usize,
693    m: usize,
694    k: usize,
695) -> FdMatrix {
696    let mut beta = FdMatrix::zeros(p, m);
697    for j in 0..p {
698        for t in 0..m {
699            let mut val = 0.0;
700            for comp in 0..k {
701                val += gamma[j][comp] * rotation[(t, comp)];
702            }
703            beta[(j, t)] = val;
704        }
705    }
706    beta
707}
708
709/// Recover b̂_i(t) = Σ_k û_ik φ_k(t) for each subject i.
710fn recover_random_effects(
711    u_hat: &[Vec<f64>],
712    rotation: &FdMatrix,
713    n_subjects: usize,
714    m: usize,
715    k: usize,
716) -> FdMatrix {
717    let mut re = FdMatrix::zeros(n_subjects, m);
718    for s in 0..n_subjects {
719        for t in 0..m {
720            let mut val = 0.0;
721            for comp in 0..k {
722                val += u_hat[s][comp] * rotation[(t, comp)];
723            }
724            re[(s, t)] = val;
725        }
726    }
727    re
728}
729
730/// Compute random effect variance function: Var_i(b̂_i(t)).
731fn compute_random_variance(random_effects: &FdMatrix, n_subjects: usize, m: usize) -> Vec<f64> {
732    (0..m)
733        .map(|t| {
734            let mean: f64 =
735                (0..n_subjects).map(|s| random_effects[(s, t)]).sum::<f64>() / n_subjects as f64;
736            let var: f64 = (0..n_subjects)
737                .map(|s| (random_effects[(s, t)] - mean).powi(2))
738                .sum::<f64>()
739                / n_subjects.max(1) as f64;
740            var
741        })
742        .collect()
743}
744
745/// Compute fitted values and residuals.
746fn compute_fitted_residuals(
747    data: &FdMatrix,
748    mean_function: &[f64],
749    beta_functions: &FdMatrix,
750    random_effects: &FdMatrix,
751    covariates: Option<&FdMatrix>,
752    subject_map: &[usize],
753    n_total: usize,
754    m: usize,
755    p: usize,
756) -> (FdMatrix, FdMatrix) {
757    let mut fitted = FdMatrix::zeros(n_total, m);
758    let mut residuals = FdMatrix::zeros(n_total, m);
759
760    for i in 0..n_total {
761        let s = subject_map[i];
762        for t in 0..m {
763            let mut val = mean_function[t] + random_effects[(s, t)];
764            if p > 0 {
765                if let Some(cov) = covariates {
766                    for j in 0..p {
767                        val += cov[(i, j)] * beta_functions[(j, t)];
768                    }
769                }
770            }
771            fitted[(i, t)] = val;
772            residuals[(i, t)] = data[(i, t)] - val;
773        }
774    }
775
776    (fitted, residuals)
777}
778
779// ---------------------------------------------------------------------------
780// Prediction
781// ---------------------------------------------------------------------------
782
783/// Predict curves for new subjects.
784///
785/// # Arguments
786/// * `result` — Fitted FMM result
787/// * `new_covariates` — Covariates for new subjects (n_new × p)
788///
789/// Returns predicted curves (n_new × m) using only fixed effects (no random effects for new subjects).
790#[must_use = "prediction result should not be discarded"]
791pub fn fmm_predict(result: &FmmResult, new_covariates: Option<&FdMatrix>) -> FdMatrix {
792    let m = result.mean_function.len();
793    let n_new = new_covariates.map_or(1, super::matrix::FdMatrix::nrows);
794    let p = result.beta_functions.nrows();
795
796    let mut predicted = FdMatrix::zeros(n_new, m);
797    for i in 0..n_new {
798        for t in 0..m {
799            let mut val = result.mean_function[t];
800            if let Some(cov) = new_covariates {
801                for j in 0..p {
802                    val += cov[(i, j)] * result.beta_functions[(j, t)];
803                }
804            }
805            predicted[(i, t)] = val;
806        }
807    }
808    predicted
809}
810
811// ---------------------------------------------------------------------------
812// Hypothesis testing
813// ---------------------------------------------------------------------------
814
815/// Permutation test for fixed effects in functional mixed model.
816///
817/// Tests H₀: β_j(t) = 0 for each covariate j.
818/// Uses integrated squared norm as test statistic: T_j = ∫ β̂_j(t)² dt.
819///
820/// # Arguments
821/// * `data` — All observed curves (n_total × m)
822/// * `subject_ids` — Subject identifiers
823/// * `covariates` — Subject-level covariates (n_total × p)
824/// * `ncomp` — Number of FPC components
825/// * `n_perm` — Number of permutations
826/// * `seed` — Random seed
827///
828/// # Errors
829///
830/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows, or
831/// `covariates` has zero columns.
832/// Propagates errors from [`fmm`] (e.g., dimension mismatches or FPCA failure).
833#[must_use = "expensive computation whose result should not be discarded"]
834pub fn fmm_test_fixed(
835    data: &FdMatrix,
836    subject_ids: &[usize],
837    covariates: &FdMatrix,
838    ncomp: usize,
839    n_perm: usize,
840    seed: u64,
841) -> Result<FmmTestResult, FdarError> {
842    let n_total = data.nrows();
843    let m = data.ncols();
844    let p = covariates.ncols();
845    if n_total == 0 {
846        return Err(FdarError::InvalidDimension {
847            parameter: "data",
848            expected: "non-empty matrix".to_string(),
849            actual: format!("{n_total} rows"),
850        });
851    }
852    if p == 0 {
853        return Err(FdarError::InvalidDimension {
854            parameter: "covariates",
855            expected: "at least 1 column".to_string(),
856            actual: "0 columns".to_string(),
857        });
858    }
859
860    // Fit observed model
861    let result = fmm(data, subject_ids, Some(covariates), ncomp)?;
862
863    // Observed test statistics: ∫ β̂_j(t)² dt for each covariate
864    let observed_stats = compute_integrated_beta_sq(&result.beta_functions, p, m);
865
866    // Permutation test
867    let (f_statistics, p_values) = permutation_test(
868        data,
869        subject_ids,
870        covariates,
871        ncomp,
872        n_perm,
873        seed,
874        &observed_stats,
875        p,
876        m,
877    );
878
879    Ok(FmmTestResult {
880        f_statistics,
881        p_values,
882    })
883}
884
885/// Compute ∫ β̂_j(t)² dt for each covariate.
886fn compute_integrated_beta_sq(beta: &FdMatrix, p: usize, m: usize) -> Vec<f64> {
887    let h = if m > 1 { 1.0 / (m - 1) as f64 } else { 1.0 };
888    (0..p)
889        .map(|j| {
890            let ss: f64 = (0..m).map(|t| beta[(j, t)].powi(2)).sum();
891            ss * h
892        })
893        .collect()
894}
895
896/// Run permutation test for fixed effects.
897fn permutation_test(
898    data: &FdMatrix,
899    subject_ids: &[usize],
900    covariates: &FdMatrix,
901    ncomp: usize,
902    n_perm: usize,
903    seed: u64,
904    observed_stats: &[f64],
905    p: usize,
906    m: usize,
907) -> (Vec<f64>, Vec<f64>) {
908    use rand::prelude::*;
909    let n_total = data.nrows();
910    let mut rng = StdRng::seed_from_u64(seed);
911    let mut n_ge = vec![0usize; p];
912
913    for _ in 0..n_perm {
914        // Permute covariates across subjects
915        let mut perm_indices: Vec<usize> = (0..n_total).collect();
916        perm_indices.shuffle(&mut rng);
917        let perm_cov = permute_rows(covariates, &perm_indices);
918
919        if let Ok(perm_result) = fmm(data, subject_ids, Some(&perm_cov), ncomp) {
920            let perm_stats = compute_integrated_beta_sq(&perm_result.beta_functions, p, m);
921            for j in 0..p {
922                if perm_stats[j] >= observed_stats[j] {
923                    n_ge[j] += 1;
924                }
925            }
926        }
927    }
928
929    let p_values: Vec<f64> = n_ge
930        .iter()
931        .map(|&count| (count + 1) as f64 / (n_perm + 1) as f64)
932        .collect();
933    let f_statistics = observed_stats.to_vec();
934
935    (f_statistics, p_values)
936}
937
938/// Permute rows of a matrix according to given indices.
939fn permute_rows(mat: &FdMatrix, indices: &[usize]) -> FdMatrix {
940    let n = indices.len();
941    let m = mat.ncols();
942    let mut result = FdMatrix::zeros(n, m);
943    for (new_i, &old_i) in indices.iter().enumerate() {
944        for j in 0..m {
945            result[(new_i, j)] = mat[(old_i, j)];
946        }
947    }
948    result
949}
950
951// ---------------------------------------------------------------------------
952// Tests
953// ---------------------------------------------------------------------------
954
955#[cfg(test)]
956mod tests {
957    use super::*;
958    use crate::test_helpers::uniform_grid;
959    use std::f64::consts::PI;
960
961    /// Generate repeated measurements: n_subjects × n_visits curves.
962    /// Subject-level covariate z affects the curve amplitude.
963    fn generate_fmm_data(
964        n_subjects: usize,
965        n_visits: usize,
966        m: usize,
967    ) -> (FdMatrix, Vec<usize>, FdMatrix, Vec<f64>) {
968        let t = uniform_grid(m);
969        let n_total = n_subjects * n_visits;
970        let mut col_major = vec![0.0; n_total * m];
971        let mut subject_ids = vec![0usize; n_total];
972        let mut cov_data = vec![0.0; n_total];
973
974        for s in 0..n_subjects {
975            let z = s as f64 / n_subjects as f64; // covariate in [0, 1)
976            let subject_effect = 0.5 * (s as f64 - n_subjects as f64 / 2.0); // random-like effect
977
978            for v in 0..n_visits {
979                let obs = s * n_visits + v;
980                subject_ids[obs] = s;
981                cov_data[obs] = z;
982                let noise_scale = 0.05;
983
984                for (j, &tj) in t.iter().enumerate() {
985                    // Y_sv(t) = sin(2πt) + z * t + subject_effect * cos(2πt) + noise
986                    let mu = (2.0 * PI * tj).sin();
987                    let fixed = z * tj * 3.0;
988                    let random = subject_effect * (2.0 * PI * tj).cos() * 0.3;
989                    let noise = noise_scale * ((obs * 7 + j * 3) % 100) as f64 / 100.0;
990                    col_major[obs + j * n_total] = mu + fixed + random + noise;
991                }
992            }
993        }
994
995        let data = FdMatrix::from_column_major(col_major, n_total, m).unwrap();
996        let covariates = FdMatrix::from_column_major(cov_data, n_total, 1).unwrap();
997        (data, subject_ids, covariates, t)
998    }
999
1000    #[test]
1001    fn test_fmm_basic() {
1002        let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
1003        let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
1004
1005        assert_eq!(result.mean_function.len(), 50);
1006        assert_eq!(result.beta_functions.nrows(), 1); // 1 covariate
1007        assert_eq!(result.beta_functions.ncols(), 50);
1008        assert_eq!(result.random_effects.nrows(), 10);
1009        assert_eq!(result.fitted.nrows(), 30);
1010        assert_eq!(result.residuals.nrows(), 30);
1011        assert_eq!(result.n_subjects, 10);
1012    }
1013
1014    #[test]
1015    fn test_fmm_fitted_plus_residuals_equals_data() {
1016        let (data, subject_ids, covariates, _t) = generate_fmm_data(8, 3, 40);
1017        let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
1018
1019        let n = data.nrows();
1020        let m = data.ncols();
1021        for i in 0..n {
1022            for t in 0..m {
1023                let reconstructed = result.fitted[(i, t)] + result.residuals[(i, t)];
1024                assert!(
1025                    (reconstructed - data[(i, t)]).abs() < 1e-8,
1026                    "Fitted + residual should equal data at ({}, {}): {} vs {}",
1027                    i,
1028                    t,
1029                    reconstructed,
1030                    data[(i, t)]
1031                );
1032            }
1033        }
1034    }
1035
1036    #[test]
1037    fn test_fmm_random_variance_positive() {
1038        let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
1039        let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
1040
1041        for &v in &result.random_variance {
1042            assert!(v >= 0.0, "Random variance should be non-negative");
1043        }
1044    }
1045
1046    #[test]
1047    fn test_fmm_no_covariates() {
1048        let (data, subject_ids, _cov, _t) = generate_fmm_data(8, 3, 40);
1049        let result = fmm(&data, &subject_ids, None, 3).unwrap();
1050
1051        assert_eq!(result.beta_functions.nrows(), 0);
1052        assert_eq!(result.n_subjects, 8);
1053        assert_eq!(result.fitted.nrows(), 24);
1054    }
1055
1056    #[test]
1057    fn test_fmm_predict() {
1058        let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
1059        let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
1060
1061        // Predict for new subjects with covariate = 0.5
1062        let new_cov = FdMatrix::from_column_major(vec![0.5], 1, 1).unwrap();
1063        let predicted = fmm_predict(&result, Some(&new_cov));
1064
1065        assert_eq!(predicted.nrows(), 1);
1066        assert_eq!(predicted.ncols(), 50);
1067
1068        // Predicted curve should be reasonable (not NaN or extreme)
1069        for t in 0..50 {
1070            assert!(predicted[(0, t)].is_finite());
1071            assert!(
1072                predicted[(0, t)].abs() < 20.0,
1073                "Predicted value too extreme at t={}: {}",
1074                t,
1075                predicted[(0, t)]
1076            );
1077        }
1078    }
1079
1080    #[test]
1081    fn test_fmm_test_fixed_detects_effect() {
1082        let (data, subject_ids, covariates, _t) = generate_fmm_data(15, 3, 40);
1083
1084        let result = fmm_test_fixed(&data, &subject_ids, &covariates, 3, 99, 42).unwrap();
1085
1086        assert_eq!(result.f_statistics.len(), 1);
1087        assert_eq!(result.p_values.len(), 1);
1088        assert!(
1089            result.p_values[0] < 0.1,
1090            "Should detect covariate effect, got p={}",
1091            result.p_values[0]
1092        );
1093    }
1094
1095    #[test]
1096    fn test_fmm_test_fixed_no_effect() {
1097        let n_subjects = 10;
1098        let n_visits = 3;
1099        let m = 40;
1100        let t = uniform_grid(m);
1101        let n_total = n_subjects * n_visits;
1102
1103        // No covariate effect: Y = sin(2πt) + noise
1104        let mut col_major = vec![0.0; n_total * m];
1105        let mut subject_ids = vec![0usize; n_total];
1106        let mut cov_data = vec![0.0; n_total];
1107
1108        for s in 0..n_subjects {
1109            for v in 0..n_visits {
1110                let obs = s * n_visits + v;
1111                subject_ids[obs] = s;
1112                cov_data[obs] = s as f64 / n_subjects as f64;
1113                for (j, &tj) in t.iter().enumerate() {
1114                    col_major[obs + j * n_total] =
1115                        (2.0 * PI * tj).sin() + 0.1 * ((obs * 7 + j * 3) % 100) as f64 / 100.0;
1116                }
1117            }
1118        }
1119
1120        let data = FdMatrix::from_column_major(col_major, n_total, m).unwrap();
1121        let covariates = FdMatrix::from_column_major(cov_data, n_total, 1).unwrap();
1122
1123        let result = fmm_test_fixed(&data, &subject_ids, &covariates, 3, 99, 42).unwrap();
1124        assert!(
1125            result.p_values[0] > 0.05,
1126            "Should not detect effect, got p={}",
1127            result.p_values[0]
1128        );
1129    }
1130
1131    #[test]
1132    fn test_fmm_invalid_input() {
1133        let data = FdMatrix::zeros(0, 0);
1134        assert!(fmm(&data, &[], None, 1).is_err());
1135
1136        let data = FdMatrix::zeros(10, 50);
1137        let ids = vec![0; 5]; // wrong length
1138        assert!(fmm(&data, &ids, None, 1).is_err());
1139    }
1140
1141    #[test]
1142    fn test_fmm_single_visit_per_subject() {
1143        let n = 10;
1144        let m = 40;
1145        let t = uniform_grid(m);
1146        let mut col_major = vec![0.0; n * m];
1147        let subject_ids: Vec<usize> = (0..n).collect();
1148
1149        for i in 0..n {
1150            for (j, &tj) in t.iter().enumerate() {
1151                col_major[i + j * n] = (2.0 * PI * tj).sin();
1152            }
1153        }
1154        let data = FdMatrix::from_column_major(col_major, n, m).unwrap();
1155
1156        // Should still work with 1 visit per subject
1157        let result = fmm(&data, &subject_ids, None, 2).unwrap();
1158        assert_eq!(result.n_subjects, n);
1159        assert_eq!(result.fitted.nrows(), n);
1160    }
1161
1162    #[test]
1163    fn test_build_subject_map() {
1164        let (map, n) = build_subject_map(&[5, 5, 10, 10, 20]);
1165        assert_eq!(n, 3);
1166        assert_eq!(map, vec![0, 0, 1, 1, 2]);
1167    }
1168
1169    #[test]
1170    fn test_variance_components_positive() {
1171        let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
1172        let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
1173
1174        assert!(result.sigma2_eps >= 0.0);
1175        for &s in &result.sigma2_u {
1176            assert!(s >= 0.0);
1177        }
1178    }
1179
1180    // -------------------------------------------------------------------
1181    // Additional tests
1182    // -------------------------------------------------------------------
1183
1184    #[test]
1185    fn test_fmm_ncomp_zero_returns_error() {
1186        let (data, subject_ids, _cov, _t) = generate_fmm_data(5, 2, 20);
1187        let err = fmm(&data, &subject_ids, None, 0).unwrap_err();
1188        match err {
1189            FdarError::InvalidParameter { parameter, .. } => {
1190                assert_eq!(parameter, "ncomp");
1191            }
1192            other => panic!("Expected InvalidParameter, got {:?}", other),
1193        }
1194    }
1195
1196    #[test]
1197    fn test_fmm_single_component() {
1198        // Fit with only 1 FPC component
1199        let (data, subject_ids, covariates, _t) = generate_fmm_data(8, 3, 30);
1200        let result = fmm(&data, &subject_ids, Some(&covariates), 1).unwrap();
1201
1202        assert_eq!(result.ncomp, 1);
1203        assert_eq!(result.sigma2_u.len(), 1);
1204        assert_eq!(result.eigenvalues.len(), 1);
1205        assert_eq!(result.mean_function.len(), 30);
1206        // Fitted + residuals = data
1207        for i in 0..data.nrows() {
1208            for t in 0..data.ncols() {
1209                let diff = (result.fitted[(i, t)] + result.residuals[(i, t)] - data[(i, t)]).abs();
1210                assert!(diff < 1e-8);
1211            }
1212        }
1213    }
1214
1215    #[test]
1216    fn test_fmm_two_subjects() {
1217        // Minimal number of subjects (2) with multiple visits
1218        let n_subjects = 2;
1219        let n_visits = 5;
1220        let m = 20;
1221        let t = uniform_grid(m);
1222        let n_total = n_subjects * n_visits;
1223        let mut col_major = vec![0.0; n_total * m];
1224        let mut subject_ids = vec![0usize; n_total];
1225
1226        for s in 0..n_subjects {
1227            for v in 0..n_visits {
1228                let obs = s * n_visits + v;
1229                subject_ids[obs] = s;
1230                for (j, &tj) in t.iter().enumerate() {
1231                    col_major[obs + j * n_total] =
1232                        (2.0 * PI * tj).sin() + (s as f64) * 0.5 + 0.01 * v as f64;
1233                }
1234            }
1235        }
1236        let data = FdMatrix::from_column_major(col_major, n_total, m).unwrap();
1237        let result = fmm(&data, &subject_ids, None, 2).unwrap();
1238
1239        assert_eq!(result.n_subjects, 2);
1240        assert_eq!(result.random_effects.nrows(), 2);
1241        assert_eq!(result.fitted.nrows(), n_total);
1242    }
1243
1244    #[test]
1245    fn test_fmm_predict_no_covariates() {
1246        let (data, subject_ids, _cov, _t) = generate_fmm_data(6, 3, 30);
1247        let result = fmm(&data, &subject_ids, None, 2).unwrap();
1248
1249        // Predict without covariates — should return mean function
1250        let predicted = fmm_predict(&result, None);
1251        assert_eq!(predicted.nrows(), 1);
1252        assert_eq!(predicted.ncols(), 30);
1253        for t in 0..30 {
1254            let diff = (predicted[(0, t)] - result.mean_function[t]).abs();
1255            assert!(
1256                diff < 1e-12,
1257                "Without covariates, prediction should equal mean"
1258            );
1259        }
1260    }
1261
1262    #[test]
1263    fn test_fmm_predict_multiple_new_subjects() {
1264        let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 40);
1265        let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
1266
1267        // Predict for 3 new subjects with different covariate values
1268        let new_cov = FdMatrix::from_column_major(vec![0.1, 0.5, 0.9], 3, 1).unwrap();
1269        let predicted = fmm_predict(&result, Some(&new_cov));
1270
1271        assert_eq!(predicted.nrows(), 3);
1272        assert_eq!(predicted.ncols(), 40);
1273
1274        // All predictions should be finite
1275        for i in 0..3 {
1276            for t in 0..40 {
1277                assert!(predicted[(i, t)].is_finite());
1278            }
1279        }
1280
1281        // Predictions for different covariates should differ
1282        let diff_01: f64 = (0..40)
1283            .map(|t| (predicted[(0, t)] - predicted[(1, t)]).powi(2))
1284            .sum();
1285        assert!(
1286            diff_01 > 1e-10,
1287            "Different covariates should yield different predictions"
1288        );
1289    }
1290
1291    #[test]
1292    fn test_fmm_eigenvalues_decreasing() {
1293        let (data, subject_ids, _cov, _t) = generate_fmm_data(10, 3, 50);
1294        let result = fmm(&data, &subject_ids, None, 5).unwrap();
1295
1296        // Eigenvalues should be in decreasing order (from FPCA)
1297        for i in 1..result.eigenvalues.len() {
1298            assert!(
1299                result.eigenvalues[i] <= result.eigenvalues[i - 1] + 1e-10,
1300                "Eigenvalues should be non-increasing: {} > {}",
1301                result.eigenvalues[i],
1302                result.eigenvalues[i - 1]
1303            );
1304        }
1305    }
1306
1307    #[test]
1308    fn test_fmm_random_effects_sum_near_zero() {
1309        // Random effects should approximately sum to zero across subjects
1310        let (data, subject_ids, covariates, _t) = generate_fmm_data(20, 3, 40);
1311        let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
1312
1313        let m = result.mean_function.len();
1314        for t in 0..m {
1315            let sum: f64 = (0..result.n_subjects)
1316                .map(|s| result.random_effects[(s, t)])
1317                .sum();
1318            let mean_abs: f64 = (0..result.n_subjects)
1319                .map(|s| result.random_effects[(s, t)].abs())
1320                .sum::<f64>()
1321                / result.n_subjects as f64;
1322            // Relative to the scale of random effects, the sum should be small
1323            if mean_abs > 1e-10 {
1324                assert!(
1325                    (sum / result.n_subjects as f64).abs() < mean_abs * 2.0,
1326                    "Random effects should roughly center around zero at t={}: sum={}, mean_abs={}",
1327                    t,
1328                    sum,
1329                    mean_abs
1330                );
1331            }
1332        }
1333    }
1334
1335    #[test]
1336    fn test_fmm_subject_ids_mismatch_error() {
1337        let data = FdMatrix::zeros(10, 20);
1338        let ids = vec![0; 7]; // wrong length
1339        let err = fmm(&data, &ids, None, 1).unwrap_err();
1340        match err {
1341            FdarError::InvalidDimension { parameter, .. } => {
1342                assert_eq!(parameter, "subject_ids");
1343            }
1344            other => panic!("Expected InvalidDimension, got {:?}", other),
1345        }
1346    }
1347
1348    #[test]
1349    fn test_fmm_test_fixed_empty_data_error() {
1350        let data = FdMatrix::zeros(0, 0);
1351        let covariates = FdMatrix::zeros(0, 1);
1352        let err = fmm_test_fixed(&data, &[], &covariates, 1, 10, 42).unwrap_err();
1353        match err {
1354            FdarError::InvalidDimension { parameter, .. } => {
1355                assert_eq!(parameter, "data");
1356            }
1357            other => panic!("Expected InvalidDimension for data, got {:?}", other),
1358        }
1359    }
1360
1361    #[test]
1362    fn test_fmm_test_fixed_zero_covariates_error() {
1363        let data = FdMatrix::zeros(10, 20);
1364        let ids = vec![0; 10];
1365        let covariates = FdMatrix::zeros(10, 0);
1366        let err = fmm_test_fixed(&data, &ids, &covariates, 1, 10, 42).unwrap_err();
1367        match err {
1368            FdarError::InvalidDimension { parameter, .. } => {
1369                assert_eq!(parameter, "covariates");
1370            }
1371            other => panic!("Expected InvalidDimension for covariates, got {:?}", other),
1372        }
1373    }
1374
1375    #[test]
1376    fn test_build_subject_map_single_subject() {
1377        let (map, n) = build_subject_map(&[42, 42, 42]);
1378        assert_eq!(n, 1);
1379        assert_eq!(map, vec![0, 0, 0]);
1380    }
1381
1382    #[test]
1383    fn test_build_subject_map_non_contiguous_ids() {
1384        let (map, n) = build_subject_map(&[100, 200, 100, 300, 200]);
1385        assert_eq!(n, 3);
1386        // sorted unique: [100, 200, 300] -> indices [0, 1, 2]
1387        assert_eq!(map, vec![0, 1, 0, 2, 1]);
1388    }
1389
1390    #[test]
1391    fn test_fmm_many_components_clamped() {
1392        // Request more components than available; FPCA should clamp
1393        let (data, subject_ids, _cov, _t) = generate_fmm_data(5, 3, 20);
1394        let n_total = data.nrows();
1395        // Request 100 components — should be clamped to min(n_total, m) - 1
1396        let result = fmm(&data, &subject_ids, None, 100).unwrap();
1397        assert!(
1398            result.ncomp <= n_total.min(20),
1399            "ncomp should be clamped: got {}",
1400            result.ncomp
1401        );
1402        assert!(result.ncomp >= 1);
1403    }
1404
1405    #[test]
1406    fn test_fmm_residuals_small_with_enough_components() {
1407        // With enough components, residuals should be small relative to data
1408        let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 30);
1409        let result = fmm(&data, &subject_ids, Some(&covariates), 5).unwrap();
1410
1411        let n = data.nrows();
1412        let m = data.ncols();
1413        let mut data_ss = 0.0_f64;
1414        let mut resid_ss = 0.0_f64;
1415        for i in 0..n {
1416            for t in 0..m {
1417                data_ss += data[(i, t)].powi(2);
1418                resid_ss += result.residuals[(i, t)].powi(2);
1419            }
1420        }
1421
1422        // R-squared should be reasonably high for structured data
1423        let r_squared = 1.0 - resid_ss / data_ss;
1424        assert!(
1425            r_squared > 0.5,
1426            "R-squared should be high with enough components: {}",
1427            r_squared
1428        );
1429    }
1430}