Skip to main content

fdars_core/
famm.rs

1//! Functional Additive Mixed Models (FAMM).
2//!
3//! Implements functional mixed effects models for repeated functional
4//! measurements with subject-level covariates.
5//!
6//! Model: `Y_ij(t) = μ(t) + X_i'β(t) + b_i(t) + ε_ij(t)`
7//!
8//! Key functions:
9//! - [`fmm`] — Fit a functional mixed model via FPC decomposition
10//! - [`fmm_predict`] — Predict curves for new subjects
11//! - [`fmm_test_fixed`] — Hypothesis test on fixed effects
12
13use crate::error::FdarError;
14use crate::iter_maybe_parallel;
15use crate::matrix::FdMatrix;
16use crate::regression::fdata_to_pc_1d;
17#[cfg(feature = "parallel")]
18use rayon::iter::ParallelIterator;
19
20/// Result of a functional mixed model fit.
21#[derive(Debug, Clone, PartialEq)]
22pub struct FmmResult {
23    /// Overall mean function μ̂(t) (length m)
24    pub mean_function: Vec<f64>,
25    /// Fixed effect coefficient functions β̂_j(t) (p × m matrix, one row per covariate)
26    pub beta_functions: FdMatrix,
27    /// Random effect functions b̂_i(t) per subject (n_subjects × m)
28    pub random_effects: FdMatrix,
29    /// Fitted values for all observations (n_total × m)
30    pub fitted: FdMatrix,
31    /// Residuals (n_total × m)
32    pub residuals: FdMatrix,
33    /// Variance of random effects at each time point (length m)
34    pub random_variance: Vec<f64>,
35    /// Residual variance estimate
36    pub sigma2_eps: f64,
37    /// Random effect variance estimate (per-component)
38    pub sigma2_u: Vec<f64>,
39    /// Number of FPC components used
40    pub ncomp: usize,
41    /// Number of subjects
42    pub n_subjects: usize,
43    /// FPC eigenvalues (singular values squared / n)
44    pub eigenvalues: Vec<f64>,
45}
46
47/// Result of fixed effect hypothesis test.
48#[derive(Debug, Clone, PartialEq)]
49pub struct FmmTestResult {
50    /// F-statistic per covariate (length p)
51    pub f_statistics: Vec<f64>,
52    /// P-values per covariate (via permutation, length p)
53    pub p_values: Vec<f64>,
54}
55
56// ---------------------------------------------------------------------------
57// Core FMM algorithm
58// ---------------------------------------------------------------------------
59
60/// Fit a functional mixed model via FPC decomposition.
61///
62/// # Arguments
63/// * `data` — All observed curves (n_total × m), stacked across subjects and visits
64/// * `subject_ids` — Subject identifier for each curve (length n_total)
65/// * `covariates` — Subject-level covariates (n_total × p).
66///   Each row corresponds to the same curve in `data`.
67///   If a covariate is subject-level, its value should be repeated across visits.
68/// * `ncomp` — Number of FPC components
69///
70/// # Algorithm
71/// 1. Pool curves, compute FPCA
72/// 2. For each FPC score, fit scalar mixed model: ξ_ijk = x_i'γ_k + u_ik + e_ijk
73/// 3. Recover β̂(t) and b̂_i(t) from component coefficients
74///
75/// # Errors
76///
77/// Returns [`FdarError::InvalidDimension`] if `data` is empty (zero rows or
78/// columns), or if `subject_ids.len()` does not match the number of rows.
79/// Returns [`FdarError::InvalidParameter`] if `ncomp` is zero.
80/// Returns [`FdarError::ComputationFailed`] if the underlying FPCA fails.
81#[must_use = "expensive computation whose result should not be discarded"]
82pub fn fmm(
83    data: &FdMatrix,
84    subject_ids: &[usize],
85    covariates: Option<&FdMatrix>,
86    ncomp: usize,
87) -> Result<FmmResult, FdarError> {
88    let n_total = data.nrows();
89    let m = data.ncols();
90    if n_total == 0 || m == 0 {
91        return Err(FdarError::InvalidDimension {
92            parameter: "data",
93            expected: "non-empty matrix".to_string(),
94            actual: format!("{n_total} x {m}"),
95        });
96    }
97    if subject_ids.len() != n_total {
98        return Err(FdarError::InvalidDimension {
99            parameter: "subject_ids",
100            expected: format!("length {n_total}"),
101            actual: format!("length {}", subject_ids.len()),
102        });
103    }
104    if ncomp == 0 {
105        return Err(FdarError::InvalidParameter {
106            parameter: "ncomp",
107            message: "must be >= 1".to_string(),
108        });
109    }
110
111    // Determine unique subjects
112    let (subject_map, n_subjects) = build_subject_map(subject_ids);
113
114    // Step 1: FPCA on pooled data
115    let argvals: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1).max(1) as f64).collect();
116    let fpca = fdata_to_pc_1d(data, ncomp, &argvals)?;
117    let k = fpca.scores.ncols(); // actual number of components
118
119    // Step 2: For each FPC score, fit scalar mixed model (parallelized)
120    let p = covariates.map_or(0, super::matrix::FdMatrix::ncols);
121    let ComponentResults {
122        gamma,
123        u_hat,
124        sigma2_u,
125        sigma2_eps,
126    } = fit_all_components(
127        &fpca.scores,
128        &subject_map,
129        n_subjects,
130        covariates,
131        p,
132        k,
133        n_total,
134        m,
135    );
136
137    // Step 3: Recover functional coefficients (using gamma in original scale)
138    let beta_functions = recover_beta_functions(&gamma, &fpca.rotation, p, m, k);
139    let random_effects = recover_random_effects(&u_hat, &fpca.rotation, n_subjects, m, k);
140
141    // Compute random variance function: Var(b_i(t)) across subjects
142    let random_variance = compute_random_variance(&random_effects, n_subjects, m);
143
144    // Compute fitted and residuals
145    let (fitted, residuals) = compute_fitted_residuals(
146        data,
147        &fpca.mean,
148        &beta_functions,
149        &random_effects,
150        covariates,
151        &subject_map,
152        n_total,
153        m,
154        p,
155    );
156
157    let eigenvalues: Vec<f64> = fpca
158        .singular_values
159        .iter()
160        .map(|&sv| sv * sv / n_total as f64)
161        .collect();
162
163    Ok(FmmResult {
164        mean_function: fpca.mean,
165        beta_functions,
166        random_effects,
167        fitted,
168        residuals,
169        random_variance,
170        sigma2_eps,
171        sigma2_u,
172        ncomp: k,
173        n_subjects,
174        eigenvalues,
175    })
176}
177
178/// Build mapping from observation index to subject index (0..n_subjects-1).
179fn build_subject_map(subject_ids: &[usize]) -> (Vec<usize>, usize) {
180    let mut unique_ids: Vec<usize> = subject_ids.to_vec();
181    unique_ids.sort_unstable();
182    unique_ids.dedup();
183    let n_subjects = unique_ids.len();
184
185    let map: Vec<usize> = subject_ids
186        .iter()
187        .map(|id| unique_ids.iter().position(|u| u == id).unwrap_or(0))
188        .collect();
189
190    (map, n_subjects)
191}
192
193/// Aggregated results from fitting all FPC components.
194struct ComponentResults {
195    gamma: Vec<Vec<f64>>, // gamma[j][k] = fixed effect coeff j for component k
196    u_hat: Vec<Vec<f64>>, // u_hat[i][k] = random effect for subject i, component k
197    sigma2_u: Vec<f64>,   // per-component random effect variance
198    sigma2_eps: f64,      // average residual variance across components
199}
200
201/// Fit scalar mixed models for all FPC components (parallelized across components).
202///
203/// For each component k, scales FPC scores to L²-normalized space, fits a scalar
204/// mixed model, then scales coefficients back to the original score space.
205#[allow(clippy::too_many_arguments)]
206fn fit_all_components(
207    scores: &FdMatrix,
208    subject_map: &[usize],
209    n_subjects: usize,
210    covariates: Option<&FdMatrix>,
211    p: usize,
212    k: usize,
213    n_total: usize,
214    m: usize,
215) -> ComponentResults {
216    // Normalize scores by sqrt(h) to match R's L²-weighted FPCA convention.
217    // This ensures variance components are on the same scale as R's lmer().
218    let h = if m > 1 { 1.0 / (m - 1) as f64 } else { 1.0 };
219    let score_scale = h.sqrt();
220
221    // Fit each component independently — parallelized when the feature is enabled
222    let per_comp: Vec<ScalarMixedResult> = iter_maybe_parallel!(0..k)
223        .map(|comp| {
224            let comp_scores: Vec<f64> = (0..n_total)
225                .map(|i| scores[(i, comp)] * score_scale)
226                .collect();
227            fit_scalar_mixed_model(&comp_scores, subject_map, n_subjects, covariates, p)
228        })
229        .collect();
230
231    // Unpack per-component results into the aggregate structure
232    let mut gamma = vec![vec![0.0; k]; p];
233    let mut u_hat = vec![vec![0.0; k]; n_subjects];
234    let mut sigma2_u = vec![0.0; k];
235    let mut sigma2_eps_total = 0.0;
236
237    for (comp, result) in per_comp.iter().enumerate() {
238        for j in 0..p {
239            gamma[j][comp] = result.gamma[j] / score_scale;
240        }
241        for s in 0..n_subjects {
242            u_hat[s][comp] = result.u_hat[s] / score_scale;
243        }
244        sigma2_u[comp] = result.sigma2_u;
245        sigma2_eps_total += result.sigma2_eps;
246    }
247    let sigma2_eps = sigma2_eps_total / k as f64;
248
249    ComponentResults {
250        gamma,
251        u_hat,
252        sigma2_u,
253        sigma2_eps,
254    }
255}
256
257/// Scalar mixed model result for one FPC component.
258struct ScalarMixedResult {
259    gamma: Vec<f64>, // fixed effects (length p)
260    u_hat: Vec<f64>, // random effects per subject (length n_subjects)
261    sigma2_u: f64,   // random effect variance
262    sigma2_eps: f64, // residual variance
263}
264
265/// Precomputed subject structure for the mixed model.
266struct SubjectStructure {
267    counts: Vec<usize>,
268    obs: Vec<Vec<usize>>,
269}
270
271impl SubjectStructure {
272    fn new(subject_map: &[usize], n_subjects: usize, n: usize) -> Self {
273        let mut counts = vec![0usize; n_subjects];
274        let mut obs: Vec<Vec<usize>> = vec![Vec::new(); n_subjects];
275        for i in 0..n {
276            let s = subject_map[i];
277            counts[s] += 1;
278            obs[s].push(i);
279        }
280        Self { counts, obs }
281    }
282}
283
284/// Compute shrinkage weights: w_s = σ²_u / (σ²_u + σ²_e / n_s).
285fn shrinkage_weights(ss: &SubjectStructure, sigma2_u: f64, sigma2_e: f64) -> Vec<f64> {
286    ss.counts
287        .iter()
288        .map(|&c| {
289            let ns = c as f64;
290            if ns < 1.0 {
291                0.0
292            } else {
293                sigma2_u / (sigma2_u + sigma2_e / ns)
294            }
295        })
296        .collect()
297}
298
299/// GLS fixed effect update using block-diagonal V^{-1}.
300///
301/// Computes γ = (X'V⁻¹X)⁻¹ X'V⁻¹y exploiting the balanced random intercept structure.
302fn gls_update_gamma(
303    cov: &FdMatrix,
304    p: usize,
305    ss: &SubjectStructure,
306    weights: &[f64],
307    y: &[f64],
308    sigma2_e: f64,
309) -> Option<Vec<f64>> {
310    let n_subjects = ss.counts.len();
311    let mut xtvinvx = vec![0.0; p * p];
312    let mut xtvinvy = vec![0.0; p];
313    let inv_e = 1.0 / sigma2_e;
314
315    for s in 0..n_subjects {
316        let ns = ss.counts[s] as f64;
317        if ns < 1.0 {
318            continue;
319        }
320        let (x_sum, y_sum) = subject_sums(cov, y, &ss.obs[s], p);
321        accumulate_gls_terms(
322            cov,
323            y,
324            &ss.obs[s],
325            &x_sum,
326            y_sum,
327            weights[s],
328            ns,
329            inv_e,
330            p,
331            &mut xtvinvx,
332            &mut xtvinvy,
333        );
334    }
335
336    for j in 0..p {
337        xtvinvx[j * p + j] += 1e-10;
338    }
339    cholesky_solve(&xtvinvx, &xtvinvy, p)
340}
341
342/// Compute subject-level covariate sums and response sum.
343fn subject_sums(cov: &FdMatrix, y: &[f64], obs: &[usize], p: usize) -> (Vec<f64>, f64) {
344    let mut x_sum = vec![0.0; p];
345    let mut y_sum = 0.0;
346    for &i in obs {
347        for r in 0..p {
348            x_sum[r] += cov[(i, r)];
349        }
350        y_sum += y[i];
351    }
352    (x_sum, y_sum)
353}
354
355/// Accumulate X'V^{-1}X and X'V^{-1}y for one subject.
356fn accumulate_gls_terms(
357    cov: &FdMatrix,
358    y: &[f64],
359    obs: &[usize],
360    x_sum: &[f64],
361    y_sum: f64,
362    w_s: f64,
363    ns: f64,
364    inv_e: f64,
365    p: usize,
366    xtvinvx: &mut [f64],
367    xtvinvy: &mut [f64],
368) {
369    for &i in obs {
370        let vinv_y = inv_e * (y[i] - w_s * y_sum / ns);
371        for r in 0..p {
372            xtvinvy[r] += cov[(i, r)] * vinv_y;
373            for c in r..p {
374                let vinv_xc = inv_e * (cov[(i, c)] - w_s * x_sum[c] / ns);
375                let val = cov[(i, r)] * vinv_xc;
376                xtvinvx[r * p + c] += val;
377                if r != c {
378                    xtvinvx[c * p + r] += val;
379                }
380            }
381        }
382    }
383}
384
385/// REML EM update for variance components.
386///
387/// Returns (σ²_u_new, σ²_e_new) from the conditional expectations.
388/// Uses n - p divisor for σ²_e (REML correction where p = number of fixed effects).
389fn reml_variance_update(
390    residuals: &[f64],
391    ss: &SubjectStructure,
392    weights: &[f64],
393    sigma2_u: f64,
394    p: usize,
395) -> (f64, f64) {
396    let n_subjects = ss.counts.len();
397    let n: usize = ss.counts.iter().sum();
398    let mut sigma2_u_new = 0.0;
399    let mut sigma2_e_new = 0.0;
400
401    for s in 0..n_subjects {
402        let ns = ss.counts[s] as f64;
403        if ns < 1.0 {
404            continue;
405        }
406        let w_s = weights[s];
407        let mean_r_s: f64 = ss.obs[s].iter().map(|&i| residuals[i]).sum::<f64>() / ns;
408        let u_hat_s = w_s * mean_r_s;
409        let cond_var_s = sigma2_u * (1.0 - w_s);
410
411        sigma2_u_new += u_hat_s * u_hat_s + cond_var_s;
412        for &i in &ss.obs[s] {
413            sigma2_e_new += (residuals[i] - u_hat_s).powi(2);
414        }
415        sigma2_e_new += ns * cond_var_s;
416    }
417
418    // REML divisor: n - p for residual variance (matches R's lmer)
419    let denom_e = (n.saturating_sub(p)).max(1) as f64;
420
421    (
422        (sigma2_u_new / n_subjects as f64).max(1e-15),
423        (sigma2_e_new / denom_e).max(1e-15),
424    )
425}
426
427/// Fit scalar mixed model: y_ij = x_i'γ + u_i + e_ij.
428///
429/// Uses iterative GLS for fixed effects + REML EM for variance components,
430/// matching R's lmer() behavior. Initializes from Henderson's ANOVA, then
431/// iterates until convergence.
432fn fit_scalar_mixed_model(
433    y: &[f64],
434    subject_map: &[usize],
435    n_subjects: usize,
436    covariates: Option<&FdMatrix>,
437    p: usize,
438) -> ScalarMixedResult {
439    let n = y.len();
440    let ss = SubjectStructure::new(subject_map, n_subjects, n);
441
442    // Initialize from OLS + Henderson's ANOVA
443    let gamma_init = estimate_fixed_effects(y, covariates, p, n);
444    let residuals_init = compute_ols_residuals(y, covariates, &gamma_init, p, n);
445    let (mut sigma2_u, mut sigma2_e) =
446        estimate_variance_components(&residuals_init, subject_map, n_subjects, n);
447
448    if sigma2_e < 1e-15 {
449        sigma2_e = 1e-6;
450    }
451    if sigma2_u < 1e-15 {
452        sigma2_u = sigma2_e * 0.1;
453    }
454
455    let mut gamma = gamma_init;
456
457    for _iter in 0..50 {
458        let sigma2_u_old = sigma2_u;
459        let sigma2_e_old = sigma2_e;
460
461        let weights = shrinkage_weights(&ss, sigma2_u, sigma2_e);
462
463        if let Some(cov) = covariates.filter(|_| p > 0) {
464            if let Some(g) = gls_update_gamma(cov, p, &ss, &weights, y, sigma2_e) {
465                gamma = g;
466            }
467        }
468
469        let r = compute_ols_residuals(y, covariates, &gamma, p, n);
470        (sigma2_u, sigma2_e) = reml_variance_update(&r, &ss, &weights, sigma2_u, p);
471
472        let delta = (sigma2_u - sigma2_u_old).abs() + (sigma2_e - sigma2_e_old).abs();
473        if delta < 1e-10 * (sigma2_u_old + sigma2_e_old) {
474            break;
475        }
476    }
477
478    let final_residuals = compute_ols_residuals(y, covariates, &gamma, p, n);
479    let u_hat = compute_blup(
480        &final_residuals,
481        subject_map,
482        n_subjects,
483        sigma2_u,
484        sigma2_e,
485    );
486
487    ScalarMixedResult {
488        gamma,
489        u_hat,
490        sigma2_u,
491        sigma2_eps: sigma2_e,
492    }
493}
494
495/// OLS estimation of fixed effects.
496fn estimate_fixed_effects(
497    y: &[f64],
498    covariates: Option<&FdMatrix>,
499    p: usize,
500    n: usize,
501) -> Vec<f64> {
502    if p == 0 || covariates.is_none() {
503        return Vec::new();
504    }
505    let cov = covariates.expect("checked: covariates is Some");
506
507    // Solve (X'X)γ = X'y via Cholesky
508    let mut xtx = vec![0.0; p * p];
509    let mut xty = vec![0.0; p];
510    for i in 0..n {
511        for r in 0..p {
512            xty[r] += cov[(i, r)] * y[i];
513            for s in r..p {
514                let val = cov[(i, r)] * cov[(i, s)];
515                xtx[r * p + s] += val;
516                if r != s {
517                    xtx[s * p + r] += val;
518                }
519            }
520        }
521    }
522    // Regularize
523    for j in 0..p {
524        xtx[j * p + j] += 1e-8;
525    }
526
527    cholesky_solve(&xtx, &xty, p).unwrap_or(vec![0.0; p])
528}
529
530/// Cholesky factorization: A = LL'. Returns L (p×p flat row-major) or None if singular.
531fn cholesky_factor_famm(a: &[f64], p: usize) -> Option<Vec<f64>> {
532    let mut l = vec![0.0; p * p];
533    for j in 0..p {
534        let mut sum = 0.0;
535        for k in 0..j {
536            sum += l[j * p + k] * l[j * p + k];
537        }
538        let diag = a[j * p + j] - sum;
539        if diag <= 0.0 {
540            return None;
541        }
542        l[j * p + j] = diag.sqrt();
543        for i in (j + 1)..p {
544            let mut s = 0.0;
545            for k in 0..j {
546                s += l[i * p + k] * l[j * p + k];
547            }
548            l[i * p + j] = (a[i * p + j] - s) / l[j * p + j];
549        }
550    }
551    Some(l)
552}
553
554/// Solve L z = b (forward) then L' x = z (back).
555fn cholesky_triangular_solve(l: &[f64], b: &[f64], p: usize) -> Vec<f64> {
556    let mut z = vec![0.0; p];
557    for i in 0..p {
558        let mut s = 0.0;
559        for j in 0..i {
560            s += l[i * p + j] * z[j];
561        }
562        z[i] = (b[i] - s) / l[i * p + i];
563    }
564    for i in (0..p).rev() {
565        let mut s = 0.0;
566        for j in (i + 1)..p {
567            s += l[j * p + i] * z[j];
568        }
569        z[i] = (z[i] - s) / l[i * p + i];
570    }
571    z
572}
573
574/// Cholesky solve: A x = b where A is p×p symmetric positive definite.
575fn cholesky_solve(a: &[f64], b: &[f64], p: usize) -> Option<Vec<f64>> {
576    let l = cholesky_factor_famm(a, p)?;
577    Some(cholesky_triangular_solve(&l, b, p))
578}
579
580/// Compute OLS residuals: r = y - X*gamma.
581fn compute_ols_residuals(
582    y: &[f64],
583    covariates: Option<&FdMatrix>,
584    gamma: &[f64],
585    p: usize,
586    n: usize,
587) -> Vec<f64> {
588    let mut residuals = y.to_vec();
589    if p > 0 {
590        if let Some(cov) = covariates {
591            for i in 0..n {
592                for j in 0..p {
593                    residuals[i] -= cov[(i, j)] * gamma[j];
594                }
595            }
596        }
597    }
598    residuals
599}
600
601/// Estimate variance components via method of moments.
602///
603/// σ²_u and σ²_ε from one-way random effects ANOVA.
604fn estimate_variance_components(
605    residuals: &[f64],
606    subject_map: &[usize],
607    n_subjects: usize,
608    n: usize,
609) -> (f64, f64) {
610    // Compute subject means and within-subject SS
611    let mut subject_sums = vec![0.0; n_subjects];
612    let mut subject_counts = vec![0usize; n_subjects];
613    for i in 0..n {
614        let s = subject_map[i];
615        subject_sums[s] += residuals[i];
616        subject_counts[s] += 1;
617    }
618    let subject_means: Vec<f64> = subject_sums
619        .iter()
620        .zip(&subject_counts)
621        .map(|(&s, &c)| if c > 0 { s / c as f64 } else { 0.0 })
622        .collect();
623
624    // Within-subject SS
625    let mut ss_within = 0.0;
626    for i in 0..n {
627        let s = subject_map[i];
628        ss_within += (residuals[i] - subject_means[s]).powi(2);
629    }
630    let df_within = n.saturating_sub(n_subjects);
631
632    // Between-subject SS
633    let grand_mean = residuals.iter().sum::<f64>() / n as f64;
634    let mut ss_between = 0.0;
635    for s in 0..n_subjects {
636        ss_between += subject_counts[s] as f64 * (subject_means[s] - grand_mean).powi(2);
637    }
638
639    let sigma2_eps = if df_within > 0 {
640        ss_within / df_within as f64
641    } else {
642        1e-6
643    };
644
645    // Mean number of observations per subject
646    let n_bar = n as f64 / n_subjects.max(1) as f64;
647    let df_between = n_subjects.saturating_sub(1).max(1);
648    let ms_between = ss_between / df_between as f64;
649    let sigma2_u = ((ms_between - sigma2_eps) / n_bar).max(0.0);
650
651    (sigma2_u, sigma2_eps)
652}
653
654/// Compute BLUP (Best Linear Unbiased Prediction) for random effects.
655///
656/// û_i = σ²_u / (σ²_u + σ²_ε/n_i) * (ȳ_i - x̄_i'γ)
657fn compute_blup(
658    residuals: &[f64],
659    subject_map: &[usize],
660    n_subjects: usize,
661    sigma2_u: f64,
662    sigma2_eps: f64,
663) -> Vec<f64> {
664    let mut subject_sums = vec![0.0; n_subjects];
665    let mut subject_counts = vec![0usize; n_subjects];
666    for (i, &r) in residuals.iter().enumerate() {
667        let s = subject_map[i];
668        subject_sums[s] += r;
669        subject_counts[s] += 1;
670    }
671
672    (0..n_subjects)
673        .map(|s| {
674            let ni = subject_counts[s] as f64;
675            if ni < 1.0 {
676                return 0.0;
677            }
678            let mean_r = subject_sums[s] / ni;
679            let shrinkage = sigma2_u / (sigma2_u + sigma2_eps / ni).max(1e-15);
680            shrinkage * mean_r
681        })
682        .collect()
683}
684
685// ---------------------------------------------------------------------------
686// Recovery of functional coefficients
687// ---------------------------------------------------------------------------
688
689/// Recover β̂(t) = Σ_k γ̂_jk φ_k(t) for each covariate j.
690fn recover_beta_functions(
691    gamma: &[Vec<f64>],
692    rotation: &FdMatrix,
693    p: usize,
694    m: usize,
695    k: usize,
696) -> FdMatrix {
697    let mut beta = FdMatrix::zeros(p, m);
698    for j in 0..p {
699        for t in 0..m {
700            let mut val = 0.0;
701            for comp in 0..k {
702                val += gamma[j][comp] * rotation[(t, comp)];
703            }
704            beta[(j, t)] = val;
705        }
706    }
707    beta
708}
709
710/// Recover b̂_i(t) = Σ_k û_ik φ_k(t) for each subject i.
711fn recover_random_effects(
712    u_hat: &[Vec<f64>],
713    rotation: &FdMatrix,
714    n_subjects: usize,
715    m: usize,
716    k: usize,
717) -> FdMatrix {
718    let mut re = FdMatrix::zeros(n_subjects, m);
719    for s in 0..n_subjects {
720        for t in 0..m {
721            let mut val = 0.0;
722            for comp in 0..k {
723                val += u_hat[s][comp] * rotation[(t, comp)];
724            }
725            re[(s, t)] = val;
726        }
727    }
728    re
729}
730
731/// Compute random effect variance function: Var_i(b̂_i(t)).
732fn compute_random_variance(random_effects: &FdMatrix, n_subjects: usize, m: usize) -> Vec<f64> {
733    (0..m)
734        .map(|t| {
735            let mean: f64 =
736                (0..n_subjects).map(|s| random_effects[(s, t)]).sum::<f64>() / n_subjects as f64;
737            let var: f64 = (0..n_subjects)
738                .map(|s| (random_effects[(s, t)] - mean).powi(2))
739                .sum::<f64>()
740                / n_subjects.max(1) as f64;
741            var
742        })
743        .collect()
744}
745
746/// Compute fitted values and residuals.
747fn compute_fitted_residuals(
748    data: &FdMatrix,
749    mean_function: &[f64],
750    beta_functions: &FdMatrix,
751    random_effects: &FdMatrix,
752    covariates: Option<&FdMatrix>,
753    subject_map: &[usize],
754    n_total: usize,
755    m: usize,
756    p: usize,
757) -> (FdMatrix, FdMatrix) {
758    let mut fitted = FdMatrix::zeros(n_total, m);
759    let mut residuals = FdMatrix::zeros(n_total, m);
760
761    for i in 0..n_total {
762        let s = subject_map[i];
763        for t in 0..m {
764            let mut val = mean_function[t] + random_effects[(s, t)];
765            if p > 0 {
766                if let Some(cov) = covariates {
767                    for j in 0..p {
768                        val += cov[(i, j)] * beta_functions[(j, t)];
769                    }
770                }
771            }
772            fitted[(i, t)] = val;
773            residuals[(i, t)] = data[(i, t)] - val;
774        }
775    }
776
777    (fitted, residuals)
778}
779
780// ---------------------------------------------------------------------------
781// Prediction
782// ---------------------------------------------------------------------------
783
784/// Predict curves for new subjects.
785///
786/// # Arguments
787/// * `result` — Fitted FMM result
788/// * `new_covariates` — Covariates for new subjects (n_new × p)
789///
790/// Returns predicted curves (n_new × m) using only fixed effects (no random effects for new subjects).
791#[must_use = "prediction result should not be discarded"]
792pub fn fmm_predict(result: &FmmResult, new_covariates: Option<&FdMatrix>) -> FdMatrix {
793    let m = result.mean_function.len();
794    let n_new = new_covariates.map_or(1, super::matrix::FdMatrix::nrows);
795    let p = result.beta_functions.nrows();
796
797    let mut predicted = FdMatrix::zeros(n_new, m);
798    for i in 0..n_new {
799        for t in 0..m {
800            let mut val = result.mean_function[t];
801            if let Some(cov) = new_covariates {
802                for j in 0..p {
803                    val += cov[(i, j)] * result.beta_functions[(j, t)];
804                }
805            }
806            predicted[(i, t)] = val;
807        }
808    }
809    predicted
810}
811
812// ---------------------------------------------------------------------------
813// Hypothesis testing
814// ---------------------------------------------------------------------------
815
816/// Permutation test for fixed effects in functional mixed model.
817///
818/// Tests H₀: β_j(t) = 0 for each covariate j.
819/// Uses integrated squared norm as test statistic: T_j = ∫ β̂_j(t)² dt.
820///
821/// # Arguments
822/// * `data` — All observed curves (n_total × m)
823/// * `subject_ids` — Subject identifiers
824/// * `covariates` — Subject-level covariates (n_total × p)
825/// * `ncomp` — Number of FPC components
826/// * `n_perm` — Number of permutations
827/// * `seed` — Random seed
828///
829/// # Errors
830///
831/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows, or
832/// `covariates` has zero columns.
833/// Propagates errors from [`fmm`] (e.g., dimension mismatches or FPCA failure).
834#[must_use = "expensive computation whose result should not be discarded"]
835pub fn fmm_test_fixed(
836    data: &FdMatrix,
837    subject_ids: &[usize],
838    covariates: &FdMatrix,
839    ncomp: usize,
840    n_perm: usize,
841    seed: u64,
842) -> Result<FmmTestResult, FdarError> {
843    let n_total = data.nrows();
844    let m = data.ncols();
845    let p = covariates.ncols();
846    if n_total == 0 {
847        return Err(FdarError::InvalidDimension {
848            parameter: "data",
849            expected: "non-empty matrix".to_string(),
850            actual: format!("{n_total} rows"),
851        });
852    }
853    if p == 0 {
854        return Err(FdarError::InvalidDimension {
855            parameter: "covariates",
856            expected: "at least 1 column".to_string(),
857            actual: "0 columns".to_string(),
858        });
859    }
860
861    // Fit observed model
862    let result = fmm(data, subject_ids, Some(covariates), ncomp)?;
863
864    // Observed test statistics: ∫ β̂_j(t)² dt for each covariate
865    let observed_stats = compute_integrated_beta_sq(&result.beta_functions, p, m);
866
867    // Permutation test
868    let (f_statistics, p_values) = permutation_test(
869        data,
870        subject_ids,
871        covariates,
872        ncomp,
873        n_perm,
874        seed,
875        &observed_stats,
876        p,
877        m,
878    );
879
880    Ok(FmmTestResult {
881        f_statistics,
882        p_values,
883    })
884}
885
886/// Compute ∫ β̂_j(t)² dt for each covariate.
887fn compute_integrated_beta_sq(beta: &FdMatrix, p: usize, m: usize) -> Vec<f64> {
888    let h = if m > 1 { 1.0 / (m - 1) as f64 } else { 1.0 };
889    (0..p)
890        .map(|j| {
891            let ss: f64 = (0..m).map(|t| beta[(j, t)].powi(2)).sum();
892            ss * h
893        })
894        .collect()
895}
896
897/// Run permutation test for fixed effects.
898fn permutation_test(
899    data: &FdMatrix,
900    subject_ids: &[usize],
901    covariates: &FdMatrix,
902    ncomp: usize,
903    n_perm: usize,
904    seed: u64,
905    observed_stats: &[f64],
906    p: usize,
907    m: usize,
908) -> (Vec<f64>, Vec<f64>) {
909    use rand::prelude::*;
910    let n_total = data.nrows();
911    let mut rng = StdRng::seed_from_u64(seed);
912    let mut n_ge = vec![0usize; p];
913
914    for _ in 0..n_perm {
915        // Permute covariates across subjects
916        let mut perm_indices: Vec<usize> = (0..n_total).collect();
917        perm_indices.shuffle(&mut rng);
918        let perm_cov = permute_rows(covariates, &perm_indices);
919
920        if let Ok(perm_result) = fmm(data, subject_ids, Some(&perm_cov), ncomp) {
921            let perm_stats = compute_integrated_beta_sq(&perm_result.beta_functions, p, m);
922            for j in 0..p {
923                if perm_stats[j] >= observed_stats[j] {
924                    n_ge[j] += 1;
925                }
926            }
927        }
928    }
929
930    let p_values: Vec<f64> = n_ge
931        .iter()
932        .map(|&count| (count + 1) as f64 / (n_perm + 1) as f64)
933        .collect();
934    let f_statistics = observed_stats.to_vec();
935
936    (f_statistics, p_values)
937}
938
939/// Permute rows of a matrix according to given indices.
940fn permute_rows(mat: &FdMatrix, indices: &[usize]) -> FdMatrix {
941    let n = indices.len();
942    let m = mat.ncols();
943    let mut result = FdMatrix::zeros(n, m);
944    for (new_i, &old_i) in indices.iter().enumerate() {
945        for j in 0..m {
946            result[(new_i, j)] = mat[(old_i, j)];
947        }
948    }
949    result
950}
951
952// ---------------------------------------------------------------------------
953// Tests
954// ---------------------------------------------------------------------------
955
956#[cfg(test)]
957mod tests {
958    use super::*;
959    use crate::test_helpers::uniform_grid;
960    use std::f64::consts::PI;
961
962    /// Generate repeated measurements: n_subjects × n_visits curves.
963    /// Subject-level covariate z affects the curve amplitude.
964    fn generate_fmm_data(
965        n_subjects: usize,
966        n_visits: usize,
967        m: usize,
968    ) -> (FdMatrix, Vec<usize>, FdMatrix, Vec<f64>) {
969        let t = uniform_grid(m);
970        let n_total = n_subjects * n_visits;
971        let mut col_major = vec![0.0; n_total * m];
972        let mut subject_ids = vec![0usize; n_total];
973        let mut cov_data = vec![0.0; n_total];
974
975        for s in 0..n_subjects {
976            let z = s as f64 / n_subjects as f64; // covariate in [0, 1)
977            let subject_effect = 0.5 * (s as f64 - n_subjects as f64 / 2.0); // random-like effect
978
979            for v in 0..n_visits {
980                let obs = s * n_visits + v;
981                subject_ids[obs] = s;
982                cov_data[obs] = z;
983                let noise_scale = 0.05;
984
985                for (j, &tj) in t.iter().enumerate() {
986                    // Y_sv(t) = sin(2πt) + z * t + subject_effect * cos(2πt) + noise
987                    let mu = (2.0 * PI * tj).sin();
988                    let fixed = z * tj * 3.0;
989                    let random = subject_effect * (2.0 * PI * tj).cos() * 0.3;
990                    let noise = noise_scale * ((obs * 7 + j * 3) % 100) as f64 / 100.0;
991                    col_major[obs + j * n_total] = mu + fixed + random + noise;
992                }
993            }
994        }
995
996        let data = FdMatrix::from_column_major(col_major, n_total, m).unwrap();
997        let covariates = FdMatrix::from_column_major(cov_data, n_total, 1).unwrap();
998        (data, subject_ids, covariates, t)
999    }
1000
1001    #[test]
1002    fn test_fmm_basic() {
1003        let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
1004        let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
1005
1006        assert_eq!(result.mean_function.len(), 50);
1007        assert_eq!(result.beta_functions.nrows(), 1); // 1 covariate
1008        assert_eq!(result.beta_functions.ncols(), 50);
1009        assert_eq!(result.random_effects.nrows(), 10);
1010        assert_eq!(result.fitted.nrows(), 30);
1011        assert_eq!(result.residuals.nrows(), 30);
1012        assert_eq!(result.n_subjects, 10);
1013    }
1014
1015    #[test]
1016    fn test_fmm_fitted_plus_residuals_equals_data() {
1017        let (data, subject_ids, covariates, _t) = generate_fmm_data(8, 3, 40);
1018        let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
1019
1020        let n = data.nrows();
1021        let m = data.ncols();
1022        for i in 0..n {
1023            for t in 0..m {
1024                let reconstructed = result.fitted[(i, t)] + result.residuals[(i, t)];
1025                assert!(
1026                    (reconstructed - data[(i, t)]).abs() < 1e-8,
1027                    "Fitted + residual should equal data at ({}, {}): {} vs {}",
1028                    i,
1029                    t,
1030                    reconstructed,
1031                    data[(i, t)]
1032                );
1033            }
1034        }
1035    }
1036
1037    #[test]
1038    fn test_fmm_random_variance_positive() {
1039        let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
1040        let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
1041
1042        for &v in &result.random_variance {
1043            assert!(v >= 0.0, "Random variance should be non-negative");
1044        }
1045    }
1046
1047    #[test]
1048    fn test_fmm_no_covariates() {
1049        let (data, subject_ids, _cov, _t) = generate_fmm_data(8, 3, 40);
1050        let result = fmm(&data, &subject_ids, None, 3).unwrap();
1051
1052        assert_eq!(result.beta_functions.nrows(), 0);
1053        assert_eq!(result.n_subjects, 8);
1054        assert_eq!(result.fitted.nrows(), 24);
1055    }
1056
1057    #[test]
1058    fn test_fmm_predict() {
1059        let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
1060        let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
1061
1062        // Predict for new subjects with covariate = 0.5
1063        let new_cov = FdMatrix::from_column_major(vec![0.5], 1, 1).unwrap();
1064        let predicted = fmm_predict(&result, Some(&new_cov));
1065
1066        assert_eq!(predicted.nrows(), 1);
1067        assert_eq!(predicted.ncols(), 50);
1068
1069        // Predicted curve should be reasonable (not NaN or extreme)
1070        for t in 0..50 {
1071            assert!(predicted[(0, t)].is_finite());
1072            assert!(
1073                predicted[(0, t)].abs() < 20.0,
1074                "Predicted value too extreme at t={}: {}",
1075                t,
1076                predicted[(0, t)]
1077            );
1078        }
1079    }
1080
1081    #[test]
1082    fn test_fmm_test_fixed_detects_effect() {
1083        let (data, subject_ids, covariates, _t) = generate_fmm_data(15, 3, 40);
1084
1085        let result = fmm_test_fixed(&data, &subject_ids, &covariates, 3, 99, 42).unwrap();
1086
1087        assert_eq!(result.f_statistics.len(), 1);
1088        assert_eq!(result.p_values.len(), 1);
1089        assert!(
1090            result.p_values[0] < 0.1,
1091            "Should detect covariate effect, got p={}",
1092            result.p_values[0]
1093        );
1094    }
1095
1096    #[test]
1097    fn test_fmm_test_fixed_no_effect() {
1098        let n_subjects = 10;
1099        let n_visits = 3;
1100        let m = 40;
1101        let t = uniform_grid(m);
1102        let n_total = n_subjects * n_visits;
1103
1104        // No covariate effect: Y = sin(2πt) + noise
1105        let mut col_major = vec![0.0; n_total * m];
1106        let mut subject_ids = vec![0usize; n_total];
1107        let mut cov_data = vec![0.0; n_total];
1108
1109        for s in 0..n_subjects {
1110            for v in 0..n_visits {
1111                let obs = s * n_visits + v;
1112                subject_ids[obs] = s;
1113                cov_data[obs] = s as f64 / n_subjects as f64;
1114                for (j, &tj) in t.iter().enumerate() {
1115                    col_major[obs + j * n_total] =
1116                        (2.0 * PI * tj).sin() + 0.1 * ((obs * 7 + j * 3) % 100) as f64 / 100.0;
1117                }
1118            }
1119        }
1120
1121        let data = FdMatrix::from_column_major(col_major, n_total, m).unwrap();
1122        let covariates = FdMatrix::from_column_major(cov_data, n_total, 1).unwrap();
1123
1124        let result = fmm_test_fixed(&data, &subject_ids, &covariates, 3, 99, 42).unwrap();
1125        assert!(
1126            result.p_values[0] > 0.05,
1127            "Should not detect effect, got p={}",
1128            result.p_values[0]
1129        );
1130    }
1131
1132    #[test]
1133    fn test_fmm_invalid_input() {
1134        let data = FdMatrix::zeros(0, 0);
1135        assert!(fmm(&data, &[], None, 1).is_err());
1136
1137        let data = FdMatrix::zeros(10, 50);
1138        let ids = vec![0; 5]; // wrong length
1139        assert!(fmm(&data, &ids, None, 1).is_err());
1140    }
1141
1142    #[test]
1143    fn test_fmm_single_visit_per_subject() {
1144        let n = 10;
1145        let m = 40;
1146        let t = uniform_grid(m);
1147        let mut col_major = vec![0.0; n * m];
1148        let subject_ids: Vec<usize> = (0..n).collect();
1149
1150        for i in 0..n {
1151            for (j, &tj) in t.iter().enumerate() {
1152                col_major[i + j * n] = (2.0 * PI * tj).sin();
1153            }
1154        }
1155        let data = FdMatrix::from_column_major(col_major, n, m).unwrap();
1156
1157        // Should still work with 1 visit per subject
1158        let result = fmm(&data, &subject_ids, None, 2).unwrap();
1159        assert_eq!(result.n_subjects, n);
1160        assert_eq!(result.fitted.nrows(), n);
1161    }
1162
1163    #[test]
1164    fn test_build_subject_map() {
1165        let (map, n) = build_subject_map(&[5, 5, 10, 10, 20]);
1166        assert_eq!(n, 3);
1167        assert_eq!(map, vec![0, 0, 1, 1, 2]);
1168    }
1169
1170    #[test]
1171    fn test_variance_components_positive() {
1172        let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
1173        let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
1174
1175        assert!(result.sigma2_eps >= 0.0);
1176        for &s in &result.sigma2_u {
1177            assert!(s >= 0.0);
1178        }
1179    }
1180
1181    // -------------------------------------------------------------------
1182    // Additional tests
1183    // -------------------------------------------------------------------
1184
1185    #[test]
1186    fn test_fmm_ncomp_zero_returns_error() {
1187        let (data, subject_ids, _cov, _t) = generate_fmm_data(5, 2, 20);
1188        let err = fmm(&data, &subject_ids, None, 0).unwrap_err();
1189        match err {
1190            FdarError::InvalidParameter { parameter, .. } => {
1191                assert_eq!(parameter, "ncomp");
1192            }
1193            other => panic!("Expected InvalidParameter, got {:?}", other),
1194        }
1195    }
1196
1197    #[test]
1198    fn test_fmm_single_component() {
1199        // Fit with only 1 FPC component
1200        let (data, subject_ids, covariates, _t) = generate_fmm_data(8, 3, 30);
1201        let result = fmm(&data, &subject_ids, Some(&covariates), 1).unwrap();
1202
1203        assert_eq!(result.ncomp, 1);
1204        assert_eq!(result.sigma2_u.len(), 1);
1205        assert_eq!(result.eigenvalues.len(), 1);
1206        assert_eq!(result.mean_function.len(), 30);
1207        // Fitted + residuals = data
1208        for i in 0..data.nrows() {
1209            for t in 0..data.ncols() {
1210                let diff = (result.fitted[(i, t)] + result.residuals[(i, t)] - data[(i, t)]).abs();
1211                assert!(diff < 1e-8);
1212            }
1213        }
1214    }
1215
1216    #[test]
1217    fn test_fmm_two_subjects() {
1218        // Minimal number of subjects (2) with multiple visits
1219        let n_subjects = 2;
1220        let n_visits = 5;
1221        let m = 20;
1222        let t = uniform_grid(m);
1223        let n_total = n_subjects * n_visits;
1224        let mut col_major = vec![0.0; n_total * m];
1225        let mut subject_ids = vec![0usize; n_total];
1226
1227        for s in 0..n_subjects {
1228            for v in 0..n_visits {
1229                let obs = s * n_visits + v;
1230                subject_ids[obs] = s;
1231                for (j, &tj) in t.iter().enumerate() {
1232                    col_major[obs + j * n_total] =
1233                        (2.0 * PI * tj).sin() + (s as f64) * 0.5 + 0.01 * v as f64;
1234                }
1235            }
1236        }
1237        let data = FdMatrix::from_column_major(col_major, n_total, m).unwrap();
1238        let result = fmm(&data, &subject_ids, None, 2).unwrap();
1239
1240        assert_eq!(result.n_subjects, 2);
1241        assert_eq!(result.random_effects.nrows(), 2);
1242        assert_eq!(result.fitted.nrows(), n_total);
1243    }
1244
1245    #[test]
1246    fn test_fmm_predict_no_covariates() {
1247        let (data, subject_ids, _cov, _t) = generate_fmm_data(6, 3, 30);
1248        let result = fmm(&data, &subject_ids, None, 2).unwrap();
1249
1250        // Predict without covariates — should return mean function
1251        let predicted = fmm_predict(&result, None);
1252        assert_eq!(predicted.nrows(), 1);
1253        assert_eq!(predicted.ncols(), 30);
1254        for t in 0..30 {
1255            let diff = (predicted[(0, t)] - result.mean_function[t]).abs();
1256            assert!(
1257                diff < 1e-12,
1258                "Without covariates, prediction should equal mean"
1259            );
1260        }
1261    }
1262
1263    #[test]
1264    fn test_fmm_predict_multiple_new_subjects() {
1265        let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 40);
1266        let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
1267
1268        // Predict for 3 new subjects with different covariate values
1269        let new_cov = FdMatrix::from_column_major(vec![0.1, 0.5, 0.9], 3, 1).unwrap();
1270        let predicted = fmm_predict(&result, Some(&new_cov));
1271
1272        assert_eq!(predicted.nrows(), 3);
1273        assert_eq!(predicted.ncols(), 40);
1274
1275        // All predictions should be finite
1276        for i in 0..3 {
1277            for t in 0..40 {
1278                assert!(predicted[(i, t)].is_finite());
1279            }
1280        }
1281
1282        // Predictions for different covariates should differ
1283        let diff_01: f64 = (0..40)
1284            .map(|t| (predicted[(0, t)] - predicted[(1, t)]).powi(2))
1285            .sum();
1286        assert!(
1287            diff_01 > 1e-10,
1288            "Different covariates should yield different predictions"
1289        );
1290    }
1291
1292    #[test]
1293    fn test_fmm_eigenvalues_decreasing() {
1294        let (data, subject_ids, _cov, _t) = generate_fmm_data(10, 3, 50);
1295        let result = fmm(&data, &subject_ids, None, 5).unwrap();
1296
1297        // Eigenvalues should be in decreasing order (from FPCA)
1298        for i in 1..result.eigenvalues.len() {
1299            assert!(
1300                result.eigenvalues[i] <= result.eigenvalues[i - 1] + 1e-10,
1301                "Eigenvalues should be non-increasing: {} > {}",
1302                result.eigenvalues[i],
1303                result.eigenvalues[i - 1]
1304            );
1305        }
1306    }
1307
1308    #[test]
1309    fn test_fmm_random_effects_sum_near_zero() {
1310        // Random effects should approximately sum to zero across subjects
1311        let (data, subject_ids, covariates, _t) = generate_fmm_data(20, 3, 40);
1312        let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
1313
1314        let m = result.mean_function.len();
1315        for t in 0..m {
1316            let sum: f64 = (0..result.n_subjects)
1317                .map(|s| result.random_effects[(s, t)])
1318                .sum();
1319            let mean_abs: f64 = (0..result.n_subjects)
1320                .map(|s| result.random_effects[(s, t)].abs())
1321                .sum::<f64>()
1322                / result.n_subjects as f64;
1323            // Relative to the scale of random effects, the sum should be small
1324            if mean_abs > 1e-10 {
1325                assert!(
1326                    (sum / result.n_subjects as f64).abs() < mean_abs * 2.0,
1327                    "Random effects should roughly center around zero at t={}: sum={}, mean_abs={}",
1328                    t,
1329                    sum,
1330                    mean_abs
1331                );
1332            }
1333        }
1334    }
1335
1336    #[test]
1337    fn test_fmm_subject_ids_mismatch_error() {
1338        let data = FdMatrix::zeros(10, 20);
1339        let ids = vec![0; 7]; // wrong length
1340        let err = fmm(&data, &ids, None, 1).unwrap_err();
1341        match err {
1342            FdarError::InvalidDimension { parameter, .. } => {
1343                assert_eq!(parameter, "subject_ids");
1344            }
1345            other => panic!("Expected InvalidDimension, got {:?}", other),
1346        }
1347    }
1348
1349    #[test]
1350    fn test_fmm_test_fixed_empty_data_error() {
1351        let data = FdMatrix::zeros(0, 0);
1352        let covariates = FdMatrix::zeros(0, 1);
1353        let err = fmm_test_fixed(&data, &[], &covariates, 1, 10, 42).unwrap_err();
1354        match err {
1355            FdarError::InvalidDimension { parameter, .. } => {
1356                assert_eq!(parameter, "data");
1357            }
1358            other => panic!("Expected InvalidDimension for data, got {:?}", other),
1359        }
1360    }
1361
1362    #[test]
1363    fn test_fmm_test_fixed_zero_covariates_error() {
1364        let data = FdMatrix::zeros(10, 20);
1365        let ids = vec![0; 10];
1366        let covariates = FdMatrix::zeros(10, 0);
1367        let err = fmm_test_fixed(&data, &ids, &covariates, 1, 10, 42).unwrap_err();
1368        match err {
1369            FdarError::InvalidDimension { parameter, .. } => {
1370                assert_eq!(parameter, "covariates");
1371            }
1372            other => panic!("Expected InvalidDimension for covariates, got {:?}", other),
1373        }
1374    }
1375
1376    #[test]
1377    fn test_build_subject_map_single_subject() {
1378        let (map, n) = build_subject_map(&[42, 42, 42]);
1379        assert_eq!(n, 1);
1380        assert_eq!(map, vec![0, 0, 0]);
1381    }
1382
1383    #[test]
1384    fn test_build_subject_map_non_contiguous_ids() {
1385        let (map, n) = build_subject_map(&[100, 200, 100, 300, 200]);
1386        assert_eq!(n, 3);
1387        // sorted unique: [100, 200, 300] -> indices [0, 1, 2]
1388        assert_eq!(map, vec![0, 1, 0, 2, 1]);
1389    }
1390
1391    #[test]
1392    fn test_fmm_many_components_clamped() {
1393        // Request more components than available; FPCA should clamp
1394        let (data, subject_ids, _cov, _t) = generate_fmm_data(5, 3, 20);
1395        let n_total = data.nrows();
1396        // Request 100 components — should be clamped to min(n_total, m) - 1
1397        let result = fmm(&data, &subject_ids, None, 100).unwrap();
1398        assert!(
1399            result.ncomp <= n_total.min(20),
1400            "ncomp should be clamped: got {}",
1401            result.ncomp
1402        );
1403        assert!(result.ncomp >= 1);
1404    }
1405
1406    #[test]
1407    fn test_fmm_residuals_small_with_enough_components() {
1408        // With enough components, residuals should be small relative to data
1409        let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 30);
1410        let result = fmm(&data, &subject_ids, Some(&covariates), 5).unwrap();
1411
1412        let n = data.nrows();
1413        let m = data.ncols();
1414        let mut data_ss = 0.0_f64;
1415        let mut resid_ss = 0.0_f64;
1416        for i in 0..n {
1417            for t in 0..m {
1418                data_ss += data[(i, t)].powi(2);
1419                resid_ss += result.residuals[(i, t)].powi(2);
1420            }
1421        }
1422
1423        // R-squared should be reasonably high for structured data
1424        let r_squared = 1.0 - resid_ss / data_ss;
1425        assert!(
1426            r_squared > 0.5,
1427            "R-squared should be high with enough components: {}",
1428            r_squared
1429        );
1430    }
1431}