use crate::fit::lmfit;
use ndarray::Array2;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SelectCriterion {
Aic,
Bic,
MallowsCp,
}
#[derive(Debug, Clone)]
pub struct SelectModelResult {
pub ic: Array2<f64>,
pub pref: Vec<usize>,
pub criterion: SelectCriterion,
}
pub fn select_model(
y: &Array2<f64>,
designlist: &[Array2<f64>],
criterion: SelectCriterion,
df_prior: f64,
s2_prior: Option<f64>,
s2_true: Option<&[f64]>,
) -> SelectModelResult {
assert!(!designlist.is_empty(), "designlist must be non-empty");
assert!(y.iter().all(|v| v.is_finite()), "NAs not allowed");
let ngenes = y.nrows();
let narrays = y.ncols() as f64;
let nmodels = designlist.len();
let gene_names = vec![String::new(); ngenes];
let mut ic = Array2::<f64>::zeros((ngenes, nmodels));
match criterion {
SelectCriterion::MallowsCp => {
let s2t = s2_true.expect("Need s2.true values");
assert!(
s2t.len() == ngenes || s2t.len() == 1,
"s2.true wrong length"
);
for (i, design) in designlist.iter().enumerate() {
let coef_names = vec![String::new(); design.ncols()];
let fit = lmfit(y, design, gene_names.clone(), coef_names).expect("lmFit failed");
let npar = narrays - fit.df_residual[0]; for g in 0..ngenes {
let s2tg = if s2t.len() == 1 { s2t[0] } else { s2t[g] };
let rss = fit.df_residual[g] * fit.sigma[g].powi(2);
ic[[g, i]] = rss / s2tg + npar * 2.0 - narrays;
}
}
}
SelectCriterion::Aic | SelectCriterion::Bic => {
let s2_prior = if df_prior == 0.0 {
0.0
} else {
s2_prior.expect("s2.prior must be set")
};
let ntotal = df_prior + narrays;
let penalty = if criterion == SelectCriterion::Bic {
narrays.ln()
} else {
2.0
};
for (i, design) in designlist.iter().enumerate() {
let coef_names = vec![String::new(); design.ncols()];
let fit = lmfit(y, design, gene_names.clone(), coef_names).expect("lmFit failed");
let npar = narrays - fit.df_residual[0] + 1.0; for g in 0..ngenes {
let rss = fit.df_residual[g] * fit.sigma[g].powi(2);
let s2_post = (df_prior * s2_prior + rss) / ntotal;
ic[[g, i]] = ntotal * s2_post.ln() + npar * penalty;
}
}
}
}
let pref = (0..ngenes)
.map(|g| {
let mut best = 0usize;
for i in 1..nmodels {
if ic[[g, i]] < ic[[g, best]] {
best = i;
}
}
best
})
.collect();
SelectModelResult {
ic,
pref,
criterion,
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
fn close(a: &Array2<f64>, b: &Array2<f64>, tol: f64) -> bool {
a.shape() == b.shape()
&& a.iter()
.zip(b.iter())
.all(|(&x, &y)| (x - y).abs() <= tol + tol * y.abs())
}
fn data() -> (Array2<f64>, [Array2<f64>; 3]) {
let y = array![
[
2.28724716134052,
-0.947279945228108,
0.356986230329022,
0.467680511321698,
0.839750359624071,
0.184192771235767
],
[
-1.19677168222235,
0.748139340290551,
2.71675178313072,
-0.893800723085444,
0.7053418309055,
0.752279895740033
],
[
-0.694292510435459,
-0.116955225887152,
2.28145192598956,
-0.307328299537195,
1.30596472081169,
0.591745052462727
],
[
-0.412292951136803,
0.152657626282234,
0.324020540138516,
-0.00482242226757041,
-1.38799621659285,
-0.983052595771021
],
[
-0.970673341119483,
2.18997810732938,
1.89606706680993,
0.988164149499945,
1.27291686425524,
-0.276063955112006
],
];
let d1 = Array2::<f64>::ones((6, 1));
let d2 = array![
[1.0, 0.0, 0.0],
[1.0, 0.0, 0.0],
[1.0, 1.0, 0.0],
[1.0, 1.0, 0.0],
[1.0, 0.0, 1.0],
[1.0, 0.0, 1.0],
];
let d3 = array![
[1.0, 1.0],
[1.0, 1.0],
[1.0, 1.0],
[1.0, 1.0],
[1.0, 0.0],
[1.0, 0.0],
];
(y, [d1, d2, d3])
}
#[test]
fn aic_matches_r() {
let (y, designs) = data();
let out = select_model(&y, &designs, SelectCriterion::Aic, 0.0, None, None);
let want = array![
[3.4992829066356, 7.4254343299637, 5.49804808795198],
[7.00366938955327, 10.0263185586882, 8.88251303902079],
[4.28292611841211, 5.21603854954951, 5.70504652831142],
[-1.75118839602609, -10.0621707410176, -10.5654610000774],
[5.47382556689668, 8.57792477040455, 7.1764167632861],
];
assert!(close(&out.ic, &want, 1e-9));
assert_eq!(out.pref, vec![0, 0, 0, 2, 0]);
}
#[test]
fn bic_matches_r() {
let (y, designs) = data();
let out = select_model(&y, &designs, SelectCriterion::Bic, 0.0, None, None);
let want = array![
[3.08280184509171, 6.59247220687592, 4.87332649563614],
[6.58718832800938, 9.1933564356004, 8.25779144670496],
[3.86644505686822, 4.38307642646173, 5.08032493599559],
[-2.16766945756998, -10.8951328641054, -11.1901825923932],
[5.05734450535279, 7.74496264731677, 6.55169517097026],
];
assert!(close(&out.ic, &want, 1e-9));
assert_eq!(out.pref, vec![0, 0, 0, 2, 0]);
}
#[test]
fn aic_with_prior_matches_r() {
let (y, designs) = data();
let out = select_model(&y, &designs, SelectCriterion::Aic, 4.0, Some(1.5), None);
let want = array![
[5.4146542709521, 9.35586916403886, 7.41366822047241],
[8.63632078653527, 11.653710245536, 10.5110800876037],
[6.06176649208264, 7.76998027928027, 7.58044145236395],
[2.13758814889168, 3.37272728529353, 1.50489760368478],
[7.12663696702876, 10.3164121823946, 8.85154013604554],
];
assert!(close(&out.ic, &want, 1e-9));
assert_eq!(out.pref, vec![0, 0, 0, 2, 0]);
}
#[test]
fn mallows_cp_matches_r() {
let (y, designs) = data();
let s2true = [0.8, 1.2, 1.0, 0.5, 2.0];
let out = select_model(
&y,
&designs,
SelectCriterion::MallowsCp,
0.0,
None,
Some(&s2true),
);
let want = array![
[2.89950846181918, 6.81510911166576, 4.89808866759715],
[4.2486493959232, 7.00873811148601, 6.08375710151472],
[2.28970281448665, 3.77260556543223, 3.71217897257395],
[0.601467389991582, 0.591286184894407, -1.24118973789686],
[-0.16469146823778, 3.30333671615237, 1.64983422668269],
];
assert!(close(&out.ic, &want, 1e-9));
assert_eq!(out.pref, vec![0, 0, 0, 2, 0]);
}
}