1use anyhow::{bail, Result};
12use ndarray::Array2;
13
14use crate::linalg::{qr_econ, solve_upper};
15
16pub fn unique_targets(cy3: &[&str], cy5: &[&str]) -> Vec<String> {
19 let mut names: Vec<String> = cy3
20 .iter()
21 .chain(cy5.iter())
22 .map(|s| s.to_string())
23 .collect();
24 names.sort();
25 names.dedup();
26 names
27}
28
29pub enum ModelParam<'a> {
31 Reference(&'a str),
35 Parameters {
39 matrix: Array2<f64>,
40 target_names: Vec<String>,
41 coef_names: Vec<String>,
42 },
43}
44
45#[derive(Debug, Clone)]
47pub struct ModelMatrix {
48 pub design: Array2<f64>,
50 pub coef_names: Vec<String>,
51}
52
53pub fn model_matrix(cy3: &[&str], cy5: &[&str], param: ModelParam) -> Result<ModelMatrix> {
59 let narrays = cy3.len();
60 if cy5.len() != narrays {
61 bail!("Cy3 and Cy5 have different lengths");
62 }
63 let sorted = unique_targets(cy3, cy5);
64
65 let (parameters, target_names, coef_names) = match param {
66 ModelParam::Reference(reference) => {
67 if !sorted.iter().any(|t| t == reference) {
68 bail!("\"{reference}\" not among the target names found");
69 }
70 let others: Vec<String> = sorted.into_iter().filter(|t| t != reference).collect();
71 let ntargets = others.len() + 1;
72 let ncoef = ntargets - 1;
73 let mut p = Array2::<f64>::zeros((ntargets, ncoef));
75 for j in 0..ncoef {
76 p[[0, j]] = -1.0;
77 p[[j + 1, j]] = 1.0;
78 }
79 let mut names = Vec::with_capacity(ntargets);
80 names.push(reference.to_string());
81 names.extend(others.iter().cloned());
82 (p, names, others)
83 }
84 ModelParam::Parameters {
85 matrix,
86 target_names,
87 coef_names,
88 } => {
89 if matrix.nrows() != target_names.len() {
90 bail!("rows of parameters don't match unique target names");
91 }
92 if matrix.ncols() != coef_names.len() {
93 bail!("columns of parameters don't match coefficient names");
94 }
95 let mut a = sorted.clone();
96 a.sort();
97 let mut b = target_names.clone();
98 b.sort();
99 if a != b {
100 bail!("rownames of parameters don't match unique target names");
101 }
102 (matrix, target_names, coef_names)
103 }
104 };
105
106 let ntargets = target_names.len();
107 let ncoef = parameters.ncols();
108
109 let mut j = Array2::<f64>::zeros((ntargets, narrays));
111 for (t, name) in target_names.iter().enumerate() {
112 for a in 0..narrays {
113 let v = i32::from(cy5[a] == name) - i32::from(cy3[a] == name);
114 j[[t, a]] = f64::from(v);
115 }
116 }
117
118 let (q, r) = qr_econ(¶meters);
120 let mut design = Array2::<f64>::zeros((narrays, ncoef));
121 for a in 0..narrays {
122 let qtj = q.t().dot(&j.column(a));
123 let beta = solve_upper(&r, &qtj);
124 for (k, &bk) in beta.iter().enumerate() {
125 design[[a, k]] = bk;
126 }
127 }
128 zapsmall(&mut design, 14);
129
130 Ok(ModelMatrix { design, coef_names })
131}
132
133fn zapsmall(m: &mut Array2<f64>, digits: i32) {
137 let mx = m
138 .iter()
139 .filter(|v| v.is_finite())
140 .fold(0.0f64, |acc, &v| acc.max(v.abs()));
141 let dp = if mx > 0.0 {
142 (digits - mx.log10().floor() as i32).max(0)
143 } else {
144 digits
145 };
146 let factor = 10f64.powi(dp);
147 for v in m.iter_mut() {
148 *v = (*v * factor).round() / factor;
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155 use ndarray::array;
156
157 fn check(design: &Array2<f64>, rows: &[&[f64]]) {
158 assert_eq!(design.nrows(), rows.len(), "row count");
159 for (a, row) in rows.iter().enumerate() {
160 assert_eq!(design.ncols(), row.len(), "col count");
161 for (k, &want) in row.iter().enumerate() {
162 assert!(
163 (design[[a, k]] - want).abs() < 1e-10,
164 "design[{a},{k}] = {} vs {want}",
165 design[[a, k]]
166 );
167 }
168 }
169 }
170
171 #[test]
172 fn unique_targets_sorts_and_dedups() {
173 let ut = unique_targets(&["Ref", "Ref", "B", "A"], &["A", "B", "Ref", "Ref"]);
174 assert_eq!(ut, vec!["A", "B", "Ref"]);
175 }
176
177 #[test]
179 fn model_matrix_common_reference() {
180 let cy3 = ["Ref", "Ref", "Ref", "Ref"];
181 let cy5 = ["A", "B", "A", "B"];
182 let out = model_matrix(&cy3, &cy5, ModelParam::Reference("Ref")).unwrap();
183 assert_eq!(out.coef_names, vec!["A", "B"]);
184 check(
185 &out.design,
186 &[&[1.0, 0.0], &[0.0, 1.0], &[1.0, 0.0], &[0.0, 1.0]],
187 );
188 }
189
190 #[test]
193 fn model_matrix_dye_swaps() {
194 let cy3 = ["WT", "Mut", "Ref", "Ref", "WT"];
195 let cy5 = ["Ref", "Ref", "WT", "Mut", "Mut"];
196 let out = model_matrix(&cy3, &cy5, ModelParam::Reference("Ref")).unwrap();
197 assert_eq!(out.coef_names, vec!["Mut", "WT"]);
198 check(
199 &out.design,
200 &[
201 &[0.0, -1.0],
202 &[-1.0, 0.0],
203 &[0.0, 1.0],
204 &[1.0, 0.0],
205 &[1.0, -1.0],
206 ],
207 );
208 }
209
210 #[test]
212 fn model_matrix_explicit_parameters() {
213 let cy3 = ["A", "A", "B", "C"];
214 let cy5 = ["B", "C", "C", "A"];
215 let param = ModelParam::Parameters {
216 matrix: array![[-1.0, -1.0], [1.0, 0.0], [0.0, 1.0]],
217 target_names: vec!["A".into(), "B".into(), "C".into()],
218 coef_names: vec!["B".into(), "C".into()],
219 };
220 let out = model_matrix(&cy3, &cy5, param).unwrap();
221 assert_eq!(out.coef_names, vec!["B", "C"]);
222 check(
223 &out.design,
224 &[&[1.0, 0.0], &[0.0, 1.0], &[-1.0, 1.0], &[0.0, -1.0]],
225 );
226 }
227
228 #[test]
230 fn model_matrix_three_treatments() {
231 let cy3 = ["Ctl", "Ctl", "Ctl", "Ctl", "Ctl", "Ctl"];
232 let cy5 = ["Drug1", "Drug2", "Drug3", "Drug1", "Drug2", "Drug3"];
233 let out = model_matrix(&cy3, &cy5, ModelParam::Reference("Ctl")).unwrap();
234 assert_eq!(out.coef_names, vec!["Drug1", "Drug2", "Drug3"]);
235 check(
236 &out.design,
237 &[
238 &[1.0, 0.0, 0.0],
239 &[0.0, 1.0, 0.0],
240 &[0.0, 0.0, 1.0],
241 &[1.0, 0.0, 0.0],
242 &[0.0, 1.0, 0.0],
243 &[0.0, 0.0, 1.0],
244 ],
245 );
246 }
247
248 #[test]
249 fn model_matrix_rejects_unknown_reference() {
250 let err = model_matrix(&["A"], &["B"], ModelParam::Reference("Z")).unwrap_err();
251 assert!(err.to_string().contains("not among the target names"));
252 }
253}