1use crate::fit::lmfit;
9use ndarray::Array2;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum SelectCriterion {
14 Aic,
15 Bic,
16 MallowsCp,
17}
18
19#[derive(Debug, Clone)]
21pub struct SelectModelResult {
22 pub ic: Array2<f64>,
24 pub pref: Vec<usize>,
26 pub criterion: SelectCriterion,
27}
28
29pub fn select_model(
36 y: &Array2<f64>,
37 designlist: &[Array2<f64>],
38 criterion: SelectCriterion,
39 df_prior: f64,
40 s2_prior: Option<f64>,
41 s2_true: Option<&[f64]>,
42) -> SelectModelResult {
43 assert!(!designlist.is_empty(), "designlist must be non-empty");
44 assert!(y.iter().all(|v| v.is_finite()), "NAs not allowed");
45 let ngenes = y.nrows();
46 let narrays = y.ncols() as f64;
47 let nmodels = designlist.len();
48 let gene_names = vec![String::new(); ngenes];
49
50 let mut ic = Array2::<f64>::zeros((ngenes, nmodels));
51
52 match criterion {
53 SelectCriterion::MallowsCp => {
54 let s2t = s2_true.expect("Need s2.true values");
55 assert!(
56 s2t.len() == ngenes || s2t.len() == 1,
57 "s2.true wrong length"
58 );
59 for (i, design) in designlist.iter().enumerate() {
60 let coef_names = vec![String::new(); design.ncols()];
61 let fit = lmfit(y, design, gene_names.clone(), coef_names).expect("lmFit failed");
62 let npar = narrays - fit.df_residual[0]; for g in 0..ngenes {
64 let s2tg = if s2t.len() == 1 { s2t[0] } else { s2t[g] };
65 let rss = fit.df_residual[g] * fit.sigma[g].powi(2);
66 ic[[g, i]] = rss / s2tg + npar * 2.0 - narrays;
67 }
68 }
69 }
70 SelectCriterion::Aic | SelectCriterion::Bic => {
71 let s2_prior = if df_prior == 0.0 {
72 0.0
73 } else {
74 s2_prior.expect("s2.prior must be set")
75 };
76 let ntotal = df_prior + narrays;
77 let penalty = if criterion == SelectCriterion::Bic {
78 narrays.ln()
79 } else {
80 2.0
81 };
82 for (i, design) in designlist.iter().enumerate() {
83 let coef_names = vec![String::new(); design.ncols()];
84 let fit = lmfit(y, design, gene_names.clone(), coef_names).expect("lmFit failed");
85 let npar = narrays - fit.df_residual[0] + 1.0; for g in 0..ngenes {
87 let rss = fit.df_residual[g] * fit.sigma[g].powi(2);
88 let s2_post = (df_prior * s2_prior + rss) / ntotal;
89 ic[[g, i]] = ntotal * s2_post.ln() + npar * penalty;
90 }
91 }
92 }
93 }
94
95 let pref = (0..ngenes)
96 .map(|g| {
97 let mut best = 0usize;
98 for i in 1..nmodels {
99 if ic[[g, i]] < ic[[g, best]] {
100 best = i;
101 }
102 }
103 best
104 })
105 .collect();
106
107 SelectModelResult {
108 ic,
109 pref,
110 criterion,
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117 use ndarray::array;
118
119 fn close(a: &Array2<f64>, b: &Array2<f64>, tol: f64) -> bool {
120 a.shape() == b.shape()
121 && a.iter()
122 .zip(b.iter())
123 .all(|(&x, &y)| (x - y).abs() <= tol + tol * y.abs())
124 }
125
126 fn data() -> (Array2<f64>, [Array2<f64>; 3]) {
127 let y = array![
128 [
129 2.28724716134052,
130 -0.947279945228108,
131 0.356986230329022,
132 0.467680511321698,
133 0.839750359624071,
134 0.184192771235767
135 ],
136 [
137 -1.19677168222235,
138 0.748139340290551,
139 2.71675178313072,
140 -0.893800723085444,
141 0.7053418309055,
142 0.752279895740033
143 ],
144 [
145 -0.694292510435459,
146 -0.116955225887152,
147 2.28145192598956,
148 -0.307328299537195,
149 1.30596472081169,
150 0.591745052462727
151 ],
152 [
153 -0.412292951136803,
154 0.152657626282234,
155 0.324020540138516,
156 -0.00482242226757041,
157 -1.38799621659285,
158 -0.983052595771021
159 ],
160 [
161 -0.970673341119483,
162 2.18997810732938,
163 1.89606706680993,
164 0.988164149499945,
165 1.27291686425524,
166 -0.276063955112006
167 ],
168 ];
169 let d1 = Array2::<f64>::ones((6, 1));
170 let d2 = array![
171 [1.0, 0.0, 0.0],
172 [1.0, 0.0, 0.0],
173 [1.0, 1.0, 0.0],
174 [1.0, 1.0, 0.0],
175 [1.0, 0.0, 1.0],
176 [1.0, 0.0, 1.0],
177 ];
178 let d3 = array![
179 [1.0, 1.0],
180 [1.0, 1.0],
181 [1.0, 1.0],
182 [1.0, 1.0],
183 [1.0, 0.0],
184 [1.0, 0.0],
185 ];
186 (y, [d1, d2, d3])
187 }
188
189 #[test]
190 fn aic_matches_r() {
191 let (y, designs) = data();
192 let out = select_model(&y, &designs, SelectCriterion::Aic, 0.0, None, None);
193 let want = array![
194 [3.4992829066356, 7.4254343299637, 5.49804808795198],
195 [7.00366938955327, 10.0263185586882, 8.88251303902079],
196 [4.28292611841211, 5.21603854954951, 5.70504652831142],
197 [-1.75118839602609, -10.0621707410176, -10.5654610000774],
198 [5.47382556689668, 8.57792477040455, 7.1764167632861],
199 ];
200 assert!(close(&out.ic, &want, 1e-9));
201 assert_eq!(out.pref, vec![0, 0, 0, 2, 0]);
202 }
203
204 #[test]
205 fn bic_matches_r() {
206 let (y, designs) = data();
207 let out = select_model(&y, &designs, SelectCriterion::Bic, 0.0, None, None);
208 let want = array![
209 [3.08280184509171, 6.59247220687592, 4.87332649563614],
210 [6.58718832800938, 9.1933564356004, 8.25779144670496],
211 [3.86644505686822, 4.38307642646173, 5.08032493599559],
212 [-2.16766945756998, -10.8951328641054, -11.1901825923932],
213 [5.05734450535279, 7.74496264731677, 6.55169517097026],
214 ];
215 assert!(close(&out.ic, &want, 1e-9));
216 assert_eq!(out.pref, vec![0, 0, 0, 2, 0]);
217 }
218
219 #[test]
220 fn aic_with_prior_matches_r() {
221 let (y, designs) = data();
222 let out = select_model(&y, &designs, SelectCriterion::Aic, 4.0, Some(1.5), None);
223 let want = array![
224 [5.4146542709521, 9.35586916403886, 7.41366822047241],
225 [8.63632078653527, 11.653710245536, 10.5110800876037],
226 [6.06176649208264, 7.76998027928027, 7.58044145236395],
227 [2.13758814889168, 3.37272728529353, 1.50489760368478],
228 [7.12663696702876, 10.3164121823946, 8.85154013604554],
229 ];
230 assert!(close(&out.ic, &want, 1e-9));
231 assert_eq!(out.pref, vec![0, 0, 0, 2, 0]);
232 }
233
234 #[test]
235 fn mallows_cp_matches_r() {
236 let (y, designs) = data();
237 let s2true = [0.8, 1.2, 1.0, 0.5, 2.0];
238 let out = select_model(
239 &y,
240 &designs,
241 SelectCriterion::MallowsCp,
242 0.0,
243 None,
244 Some(&s2true),
245 );
246 let want = array![
247 [2.89950846181918, 6.81510911166576, 4.89808866759715],
248 [4.2486493959232, 7.00873811148601, 6.08375710151472],
249 [2.28970281448665, 3.77260556543223, 3.71217897257395],
250 [0.601467389991582, 0.591286184894407, -1.24118973789686],
251 [-0.16469146823778, 3.30333671615237, 1.64983422668269],
252 ];
253 assert!(close(&out.ic, &want, 1e-9));
254 assert_eq!(out.pref, vec![0, 0, 0, 2, 0]);
255 }
256}