Skip to main content

limma/
fit.rs

1//! Gene-wise linear model fitting. Port of limma's `lmFit`
2//! (`lm.series` / `nonEstimable`), least-squares path only.
3
4use anyhow::{bail, Result};
5use ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, Axis};
6
7use crate::linalg::{eigh, inv_upper, matrix_rank, qr_econ, xtx_inv_from_r};
8
9/// Result of fitting gene-wise linear models. Port of `MArrayLM`.
10#[derive(Clone, Debug)]
11pub struct MArrayLM {
12    pub coefficients: Array2<f64>,     // n_genes x n_coef
13    pub stdev_unscaled: Array2<f64>,   // n_genes x n_coef
14    pub sigma: Array1<f64>,            // n_genes
15    pub df_residual: Array1<f64>,      // n_genes
16    pub cov_coefficients: Array2<f64>, // n_coef x n_coef
17    pub gene_names: Vec<String>,
18    pub coef_names: Vec<String>,
19    pub amean: Array1<f64>, // n_genes (per-feature mean over samples, NaN-skipping)
20    pub design: Option<Array2<f64>>,
21    pub contrasts: Option<Array2<f64>>,
22
23    // Filled by eBayes.
24    pub df_prior: Option<Array1<f64>>,
25    pub s2_prior: Option<Array1<f64>>, // len 1 (constant) or n_genes (trend)
26    pub var_prior: Option<Array1<f64>>, // per coef
27    pub proportion: Option<f64>,
28    pub s2_post: Option<Array1<f64>>,
29    pub t: Option<Array2<f64>>,
30    pub df_total: Option<Array1<f64>>,
31    pub p_value: Option<Array2<f64>>,
32    pub lods: Option<Array2<f64>>,
33    pub f_stat: Option<Array1<f64>>,
34    pub f_p_value: Option<Array1<f64>>,
35}
36
37impl MArrayLM {
38    pub fn n_genes(&self) -> usize {
39        self.coefficients.nrows()
40    }
41    pub fn n_coef(&self) -> usize {
42        self.coefficients.ncols()
43    }
44
45    /// `fitted.MArrayLM`: fitted values `coefficients %*% t(design)`
46    /// (`n_genes x n_samples`).
47    ///
48    /// Errors if the fit holds contrasts (its coefficients are contrasts rather
49    /// than the original coefficients) or carries no design matrix.
50    pub fn fitted(&self) -> Result<Array2<f64>> {
51        if self.contrasts.is_some() {
52            bail!("fit contains contrasts rather than coefficients, so fitted values cannot be computed");
53        }
54        let design = match &self.design {
55            Some(d) => d,
56            None => bail!("fit has no design matrix, so fitted values cannot be computed"),
57        };
58        Ok(self.coefficients.dot(&design.t()))
59    }
60
61    /// `residuals.MArrayLM`: observed minus fitted values, `y - fitted`.
62    pub fn residuals(&self, y: &Array2<f64>) -> Result<Array2<f64>> {
63        let fitted = self.fitted()?;
64        Ok(y - &fitted)
65    }
66}
67
68/// NaN-skipping mean of each row (matches R `rowMeans(x, na.rm = TRUE)`).
69fn row_nanmean(exprs: &Array2<f64>) -> Array1<f64> {
70    let n = exprs.nrows();
71    let mut out = Array1::<f64>::zeros(n);
72    for i in 0..n {
73        let mut sum = 0.0;
74        let mut cnt = 0usize;
75        for &v in exprs.row(i) {
76            if v.is_finite() {
77                sum += v;
78                cnt += 1;
79            }
80        }
81        out[i] = if cnt > 0 { sum / cnt as f64 } else { f64::NAN };
82    }
83    out
84}
85
86/// Return the indices of design columns that are linearly dependent on
87/// previous columns, or `None` if the design has full column rank.
88/// Port of `nonEstimable` (returns indices rather than reconstructed names).
89pub fn non_estimable(design: &Array2<f64>) -> Option<Vec<usize>> {
90    let p = design.ncols();
91    let rank = matrix_rank(design);
92    if rank < p {
93        // Identify dependent columns greedily.
94        let mut kept: Vec<usize> = Vec::new();
95        let mut dependent: Vec<usize> = Vec::new();
96        for j in 0..p {
97            let mut trial: Vec<usize> = kept.clone();
98            trial.push(j);
99            let sub = design.select(Axis(1), &trial);
100            if matrix_rank(&sub) == trial.len() {
101                kept.push(j);
102            } else {
103                dependent.push(j);
104            }
105        }
106        Some(dependent)
107    } else {
108        None
109    }
110}
111
112/// Whether `x` has full column rank, judged (as in limma's `is.fullrank`) by the
113/// ratio of the smallest to the largest eigenvalue of `x' x` exceeding `1e-13`.
114pub fn is_fullrank(x: &Array2<f64>) -> bool {
115    let (evals, _) = eigh(&x.t().dot(x));
116    let n = evals.len();
117    let largest = evals[n - 1];
118    let smallest = evals[0];
119    largest > 0.0 && (smallest / largest).abs() > 1e-13
120}
121
122/// Fit gene-wise linear models by ordinary least squares.
123///
124/// * `exprs` — `n_genes x n_samples` expression matrix (may contain NaN).
125/// * `design` — `n_samples x n_coef` design matrix (no NaN).
126///
127/// Port of `lmFit` (method="ls", ndups=1, block=None) + `lm_series`.
128pub fn lmfit(
129    exprs: &Array2<f64>,
130    design: &Array2<f64>,
131    gene_names: Vec<String>,
132    coef_names: Vec<String>,
133) -> Result<MArrayLM> {
134    let n_genes = exprs.nrows();
135    let n_samples = exprs.ncols();
136    let p = design.ncols();
137
138    if n_genes == 0 {
139        bail!("expression matrix has zero rows");
140    }
141    if design.nrows() != n_samples {
142        bail!(
143            "row dimension of design ({}) does not match column dimension of data ({})",
144            design.nrows(),
145            n_samples
146        );
147    }
148    if design.iter().any(|v| v.is_nan()) {
149        bail!("NAs not allowed in design matrix");
150    }
151    if matrix_rank(design) < p {
152        bail!(
153            "design matrix is not of full column rank ({} of {} columns estimable); \
154             reduce the design to estimable coefficients",
155            matrix_rank(design),
156            p
157        );
158    }
159
160    let amean = row_nanmean(exprs);
161    let any_missing = exprs.iter().any(|v| !v.is_finite());
162
163    let mut fit = if !any_missing {
164        lm_series_fast(exprs, design, &gene_names, &coef_names)
165    } else {
166        lm_series_genewise(exprs, design, &gene_names, &coef_names)
167    };
168
169    fit.amean = amean;
170    fit.design = Some(design.clone());
171    Ok(fit)
172}
173
174/// Fit gene-wise linear models by *weighted* least squares.
175///
176/// * `exprs` — `n_genes x n_samples` expression matrix (may contain NaN).
177/// * `design` — `n_samples x n_coef` design matrix (no NaN).
178/// * `weights` — `n_genes x n_samples` observation weights (the `weights`
179///   argument of `lmFit`, e.g. the matrix returned by `voom`).
180///
181/// Port of `lmFit(method="ls")` -> `lm.series` weighted gene-wise path. Each
182/// gene is fit on its observations with finite expression and finite positive
183/// weight; an observation with weight `<= 0` or non-finite is dropped, exactly
184/// as R's `weights[weights<=0] <- NA; M[!is.finite(weights)] <- NA`. As in
185/// limma, `cov.coefficients` is the *unweighted* `(XᵀX)⁻¹` of the full design.
186pub fn lmfit_weighted(
187    exprs: &Array2<f64>,
188    design: &Array2<f64>,
189    weights: &Array2<f64>,
190    gene_names: Vec<String>,
191    coef_names: Vec<String>,
192) -> Result<MArrayLM> {
193    let n_genes = exprs.nrows();
194    let n_samples = exprs.ncols();
195    let p = design.ncols();
196
197    if n_genes == 0 {
198        bail!("expression matrix has zero rows");
199    }
200    if design.nrows() != n_samples {
201        bail!(
202            "row dimension of design ({}) does not match column dimension of data ({})",
203            design.nrows(),
204            n_samples
205        );
206    }
207    if weights.nrows() != n_genes || weights.ncols() != n_samples {
208        bail!(
209            "weights dimensions ({}x{}) must match expression matrix ({}x{})",
210            weights.nrows(),
211            weights.ncols(),
212            n_genes,
213            n_samples
214        );
215    }
216    if design.iter().any(|v| v.is_nan()) {
217        bail!("NAs not allowed in design matrix");
218    }
219    if matrix_rank(design) < p {
220        bail!(
221            "design matrix is not of full column rank ({} of {} columns estimable); \
222             reduce the design to estimable coefficients",
223            matrix_rank(design),
224            p
225        );
226    }
227
228    let amean = row_nanmean(exprs);
229    let mut fit = lm_series_weighted(exprs, design, weights, &gene_names, &coef_names);
230    fit.amean = amean;
231    fit.design = Some(design.clone());
232    Ok(fit)
233}
234
235/// Fast path: identical QR for every gene (no missing values, no weights).
236fn lm_series_fast(
237    exprs: &Array2<f64>,
238    design: &Array2<f64>,
239    gene_names: &[String],
240    coef_names: &[String],
241) -> MArrayLM {
242    let n = design.nrows();
243    let p = design.ncols();
244    let n_genes = exprs.nrows();
245
246    let (q, r) = qr_econ(design);
247    let rinv = inv_upper(&r);
248    let xtx_inv = xtx_inv_from_r(&r);
249
250    // effects head = Q^T Y  (p x n_genes), Y = exprs^T (n_samples x n_genes)
251    let yt = exprs.t().to_owned(); // n_samples x n_genes
252    let qty = q.t().dot(&yt); // p x n_genes
253                              // beta = R^{-1} Q^T y  (p x n_genes)
254    let beta = rinv.dot(&qty); // p x n_genes
255    let coefficients = beta.t().to_owned(); // n_genes x p
256
257    let df_res = (n - p) as f64;
258    let mut sigma = Array1::<f64>::zeros(n_genes);
259    // Fitted sum of squares per gene = sum_j (Q^T y)[j, g]^2. `qty` is
260    // p x n_genes (row-major), so accumulate row-wise — each `qty.row(j)` is
261    // contiguous, whereas the per-gene `qty.column(g)` is strided. For a fixed
262    // gene the j = 0..p terms still accumulate in the same order, so each
263    // gene's value is bit-identical; only the cache-unfriendly stride is gone,
264    // which is what stung the wide, many-gene unweighted regime.
265    if df_res > 0.0 {
266        let mut ss_fit = vec![0.0f64; n_genes];
267        for j in 0..p {
268            for (g, &v) in qty.row(j).iter().enumerate() {
269                ss_fit[g] += v * v;
270            }
271        }
272        for g in 0..n_genes {
273            let ssy: f64 = exprs.row(g).iter().map(|&v| v * v).sum();
274            let rss = (ssy - ss_fit[g]).max(0.0);
275            sigma[g] = (rss / df_res).sqrt();
276        }
277    } else {
278        sigma.fill(f64::NAN);
279    }
280
281    let diag_se: Array1<f64> = (0..p).map(|j| xtx_inv[[j, j]].sqrt()).collect();
282    let mut stdev_unscaled = Array2::<f64>::zeros((n_genes, p));
283    for g in 0..n_genes {
284        for j in 0..p {
285            stdev_unscaled[[g, j]] = diag_se[j];
286        }
287    }
288
289    let df_residual = Array1::from_elem(n_genes, df_res);
290
291    new_fit(
292        coefficients,
293        stdev_unscaled,
294        sigma,
295        df_residual,
296        xtx_inv,
297        gene_names,
298        coef_names,
299    )
300}
301
302/// Per-gene fit result, returned by the gene-wise solvers so the outer loop can
303/// be mapped (optionally in parallel) and the rows scattered back into the
304/// output matrices in gene order.
305struct RowFit {
306    coef: Vec<f64>,  // length n_coef (all NaN if the gene was skipped)
307    stdev: Vec<f64>, // length n_coef (all NaN if the gene was skipped)
308    sigma: f64,      // NaN if df_residual == 0 or the gene was skipped
309    df: f64,         // 0.0 if the gene was skipped
310}
311
312impl RowFit {
313    /// A gene with too few usable observations to estimate the coefficients:
314    /// NaN estimates and zero residual df, matching the serial `continue` path.
315    fn skipped(p: usize) -> Self {
316        RowFit {
317            coef: vec![f64::NAN; p],
318            stdev: vec![f64::NAN; p],
319            sigma: f64::NAN,
320            df: 0.0,
321        }
322    }
323}
324
325/// Per-worker scratch for the gene-wise OLS/WLS paths: a reusable index list plus
326/// full-size design/response buffers. Each gene fills the first `n_obs` rows
327/// (`n_obs <= n_arrays`) and hands that region to `solve_scaled` as a view, so
328/// the inner loop reuses three allocations across genes instead of making them
329/// fresh every gene. The values placed in the buffers — and therefore the
330/// result — are identical to the previous allocate-per-gene path.
331struct GeneScratch {
332    obs: Vec<usize>,
333    xtil: Array2<f64>,
334    ytil: Array1<f64>,
335}
336
337impl GeneScratch {
338    fn new(n_arrays: usize, p: usize) -> Self {
339        GeneScratch {
340            obs: Vec::with_capacity(n_arrays),
341            xtil: Array2::zeros((n_arrays, p)),
342            ytil: Array1::zeros(n_arrays),
343        }
344    }
345}
346
347/// Map `f` over gene indices `0..n`, threading a per-worker mutable scratch `S`
348/// (built by `init`) into each call. Under `parallel` this is rayon's
349/// `map_init`, so each worker thread reuses one scratch across the genes it
350/// handles; serially a single scratch is reused. Results are always returned in
351/// gene order, so the output is bit-identical regardless of feature or thread
352/// count — only the *across-gene* iteration is parallel.
353fn map_genes_init<S, I, F>(n: usize, init: I, f: F) -> Vec<RowFit>
354where
355    S: Send,
356    I: Fn() -> S + Sync + Send,
357    F: Fn(&mut S, usize) -> RowFit + Sync + Send,
358{
359    #[cfg(feature = "parallel")]
360    {
361        use rayon::prelude::*;
362        (0..n)
363            .into_par_iter()
364            .map_init(init, |s, g| f(s, g))
365            .collect()
366    }
367    #[cfg(not(feature = "parallel"))]
368    {
369        let mut s = init();
370        (0..n).map(|g| f(&mut s, g)).collect()
371    }
372}
373
374/// Write per-gene fits into the output matrices (sequential, gene order).
375fn scatter_rows(
376    rows: Vec<RowFit>,
377    coefficients: &mut Array2<f64>,
378    stdev_unscaled: &mut Array2<f64>,
379    sigma: &mut Array1<f64>,
380    df_residual: &mut Array1<f64>,
381) {
382    let p = coefficients.ncols();
383    for (g, r) in rows.into_iter().enumerate() {
384        for j in 0..p {
385            coefficients[[g, j]] = r.coef[j];
386            stdev_unscaled[[g, j]] = r.stdev[j];
387        }
388        sigma[g] = r.sigma;
389        df_residual[g] = r.df;
390    }
391}
392
393/// Solve the (already `sqrt(w)`-scaled) least-squares system `X̃ β ≈ ỹ` for one
394/// gene, returning `(coefficients, unscaled SEs, residual sum of squares)`, where
395/// the unscaled SE of coefficient `j` is `sqrt(diag((X̃ᵀX̃)⁻¹))[j]`. This is the
396/// hot inner kernel of the gene-wise (voom / missing-value) fits, so the `n×p`
397/// QR — the part whose cost grows with sample count — is what we want fast.
398///
399/// With the `faer` feature the QR runs through faer's vectorised pure-Rust
400/// kernel; without it the in-crate ndarray Householder QR is used. The two are
401/// numerically equivalent (both match R to 8+ significant figures) but not
402/// bit-identical, since the QR algorithms round differently (~1e-13). The tiny
403/// `p×p` triangular work is shared via the in-crate [`inv_upper`].
404#[cfg(feature = "faer")]
405fn solve_scaled(xtil: ArrayView2<f64>, ytil: ArrayView1<f64>) -> (Vec<f64>, Vec<f64>, f64) {
406    use faer::{prelude::*, Mat};
407
408    let n_obs = xtil.nrows();
409    let p = xtil.ncols();
410
411    let xf = Mat::from_fn(n_obs, p, |i, j| xtil[[i, j]]);
412    let yf = Mat::from_fn(n_obs, 1, |i, _| ytil[i]);
413    let qr = xf.qr();
414    let beta = qr.solve_lstsq(&yf); // p x 1
415    let coef: Vec<f64> = (0..p).map(|j| beta[(j, 0)]).collect();
416
417    // R factor (p×p, upper-triangular): (X̃ᵀX̃)⁻¹ = R⁻¹R⁻ᵀ, so the unscaled SEs
418    // are the row norms of R⁻¹. Only the upper triangle is read.
419    let rf = qr.thin_R();
420    let mut r = Array2::<f64>::zeros((p, p));
421    for i in 0..p {
422        for j in i..p {
423            r[[i, j]] = rf[(i, j)];
424        }
425    }
426    let rinv = inv_upper(&r);
427    let stdev: Vec<f64> = (0..p)
428        .map(|j| {
429            (0..p)
430                .map(|k| rinv[[j, k]] * rinv[[j, k]])
431                .sum::<f64>()
432                .sqrt()
433        })
434        .collect();
435
436    // RSS from the fitted residuals of the scaled system.
437    let mut rss = 0.0_f64;
438    for i in 0..n_obs {
439        let mut fit = 0.0;
440        for (j, &b) in coef.iter().enumerate() {
441            fit += xtil[[i, j]] * b;
442        }
443        let e = ytil[i] - fit;
444        rss += e * e;
445    }
446    (coef, stdev, rss.max(0.0))
447}
448
449/// Pure-ndarray fallback used when the `faer` feature is off. Bit-identical to
450/// the gene-wise arithmetic shipped in earlier releases.
451#[cfg(not(feature = "faer"))]
452fn solve_scaled(xtil: ArrayView2<f64>, ytil: ArrayView1<f64>) -> (Vec<f64>, Vec<f64>, f64) {
453    let p = xtil.ncols();
454    // qr_econ wants an owned matrix; the per-gene `xtil` slice is small and this
455    // fallback path already allocates Q/R inside qr_econ.
456    let xq = xtil.to_owned();
457    let (q, r) = qr_econ(&xq);
458    let rinv = inv_upper(&r);
459    let qty = q.t().dot(&ytil);
460    let beta = rinv.dot(&qty);
461    let xtx_inv = xtx_inv_from_r(&r);
462    let coef: Vec<f64> = (0..p).map(|j| beta[j]).collect();
463    let stdev: Vec<f64> = (0..p).map(|j| xtx_inv[[j, j]].sqrt()).collect();
464    let ssy: f64 = ytil.iter().map(|&v| v * v).sum();
465    let ss_fit: f64 = qty.iter().map(|&v| v * v).sum();
466    let rss = (ssy - ss_fit).max(0.0);
467    (coef, stdev, rss)
468}
469
470/// Ordinary least squares for one gene on its finite observations (missing-value
471/// path).
472fn ols_one_gene(
473    scratch: &mut GeneScratch,
474    row: ArrayView1<f64>,
475    design: &Array2<f64>,
476    p: usize,
477) -> RowFit {
478    let GeneScratch { obs, xtil, ytil } = scratch;
479    obs.clear();
480    obs.extend((0..row.len()).filter(|&i| row[i].is_finite()));
481    if obs.is_empty() {
482        return RowFit::skipped(p);
483    }
484    let n_obs = obs.len();
485    if n_obs < p {
486        return RowFit::skipped(p);
487    }
488    // Gather this gene's finite rows of (design, y) into the scratch, in `obs`
489    // order — exactly the rows `design.select(Axis(0), &obs)` would have built.
490    for (r, &i) in obs.iter().enumerate() {
491        ytil[r] = row[i];
492        for j in 0..p {
493            xtil[[r, j]] = design[[i, j]];
494        }
495    }
496    let (coef, stdev, rss) = solve_scaled(xtil.slice(s![..n_obs, ..]), ytil.slice(s![..n_obs]));
497    let df = (n_obs - p) as f64;
498    let sigma = if df > 0.0 {
499        (rss / df).sqrt()
500    } else {
501        f64::NAN
502    };
503    RowFit {
504        coef,
505        stdev,
506        sigma,
507        df,
508    }
509}
510
511/// Weighted least squares for one gene on its observations with finite
512/// expression and finite positive weight (the `sqrt(w)`-scaled OLS trick).
513fn wls_one_gene(
514    scratch: &mut GeneScratch,
515    yr: ArrayView1<f64>,
516    wr: ArrayView1<f64>,
517    design: &Array2<f64>,
518    p: usize,
519) -> RowFit {
520    let GeneScratch { obs, xtil, ytil } = scratch;
521    obs.clear();
522    obs.extend((0..yr.len()).filter(|&k| yr[k].is_finite() && wr[k].is_finite() && wr[k] > 0.0));
523    let n_obs = obs.len();
524    if n_obs < p {
525        return RowFit::skipped(p);
526    }
527    // Scale each observed row of (design, y) by sqrt(weight) into the scratch:
528    // WLS becomes OLS on (X̃, ỹ), so the same `solve_scaled` kernel handles it.
529    for (r, &k) in obs.iter().enumerate() {
530        let sw = wr[k].sqrt();
531        ytil[r] = sw * yr[k];
532        for j in 0..p {
533            xtil[[r, j]] = sw * design[[k, j]];
534        }
535    }
536    let (coef, stdev, rss) = solve_scaled(xtil.slice(s![..n_obs, ..]), ytil.slice(s![..n_obs]));
537    let df = (n_obs - p) as f64;
538    let sigma = if df > 0.0 {
539        (rss / df).sqrt()
540    } else {
541        f64::NAN
542    };
543    RowFit {
544        coef,
545        stdev,
546        sigma,
547        df,
548    }
549}
550
551/// Slow path: genewise QR on the observed (finite) entries of each row.
552fn lm_series_genewise(
553    exprs: &Array2<f64>,
554    design: &Array2<f64>,
555    gene_names: &[String],
556    coef_names: &[String],
557) -> MArrayLM {
558    let p = design.ncols();
559    let n_genes = exprs.nrows();
560
561    let mut coefficients = Array2::<f64>::from_elem((n_genes, p), f64::NAN);
562    let mut stdev_unscaled = Array2::<f64>::from_elem((n_genes, p), f64::NAN);
563    let mut sigma = Array1::<f64>::from_elem(n_genes, f64::NAN);
564    let mut df_residual = Array1::<f64>::zeros(n_genes);
565
566    let n_arrays = design.nrows();
567    let rows = map_genes_init(
568        n_genes,
569        || GeneScratch::new(n_arrays, p),
570        |s, g| ols_one_gene(s, exprs.row(g), design, p),
571    );
572    scatter_rows(
573        rows,
574        &mut coefficients,
575        &mut stdev_unscaled,
576        &mut sigma,
577        &mut df_residual,
578    );
579
580    // cov_coefficients uses the full design (matches limma).
581    let (_q, r) = qr_econ(design);
582    let xtx_inv = xtx_inv_from_r(&r);
583
584    new_fit(
585        coefficients,
586        stdev_unscaled,
587        sigma,
588        df_residual,
589        xtx_inv,
590        gene_names,
591        coef_names,
592    )
593}
594
595/// Weighted gene-wise path: per-gene QR on the `sqrt(w)`-scaled observations.
596fn lm_series_weighted(
597    exprs: &Array2<f64>,
598    design: &Array2<f64>,
599    weights: &Array2<f64>,
600    gene_names: &[String],
601    coef_names: &[String],
602) -> MArrayLM {
603    let p = design.ncols();
604    let n_genes = exprs.nrows();
605
606    let mut coefficients = Array2::<f64>::from_elem((n_genes, p), f64::NAN);
607    let mut stdev_unscaled = Array2::<f64>::from_elem((n_genes, p), f64::NAN);
608    let mut sigma = Array1::<f64>::from_elem(n_genes, f64::NAN);
609    let mut df_residual = Array1::<f64>::zeros(n_genes);
610
611    let n_arrays = design.nrows();
612    let rows = map_genes_init(
613        n_genes,
614        || GeneScratch::new(n_arrays, p),
615        |s, g| wls_one_gene(s, exprs.row(g), weights.row(g), design, p),
616    );
617    scatter_rows(
618        rows,
619        &mut coefficients,
620        &mut stdev_unscaled,
621        &mut sigma,
622        &mut df_residual,
623    );
624
625    // cov_coefficients uses the unweighted full design, matching lm.series.
626    let (_q, r) = qr_econ(design);
627    let xtx_inv = xtx_inv_from_r(&r);
628
629    new_fit(
630        coefficients,
631        stdev_unscaled,
632        sigma,
633        df_residual,
634        xtx_inv,
635        gene_names,
636        coef_names,
637    )
638}
639
640#[allow(clippy::too_many_arguments)]
641pub(crate) fn new_fit(
642    coefficients: Array2<f64>,
643    stdev_unscaled: Array2<f64>,
644    sigma: Array1<f64>,
645    df_residual: Array1<f64>,
646    cov_coefficients: Array2<f64>,
647    gene_names: &[String],
648    coef_names: &[String],
649) -> MArrayLM {
650    let n_genes = coefficients.nrows();
651    MArrayLM {
652        coefficients,
653        stdev_unscaled,
654        sigma,
655        df_residual,
656        cov_coefficients,
657        gene_names: gene_names.to_vec(),
658        coef_names: coef_names.to_vec(),
659        amean: Array1::zeros(n_genes),
660        design: None,
661        contrasts: None,
662        df_prior: None,
663        s2_prior: None,
664        var_prior: None,
665        proportion: None,
666        s2_post: None,
667        t: None,
668        df_total: None,
669        p_value: None,
670        lods: None,
671        f_stat: None,
672        f_p_value: None,
673    }
674}
675
676#[cfg(test)]
677#[allow(clippy::excessive_precision)]
678mod tests {
679    use super::*;
680    use ndarray::array;
681
682    #[test]
683    fn is_fullrank_matches_r() {
684        // Reference: is.fullrank() in limma 3.68.3.
685        let full = array![[1.0, 0.0], [1.0, 0.0], [1.0, 1.0]];
686        let deficient = array![[1.0, 2.0], [2.0, 4.0], [3.0, 6.0]];
687        let single = array![[1.0], [2.0]];
688        assert!(is_fullrank(&full));
689        assert!(!is_fullrank(&deficient));
690        assert!(is_fullrank(&single));
691    }
692
693    fn rclose(a: f64, b: f64) -> bool {
694        (a - b).abs() <= 1e-7 * (1.0 + b.abs())
695    }
696
697    /// Rebuild scratch/lmfit_weighted_ref.R's 8x6 expression + weight matrices
698    /// from the same purely rational, 0-indexed formula (bit-identical inputs).
699    fn weighted_fixture() -> (Array2<f64>, Array2<f64>, Array2<f64>) {
700        let ngenes = 8usize;
701        let narrays = 6usize;
702        let group = [0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
703        let mut y = Array2::<f64>::zeros((ngenes, narrays));
704        let mut w = Array2::<f64>::zeros((ngenes, narrays));
705        for g in 0..ngenes {
706            let gi = g as i64;
707            let base = ((gi % 5) - 2) as f64;
708            let eff = (((gi * 3) % 7) - 3) as f64;
709            let mu = 6.0 + (gi % 4) as f64 * 0.5;
710            for k in 0..narrays {
711                let ki = k as i64;
712                let noise = (((gi * 13 + ki * 17) % 19) - 9) as f64 * 0.05;
713                y[[g, k]] = mu + eff * 0.3 * group[k] + base * 0.1 + noise;
714                w[[g, k]] = 0.5 + (((gi * 7 + ki * 5) % 11) as f64) * 0.1;
715            }
716        }
717        let mut design = Array2::<f64>::zeros((narrays, 2));
718        for k in 0..narrays {
719            design[[k, 0]] = 1.0;
720            design[[k, 1]] = group[k];
721        }
722        (y, design, w)
723    }
724
725    fn names(n: usize, p: usize) -> (Vec<String>, Vec<String>) {
726        ((0..n).map(|i| format!("g{i}")).collect(), {
727            let mut c = vec!["Intercept".to_string()];
728            for j in 1..p {
729                c.push(format!("x{j}"));
730            }
731            c
732        })
733    }
734
735    #[test]
736    fn lmfit_weighted_matches_r() {
737        let (y, design, w) = weighted_fixture();
738        let (gn, cn) = names(8, 2);
739        let fit = lmfit_weighted(&y, &design, &w, gn, cn).unwrap();
740
741        // optim/lmFit(Y, design, weights=W) reference from limma 3.68.3.
742        let coef = [
743            [6.0083333333333346, -1.0051075268817207],
744            [6.5034482758620706, -0.33678160919540306],
745            [6.8035714285714244, 1.1567733990147793],
746            [7.677631578947369, -0.22406015037593929],
747            [6.3538461538461544, 0.29878542510121447],
748            [6.1539999999999981, -0.534769230769229],
749            [6.8057142857142869, 0.59828571428571331],
750            [7.7029411764705884, -1.200084033613446],
751        ];
752        let stdev = [
753            [0.57735026918962584, 0.80988516376991593],
754            [0.58722021951470349, 0.8235052638205963],
755            [0.59761430466719667, 0.83783676414308383],
756            [0.51298917604257699, 0.78759174188135006],
757            [0.62017367294604198, 0.80484363658553359],
758            [0.63245553203367566, 0.88578517972214033],
759            [0.5345224838248489, 0.82807867121082501],
760            [0.54232614454664041, 0.7614669610515673],
761        ];
762        let sigma = [
763            0.26599314305172844,
764            0.099539168054648686,
765            0.33639780799178071,
766            0.38269209879641242,
767            0.11173095214852602,
768            0.31304890254497836,
769            0.36171318550949705,
770            0.10717044462761564,
771        ];
772        for g in 0..8 {
773            for j in 0..2 {
774                assert!(
775                    rclose(fit.coefficients[[g, j]], coef[g][j]),
776                    "coef[{g}][{j}] {}",
777                    fit.coefficients[[g, j]]
778                );
779                assert!(
780                    rclose(fit.stdev_unscaled[[g, j]], stdev[g][j]),
781                    "stdev[{g}][{j}] {}",
782                    fit.stdev_unscaled[[g, j]]
783                );
784            }
785            assert!(
786                rclose(fit.sigma[g], sigma[g]),
787                "sigma[{g}] {}",
788                fit.sigma[g]
789            );
790            assert_eq!(fit.df_residual[g], 4.0);
791        }
792        // cov.coefficients is the unweighted (XᵀX)⁻¹.
793        assert!(rclose(fit.cov_coefficients[[0, 0]], 0.33333333333333348));
794        assert!(rclose(fit.cov_coefficients[[0, 1]], -0.33333333333333354));
795        assert!(rclose(fit.cov_coefficients[[1, 1]], 0.66666666666666685));
796    }
797
798    #[test]
799    fn lmfit_weighted_drops_zero_weight_and_missing() {
800        let (mut y, design, mut w) = weighted_fixture();
801        w[[3, 4]] = 0.0; // zero weight -> observation dropped
802        y[[5, 2]] = f64::NAN; // missing expression -> observation dropped
803        let (gn, cn) = names(8, 2);
804        let fit = lmfit_weighted(&y, &design, &w, gn, cn).unwrap();
805
806        // Reference (lmFit with W[4,5]=0 and Y[6,3]=NA, 1-indexed): only genes
807        // 4 and 6 (0-indexed 3 and 5) lose one observation -> df 4 becomes 3.
808        let df = [4.0, 4.0, 4.0, 3.0, 4.0, 3.0, 4.0, 4.0];
809        for (g, &expect) in df.iter().enumerate() {
810            assert_eq!(fit.df_residual[g], expect, "df[{g}]");
811        }
812        assert!(rclose(fit.coefficients[[3, 0]], 7.6776315789473708));
813        assert!(rclose(fit.coefficients[[3, 1]], -0.22096491228070209));
814        assert!(rclose(fit.coefficients[[5, 0]], 6.1868421052631613));
815        assert!(rclose(fit.coefficients[[5, 1]], -0.56761133603239));
816        assert!(
817            rclose(fit.sigma[3], 0.44188309824502131),
818            "sigma3 {}",
819            fit.sigma[3]
820        );
821        assert!(
822            rclose(fit.sigma[5], 0.35751900376998208),
823            "sigma5 {}",
824            fit.sigma[5]
825        );
826        assert!(rclose(fit.stdev_unscaled[[3, 1]], 0.96427411113412587));
827        assert!(rclose(fit.stdev_unscaled[[5, 0]], 0.72547625011001193));
828        // Unaffected genes still match the clean fit.
829        assert!(rclose(fit.coefficients[[0, 1]], -1.0051075268817207));
830        assert!(rclose(fit.sigma[7], 0.10717044462761564));
831    }
832
833    #[test]
834    fn fitted_and_residuals_match_r() {
835        // Reference: scratch/fitted_resid_ref.R (lmFit -> fitted/residuals).
836        let group = [0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
837        let mut y = Array2::<f64>::zeros((5, 6));
838        for g0 in 0..5 {
839            let lvl = 5.0 + (g0 % 4) as f64 * 0.3;
840            let eff = (((g0 as i64 * 3) % 7) - 3) as f64 * 0.2;
841            for k0 in 0..6 {
842                let noise = (((g0 as i64 * 7 + k0 as i64 * 13) % 11) - 5) as f64 * 0.05;
843                y[[g0, k0]] = lvl + eff * group[k0] + noise;
844            }
845        }
846        let mut design = Array2::<f64>::zeros((6, 2));
847        for k0 in 0..6 {
848            design[[k0, 0]] = 1.0;
849            design[[k0, 1]] = group[k0];
850        }
851        let (gn, cn) = (
852            (0..5).map(|i| format!("g{i}")).collect::<Vec<_>>(),
853            vec!["Intercept".into(), "group".into()],
854        );
855        let fit = lmfit(&y, &design, gn, cn).unwrap();
856
857        let fitted = fit.fitted().unwrap();
858        assert_eq!(fitted.dim(), (5, 6));
859        assert!(rclose(fitted[[0, 0]], 4.8500000000000014));
860        assert!(rclose(fitted[[0, 3]], 4.5500000000000007));
861        assert!(rclose(fitted[[4, 0]], 5.1499999999999995));
862        assert!(rclose(fitted[[4, 3]], 5.2999999999999998));
863        // Constant within each group (intercept + group effect only).
864        assert!(rclose(fitted[[0, 1]], fitted[[0, 0]]));
865        assert!(rclose(fitted[[0, 5]], fitted[[0, 3]]));
866
867        let resid = fit.residuals(&y).unwrap();
868        assert_eq!(resid.dim(), (5, 6));
869        assert!(rclose(resid[[0, 0]], -0.10000000000000142));
870        assert!(rclose(resid[[0, 2]], 0.099999999999998757));
871        assert!(rclose(resid[[4, 2]], 0.10000000000000053));
872        // y == fitted + residuals.
873        for g in 0..5 {
874            for k in 0..6 {
875                assert!(rclose(y[[g, k]], fitted[[g, k]] + resid[[g, k]]));
876            }
877        }
878    }
879}