use anyhow::{bail, Result};
use ndarray::Array2;
pub fn modify_weights<S: AsRef<str>>(
weights: &Array2<f64>,
status: &[S],
values: &[S],
multipliers: &[f64],
) -> Array2<f64> {
assert_eq!(
status.len(),
weights.nrows(),
"nrows of weights must equal length of status"
);
let nvalues = values.len();
let mult: Vec<f64> = if multipliers.len() == 1 {
vec![multipliers[0]; nvalues]
} else {
assert_eq!(
multipliers.len(),
nvalues,
"no. values doesn't match no. multipliers"
);
multipliers.to_vec()
};
let mut out = weights.clone();
for (v, value) in values.iter().enumerate() {
let target = value.as_ref();
for (i, s) in status.iter().enumerate() {
if s.as_ref() == target {
for j in 0..out.ncols() {
out[[i, j]] *= mult[v];
}
}
}
}
out
}
pub fn as_matrix_weights(
weights: &Array2<f64>,
dim: Option<(usize, usize)>,
) -> Result<Array2<f64>> {
let (nr, nc) = match dim {
None => return Ok(weights.clone()),
Some(d) => d,
};
if nr == 0 || nc == 0 {
bail!("zero or negative dimensions not allowed");
}
let (dwr, dwc) = weights.dim();
if dwr == nr && dwc == nc {
return Ok(weights.clone());
}
if dwr.min(dwc) != 1 {
bail!("weights is of unexpected shape");
}
let flat: Vec<f64> = weights.iter().copied().collect();
let lw = flat.len();
if dwc > 1 && dwc == nc {
return Ok(fill_byrow(&flat, nr, nc));
}
if lw == 1 || lw == nr {
return Ok(fill_colmajor(&flat, nr, nc));
}
if lw == nc {
return Ok(fill_byrow(&flat, nr, nc));
}
bail!("weights is of unexpected size");
}
fn fill_colmajor(flat: &[f64], nr: usize, nc: usize) -> Array2<f64> {
Array2::from_shape_fn((nr, nc), |(i, j)| flat[(j * nr + i) % flat.len()])
}
fn fill_byrow(flat: &[f64], nr: usize, nc: usize) -> Array2<f64> {
Array2::from_shape_fn((nr, nc), |(i, j)| flat[(i * nc + j) % flat.len()])
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
fn weights() -> Array2<f64> {
array![[1.0, 0.5], [2.0, 2.0], [1.0, 1.0], [0.25, 4.0]]
}
#[test]
fn two_values_matches_r() {
let status = ["gene", "control", "gene", "spike"];
let got = modify_weights(&weights(), &status, &["control", "spike"], &[0.0, 3.0]);
let want = array![[1.0, 0.5], [0.0, 0.0], [1.0, 1.0], [0.75, 12.0]];
assert_eq!(got, want);
}
#[test]
fn scalar_multiplier_recycles() {
let status = ["gene", "control", "gene", "spike"];
let got = modify_weights(&weights(), &status, &["control", "spike"], &[2.0]);
let want = array![[1.0, 0.5], [4.0, 4.0], [1.0, 1.0], [0.5, 8.0]];
assert_eq!(got, want);
}
#[test]
fn default_unit_weights() {
let status = ["gene", "control", "gene", "spike"];
let w = Array2::<f64>::ones((4, 1));
let got = modify_weights(&w, &status, &["gene"], &[5.0]);
let want = array![[5.0], [1.0], [5.0], [1.0]];
assert_eq!(got, want);
}
#[test]
fn as_matrix_weights_matches_r() {
let pw = array![[1.0], [2.0], [3.0], [4.0]];
assert_eq!(
as_matrix_weights(&pw, Some((4, 3))).unwrap(),
array![
[1.0, 1.0, 1.0],
[2.0, 2.0, 2.0],
[3.0, 3.0, 3.0],
[4.0, 4.0, 4.0]
]
);
let aw = array![[10.0], [20.0], [30.0]];
assert_eq!(
as_matrix_weights(&aw, Some((4, 3))).unwrap(),
array![
[10.0, 20.0, 30.0],
[10.0, 20.0, 30.0],
[10.0, 20.0, 30.0],
[10.0, 20.0, 30.0]
]
);
let s = array![[0.5]];
assert_eq!(
as_matrix_weights(&s, Some((2, 3))).unwrap(),
Array2::from_elem((2, 3), 0.5)
);
let fm = array![[1.0, 3.0, 5.0], [2.0, 4.0, 6.0]];
assert_eq!(as_matrix_weights(&fm, Some((2, 3))).unwrap(), fm);
let sq = array![[1.0], [2.0], [3.0]];
assert_eq!(
as_matrix_weights(&sq, Some((3, 3))).unwrap(),
array![[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]]
);
let rm = array![[7.0, 8.0, 9.0]];
assert_eq!(
as_matrix_weights(&rm, Some((4, 3))).unwrap(),
array![
[7.0, 8.0, 9.0],
[7.0, 8.0, 9.0],
[7.0, 8.0, 9.0],
[7.0, 8.0, 9.0]
]
);
assert_eq!(as_matrix_weights(&pw, None).unwrap(), pw);
}
#[test]
fn as_matrix_weights_errors() {
let v = array![[1.0], [2.0], [3.0], [4.0]];
assert!(as_matrix_weights(&v, Some((2, 3))).is_err());
assert!(as_matrix_weights(&v, Some((0, 3))).is_err());
let blk = array![[1.0, 2.0], [3.0, 4.0]];
assert!(as_matrix_weights(&blk, Some((4, 3))).is_err());
}
}