use anyhow::{bail, Result};
use ndarray::Array2;
use crate::linalg::{qr_econ, solve_upper};
pub fn unique_targets(cy3: &[&str], cy5: &[&str]) -> Vec<String> {
let mut names: Vec<String> = cy3
.iter()
.chain(cy5.iter())
.map(|s| s.to_string())
.collect();
names.sort();
names.dedup();
names
}
pub enum ModelParam<'a> {
Reference(&'a str),
Parameters {
matrix: Array2<f64>,
target_names: Vec<String>,
coef_names: Vec<String>,
},
}
#[derive(Debug, Clone)]
pub struct ModelMatrix {
pub design: Array2<f64>,
pub coef_names: Vec<String>,
}
pub fn model_matrix(cy3: &[&str], cy5: &[&str], param: ModelParam) -> Result<ModelMatrix> {
let narrays = cy3.len();
if cy5.len() != narrays {
bail!("Cy3 and Cy5 have different lengths");
}
let sorted = unique_targets(cy3, cy5);
let (parameters, target_names, coef_names) = match param {
ModelParam::Reference(reference) => {
if !sorted.iter().any(|t| t == reference) {
bail!("\"{reference}\" not among the target names found");
}
let others: Vec<String> = sorted.into_iter().filter(|t| t != reference).collect();
let ntargets = others.len() + 1;
let ncoef = ntargets - 1;
let mut p = Array2::<f64>::zeros((ntargets, ncoef));
for j in 0..ncoef {
p[[0, j]] = -1.0;
p[[j + 1, j]] = 1.0;
}
let mut names = Vec::with_capacity(ntargets);
names.push(reference.to_string());
names.extend(others.iter().cloned());
(p, names, others)
}
ModelParam::Parameters {
matrix,
target_names,
coef_names,
} => {
if matrix.nrows() != target_names.len() {
bail!("rows of parameters don't match unique target names");
}
if matrix.ncols() != coef_names.len() {
bail!("columns of parameters don't match coefficient names");
}
let mut a = sorted.clone();
a.sort();
let mut b = target_names.clone();
b.sort();
if a != b {
bail!("rownames of parameters don't match unique target names");
}
(matrix, target_names, coef_names)
}
};
let ntargets = target_names.len();
let ncoef = parameters.ncols();
let mut j = Array2::<f64>::zeros((ntargets, narrays));
for (t, name) in target_names.iter().enumerate() {
for a in 0..narrays {
let v = i32::from(cy5[a] == name) - i32::from(cy3[a] == name);
j[[t, a]] = f64::from(v);
}
}
let (q, r) = qr_econ(¶meters);
let mut design = Array2::<f64>::zeros((narrays, ncoef));
for a in 0..narrays {
let qtj = q.t().dot(&j.column(a));
let beta = solve_upper(&r, &qtj);
for (k, &bk) in beta.iter().enumerate() {
design[[a, k]] = bk;
}
}
zapsmall(&mut design, 14);
Ok(ModelMatrix { design, coef_names })
}
fn zapsmall(m: &mut Array2<f64>, digits: i32) {
let mx = m
.iter()
.filter(|v| v.is_finite())
.fold(0.0f64, |acc, &v| acc.max(v.abs()));
let dp = if mx > 0.0 {
(digits - mx.log10().floor() as i32).max(0)
} else {
digits
};
let factor = 10f64.powi(dp);
for v in m.iter_mut() {
*v = (*v * factor).round() / factor;
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
fn check(design: &Array2<f64>, rows: &[&[f64]]) {
assert_eq!(design.nrows(), rows.len(), "row count");
for (a, row) in rows.iter().enumerate() {
assert_eq!(design.ncols(), row.len(), "col count");
for (k, &want) in row.iter().enumerate() {
assert!(
(design[[a, k]] - want).abs() < 1e-10,
"design[{a},{k}] = {} vs {want}",
design[[a, k]]
);
}
}
}
#[test]
fn unique_targets_sorts_and_dedups() {
let ut = unique_targets(&["Ref", "Ref", "B", "A"], &["A", "B", "Ref", "Ref"]);
assert_eq!(ut, vec!["A", "B", "Ref"]);
}
#[test]
fn model_matrix_common_reference() {
let cy3 = ["Ref", "Ref", "Ref", "Ref"];
let cy5 = ["A", "B", "A", "B"];
let out = model_matrix(&cy3, &cy5, ModelParam::Reference("Ref")).unwrap();
assert_eq!(out.coef_names, vec!["A", "B"]);
check(
&out.design,
&[&[1.0, 0.0], &[0.0, 1.0], &[1.0, 0.0], &[0.0, 1.0]],
);
}
#[test]
fn model_matrix_dye_swaps() {
let cy3 = ["WT", "Mut", "Ref", "Ref", "WT"];
let cy5 = ["Ref", "Ref", "WT", "Mut", "Mut"];
let out = model_matrix(&cy3, &cy5, ModelParam::Reference("Ref")).unwrap();
assert_eq!(out.coef_names, vec!["Mut", "WT"]);
check(
&out.design,
&[
&[0.0, -1.0],
&[-1.0, 0.0],
&[0.0, 1.0],
&[1.0, 0.0],
&[1.0, -1.0],
],
);
}
#[test]
fn model_matrix_explicit_parameters() {
let cy3 = ["A", "A", "B", "C"];
let cy5 = ["B", "C", "C", "A"];
let param = ModelParam::Parameters {
matrix: array![[-1.0, -1.0], [1.0, 0.0], [0.0, 1.0]],
target_names: vec!["A".into(), "B".into(), "C".into()],
coef_names: vec!["B".into(), "C".into()],
};
let out = model_matrix(&cy3, &cy5, param).unwrap();
assert_eq!(out.coef_names, vec!["B", "C"]);
check(
&out.design,
&[&[1.0, 0.0], &[0.0, 1.0], &[-1.0, 1.0], &[0.0, -1.0]],
);
}
#[test]
fn model_matrix_three_treatments() {
let cy3 = ["Ctl", "Ctl", "Ctl", "Ctl", "Ctl", "Ctl"];
let cy5 = ["Drug1", "Drug2", "Drug3", "Drug1", "Drug2", "Drug3"];
let out = model_matrix(&cy3, &cy5, ModelParam::Reference("Ctl")).unwrap();
assert_eq!(out.coef_names, vec!["Drug1", "Drug2", "Drug3"]);
check(
&out.design,
&[
&[1.0, 0.0, 0.0],
&[0.0, 1.0, 0.0],
&[0.0, 0.0, 1.0],
&[1.0, 0.0, 0.0],
&[0.0, 1.0, 0.0],
&[0.0, 0.0, 1.0],
],
);
}
#[test]
fn model_matrix_rejects_unknown_reference() {
let err = model_matrix(&["A"], &["B"], ModelParam::Reference("Z")).unwrap_err();
assert!(err.to_string().contains("not among the target names"));
}
}