Skip to main content

fdars_core/
famm.rs

1//! Functional Additive Mixed Models (FAMM).
2//!
3//! Implements functional mixed effects models for repeated functional
4//! measurements with subject-level covariates.
5//!
6//! Model: `Y_ij(t) = μ(t) + X_i'β(t) + b_i(t) + ε_ij(t)`
7//!
8//! Key functions:
9//! - [`fmm`] — Fit a functional mixed model via FPC decomposition
10//! - [`fmm_predict`] — Predict curves for new subjects
11//! - [`fmm_test_fixed`] — Hypothesis test on fixed effects
12
13use crate::error::FdarError;
14use crate::iter_maybe_parallel;
15use crate::linalg::{
16    cholesky_factor as linalg_cholesky_factor,
17    cholesky_forward_back as linalg_cholesky_forward_back,
18};
19use crate::matrix::FdMatrix;
20use crate::regression::fdata_to_pc_1d;
21#[cfg(feature = "parallel")]
22use rayon::iter::ParallelIterator;
23
24/// Result of a functional mixed model fit.
25#[derive(Debug, Clone, PartialEq)]
26#[non_exhaustive]
27pub struct FmmResult {
28    /// Overall mean function μ̂(t) (length m)
29    pub mean_function: Vec<f64>,
30    /// Fixed effect coefficient functions β̂_j(t) (p × m matrix, one row per covariate)
31    pub beta_functions: FdMatrix,
32    /// Random effect functions b̂_i(t) per subject (n_subjects × m)
33    pub random_effects: FdMatrix,
34    /// Fitted values for all observations (n_total × m)
35    pub fitted: FdMatrix,
36    /// Residuals (n_total × m)
37    pub residuals: FdMatrix,
38    /// Variance of random effects at each time point (length m)
39    pub random_variance: Vec<f64>,
40    /// Residual variance estimate
41    pub sigma2_eps: f64,
42    /// Random effect variance estimate (per-component)
43    pub sigma2_u: Vec<f64>,
44    /// Number of FPC components used
45    pub ncomp: usize,
46    /// Number of subjects
47    pub n_subjects: usize,
48    /// FPC eigenvalues (singular values squared / n)
49    pub eigenvalues: Vec<f64>,
50}
51
52/// Result of fixed effect hypothesis test.
53#[derive(Debug, Clone, PartialEq)]
54#[non_exhaustive]
55pub struct FmmTestResult {
56    /// F-statistic per covariate (length p)
57    pub f_statistics: Vec<f64>,
58    /// P-values per covariate (via permutation, length p)
59    pub p_values: Vec<f64>,
60}
61
62// ---------------------------------------------------------------------------
63// Core FMM algorithm
64// ---------------------------------------------------------------------------
65
66/// Fit a functional mixed model via FPC decomposition.
67///
68/// # Arguments
69/// * `data` — All observed curves (n_total × m), stacked across subjects and visits
70/// * `subject_ids` — Subject identifier for each curve (length n_total)
71/// * `covariates` — Subject-level covariates (n_total × p).
72///   Each row corresponds to the same curve in `data`.
73///   If a covariate is subject-level, its value should be repeated across visits.
74/// * `ncomp` — Number of FPC components
75///
76/// # Algorithm
77/// 1. Pool curves, compute FPCA
78/// 2. For each FPC score, fit scalar mixed model: ξ_ijk = x_i'γ_k + u_ik + e_ijk
79/// 3. Recover β̂(t) and b̂_i(t) from component coefficients
80///
81/// # Errors
82///
83/// Returns [`FdarError::InvalidDimension`] if `data` is empty (zero rows or
84/// columns), or if `subject_ids.len()` does not match the number of rows.
85/// Returns [`FdarError::InvalidParameter`] if `ncomp` is zero.
86/// Returns [`FdarError::ComputationFailed`] if the underlying FPCA fails.
87#[must_use = "expensive computation whose result should not be discarded"]
88pub fn fmm(
89    data: &FdMatrix,
90    subject_ids: &[usize],
91    covariates: Option<&FdMatrix>,
92    ncomp: usize,
93) -> Result<FmmResult, FdarError> {
94    let n_total = data.nrows();
95    let m = data.ncols();
96    if n_total == 0 || m == 0 {
97        return Err(FdarError::InvalidDimension {
98            parameter: "data",
99            expected: "non-empty matrix".to_string(),
100            actual: format!("{n_total} x {m}"),
101        });
102    }
103    if subject_ids.len() != n_total {
104        return Err(FdarError::InvalidDimension {
105            parameter: "subject_ids",
106            expected: format!("length {n_total}"),
107            actual: format!("length {}", subject_ids.len()),
108        });
109    }
110    if ncomp == 0 {
111        return Err(FdarError::InvalidParameter {
112            parameter: "ncomp",
113            message: "must be >= 1".to_string(),
114        });
115    }
116
117    // Determine unique subjects
118    let (subject_map, n_subjects) = build_subject_map(subject_ids);
119
120    // Step 1: FPCA on pooled data
121    let argvals: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1).max(1) as f64).collect();
122    let fpca = fdata_to_pc_1d(data, ncomp, &argvals)?;
123    let k = fpca.scores.ncols(); // actual number of components
124
125    // Step 2: For each FPC score, fit scalar mixed model (parallelized)
126    let p = covariates.map_or(0, super::matrix::FdMatrix::ncols);
127    let ComponentResults {
128        gamma,
129        u_hat,
130        sigma2_u,
131        sigma2_eps,
132    } = fit_all_components(
133        &fpca.scores,
134        &subject_map,
135        n_subjects,
136        covariates,
137        p,
138        k,
139        n_total,
140        m,
141    );
142
143    // Step 3: Recover functional coefficients (using gamma in original scale)
144    let beta_functions = recover_beta_functions(&gamma, &fpca.rotation, p, m, k);
145    let random_effects = recover_random_effects(&u_hat, &fpca.rotation, n_subjects, m, k);
146
147    // Compute random variance function: Var(b_i(t)) across subjects
148    let random_variance = compute_random_variance(&random_effects, n_subjects, m);
149
150    // Compute fitted and residuals
151    let (fitted, residuals) = compute_fitted_residuals(
152        data,
153        &fpca.mean,
154        &beta_functions,
155        &random_effects,
156        covariates,
157        &subject_map,
158        n_total,
159        m,
160        p,
161    );
162
163    let eigenvalues: Vec<f64> = fpca
164        .singular_values
165        .iter()
166        .map(|&sv| sv * sv / n_total as f64)
167        .collect();
168
169    Ok(FmmResult {
170        mean_function: fpca.mean,
171        beta_functions,
172        random_effects,
173        fitted,
174        residuals,
175        random_variance,
176        sigma2_eps,
177        sigma2_u,
178        ncomp: k,
179        n_subjects,
180        eigenvalues,
181    })
182}
183
184/// Build mapping from observation index to subject index (0..n_subjects-1).
185fn build_subject_map(subject_ids: &[usize]) -> (Vec<usize>, usize) {
186    let mut unique_ids: Vec<usize> = subject_ids.to_vec();
187    unique_ids.sort_unstable();
188    unique_ids.dedup();
189    let n_subjects = unique_ids.len();
190
191    let map: Vec<usize> = subject_ids
192        .iter()
193        .map(|id| unique_ids.iter().position(|u| u == id).unwrap_or(0))
194        .collect();
195
196    (map, n_subjects)
197}
198
199/// Aggregated results from fitting all FPC components.
200struct ComponentResults {
201    gamma: Vec<Vec<f64>>, // gamma[j][k] = fixed effect coeff j for component k
202    u_hat: Vec<Vec<f64>>, // u_hat[i][k] = random effect for subject i, component k
203    sigma2_u: Vec<f64>,   // per-component random effect variance
204    sigma2_eps: f64,      // average residual variance across components
205}
206
207/// Fit scalar mixed models for all FPC components (parallelized across components).
208///
209/// For each component k, scales FPC scores to L²-normalized space, fits a scalar
210/// mixed model, then scales coefficients back to the original score space.
211#[allow(clippy::too_many_arguments)]
212fn fit_all_components(
213    scores: &FdMatrix,
214    subject_map: &[usize],
215    n_subjects: usize,
216    covariates: Option<&FdMatrix>,
217    p: usize,
218    k: usize,
219    n_total: usize,
220    m: usize,
221) -> ComponentResults {
222    // Normalize scores by sqrt(h) to match R's L²-weighted FPCA convention.
223    // This ensures variance components are on the same scale as R's lmer().
224    let h = if m > 1 { 1.0 / (m - 1) as f64 } else { 1.0 };
225    let score_scale = h.sqrt();
226
227    // Fit each component independently — parallelized when the feature is enabled
228    let per_comp: Vec<ScalarMixedResult> = iter_maybe_parallel!(0..k)
229        .map(|comp| {
230            let comp_scores: Vec<f64> = (0..n_total)
231                .map(|i| scores[(i, comp)] * score_scale)
232                .collect();
233            fit_scalar_mixed_model(&comp_scores, subject_map, n_subjects, covariates, p)
234        })
235        .collect();
236
237    // Unpack per-component results into the aggregate structure
238    let mut gamma = vec![vec![0.0; k]; p];
239    let mut u_hat = vec![vec![0.0; k]; n_subjects];
240    let mut sigma2_u = vec![0.0; k];
241    let mut sigma2_eps_total = 0.0;
242
243    for (comp, result) in per_comp.iter().enumerate() {
244        for j in 0..p {
245            gamma[j][comp] = result.gamma[j] / score_scale;
246        }
247        for s in 0..n_subjects {
248            u_hat[s][comp] = result.u_hat[s] / score_scale;
249        }
250        sigma2_u[comp] = result.sigma2_u;
251        sigma2_eps_total += result.sigma2_eps;
252    }
253    let sigma2_eps = sigma2_eps_total / k as f64;
254
255    ComponentResults {
256        gamma,
257        u_hat,
258        sigma2_u,
259        sigma2_eps,
260    }
261}
262
263/// Scalar mixed model result for one FPC component.
264struct ScalarMixedResult {
265    gamma: Vec<f64>, // fixed effects (length p)
266    u_hat: Vec<f64>, // random effects per subject (length n_subjects)
267    sigma2_u: f64,   // random effect variance
268    sigma2_eps: f64, // residual variance
269}
270
271/// Precomputed subject structure for the mixed model.
272struct SubjectStructure {
273    counts: Vec<usize>,
274    obs: Vec<Vec<usize>>,
275}
276
277impl SubjectStructure {
278    fn new(subject_map: &[usize], n_subjects: usize, n: usize) -> Self {
279        let mut counts = vec![0usize; n_subjects];
280        let mut obs: Vec<Vec<usize>> = vec![Vec::new(); n_subjects];
281        for i in 0..n {
282            let s = subject_map[i];
283            counts[s] += 1;
284            obs[s].push(i);
285        }
286        Self { counts, obs }
287    }
288}
289
290/// Compute shrinkage weights: w_s = σ²_u / (σ²_u + σ²_e / n_s).
291fn shrinkage_weights(ss: &SubjectStructure, sigma2_u: f64, sigma2_e: f64) -> Vec<f64> {
292    ss.counts
293        .iter()
294        .map(|&c| {
295            let ns = c as f64;
296            if ns < 1.0 {
297                0.0
298            } else {
299                sigma2_u / (sigma2_u + sigma2_e / ns)
300            }
301        })
302        .collect()
303}
304
305/// GLS fixed effect update using block-diagonal V^{-1}.
306///
307/// Computes γ = (X'V⁻¹X)⁻¹ X'V⁻¹y exploiting the balanced random intercept structure.
308fn gls_update_gamma(
309    cov: &FdMatrix,
310    p: usize,
311    ss: &SubjectStructure,
312    weights: &[f64],
313    y: &[f64],
314    sigma2_e: f64,
315) -> Option<Vec<f64>> {
316    let n_subjects = ss.counts.len();
317    let mut xtvinvx = vec![0.0; p * p];
318    let mut xtvinvy = vec![0.0; p];
319    let inv_e = 1.0 / sigma2_e;
320
321    for s in 0..n_subjects {
322        let ns = ss.counts[s] as f64;
323        if ns < 1.0 {
324            continue;
325        }
326        let (x_sum, y_sum) = subject_sums(cov, y, &ss.obs[s], p);
327        accumulate_gls_terms(
328            cov,
329            y,
330            &ss.obs[s],
331            &x_sum,
332            y_sum,
333            weights[s],
334            ns,
335            inv_e,
336            p,
337            &mut xtvinvx,
338            &mut xtvinvy,
339        );
340    }
341
342    for j in 0..p {
343        xtvinvx[j * p + j] += 1e-10;
344    }
345    cholesky_solve(&xtvinvx, &xtvinvy, p)
346}
347
348/// Compute subject-level covariate sums and response sum.
349fn subject_sums(cov: &FdMatrix, y: &[f64], obs: &[usize], p: usize) -> (Vec<f64>, f64) {
350    let mut x_sum = vec![0.0; p];
351    let mut y_sum = 0.0;
352    for &i in obs {
353        for r in 0..p {
354            x_sum[r] += cov[(i, r)];
355        }
356        y_sum += y[i];
357    }
358    (x_sum, y_sum)
359}
360
361/// Accumulate X'V^{-1}X and X'V^{-1}y for one subject.
362fn accumulate_gls_terms(
363    cov: &FdMatrix,
364    y: &[f64],
365    obs: &[usize],
366    x_sum: &[f64],
367    y_sum: f64,
368    w_s: f64,
369    ns: f64,
370    inv_e: f64,
371    p: usize,
372    xtvinvx: &mut [f64],
373    xtvinvy: &mut [f64],
374) {
375    for &i in obs {
376        let vinv_y = inv_e * (y[i] - w_s * y_sum / ns);
377        for r in 0..p {
378            xtvinvy[r] += cov[(i, r)] * vinv_y;
379            for c in r..p {
380                let vinv_xc = inv_e * (cov[(i, c)] - w_s * x_sum[c] / ns);
381                let val = cov[(i, r)] * vinv_xc;
382                xtvinvx[r * p + c] += val;
383                if r != c {
384                    xtvinvx[c * p + r] += val;
385                }
386            }
387        }
388    }
389}
390
391/// REML EM update for variance components.
392///
393/// Returns (σ²_u_new, σ²_e_new) from the conditional expectations.
394/// Uses n - p divisor for σ²_e (REML correction where p = number of fixed effects).
395fn reml_variance_update(
396    residuals: &[f64],
397    ss: &SubjectStructure,
398    weights: &[f64],
399    sigma2_u: f64,
400    p: usize,
401) -> (f64, f64) {
402    let n_subjects = ss.counts.len();
403    let n: usize = ss.counts.iter().sum();
404    let mut sigma2_u_new = 0.0;
405    let mut sigma2_e_new = 0.0;
406
407    for s in 0..n_subjects {
408        let ns = ss.counts[s] as f64;
409        if ns < 1.0 {
410            continue;
411        }
412        let w_s = weights[s];
413        let mean_r_s: f64 = ss.obs[s].iter().map(|&i| residuals[i]).sum::<f64>() / ns;
414        let u_hat_s = w_s * mean_r_s;
415        let cond_var_s = sigma2_u * (1.0 - w_s);
416
417        sigma2_u_new += u_hat_s * u_hat_s + cond_var_s;
418        for &i in &ss.obs[s] {
419            sigma2_e_new += (residuals[i] - u_hat_s).powi(2);
420        }
421        sigma2_e_new += ns * cond_var_s;
422    }
423
424    // REML divisor: n - p for residual variance (matches R's lmer)
425    let denom_e = (n.saturating_sub(p)).max(1) as f64;
426
427    (
428        (sigma2_u_new / n_subjects as f64).max(1e-15),
429        (sigma2_e_new / denom_e).max(1e-15),
430    )
431}
432
433/// Fit scalar mixed model: y_ij = x_i'γ + u_i + e_ij.
434///
435/// Uses iterative GLS for fixed effects + REML EM for variance components,
436/// matching R's lmer() behavior. Initializes from Henderson's ANOVA, then
437/// iterates until convergence.
438fn fit_scalar_mixed_model(
439    y: &[f64],
440    subject_map: &[usize],
441    n_subjects: usize,
442    covariates: Option<&FdMatrix>,
443    p: usize,
444) -> ScalarMixedResult {
445    let n = y.len();
446    let ss = SubjectStructure::new(subject_map, n_subjects, n);
447
448    // Initialize from OLS + Henderson's ANOVA
449    let gamma_init = estimate_fixed_effects(y, covariates, p, n);
450    let residuals_init = compute_ols_residuals(y, covariates, &gamma_init, p, n);
451    let (mut sigma2_u, mut sigma2_e) =
452        estimate_variance_components(&residuals_init, subject_map, n_subjects, n);
453
454    if sigma2_e < 1e-15 {
455        sigma2_e = 1e-6;
456    }
457    if sigma2_u < 1e-15 {
458        sigma2_u = sigma2_e * 0.1;
459    }
460
461    let mut gamma = gamma_init;
462
463    for _iter in 0..50 {
464        let sigma2_u_old = sigma2_u;
465        let sigma2_e_old = sigma2_e;
466
467        let weights = shrinkage_weights(&ss, sigma2_u, sigma2_e);
468
469        if let Some(cov) = covariates.filter(|_| p > 0) {
470            if let Some(g) = gls_update_gamma(cov, p, &ss, &weights, y, sigma2_e) {
471                gamma = g;
472            }
473        }
474
475        let r = compute_ols_residuals(y, covariates, &gamma, p, n);
476        (sigma2_u, sigma2_e) = reml_variance_update(&r, &ss, &weights, sigma2_u, p);
477
478        let delta = (sigma2_u - sigma2_u_old).abs() + (sigma2_e - sigma2_e_old).abs();
479        if delta < 1e-10 * (sigma2_u_old + sigma2_e_old) {
480            break;
481        }
482    }
483
484    let final_residuals = compute_ols_residuals(y, covariates, &gamma, p, n);
485    let u_hat = compute_blup(
486        &final_residuals,
487        subject_map,
488        n_subjects,
489        sigma2_u,
490        sigma2_e,
491    );
492
493    ScalarMixedResult {
494        gamma,
495        u_hat,
496        sigma2_u,
497        sigma2_eps: sigma2_e,
498    }
499}
500
501/// OLS estimation of fixed effects.
502fn estimate_fixed_effects(
503    y: &[f64],
504    covariates: Option<&FdMatrix>,
505    p: usize,
506    n: usize,
507) -> Vec<f64> {
508    if p == 0 || covariates.is_none() {
509        return Vec::new();
510    }
511    let cov = covariates.expect("checked: covariates is Some");
512
513    // Solve (X'X)γ = X'y via Cholesky
514    let mut xtx = vec![0.0; p * p];
515    let mut xty = vec![0.0; p];
516    for i in 0..n {
517        for r in 0..p {
518            xty[r] += cov[(i, r)] * y[i];
519            for s in r..p {
520                let val = cov[(i, r)] * cov[(i, s)];
521                xtx[r * p + s] += val;
522                if r != s {
523                    xtx[s * p + r] += val;
524                }
525            }
526        }
527    }
528    // Regularize
529    for j in 0..p {
530        xtx[j * p + j] += 1e-8;
531    }
532
533    cholesky_solve(&xtx, &xty, p).unwrap_or(vec![0.0; p])
534}
535
536/// Cholesky solve: A x = b where A is p-by-p symmetric positive definite.
537/// Returns `None` if the matrix is singular.
538fn cholesky_solve(a: &[f64], b: &[f64], p: usize) -> Option<Vec<f64>> {
539    let l = linalg_cholesky_factor(a, p).ok()?;
540    Some(linalg_cholesky_forward_back(&l, b, p))
541}
542
543/// Compute OLS residuals: r = y - X*gamma.
544fn compute_ols_residuals(
545    y: &[f64],
546    covariates: Option<&FdMatrix>,
547    gamma: &[f64],
548    p: usize,
549    n: usize,
550) -> Vec<f64> {
551    let mut residuals = y.to_vec();
552    if p > 0 {
553        if let Some(cov) = covariates {
554            for i in 0..n {
555                for j in 0..p {
556                    residuals[i] -= cov[(i, j)] * gamma[j];
557                }
558            }
559        }
560    }
561    residuals
562}
563
564/// Estimate variance components via method of moments.
565///
566/// σ²_u and σ²_ε from one-way random effects ANOVA.
567fn estimate_variance_components(
568    residuals: &[f64],
569    subject_map: &[usize],
570    n_subjects: usize,
571    n: usize,
572) -> (f64, f64) {
573    // Compute subject means and within-subject SS
574    let mut subject_sums = vec![0.0; n_subjects];
575    let mut subject_counts = vec![0usize; n_subjects];
576    for i in 0..n {
577        let s = subject_map[i];
578        subject_sums[s] += residuals[i];
579        subject_counts[s] += 1;
580    }
581    let subject_means: Vec<f64> = subject_sums
582        .iter()
583        .zip(&subject_counts)
584        .map(|(&s, &c)| if c > 0 { s / c as f64 } else { 0.0 })
585        .collect();
586
587    // Within-subject SS
588    let mut ss_within = 0.0;
589    for i in 0..n {
590        let s = subject_map[i];
591        ss_within += (residuals[i] - subject_means[s]).powi(2);
592    }
593    let df_within = n.saturating_sub(n_subjects);
594
595    // Between-subject SS
596    let grand_mean = residuals.iter().sum::<f64>() / n as f64;
597    let mut ss_between = 0.0;
598    for s in 0..n_subjects {
599        ss_between += subject_counts[s] as f64 * (subject_means[s] - grand_mean).powi(2);
600    }
601
602    let sigma2_eps = if df_within > 0 {
603        ss_within / df_within as f64
604    } else {
605        1e-6
606    };
607
608    // Mean number of observations per subject
609    let n_bar = n as f64 / n_subjects.max(1) as f64;
610    let df_between = n_subjects.saturating_sub(1).max(1);
611    let ms_between = ss_between / df_between as f64;
612    let sigma2_u = ((ms_between - sigma2_eps) / n_bar).max(0.0);
613
614    (sigma2_u, sigma2_eps)
615}
616
617/// Compute BLUP (Best Linear Unbiased Prediction) for random effects.
618///
619/// û_i = σ²_u / (σ²_u + σ²_ε/n_i) * (ȳ_i - x̄_i'γ)
620fn compute_blup(
621    residuals: &[f64],
622    subject_map: &[usize],
623    n_subjects: usize,
624    sigma2_u: f64,
625    sigma2_eps: f64,
626) -> Vec<f64> {
627    let mut subject_sums = vec![0.0; n_subjects];
628    let mut subject_counts = vec![0usize; n_subjects];
629    for (i, &r) in residuals.iter().enumerate() {
630        let s = subject_map[i];
631        subject_sums[s] += r;
632        subject_counts[s] += 1;
633    }
634
635    (0..n_subjects)
636        .map(|s| {
637            let ni = subject_counts[s] as f64;
638            if ni < 1.0 {
639                return 0.0;
640            }
641            let mean_r = subject_sums[s] / ni;
642            let shrinkage = sigma2_u / (sigma2_u + sigma2_eps / ni).max(1e-15);
643            shrinkage * mean_r
644        })
645        .collect()
646}
647
648// ---------------------------------------------------------------------------
649// Recovery of functional coefficients
650// ---------------------------------------------------------------------------
651
652/// Recover β̂(t) = Σ_k γ̂_jk φ_k(t) for each covariate j.
653fn recover_beta_functions(
654    gamma: &[Vec<f64>],
655    rotation: &FdMatrix,
656    p: usize,
657    m: usize,
658    k: usize,
659) -> FdMatrix {
660    let mut beta = FdMatrix::zeros(p, m);
661    for j in 0..p {
662        for t in 0..m {
663            let mut val = 0.0;
664            for comp in 0..k {
665                val += gamma[j][comp] * rotation[(t, comp)];
666            }
667            beta[(j, t)] = val;
668        }
669    }
670    beta
671}
672
673/// Recover b̂_i(t) = Σ_k û_ik φ_k(t) for each subject i.
674fn recover_random_effects(
675    u_hat: &[Vec<f64>],
676    rotation: &FdMatrix,
677    n_subjects: usize,
678    m: usize,
679    k: usize,
680) -> FdMatrix {
681    let mut re = FdMatrix::zeros(n_subjects, m);
682    for s in 0..n_subjects {
683        for t in 0..m {
684            let mut val = 0.0;
685            for comp in 0..k {
686                val += u_hat[s][comp] * rotation[(t, comp)];
687            }
688            re[(s, t)] = val;
689        }
690    }
691    re
692}
693
694/// Compute random effect variance function: Var_i(b̂_i(t)).
695fn compute_random_variance(random_effects: &FdMatrix, n_subjects: usize, m: usize) -> Vec<f64> {
696    (0..m)
697        .map(|t| {
698            let mean: f64 =
699                (0..n_subjects).map(|s| random_effects[(s, t)]).sum::<f64>() / n_subjects as f64;
700            let var: f64 = (0..n_subjects)
701                .map(|s| (random_effects[(s, t)] - mean).powi(2))
702                .sum::<f64>()
703                / n_subjects.max(1) as f64;
704            var
705        })
706        .collect()
707}
708
709/// Compute fitted values and residuals.
710fn compute_fitted_residuals(
711    data: &FdMatrix,
712    mean_function: &[f64],
713    beta_functions: &FdMatrix,
714    random_effects: &FdMatrix,
715    covariates: Option<&FdMatrix>,
716    subject_map: &[usize],
717    n_total: usize,
718    m: usize,
719    p: usize,
720) -> (FdMatrix, FdMatrix) {
721    let mut fitted = FdMatrix::zeros(n_total, m);
722    let mut residuals = FdMatrix::zeros(n_total, m);
723
724    for i in 0..n_total {
725        let s = subject_map[i];
726        for t in 0..m {
727            let mut val = mean_function[t] + random_effects[(s, t)];
728            if p > 0 {
729                if let Some(cov) = covariates {
730                    for j in 0..p {
731                        val += cov[(i, j)] * beta_functions[(j, t)];
732                    }
733                }
734            }
735            fitted[(i, t)] = val;
736            residuals[(i, t)] = data[(i, t)] - val;
737        }
738    }
739
740    (fitted, residuals)
741}
742
743// ---------------------------------------------------------------------------
744// Prediction
745// ---------------------------------------------------------------------------
746
747/// Predict curves for new subjects.
748///
749/// # Arguments
750/// * `result` — Fitted FMM result
751/// * `new_covariates` — Covariates for new subjects (n_new × p)
752///
753/// Returns predicted curves (n_new × m) using only fixed effects (no random effects for new subjects).
754#[must_use = "prediction result should not be discarded"]
755pub fn fmm_predict(result: &FmmResult, new_covariates: Option<&FdMatrix>) -> FdMatrix {
756    let m = result.mean_function.len();
757    let n_new = new_covariates.map_or(1, super::matrix::FdMatrix::nrows);
758    let p = result.beta_functions.nrows();
759
760    let mut predicted = FdMatrix::zeros(n_new, m);
761    for i in 0..n_new {
762        for t in 0..m {
763            let mut val = result.mean_function[t];
764            if let Some(cov) = new_covariates {
765                for j in 0..p {
766                    val += cov[(i, j)] * result.beta_functions[(j, t)];
767                }
768            }
769            predicted[(i, t)] = val;
770        }
771    }
772    predicted
773}
774
775// ---------------------------------------------------------------------------
776// Hypothesis testing
777// ---------------------------------------------------------------------------
778
779/// Permutation test for fixed effects in functional mixed model.
780///
781/// Tests H₀: β_j(t) = 0 for each covariate j.
782/// Uses integrated squared norm as test statistic: T_j = ∫ β̂_j(t)² dt.
783///
784/// # Arguments
785/// * `data` — All observed curves (n_total × m)
786/// * `subject_ids` — Subject identifiers
787/// * `covariates` — Subject-level covariates (n_total × p)
788/// * `ncomp` — Number of FPC components
789/// * `n_perm` — Number of permutations
790/// * `seed` — Random seed
791///
792/// # Errors
793///
794/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows, or
795/// `covariates` has zero columns.
796/// Propagates errors from [`fmm`] (e.g., dimension mismatches or FPCA failure).
797#[must_use = "expensive computation whose result should not be discarded"]
798pub fn fmm_test_fixed(
799    data: &FdMatrix,
800    subject_ids: &[usize],
801    covariates: &FdMatrix,
802    ncomp: usize,
803    n_perm: usize,
804    seed: u64,
805) -> Result<FmmTestResult, FdarError> {
806    let n_total = data.nrows();
807    let m = data.ncols();
808    let p = covariates.ncols();
809    if n_total == 0 {
810        return Err(FdarError::InvalidDimension {
811            parameter: "data",
812            expected: "non-empty matrix".to_string(),
813            actual: format!("{n_total} rows"),
814        });
815    }
816    if p == 0 {
817        return Err(FdarError::InvalidDimension {
818            parameter: "covariates",
819            expected: "at least 1 column".to_string(),
820            actual: "0 columns".to_string(),
821        });
822    }
823
824    // Fit observed model
825    let result = fmm(data, subject_ids, Some(covariates), ncomp)?;
826
827    // Observed test statistics: ∫ β̂_j(t)² dt for each covariate
828    let observed_stats = compute_integrated_beta_sq(&result.beta_functions, p, m);
829
830    // Permutation test
831    let (f_statistics, p_values) = permutation_test(
832        data,
833        subject_ids,
834        covariates,
835        ncomp,
836        n_perm,
837        seed,
838        &observed_stats,
839        p,
840        m,
841    );
842
843    Ok(FmmTestResult {
844        f_statistics,
845        p_values,
846    })
847}
848
849/// Compute ∫ β̂_j(t)² dt for each covariate.
850fn compute_integrated_beta_sq(beta: &FdMatrix, p: usize, m: usize) -> Vec<f64> {
851    let h = if m > 1 { 1.0 / (m - 1) as f64 } else { 1.0 };
852    (0..p)
853        .map(|j| {
854            let ss: f64 = (0..m).map(|t| beta[(j, t)].powi(2)).sum();
855            ss * h
856        })
857        .collect()
858}
859
860/// Run permutation test for fixed effects.
861fn permutation_test(
862    data: &FdMatrix,
863    subject_ids: &[usize],
864    covariates: &FdMatrix,
865    ncomp: usize,
866    n_perm: usize,
867    seed: u64,
868    observed_stats: &[f64],
869    p: usize,
870    m: usize,
871) -> (Vec<f64>, Vec<f64>) {
872    use rand::prelude::*;
873    let n_total = data.nrows();
874    let mut rng = StdRng::seed_from_u64(seed);
875    let mut n_ge = vec![0usize; p];
876
877    for _ in 0..n_perm {
878        // Permute covariates across subjects
879        let mut perm_indices: Vec<usize> = (0..n_total).collect();
880        perm_indices.shuffle(&mut rng);
881        let perm_cov = permute_rows(covariates, &perm_indices);
882
883        if let Ok(perm_result) = fmm(data, subject_ids, Some(&perm_cov), ncomp) {
884            let perm_stats = compute_integrated_beta_sq(&perm_result.beta_functions, p, m);
885            for j in 0..p {
886                if perm_stats[j] >= observed_stats[j] {
887                    n_ge[j] += 1;
888                }
889            }
890        }
891    }
892
893    let p_values: Vec<f64> = n_ge
894        .iter()
895        .map(|&count| (count + 1) as f64 / (n_perm + 1) as f64)
896        .collect();
897    let f_statistics = observed_stats.to_vec();
898
899    (f_statistics, p_values)
900}
901
902/// Permute rows of a matrix according to given indices.
903fn permute_rows(mat: &FdMatrix, indices: &[usize]) -> FdMatrix {
904    let n = indices.len();
905    let m = mat.ncols();
906    let mut result = FdMatrix::zeros(n, m);
907    for (new_i, &old_i) in indices.iter().enumerate() {
908        for j in 0..m {
909            result[(new_i, j)] = mat[(old_i, j)];
910        }
911    }
912    result
913}
914
915// ---------------------------------------------------------------------------
916// Tests
917// ---------------------------------------------------------------------------
918
919#[cfg(test)]
920mod tests {
921    use super::*;
922    use crate::test_helpers::uniform_grid;
923    use std::f64::consts::PI;
924
925    /// Generate repeated measurements: n_subjects × n_visits curves.
926    /// Subject-level covariate z affects the curve amplitude.
927    fn generate_fmm_data(
928        n_subjects: usize,
929        n_visits: usize,
930        m: usize,
931    ) -> (FdMatrix, Vec<usize>, FdMatrix, Vec<f64>) {
932        let t = uniform_grid(m);
933        let n_total = n_subjects * n_visits;
934        let mut col_major = vec![0.0; n_total * m];
935        let mut subject_ids = vec![0usize; n_total];
936        let mut cov_data = vec![0.0; n_total];
937
938        for s in 0..n_subjects {
939            let z = s as f64 / n_subjects as f64; // covariate in [0, 1)
940            let subject_effect = 0.5 * (s as f64 - n_subjects as f64 / 2.0); // random-like effect
941
942            for v in 0..n_visits {
943                let obs = s * n_visits + v;
944                subject_ids[obs] = s;
945                cov_data[obs] = z;
946                let noise_scale = 0.05;
947
948                for (j, &tj) in t.iter().enumerate() {
949                    // Y_sv(t) = sin(2πt) + z * t + subject_effect * cos(2πt) + noise
950                    let mu = (2.0 * PI * tj).sin();
951                    let fixed = z * tj * 3.0;
952                    let random = subject_effect * (2.0 * PI * tj).cos() * 0.3;
953                    let noise = noise_scale * ((obs * 7 + j * 3) % 100) as f64 / 100.0;
954                    col_major[obs + j * n_total] = mu + fixed + random + noise;
955                }
956            }
957        }
958
959        let data = FdMatrix::from_column_major(col_major, n_total, m).unwrap();
960        let covariates = FdMatrix::from_column_major(cov_data, n_total, 1).unwrap();
961        (data, subject_ids, covariates, t)
962    }
963
964    #[test]
965    fn test_fmm_basic() {
966        let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
967        let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
968
969        assert_eq!(result.mean_function.len(), 50);
970        assert_eq!(result.beta_functions.nrows(), 1); // 1 covariate
971        assert_eq!(result.beta_functions.ncols(), 50);
972        assert_eq!(result.random_effects.nrows(), 10);
973        assert_eq!(result.fitted.nrows(), 30);
974        assert_eq!(result.residuals.nrows(), 30);
975        assert_eq!(result.n_subjects, 10);
976    }
977
978    #[test]
979    fn test_fmm_fitted_plus_residuals_equals_data() {
980        let (data, subject_ids, covariates, _t) = generate_fmm_data(8, 3, 40);
981        let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
982
983        let n = data.nrows();
984        let m = data.ncols();
985        for i in 0..n {
986            for t in 0..m {
987                let reconstructed = result.fitted[(i, t)] + result.residuals[(i, t)];
988                assert!(
989                    (reconstructed - data[(i, t)]).abs() < 1e-8,
990                    "Fitted + residual should equal data at ({}, {}): {} vs {}",
991                    i,
992                    t,
993                    reconstructed,
994                    data[(i, t)]
995                );
996            }
997        }
998    }
999
1000    #[test]
1001    fn test_fmm_random_variance_positive() {
1002        let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
1003        let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
1004
1005        for &v in &result.random_variance {
1006            assert!(v >= 0.0, "Random variance should be non-negative");
1007        }
1008    }
1009
1010    #[test]
1011    fn test_fmm_no_covariates() {
1012        let (data, subject_ids, _cov, _t) = generate_fmm_data(8, 3, 40);
1013        let result = fmm(&data, &subject_ids, None, 3).unwrap();
1014
1015        assert_eq!(result.beta_functions.nrows(), 0);
1016        assert_eq!(result.n_subjects, 8);
1017        assert_eq!(result.fitted.nrows(), 24);
1018    }
1019
1020    #[test]
1021    fn test_fmm_predict() {
1022        let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
1023        let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
1024
1025        // Predict for new subjects with covariate = 0.5
1026        let new_cov = FdMatrix::from_column_major(vec![0.5], 1, 1).unwrap();
1027        let predicted = fmm_predict(&result, Some(&new_cov));
1028
1029        assert_eq!(predicted.nrows(), 1);
1030        assert_eq!(predicted.ncols(), 50);
1031
1032        // Predicted curve should be reasonable (not NaN or extreme)
1033        for t in 0..50 {
1034            assert!(predicted[(0, t)].is_finite());
1035            assert!(
1036                predicted[(0, t)].abs() < 20.0,
1037                "Predicted value too extreme at t={}: {}",
1038                t,
1039                predicted[(0, t)]
1040            );
1041        }
1042    }
1043
1044    #[test]
1045    fn test_fmm_test_fixed_detects_effect() {
1046        let (data, subject_ids, covariates, _t) = generate_fmm_data(15, 3, 40);
1047
1048        let result = fmm_test_fixed(&data, &subject_ids, &covariates, 3, 99, 42).unwrap();
1049
1050        assert_eq!(result.f_statistics.len(), 1);
1051        assert_eq!(result.p_values.len(), 1);
1052        assert!(
1053            result.p_values[0] < 0.1,
1054            "Should detect covariate effect, got p={}",
1055            result.p_values[0]
1056        );
1057    }
1058
1059    #[test]
1060    fn test_fmm_test_fixed_no_effect() {
1061        let n_subjects = 10;
1062        let n_visits = 3;
1063        let m = 40;
1064        let t = uniform_grid(m);
1065        let n_total = n_subjects * n_visits;
1066
1067        // No covariate effect: Y = sin(2πt) + noise
1068        let mut col_major = vec![0.0; n_total * m];
1069        let mut subject_ids = vec![0usize; n_total];
1070        let mut cov_data = vec![0.0; n_total];
1071
1072        for s in 0..n_subjects {
1073            for v in 0..n_visits {
1074                let obs = s * n_visits + v;
1075                subject_ids[obs] = s;
1076                cov_data[obs] = s as f64 / n_subjects as f64;
1077                for (j, &tj) in t.iter().enumerate() {
1078                    col_major[obs + j * n_total] =
1079                        (2.0 * PI * tj).sin() + 0.1 * ((obs * 7 + j * 3) % 100) as f64 / 100.0;
1080                }
1081            }
1082        }
1083
1084        let data = FdMatrix::from_column_major(col_major, n_total, m).unwrap();
1085        let covariates = FdMatrix::from_column_major(cov_data, n_total, 1).unwrap();
1086
1087        let result = fmm_test_fixed(&data, &subject_ids, &covariates, 3, 99, 42).unwrap();
1088        assert!(
1089            result.p_values[0] > 0.05,
1090            "Should not detect effect, got p={}",
1091            result.p_values[0]
1092        );
1093    }
1094
1095    #[test]
1096    fn test_fmm_invalid_input() {
1097        let data = FdMatrix::zeros(0, 0);
1098        assert!(fmm(&data, &[], None, 1).is_err());
1099
1100        let data = FdMatrix::zeros(10, 50);
1101        let ids = vec![0; 5]; // wrong length
1102        assert!(fmm(&data, &ids, None, 1).is_err());
1103    }
1104
1105    #[test]
1106    fn test_fmm_single_visit_per_subject() {
1107        let n = 10;
1108        let m = 40;
1109        let t = uniform_grid(m);
1110        let mut col_major = vec![0.0; n * m];
1111        let subject_ids: Vec<usize> = (0..n).collect();
1112
1113        for i in 0..n {
1114            for (j, &tj) in t.iter().enumerate() {
1115                col_major[i + j * n] = (2.0 * PI * tj).sin();
1116            }
1117        }
1118        let data = FdMatrix::from_column_major(col_major, n, m).unwrap();
1119
1120        // Should still work with 1 visit per subject
1121        let result = fmm(&data, &subject_ids, None, 2).unwrap();
1122        assert_eq!(result.n_subjects, n);
1123        assert_eq!(result.fitted.nrows(), n);
1124    }
1125
1126    #[test]
1127    fn test_build_subject_map() {
1128        let (map, n) = build_subject_map(&[5, 5, 10, 10, 20]);
1129        assert_eq!(n, 3);
1130        assert_eq!(map, vec![0, 0, 1, 1, 2]);
1131    }
1132
1133    #[test]
1134    fn test_variance_components_positive() {
1135        let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 50);
1136        let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
1137
1138        assert!(result.sigma2_eps >= 0.0);
1139        for &s in &result.sigma2_u {
1140            assert!(s >= 0.0);
1141        }
1142    }
1143
1144    // -------------------------------------------------------------------
1145    // Additional tests
1146    // -------------------------------------------------------------------
1147
1148    #[test]
1149    fn test_fmm_ncomp_zero_returns_error() {
1150        let (data, subject_ids, _cov, _t) = generate_fmm_data(5, 2, 20);
1151        let err = fmm(&data, &subject_ids, None, 0).unwrap_err();
1152        match err {
1153            FdarError::InvalidParameter { parameter, .. } => {
1154                assert_eq!(parameter, "ncomp");
1155            }
1156            other => panic!("Expected InvalidParameter, got {:?}", other),
1157        }
1158    }
1159
1160    #[test]
1161    fn test_fmm_single_component() {
1162        // Fit with only 1 FPC component
1163        let (data, subject_ids, covariates, _t) = generate_fmm_data(8, 3, 30);
1164        let result = fmm(&data, &subject_ids, Some(&covariates), 1).unwrap();
1165
1166        assert_eq!(result.ncomp, 1);
1167        assert_eq!(result.sigma2_u.len(), 1);
1168        assert_eq!(result.eigenvalues.len(), 1);
1169        assert_eq!(result.mean_function.len(), 30);
1170        // Fitted + residuals = data
1171        for i in 0..data.nrows() {
1172            for t in 0..data.ncols() {
1173                let diff = (result.fitted[(i, t)] + result.residuals[(i, t)] - data[(i, t)]).abs();
1174                assert!(diff < 1e-8);
1175            }
1176        }
1177    }
1178
1179    #[test]
1180    fn test_fmm_two_subjects() {
1181        // Minimal number of subjects (2) with multiple visits
1182        let n_subjects = 2;
1183        let n_visits = 5;
1184        let m = 20;
1185        let t = uniform_grid(m);
1186        let n_total = n_subjects * n_visits;
1187        let mut col_major = vec![0.0; n_total * m];
1188        let mut subject_ids = vec![0usize; n_total];
1189
1190        for s in 0..n_subjects {
1191            for v in 0..n_visits {
1192                let obs = s * n_visits + v;
1193                subject_ids[obs] = s;
1194                for (j, &tj) in t.iter().enumerate() {
1195                    col_major[obs + j * n_total] =
1196                        (2.0 * PI * tj).sin() + (s as f64) * 0.5 + 0.01 * v as f64;
1197                }
1198            }
1199        }
1200        let data = FdMatrix::from_column_major(col_major, n_total, m).unwrap();
1201        let result = fmm(&data, &subject_ids, None, 2).unwrap();
1202
1203        assert_eq!(result.n_subjects, 2);
1204        assert_eq!(result.random_effects.nrows(), 2);
1205        assert_eq!(result.fitted.nrows(), n_total);
1206    }
1207
1208    #[test]
1209    fn test_fmm_predict_no_covariates() {
1210        let (data, subject_ids, _cov, _t) = generate_fmm_data(6, 3, 30);
1211        let result = fmm(&data, &subject_ids, None, 2).unwrap();
1212
1213        // Predict without covariates — should return mean function
1214        let predicted = fmm_predict(&result, None);
1215        assert_eq!(predicted.nrows(), 1);
1216        assert_eq!(predicted.ncols(), 30);
1217        for t in 0..30 {
1218            let diff = (predicted[(0, t)] - result.mean_function[t]).abs();
1219            assert!(
1220                diff < 1e-12,
1221                "Without covariates, prediction should equal mean"
1222            );
1223        }
1224    }
1225
1226    #[test]
1227    fn test_fmm_predict_multiple_new_subjects() {
1228        let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 40);
1229        let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
1230
1231        // Predict for 3 new subjects with different covariate values
1232        let new_cov = FdMatrix::from_column_major(vec![0.1, 0.5, 0.9], 3, 1).unwrap();
1233        let predicted = fmm_predict(&result, Some(&new_cov));
1234
1235        assert_eq!(predicted.nrows(), 3);
1236        assert_eq!(predicted.ncols(), 40);
1237
1238        // All predictions should be finite
1239        for i in 0..3 {
1240            for t in 0..40 {
1241                assert!(predicted[(i, t)].is_finite());
1242            }
1243        }
1244
1245        // Predictions for different covariates should differ
1246        let diff_01: f64 = (0..40)
1247            .map(|t| (predicted[(0, t)] - predicted[(1, t)]).powi(2))
1248            .sum();
1249        assert!(
1250            diff_01 > 1e-10,
1251            "Different covariates should yield different predictions"
1252        );
1253    }
1254
1255    #[test]
1256    fn test_fmm_eigenvalues_decreasing() {
1257        let (data, subject_ids, _cov, _t) = generate_fmm_data(10, 3, 50);
1258        let result = fmm(&data, &subject_ids, None, 5).unwrap();
1259
1260        // Eigenvalues should be in decreasing order (from FPCA)
1261        for i in 1..result.eigenvalues.len() {
1262            assert!(
1263                result.eigenvalues[i] <= result.eigenvalues[i - 1] + 1e-10,
1264                "Eigenvalues should be non-increasing: {} > {}",
1265                result.eigenvalues[i],
1266                result.eigenvalues[i - 1]
1267            );
1268        }
1269    }
1270
1271    #[test]
1272    fn test_fmm_random_effects_sum_near_zero() {
1273        // Random effects should approximately sum to zero across subjects
1274        let (data, subject_ids, covariates, _t) = generate_fmm_data(20, 3, 40);
1275        let result = fmm(&data, &subject_ids, Some(&covariates), 3).unwrap();
1276
1277        let m = result.mean_function.len();
1278        for t in 0..m {
1279            let sum: f64 = (0..result.n_subjects)
1280                .map(|s| result.random_effects[(s, t)])
1281                .sum();
1282            let mean_abs: f64 = (0..result.n_subjects)
1283                .map(|s| result.random_effects[(s, t)].abs())
1284                .sum::<f64>()
1285                / result.n_subjects as f64;
1286            // Relative to the scale of random effects, the sum should be small
1287            if mean_abs > 1e-10 {
1288                assert!(
1289                    (sum / result.n_subjects as f64).abs() < mean_abs * 2.0,
1290                    "Random effects should roughly center around zero at t={}: sum={}, mean_abs={}",
1291                    t,
1292                    sum,
1293                    mean_abs
1294                );
1295            }
1296        }
1297    }
1298
1299    #[test]
1300    fn test_fmm_subject_ids_mismatch_error() {
1301        let data = FdMatrix::zeros(10, 20);
1302        let ids = vec![0; 7]; // wrong length
1303        let err = fmm(&data, &ids, None, 1).unwrap_err();
1304        match err {
1305            FdarError::InvalidDimension { parameter, .. } => {
1306                assert_eq!(parameter, "subject_ids");
1307            }
1308            other => panic!("Expected InvalidDimension, got {:?}", other),
1309        }
1310    }
1311
1312    #[test]
1313    fn test_fmm_test_fixed_empty_data_error() {
1314        let data = FdMatrix::zeros(0, 0);
1315        let covariates = FdMatrix::zeros(0, 1);
1316        let err = fmm_test_fixed(&data, &[], &covariates, 1, 10, 42).unwrap_err();
1317        match err {
1318            FdarError::InvalidDimension { parameter, .. } => {
1319                assert_eq!(parameter, "data");
1320            }
1321            other => panic!("Expected InvalidDimension for data, got {:?}", other),
1322        }
1323    }
1324
1325    #[test]
1326    fn test_fmm_test_fixed_zero_covariates_error() {
1327        let data = FdMatrix::zeros(10, 20);
1328        let ids = vec![0; 10];
1329        let covariates = FdMatrix::zeros(10, 0);
1330        let err = fmm_test_fixed(&data, &ids, &covariates, 1, 10, 42).unwrap_err();
1331        match err {
1332            FdarError::InvalidDimension { parameter, .. } => {
1333                assert_eq!(parameter, "covariates");
1334            }
1335            other => panic!("Expected InvalidDimension for covariates, got {:?}", other),
1336        }
1337    }
1338
1339    #[test]
1340    fn test_build_subject_map_single_subject() {
1341        let (map, n) = build_subject_map(&[42, 42, 42]);
1342        assert_eq!(n, 1);
1343        assert_eq!(map, vec![0, 0, 0]);
1344    }
1345
1346    #[test]
1347    fn test_build_subject_map_non_contiguous_ids() {
1348        let (map, n) = build_subject_map(&[100, 200, 100, 300, 200]);
1349        assert_eq!(n, 3);
1350        // sorted unique: [100, 200, 300] -> indices [0, 1, 2]
1351        assert_eq!(map, vec![0, 1, 0, 2, 1]);
1352    }
1353
1354    #[test]
1355    fn test_fmm_many_components_clamped() {
1356        // Request more components than available; FPCA should clamp
1357        let (data, subject_ids, _cov, _t) = generate_fmm_data(5, 3, 20);
1358        let n_total = data.nrows();
1359        // Request 100 components — should be clamped to min(n_total, m) - 1
1360        let result = fmm(&data, &subject_ids, None, 100).unwrap();
1361        assert!(
1362            result.ncomp <= n_total.min(20),
1363            "ncomp should be clamped: got {}",
1364            result.ncomp
1365        );
1366        assert!(result.ncomp >= 1);
1367    }
1368
1369    #[test]
1370    fn test_fmm_residuals_small_with_enough_components() {
1371        // With enough components, residuals should be small relative to data
1372        let (data, subject_ids, covariates, _t) = generate_fmm_data(10, 3, 30);
1373        let result = fmm(&data, &subject_ids, Some(&covariates), 5).unwrap();
1374
1375        let n = data.nrows();
1376        let m = data.ncols();
1377        let mut data_ss = 0.0_f64;
1378        let mut resid_ss = 0.0_f64;
1379        for i in 0..n {
1380            for t in 0..m {
1381                data_ss += data[(i, t)].powi(2);
1382                resid_ss += result.residuals[(i, t)].powi(2);
1383            }
1384        }
1385
1386        // R-squared should be reasonably high for structured data
1387        let r_squared = 1.0 - resid_ss / data_ss;
1388        assert!(
1389            r_squared > 0.5,
1390            "R-squared should be high with enough components: {}",
1391            r_squared
1392        );
1393    }
1394}