use anyhow::{bail, Result};
use ndarray::Array2;
use crate::fit::MArrayLM;
use crate::linalg::{cholesky_upper, cov2cor};
pub fn contrasts_fit(
fit: &MArrayLM,
contrasts: &Array2<f64>,
contrast_names: Vec<String>,
) -> Result<MArrayLM> {
let ncoef = fit.n_coef();
if contrasts.nrows() != ncoef {
bail!(
"number of rows of contrast matrix ({}) must match number of coefficients in fit ({})",
contrasts.nrows(),
ncoef
);
}
if contrasts.iter().any(|v| v.is_nan()) {
bail!("contrasts must be a numeric matrix (NA not allowed)");
}
let n_genes = fit.n_genes();
let ncont = contrasts.ncols();
let cormatrix = cov2cor(&fit.cov_coefficients);
let orthog = if cormatrix.nrows() < 2 {
true
} else {
let mut ok = true;
for i in 0..cormatrix.nrows() {
for j in 0..i {
if cormatrix[[i, j]].abs() >= 1e-14 {
ok = false;
}
}
}
ok
};
let mut coef = fit.coefficients.clone();
let mut stdev = fit.stdev_unscaled.clone();
let na_coef = coef.iter().any(|v| v.is_nan());
if na_coef {
for g in 0..n_genes {
for j in 0..ncoef {
if coef[[g, j]].is_nan() {
coef[[g, j]] = 0.0;
stdev[[g, j]] = 1e30;
}
}
}
}
let mut new_coef = coef.dot(contrasts);
let r_cov = cholesky_upper(&fit.cov_coefficients);
let tmp = r_cov.dot(contrasts);
let new_cov = tmp.t().dot(&tmp);
let mut new_stdev = Array2::<f64>::zeros((n_genes, ncont));
if orthog {
let stdev2 = stdev.mapv(|v| v * v);
let cont2 = contrasts.mapv(|v| v * v);
let prod = stdev2.dot(&cont2);
new_stdev = prod.mapv(f64::sqrt);
} else {
let r_cor = cholesky_upper(&cormatrix);
for i in 0..n_genes {
for c in 0..ncont {
let mut acc = 0.0;
for k in 0..ncoef {
let mut ruc = 0.0;
for j in 0..ncoef {
ruc += r_cor[[k, j]] * contrasts[[j, c]] * stdev[[i, j]];
}
acc += ruc * ruc;
}
new_stdev[[i, c]] = acc.sqrt();
}
}
}
if na_coef {
for g in 0..n_genes {
for c in 0..ncont {
if new_stdev[[g, c]] > 1e20 {
new_coef[[g, c]] = f64::NAN;
new_stdev[[g, c]] = f64::NAN;
}
}
}
}
let mut out = fit.clone();
out.coefficients = new_coef;
out.stdev_unscaled = new_stdev;
out.cov_coefficients = new_cov;
out.coef_names = contrast_names;
out.contrasts = Some(contrasts.clone());
out.df_prior = None;
out.s2_prior = None;
out.var_prior = None;
out.proportion = None;
out.s2_post = None;
out.t = None;
out.df_total = None;
out.p_value = None;
out.lods = None;
out.f_stat = None;
out.f_p_value = None;
Ok(out)
}
pub fn make_contrasts(
contrasts: &[(String, String)],
levels: &[String],
) -> Result<(Array2<f64>, Vec<String>)> {
let n = levels.len();
if n < 1 {
bail!("No levels to construct contrasts from");
}
let mut idx: std::collections::HashMap<String, usize> = std::collections::HashMap::new();
for (i, l) in levels.iter().enumerate() {
let key = if l == "(Intercept)" {
"Intercept".to_string()
} else {
l.clone()
};
if idx.insert(key.clone(), i).is_some() {
bail!("duplicate level name '{}'", key);
}
}
let ne = contrasts.len();
let mut cm = Array2::<f64>::zeros((n, ne));
let mut names = Vec::with_capacity(ne);
for (j, (name, expr)) in contrasts.iter().enumerate() {
let toks = tokenize(expr)?;
let mut parser = ContrastParser {
toks: &toks,
pos: 0,
n,
idx: &idx,
};
let val = parser.parse_expr()?;
parser.expect_end(expr)?;
let col = match val {
Value::Vector(v) => v,
Value::Scalar(s) => vec![s; n], };
for (i, &c) in col.iter().enumerate() {
cm[[i, j]] = c;
}
names.push(if name.is_empty() {
expr.clone()
} else {
name.clone()
});
}
Ok((cm, names))
}
#[derive(Clone)]
enum Value {
Scalar(f64),
Vector(Vec<f64>),
}
#[derive(Clone, PartialEq)]
enum Tok {
Ident(String),
Num(f64),
Plus,
Minus,
Star,
Slash,
LParen,
RParen,
}
fn tokenize(s: &str) -> Result<Vec<Tok>> {
let cs: Vec<char> = s.chars().collect();
let mut i = 0;
let mut out = Vec::new();
while i < cs.len() {
let c = cs[i];
if c.is_whitespace() {
i += 1;
continue;
}
match c {
'+' => {
out.push(Tok::Plus);
i += 1;
}
'-' => {
out.push(Tok::Minus);
i += 1;
}
'*' => {
out.push(Tok::Star);
i += 1;
}
'/' => {
out.push(Tok::Slash);
i += 1;
}
'(' => {
out.push(Tok::LParen);
i += 1;
}
')' => {
out.push(Tok::RParen);
i += 1;
}
_ if c.is_ascii_digit()
|| (c == '.' && i + 1 < cs.len() && cs[i + 1].is_ascii_digit()) =>
{
let start = i;
while i < cs.len() && (cs[i].is_ascii_digit() || cs[i] == '.') {
i += 1;
}
if i < cs.len() && (cs[i] == 'e' || cs[i] == 'E') {
i += 1;
if i < cs.len() && (cs[i] == '+' || cs[i] == '-') {
i += 1;
}
while i < cs.len() && cs[i].is_ascii_digit() {
i += 1;
}
}
let num: String = cs[start..i].iter().collect();
out.push(Tok::Num(num.parse().map_err(|_| {
anyhow::anyhow!("invalid number '{}' in contrast", num)
})?));
}
_ if c.is_ascii_alphabetic() || c == '.' => {
let start = i;
while i < cs.len()
&& (cs[i].is_ascii_alphanumeric() || cs[i] == '.' || cs[i] == '_')
{
i += 1;
}
out.push(Tok::Ident(cs[start..i].iter().collect()));
}
_ => bail!("unexpected character '{}' in contrast expression", c),
}
}
Ok(out)
}
struct ContrastParser<'a> {
toks: &'a [Tok],
pos: usize,
n: usize,
idx: &'a std::collections::HashMap<String, usize>,
}
impl ContrastParser<'_> {
fn peek(&self) -> Option<&Tok> {
self.toks.get(self.pos)
}
fn expect_end(&self, expr: &str) -> Result<()> {
if self.pos == self.toks.len() {
Ok(())
} else {
bail!("trailing tokens in contrast expression '{}'", expr)
}
}
fn parse_expr(&mut self) -> Result<Value> {
let mut acc = self.parse_term()?;
while let Some(op) = self.peek() {
match op {
Tok::Plus => {
self.pos += 1;
let rhs = self.parse_term()?;
acc = combine(acc, rhs, |a, b| a + b, self.n);
}
Tok::Minus => {
self.pos += 1;
let rhs = self.parse_term()?;
acc = combine(acc, rhs, |a, b| a - b, self.n);
}
_ => break,
}
}
Ok(acc)
}
fn parse_term(&mut self) -> Result<Value> {
let mut acc = self.parse_unary()?;
while let Some(op) = self.peek() {
match op {
Tok::Star => {
self.pos += 1;
let rhs = self.parse_unary()?;
acc = combine(acc, rhs, |a, b| a * b, self.n);
}
Tok::Slash => {
self.pos += 1;
let rhs = self.parse_unary()?;
acc = combine(acc, rhs, |a, b| a / b, self.n);
}
_ => break,
}
}
Ok(acc)
}
fn parse_unary(&mut self) -> Result<Value> {
match self.peek() {
Some(Tok::Plus) => {
self.pos += 1;
self.parse_unary()
}
Some(Tok::Minus) => {
self.pos += 1;
let v = self.parse_unary()?;
Ok(combine(Value::Scalar(0.0), v, |a, b| a - b, self.n))
}
_ => self.parse_primary(),
}
}
fn parse_primary(&mut self) -> Result<Value> {
match self.peek().cloned() {
Some(Tok::Num(x)) => {
self.pos += 1;
Ok(Value::Scalar(x))
}
Some(Tok::Ident(name)) => {
self.pos += 1;
let &i = self.idx.get(&name).ok_or_else(|| {
anyhow::anyhow!("contrast references unknown level '{}'", name)
})?;
let mut v = vec![0.0; self.n];
v[i] = 1.0;
Ok(Value::Vector(v))
}
Some(Tok::LParen) => {
self.pos += 1;
let v = self.parse_expr()?;
match self.peek() {
Some(Tok::RParen) => {
self.pos += 1;
Ok(v)
}
_ => bail!("unbalanced parentheses in contrast expression"),
}
}
other => bail!(
"expected a level, number or '(' in contrast expression, found {}",
match other {
None => "end of input".to_string(),
Some(_) => "an operator".to_string(),
}
),
}
}
}
fn combine(a: Value, b: Value, f: impl Fn(f64, f64) -> f64, _n: usize) -> Value {
match (a, b) {
(Value::Scalar(x), Value::Scalar(y)) => Value::Scalar(f(x, y)),
(Value::Scalar(x), Value::Vector(y)) => Value::Vector(y.iter().map(|&v| f(x, v)).collect()),
(Value::Vector(x), Value::Scalar(y)) => Value::Vector(x.iter().map(|&v| f(v, y)).collect()),
(Value::Vector(x), Value::Vector(y)) => {
Value::Vector(x.iter().zip(y.iter()).map(|(&u, &v)| f(u, v)).collect())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn levels() -> Vec<String> {
vec!["A".into(), "B".into(), "C".into()]
}
#[test]
fn simple_difference() {
let (cm, names) = make_contrasts(
&[
("BvsA".into(), "B-A".into()),
("CvsA".into(), "C - A".into()),
],
&levels(),
)
.unwrap();
assert_eq!(names, vec!["BvsA", "CvsA"]);
assert_eq!(cm[[0, 0]], -1.0);
assert_eq!(cm[[1, 0]], 1.0);
assert_eq!(cm[[2, 0]], 0.0);
assert_eq!(cm[[0, 1]], -1.0);
assert_eq!(cm[[1, 1]], 0.0);
assert_eq!(cm[[2, 1]], 1.0);
}
#[test]
fn average_and_scaling() {
let (cm, _) = make_contrasts(&[("".into(), "(B+C)/2 - A".into())], &levels()).unwrap();
assert_eq!(cm[[0, 0]], -1.0);
assert!((cm[[1, 0]] - 0.5).abs() < 1e-15);
assert!((cm[[2, 0]] - 0.5).abs() < 1e-15);
}
#[test]
fn unnamed_defaults_to_expression() {
let (_, names) = make_contrasts(&[("".into(), "B-A".into())], &levels()).unwrap();
assert_eq!(names, vec!["B-A"]);
}
#[test]
fn unary_minus_and_coeffs() {
let (cm, _) = make_contrasts(&[("x".into(), "-A + 2*B - C".into())], &levels()).unwrap();
assert_eq!(cm[[0, 0]], -1.0);
assert_eq!(cm[[1, 0]], 2.0);
assert_eq!(cm[[2, 0]], -1.0);
}
#[test]
fn unknown_level_errors() {
assert!(make_contrasts(&[("x".into(), "B-Z".into())], &levels()).is_err());
}
#[test]
fn intercept_is_renamed() {
let levs = vec!["(Intercept)".into(), "groupB".into()];
let (cm, _) = make_contrasts(&[("x".into(), "groupB - Intercept".into())], &levs).unwrap();
assert_eq!(cm[[0, 0]], -1.0);
assert_eq!(cm[[1, 0]], 1.0);
}
}