1use anyhow::{bail, Result};
15use ndarray::{Array1, Array2};
16
17use crate::arrayweights::{contr_sum, solve_linear, wfit_resid_lev_s2};
18use crate::normwithin::PrinterLayout;
19
20pub fn printtip_weights(
34 m: &Array2<f64>,
35 design: &Array2<f64>,
36 weights: Option<&Array2<f64>>,
37 layout: PrinterLayout,
38) -> Result<Array2<f64>> {
39 let nprobes = m.nrows();
40 let narrays = m.ncols();
41 let nparams = design.ncols();
42 if design.nrows() != narrays {
43 bail!("design row dimension must equal number of arrays");
44 }
45 if narrays < 3 {
46 bail!("too few arrays");
47 }
48 if nprobes < narrays {
49 bail!("too few probes");
50 }
51 let nspots = layout.nspot_r * layout.nspot_c;
52 let ngrids = layout.ngrid_r * layout.ngrid_c;
53 if ngrids * nspots != nprobes {
54 bail!("printer layout information does not match M row dimension");
55 }
56 if let Some(w) = weights {
57 if w.dim() != (nprobes, narrays) {
58 bail!("dimensions of weights do not match M");
59 }
60 }
61
62 let z = contr_sum(narrays); let ngam = narrays - 1;
64 let ztz = z.t().dot(&z);
65 let prior = 10.0 * (narrays - nparams) as f64 / narrays as f64;
66
67 let mut blockw = Array2::<f64>::zeros((ngrids, narrays));
68 for blk in 0..ngrids {
69 let start = blk * nspots;
70 let mut gammas = Array1::<f64>::zeros(ngam);
71 let mut zinfo = ztz.mapv(|v| v * prior);
72
73 for s in 0..nspots {
74 let i = start + s;
75 if gammas.iter().any(|v| !v.is_finite()) {
76 bail!("convergence problem at block {blk} spot {s}: array weights not estimable");
77 }
78
79 let zg = z.dot(&gammas);
82 let mut wfull: Vec<f64> = (0..narrays).map(|a| (-zg[a]).exp()).collect();
83 if let Some(wmat) = weights {
84 let mut wrow: Vec<f64> = (0..narrays).map(|a| wmat[[i, a]]).collect();
85 let mx = wrow
86 .iter()
87 .copied()
88 .filter(|v| v.is_finite())
89 .fold(f64::NEG_INFINITY, f64::max);
90 if mx > 1.0 {
91 for v in wrow.iter_mut() {
92 *v /= mx;
93 }
94 }
95 for a in 0..narrays {
96 wfull[a] *= wrow[a];
97 }
98 }
99
100 let yfull: Vec<f64> = (0..narrays).map(|a| m[[i, a]]).collect();
101 let obs: Vec<usize> = (0..narrays)
102 .filter(|&a| yfull[a].is_finite() && wfull[a] != 0.0)
103 .collect();
104 let nobs = obs.len();
105 if nobs <= 1 {
106 continue;
107 }
108
109 let xsub = Array2::from_shape_fn((nobs, nparams), |(r, c)| design[[obs[r], c]]);
110 let ysub: Vec<f64> = obs.iter().map(|&a| yfull[a]).collect();
111 let wsub: Vec<f64> = obs.iter().map(|&a| wfull[a]).collect();
112 let (resid, lev, s2) = wfit_resid_lev_s2(&xsub, &ysub, &wsub);
113 let df_resid = (nobs - nparams) as f64;
114
115 let mut d = vec![0.0f64; narrays];
118 let mut h = vec![1.0f64; narrays];
119 for (k, &a) in obs.iter().enumerate() {
120 d[a] = wsub[k] * resid[k] * resid[k];
121 h[a] = lev[k];
122 }
123
124 let mut agene = Array2::<f64>::zeros((ngam, ngam));
127 for p in 0..ngam {
128 for q in 0..ngam {
129 let mut acc = 0.0;
130 for a in 0..narrays {
131 acc += z[[a, p]] * (1.0 - h[a]) * z[[a, q]];
132 }
133 agene[[p, q]] = acc;
134 }
135 }
136 let cdel = h[narrays - 1] - h[0];
137 let del = cdel * cdel / df_resid;
138 agene.mapv_inplace(|v| v - del);
139 if !agene.iter().all(|v| v.is_finite()) {
140 continue;
141 }
142
143 let zd: Array1<f64> = (0..narrays).map(|a| d[a] / s2 - 1.0 + h[a]).collect();
144
145 if nobs == narrays {
146 zinfo = &zinfo + &agene;
147 let zzd = z.t().dot(&zd);
148 let step = solve_linear(&zinfo, &zzd);
149 gammas = &gammas + &step;
150 } else if nobs > 2 {
151 zinfo = &zinfo + &agene;
152 let z2 = Array2::from_shape_fn((nobs, ngam), |(r, c)| z[[obs[r], c]]);
154 let mut a1 = Array2::<f64>::zeros((nobs - 1, ngam));
155 for r in 0..(nobs - 1) {
156 for c in 0..ngam {
157 let mut acc = 0.0;
158 for k in 0..nobs {
159 let centering = (if k == r { 1.0 } else { 0.0 }) - 1.0 / nobs as f64;
160 acc += centering * z2[[k, c]];
161 }
162 a1[[r, c]] = acc;
163 }
164 }
165 let ztzd = z.t().dot(&zd);
166 let zzd = a1.dot(&ztzd); let mut zinv_a1t = Array2::<f64>::zeros((ngam, nobs - 1));
169 for r in 0..(nobs - 1) {
170 let rhs: Array1<f64> = (0..ngam).map(|c| a1[[r, c]]).collect();
171 let sol = solve_linear(&zinfo, &rhs);
172 for c in 0..ngam {
173 zinv_a1t[[c, r]] = sol[c];
174 }
175 }
176 let alphas_iter = a1.dot(&zinv_a1t).dot(&zzd); let mut usalphas = vec![0.0f64; nobs];
179 for k in 0..(nobs - 1) {
180 usalphas[k] = alphas_iter[k];
181 usalphas[nobs - 1] -= alphas_iter[k];
182 }
183 let mut usg = z.dot(&gammas);
184 for (k, &a) in obs.iter().enumerate() {
185 usg[a] += usalphas[k];
186 }
187 gammas = (0..ngam).map(|a| usg[a]).collect();
188 }
189 }
190
191 let zg = z.dot(&gammas);
192 for a in 0..narrays {
193 blockw[[blk, a]] = (-zg[a]).exp();
194 }
195 }
196
197 let mut wts = Array2::<f64>::zeros((nprobes, narrays));
198 for blk in 0..ngrids {
199 for s in 0..nspots {
200 let i = blk * nspots + s;
201 for a in 0..narrays {
202 wts[[i, a]] = blockw[[blk, a]];
203 }
204 }
205 }
206 Ok(wts)
207}
208
209#[cfg(test)]
210#[allow(clippy::excessive_precision)]
211mod tests {
212 use super::*;
213
214 fn rclose(a: f64, b: f64) -> bool {
215 (a - b).abs() <= 1e-7 * (1.0 + b.abs())
216 }
217
218 fn fixture() -> (Array2<f64>, Array2<f64>) {
221 let scale = [1.0, 1.5, 0.7, 2.0];
222 let (nprobe, narray) = (12usize, 4usize);
223 let mut m = Array2::zeros((nprobe, narray));
224 let mut w = Array2::zeros((nprobe, narray));
225 for g0 in 0..nprobe {
226 for j0 in 0..narray {
227 let (gi, ji) = (g0 as i64, j0 as i64);
228 m[[g0, j0]] = 3.0
229 + (gi % 4) as f64 * 0.5
230 + ((gi * 5 + ji * 3) % 7 - 3) as f64 * 0.2 * scale[j0];
231 w[[g0, j0]] = 0.5 + ((gi * 2 + ji * 5) % 6) as f64 * 0.2;
232 }
233 }
234 (m, w)
235 }
236
237 fn design4() -> Array2<f64> {
238 Array2::from_shape_vec((4, 2), vec![1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0]).unwrap()
239 }
240
241 const LAYOUT: PrinterLayout = PrinterLayout {
242 ngrid_r: 1,
243 ngrid_c: 2,
244 nspot_r: 2,
245 nspot_c: 3,
246 };
247
248 fn assert_blocks(out: &Array2<f64>, blocks: &[[f64; 4]; 2], label: &str) {
251 for i in 0..12 {
252 let exp = blocks[i / 6];
253 for a in 0..4 {
254 assert!(
255 rclose(out[[i, a]], exp[a]),
256 "{label}[{i},{a}]: {} vs {}",
257 out[[i, a]],
258 exp[a]
259 );
260 }
261 }
262 }
263
264 #[test]
265 fn printtip_weights_no_weights_clean() {
266 let (m, _w) = fixture();
267 let out = printtip_weights(&m, &design4(), None, LAYOUT).unwrap();
268 let blocks = [
269 [
270 1.0279718633477748,
271 1.0279718633477748,
272 0.97278927143328675,
273 0.97278927143328664,
274 ],
275 [
276 1.0117348683445897,
277 1.0117348683445895,
278 0.98840124155867992,
279 0.98840124155868003,
280 ],
281 ];
282 assert_blocks(&out, &blocks, "ptw A");
283 }
284
285 #[test]
286 fn printtip_weights_no_weights_na_branch() {
287 let (m, _w) = fixture();
288 let mut mna = m.clone();
289 mna[[7, 1]] = f64::NAN; let out = printtip_weights(&mna, &design4(), None, LAYOUT).unwrap();
291 let blocks = [
292 [
293 1.0279718633477748,
294 1.0279718633477748,
295 0.97278927143328675,
296 0.97278927143328664,
297 ],
298 [
299 0.99170352734697842,
300 0.99170352734697809,
301 1.0085121081848278,
302 1.0082196729118387,
303 ],
304 ];
305 assert_blocks(&out, &blocks, "ptw B");
306 }
307
308 #[test]
309 fn printtip_weights_with_spot_weights() {
310 let (m, w) = fixture();
311 let out = printtip_weights(&m, &design4(), Some(&w), LAYOUT).unwrap();
312 let blocks = [
313 [
314 1.032303811727709,
315 1.013658601320103,
316 0.96283233132495361,
317 0.99254474410404492,
318 ],
319 [
320 1.0258412521778075,
321 0.99993246962846849,
322 0.99494360599235865,
323 0.97982993672697472,
324 ],
325 ];
326 assert_blocks(&out, &blocks, "ptw C");
327 }
328}