Skip to main content

limma/
selectmodel.rs

1//! Model selection by information criterion (limma `selectModel`, selmod.R).
2//!
3//! Pure-Rust port of [`select_model`]: fit every candidate design to the same
4//! expression matrix with [`crate::fit::lmfit`] and score each with AIC, BIC or
5//! Mallows' Cp. Returns the per-gene information criterion for every model and
6//! the index of the preferred (minimum-criterion) model.
7
8use crate::fit::lmfit;
9use ndarray::Array2;
10
11/// Information criterion used to rank candidate models.
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum SelectCriterion {
14    Aic,
15    Bic,
16    MallowsCp,
17}
18
19/// Output of [`select_model`].
20#[derive(Debug, Clone)]
21pub struct SelectModelResult {
22    /// `ngenes × nmodels` information-criterion values (lower is better).
23    pub ic: Array2<f64>,
24    /// Per-gene 0-based index of the preferred (minimum-IC) model.
25    pub pref: Vec<usize>,
26    pub criterion: SelectCriterion,
27}
28
29/// Score `designlist` against expression matrix `y` (genes × arrays).
30///
31/// `df_prior` / `s2_prior` supply a prior for the residual variance (AIC/BIC
32/// only; `s2_prior` is required when `df_prior > 0`). `s2_true` (length 1 or
33/// `ngenes`) is the known variance required by Mallows' Cp. Panics on any NaN
34/// in `y`, mirroring limma's `stop("NAs not allowed")`.
35pub fn select_model(
36    y: &Array2<f64>,
37    designlist: &[Array2<f64>],
38    criterion: SelectCriterion,
39    df_prior: f64,
40    s2_prior: Option<f64>,
41    s2_true: Option<&[f64]>,
42) -> SelectModelResult {
43    assert!(!designlist.is_empty(), "designlist must be non-empty");
44    assert!(y.iter().all(|v| v.is_finite()), "NAs not allowed");
45    let ngenes = y.nrows();
46    let narrays = y.ncols() as f64;
47    let nmodels = designlist.len();
48    let gene_names = vec![String::new(); ngenes];
49
50    let mut ic = Array2::<f64>::zeros((ngenes, nmodels));
51
52    match criterion {
53        SelectCriterion::MallowsCp => {
54            let s2t = s2_true.expect("Need s2.true values");
55            assert!(
56                s2t.len() == ngenes || s2t.len() == 1,
57                "s2.true wrong length"
58            );
59            for (i, design) in designlist.iter().enumerate() {
60                let coef_names = vec![String::new(); design.ncols()];
61                let fit = lmfit(y, design, gene_names.clone(), coef_names).expect("lmFit failed");
62                let npar = narrays - fit.df_residual[0]; // = model rank
63                for g in 0..ngenes {
64                    let s2tg = if s2t.len() == 1 { s2t[0] } else { s2t[g] };
65                    let rss = fit.df_residual[g] * fit.sigma[g].powi(2);
66                    ic[[g, i]] = rss / s2tg + npar * 2.0 - narrays;
67                }
68            }
69        }
70        SelectCriterion::Aic | SelectCriterion::Bic => {
71            let s2_prior = if df_prior == 0.0 {
72                0.0
73            } else {
74                s2_prior.expect("s2.prior must be set")
75            };
76            let ntotal = df_prior + narrays;
77            let penalty = if criterion == SelectCriterion::Bic {
78                narrays.ln()
79            } else {
80                2.0
81            };
82            for (i, design) in designlist.iter().enumerate() {
83                let coef_names = vec![String::new(); design.ncols()];
84                let fit = lmfit(y, design, gene_names.clone(), coef_names).expect("lmFit failed");
85                let npar = narrays - fit.df_residual[0] + 1.0; // = model rank + 1
86                for g in 0..ngenes {
87                    let rss = fit.df_residual[g] * fit.sigma[g].powi(2);
88                    let s2_post = (df_prior * s2_prior + rss) / ntotal;
89                    ic[[g, i]] = ntotal * s2_post.ln() + npar * penalty;
90                }
91            }
92        }
93    }
94
95    let pref = (0..ngenes)
96        .map(|g| {
97            let mut best = 0usize;
98            for i in 1..nmodels {
99                if ic[[g, i]] < ic[[g, best]] {
100                    best = i;
101                }
102            }
103            best
104        })
105        .collect();
106
107    SelectModelResult {
108        ic,
109        pref,
110        criterion,
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117    use ndarray::array;
118
119    fn close(a: &Array2<f64>, b: &Array2<f64>, tol: f64) -> bool {
120        a.shape() == b.shape()
121            && a.iter()
122                .zip(b.iter())
123                .all(|(&x, &y)| (x - y).abs() <= tol + tol * y.abs())
124    }
125
126    fn data() -> (Array2<f64>, [Array2<f64>; 3]) {
127        let y = array![
128            [
129                2.28724716134052,
130                -0.947279945228108,
131                0.356986230329022,
132                0.467680511321698,
133                0.839750359624071,
134                0.184192771235767
135            ],
136            [
137                -1.19677168222235,
138                0.748139340290551,
139                2.71675178313072,
140                -0.893800723085444,
141                0.7053418309055,
142                0.752279895740033
143            ],
144            [
145                -0.694292510435459,
146                -0.116955225887152,
147                2.28145192598956,
148                -0.307328299537195,
149                1.30596472081169,
150                0.591745052462727
151            ],
152            [
153                -0.412292951136803,
154                0.152657626282234,
155                0.324020540138516,
156                -0.00482242226757041,
157                -1.38799621659285,
158                -0.983052595771021
159            ],
160            [
161                -0.970673341119483,
162                2.18997810732938,
163                1.89606706680993,
164                0.988164149499945,
165                1.27291686425524,
166                -0.276063955112006
167            ],
168        ];
169        let d1 = Array2::<f64>::ones((6, 1));
170        let d2 = array![
171            [1.0, 0.0, 0.0],
172            [1.0, 0.0, 0.0],
173            [1.0, 1.0, 0.0],
174            [1.0, 1.0, 0.0],
175            [1.0, 0.0, 1.0],
176            [1.0, 0.0, 1.0],
177        ];
178        let d3 = array![
179            [1.0, 1.0],
180            [1.0, 1.0],
181            [1.0, 1.0],
182            [1.0, 1.0],
183            [1.0, 0.0],
184            [1.0, 0.0],
185        ];
186        (y, [d1, d2, d3])
187    }
188
189    #[test]
190    fn aic_matches_r() {
191        let (y, designs) = data();
192        let out = select_model(&y, &designs, SelectCriterion::Aic, 0.0, None, None);
193        let want = array![
194            [3.4992829066356, 7.4254343299637, 5.49804808795198],
195            [7.00366938955327, 10.0263185586882, 8.88251303902079],
196            [4.28292611841211, 5.21603854954951, 5.70504652831142],
197            [-1.75118839602609, -10.0621707410176, -10.5654610000774],
198            [5.47382556689668, 8.57792477040455, 7.1764167632861],
199        ];
200        assert!(close(&out.ic, &want, 1e-9));
201        assert_eq!(out.pref, vec![0, 0, 0, 2, 0]);
202    }
203
204    #[test]
205    fn bic_matches_r() {
206        let (y, designs) = data();
207        let out = select_model(&y, &designs, SelectCriterion::Bic, 0.0, None, None);
208        let want = array![
209            [3.08280184509171, 6.59247220687592, 4.87332649563614],
210            [6.58718832800938, 9.1933564356004, 8.25779144670496],
211            [3.86644505686822, 4.38307642646173, 5.08032493599559],
212            [-2.16766945756998, -10.8951328641054, -11.1901825923932],
213            [5.05734450535279, 7.74496264731677, 6.55169517097026],
214        ];
215        assert!(close(&out.ic, &want, 1e-9));
216        assert_eq!(out.pref, vec![0, 0, 0, 2, 0]);
217    }
218
219    #[test]
220    fn aic_with_prior_matches_r() {
221        let (y, designs) = data();
222        let out = select_model(&y, &designs, SelectCriterion::Aic, 4.0, Some(1.5), None);
223        let want = array![
224            [5.4146542709521, 9.35586916403886, 7.41366822047241],
225            [8.63632078653527, 11.653710245536, 10.5110800876037],
226            [6.06176649208264, 7.76998027928027, 7.58044145236395],
227            [2.13758814889168, 3.37272728529353, 1.50489760368478],
228            [7.12663696702876, 10.3164121823946, 8.85154013604554],
229        ];
230        assert!(close(&out.ic, &want, 1e-9));
231        assert_eq!(out.pref, vec![0, 0, 0, 2, 0]);
232    }
233
234    #[test]
235    fn mallows_cp_matches_r() {
236        let (y, designs) = data();
237        let s2true = [0.8, 1.2, 1.0, 0.5, 2.0];
238        let out = select_model(
239            &y,
240            &designs,
241            SelectCriterion::MallowsCp,
242            0.0,
243            None,
244            Some(&s2true),
245        );
246        let want = array![
247            [2.89950846181918, 6.81510911166576, 4.89808866759715],
248            [4.2486493959232, 7.00873811148601, 6.08375710151472],
249            [2.28970281448665, 3.77260556543223, 3.71217897257395],
250            [0.601467389991582, 0.591286184894407, -1.24118973789686],
251            [-0.16469146823778, 3.30333671615237, 1.64983422668269],
252        ];
253        assert!(close(&out.ic, &want, 1e-9));
254        assert_eq!(out.pref, vec![0, 0, 0, 2, 0]);
255    }
256}