1use anyhow::{bail, Result};
12use ndarray::Array2;
13
14pub fn modify_weights<S: AsRef<str>>(
22 weights: &Array2<f64>,
23 status: &[S],
24 values: &[S],
25 multipliers: &[f64],
26) -> Array2<f64> {
27 assert_eq!(
28 status.len(),
29 weights.nrows(),
30 "nrows of weights must equal length of status"
31 );
32 let nvalues = values.len();
33 let mult: Vec<f64> = if multipliers.len() == 1 {
34 vec![multipliers[0]; nvalues]
35 } else {
36 assert_eq!(
37 multipliers.len(),
38 nvalues,
39 "no. values doesn't match no. multipliers"
40 );
41 multipliers.to_vec()
42 };
43
44 let mut out = weights.clone();
45 for (v, value) in values.iter().enumerate() {
46 let target = value.as_ref();
47 for (i, s) in status.iter().enumerate() {
48 if s.as_ref() == target {
49 for j in 0..out.ncols() {
50 out[[i, j]] *= mult[v];
51 }
52 }
53 }
54 }
55 out
56}
57
58pub fn as_matrix_weights(
73 weights: &Array2<f64>,
74 dim: Option<(usize, usize)>,
75) -> Result<Array2<f64>> {
76 let (nr, nc) = match dim {
77 None => return Ok(weights.clone()),
78 Some(d) => d,
79 };
80 if nr == 0 || nc == 0 {
81 bail!("zero or negative dimensions not allowed");
82 }
83 let (dwr, dwc) = weights.dim();
84 if dwr == nr && dwc == nc {
86 return Ok(weights.clone());
87 }
88 if dwr.min(dwc) != 1 {
89 bail!("weights is of unexpected shape");
90 }
91 let flat: Vec<f64> = weights.iter().copied().collect();
92 let lw = flat.len();
93 if dwc > 1 && dwc == nc {
95 return Ok(fill_byrow(&flat, nr, nc));
96 }
97 if lw == 1 || lw == nr {
99 return Ok(fill_colmajor(&flat, nr, nc));
100 }
101 if lw == nc {
103 return Ok(fill_byrow(&flat, nr, nc));
104 }
105 bail!("weights is of unexpected size");
106}
107
108fn fill_colmajor(flat: &[f64], nr: usize, nc: usize) -> Array2<f64> {
110 Array2::from_shape_fn((nr, nc), |(i, j)| flat[(j * nr + i) % flat.len()])
111}
112
113fn fill_byrow(flat: &[f64], nr: usize, nc: usize) -> Array2<f64> {
115 Array2::from_shape_fn((nr, nc), |(i, j)| flat[(i * nc + j) % flat.len()])
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121 use ndarray::array;
122
123 fn weights() -> Array2<f64> {
124 array![[1.0, 0.5], [2.0, 2.0], [1.0, 1.0], [0.25, 4.0]]
125 }
126
127 #[test]
128 fn two_values_matches_r() {
129 let status = ["gene", "control", "gene", "spike"];
130 let got = modify_weights(&weights(), &status, &["control", "spike"], &[0.0, 3.0]);
131 let want = array![[1.0, 0.5], [0.0, 0.0], [1.0, 1.0], [0.75, 12.0]];
132 assert_eq!(got, want);
133 }
134
135 #[test]
136 fn scalar_multiplier_recycles() {
137 let status = ["gene", "control", "gene", "spike"];
138 let got = modify_weights(&weights(), &status, &["control", "spike"], &[2.0]);
139 let want = array![[1.0, 0.5], [4.0, 4.0], [1.0, 1.0], [0.5, 8.0]];
140 assert_eq!(got, want);
141 }
142
143 #[test]
144 fn default_unit_weights() {
145 let status = ["gene", "control", "gene", "spike"];
146 let w = Array2::<f64>::ones((4, 1));
147 let got = modify_weights(&w, &status, &["gene"], &[5.0]);
148 let want = array![[5.0], [1.0], [5.0], [1.0]];
149 assert_eq!(got, want);
150 }
151
152 #[test]
153 fn as_matrix_weights_matches_r() {
154 let pw = array![[1.0], [2.0], [3.0], [4.0]];
156 assert_eq!(
157 as_matrix_weights(&pw, Some((4, 3))).unwrap(),
158 array![
159 [1.0, 1.0, 1.0],
160 [2.0, 2.0, 2.0],
161 [3.0, 3.0, 3.0],
162 [4.0, 4.0, 4.0]
163 ]
164 );
165 let aw = array![[10.0], [20.0], [30.0]];
167 assert_eq!(
168 as_matrix_weights(&aw, Some((4, 3))).unwrap(),
169 array![
170 [10.0, 20.0, 30.0],
171 [10.0, 20.0, 30.0],
172 [10.0, 20.0, 30.0],
173 [10.0, 20.0, 30.0]
174 ]
175 );
176 let s = array![[0.5]];
178 assert_eq!(
179 as_matrix_weights(&s, Some((2, 3))).unwrap(),
180 Array2::from_elem((2, 3), 0.5)
181 );
182 let fm = array![[1.0, 3.0, 5.0], [2.0, 4.0, 6.0]];
184 assert_eq!(as_matrix_weights(&fm, Some((2, 3))).unwrap(), fm);
185 let sq = array![[1.0], [2.0], [3.0]];
187 assert_eq!(
188 as_matrix_weights(&sq, Some((3, 3))).unwrap(),
189 array![[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]]
190 );
191 let rm = array![[7.0, 8.0, 9.0]];
193 assert_eq!(
194 as_matrix_weights(&rm, Some((4, 3))).unwrap(),
195 array![
196 [7.0, 8.0, 9.0],
197 [7.0, 8.0, 9.0],
198 [7.0, 8.0, 9.0],
199 [7.0, 8.0, 9.0]
200 ]
201 );
202 assert_eq!(as_matrix_weights(&pw, None).unwrap(), pw);
204 }
205
206 #[test]
207 fn as_matrix_weights_errors() {
208 let v = array![[1.0], [2.0], [3.0], [4.0]];
209 assert!(as_matrix_weights(&v, Some((2, 3))).is_err());
211 assert!(as_matrix_weights(&v, Some((0, 3))).is_err());
213 let blk = array![[1.0, 2.0], [3.0, 4.0]];
215 assert!(as_matrix_weights(&blk, Some((4, 3))).is_err());
216 }
217}