Skip to main content

limma/
printtipweights.rs

1//! `printtipWeights` (printtipWeights.R): print-tip array quality weights for
2//! two-colour arrays. Each print-tip block (a contiguous run of `nspots` rows)
3//! gets its own array-weight estimate via the gene-by-gene update algorithm —
4//! the same machinery as `arrayWeights(method = "genebygene")` but with a
5//! `contr.sum` variance design and a per-block prior — and the resulting
6//! `narrays` weights are broadcast across the block's rows.
7//!
8//! Only the default `method = "genebygene"` is ported; the `reml` branch is out
9//! of scope. The design is assumed to have full column rank (limma reports
10//! non-estimable coefficients but otherwise proceeds). The `Agam.del` rank-1
11//! correction is reproduced exactly as shipped, including limma's
12//! `h[1:(length(narrays)-1)]` indexing quirk (which reduces it to `h[1]`).
13
14use anyhow::{bail, Result};
15use ndarray::{Array1, Array2};
16
17use crate::arrayweights::{contr_sum, solve_linear, wfit_resid_lev_s2};
18use crate::normwithin::PrinterLayout;
19
20/// `printtipWeights(M, design, weights, method = "genebygene", layout)`.
21///
22/// * `m` — `n_probes x n_arrays` log-ratio matrix (`NA`/infinite entries are
23///   dropped per spot).
24/// * `design` — `n_arrays x p` design matrix (full column rank).
25/// * `weights` — optional `n_probes x n_arrays` spot weights; a spot's weights
26///   are rescaled to a maximum of 1 (when that maximum exceeds 1) before being
27///   combined with the running variance estimate, exactly as limma does.
28/// * `layout` — print-tip layout; `ngrid_r*ngrid_c*nspot_r*nspot_c` must equal
29///   `n_probes`.
30///
31/// Returns the `n_probes x n_arrays` weight matrix (each block's weights
32/// broadcast across its rows), ready to pass to `lmFit`.
33pub fn printtip_weights(
34    m: &Array2<f64>,
35    design: &Array2<f64>,
36    weights: Option<&Array2<f64>>,
37    layout: PrinterLayout,
38) -> Result<Array2<f64>> {
39    let nprobes = m.nrows();
40    let narrays = m.ncols();
41    let nparams = design.ncols();
42    if design.nrows() != narrays {
43        bail!("design row dimension must equal number of arrays");
44    }
45    if narrays < 3 {
46        bail!("too few arrays");
47    }
48    if nprobes < narrays {
49        bail!("too few probes");
50    }
51    let nspots = layout.nspot_r * layout.nspot_c;
52    let ngrids = layout.ngrid_r * layout.ngrid_c;
53    if ngrids * nspots != nprobes {
54        bail!("printer layout information does not match M row dimension");
55    }
56    if let Some(w) = weights {
57        if w.dim() != (nprobes, narrays) {
58            bail!("dimensions of weights do not match M");
59        }
60    }
61
62    let z = contr_sum(narrays); // narrays x (narrays-1)
63    let ngam = narrays - 1;
64    let ztz = z.t().dot(&z);
65    let prior = 10.0 * (narrays - nparams) as f64 / narrays as f64;
66
67    let mut blockw = Array2::<f64>::zeros((ngrids, narrays));
68    for blk in 0..ngrids {
69        let start = blk * nspots;
70        let mut gammas = Array1::<f64>::zeros(ngam);
71        let mut zinfo = ztz.mapv(|v| v * prior);
72
73        for s in 0..nspots {
74            let i = start + s;
75            if gammas.iter().any(|v| !v.is_finite()) {
76                bail!("convergence problem at block {blk} spot {s}: array weights not estimable");
77            }
78
79            // w = 1/vary = exp(-Z gammas), optionally times the (re-scaled) spot
80            // weights for this probe.
81            let zg = z.dot(&gammas);
82            let mut wfull: Vec<f64> = (0..narrays).map(|a| (-zg[a]).exp()).collect();
83            if let Some(wmat) = weights {
84                let mut wrow: Vec<f64> = (0..narrays).map(|a| wmat[[i, a]]).collect();
85                let mx = wrow
86                    .iter()
87                    .copied()
88                    .filter(|v| v.is_finite())
89                    .fold(f64::NEG_INFINITY, f64::max);
90                if mx > 1.0 {
91                    for v in wrow.iter_mut() {
92                        *v /= mx;
93                    }
94                }
95                for a in 0..narrays {
96                    wfull[a] *= wrow[a];
97                }
98            }
99
100            let yfull: Vec<f64> = (0..narrays).map(|a| m[[i, a]]).collect();
101            let obs: Vec<usize> = (0..narrays)
102                .filter(|&a| yfull[a].is_finite() && wfull[a] != 0.0)
103                .collect();
104            let nobs = obs.len();
105            if nobs <= 1 {
106                continue;
107            }
108
109            let xsub = Array2::from_shape_fn((nobs, nparams), |(r, c)| design[[obs[r], c]]);
110            let ysub: Vec<f64> = obs.iter().map(|&a| yfull[a]).collect();
111            let wsub: Vec<f64> = obs.iter().map(|&a| wfull[a]).collect();
112            let (resid, lev, s2) = wfit_resid_lev_s2(&xsub, &ysub, &wsub);
113            let df_resid = (nobs - nparams) as f64;
114
115            // d (= w*resid^2) and h (leverage) spread to full length; unobserved
116            // arrays carry h = 1, d = 0.
117            let mut d = vec![0.0f64; narrays];
118            let mut h = vec![1.0f64; narrays];
119            for (k, &a) in obs.iter().enumerate() {
120                d[a] = wsub[k] * resid[k] * resid[k];
121                h[a] = lev[k];
122            }
123
124            // Agene.gam = Z' diag(1-h) Z - (1/df) * Agam.del, with limma's
125            // Agam.del reducing to (h[last]-h[first])^2 * ones.
126            let mut agene = Array2::<f64>::zeros((ngam, ngam));
127            for p in 0..ngam {
128                for q in 0..ngam {
129                    let mut acc = 0.0;
130                    for a in 0..narrays {
131                        acc += z[[a, p]] * (1.0 - h[a]) * z[[a, q]];
132                    }
133                    agene[[p, q]] = acc;
134                }
135            }
136            let cdel = h[narrays - 1] - h[0];
137            let del = cdel * cdel / df_resid;
138            agene.mapv_inplace(|v| v - del);
139            if !agene.iter().all(|v| v.is_finite()) {
140                continue;
141            }
142
143            let zd: Array1<f64> = (0..narrays).map(|a| d[a] / s2 - 1.0 + h[a]).collect();
144
145            if nobs == narrays {
146                zinfo = &zinfo + &agene;
147                let zzd = z.t().dot(&zd);
148                let step = solve_linear(&zinfo, &zzd);
149                gammas = &gammas + &step;
150            } else if nobs > 2 {
151                zinfo = &zinfo + &agene;
152                // A1 = (I - J/nobs) Z[obs,] with its last row dropped.
153                let z2 = Array2::from_shape_fn((nobs, ngam), |(r, c)| z[[obs[r], c]]);
154                let mut a1 = Array2::<f64>::zeros((nobs - 1, ngam));
155                for r in 0..(nobs - 1) {
156                    for c in 0..ngam {
157                        let mut acc = 0.0;
158                        for k in 0..nobs {
159                            let centering = (if k == r { 1.0 } else { 0.0 }) - 1.0 / nobs as f64;
160                            acc += centering * z2[[k, c]];
161                        }
162                        a1[[r, c]] = acc;
163                    }
164                }
165                let ztzd = z.t().dot(&zd);
166                let zzd = a1.dot(&ztzd); // length nobs-1
167                                         // A1 Zinfo^-1 A1' via per-column solves of Zinfo x = A1'[,r].
168                let mut zinv_a1t = Array2::<f64>::zeros((ngam, nobs - 1));
169                for r in 0..(nobs - 1) {
170                    let rhs: Array1<f64> = (0..ngam).map(|c| a1[[r, c]]).collect();
171                    let sol = solve_linear(&zinfo, &rhs);
172                    for c in 0..ngam {
173                        zinv_a1t[[c, r]] = sol[c];
174                    }
175                }
176                let alphas_iter = a1.dot(&zinv_a1t).dot(&zzd); // length nobs-1
177                                                               // Us (alphas_new - alphas_old) with Us = [I_{nobs-1}; -1].
178                let mut usalphas = vec![0.0f64; nobs];
179                for k in 0..(nobs - 1) {
180                    usalphas[k] = alphas_iter[k];
181                    usalphas[nobs - 1] -= alphas_iter[k];
182                }
183                let mut usg = z.dot(&gammas);
184                for (k, &a) in obs.iter().enumerate() {
185                    usg[a] += usalphas[k];
186                }
187                gammas = (0..ngam).map(|a| usg[a]).collect();
188            }
189        }
190
191        let zg = z.dot(&gammas);
192        for a in 0..narrays {
193            blockw[[blk, a]] = (-zg[a]).exp();
194        }
195    }
196
197    let mut wts = Array2::<f64>::zeros((nprobes, narrays));
198    for blk in 0..ngrids {
199        for s in 0..nspots {
200            let i = blk * nspots + s;
201            for a in 0..narrays {
202                wts[[i, a]] = blockw[[blk, a]];
203            }
204        }
205    }
206    Ok(wts)
207}
208
209#[cfg(test)]
210#[allow(clippy::excessive_precision)]
211mod tests {
212    use super::*;
213
214    fn rclose(a: f64, b: f64) -> bool {
215        (a - b).abs() <= 1e-7 * (1.0 + b.abs())
216    }
217
218    /// 12x4 M matrix (with per-array scale heterogeneity) and a 12x4 spot-weight
219    /// matrix, matching `scratch/printtipweights_ref.R`.
220    fn fixture() -> (Array2<f64>, Array2<f64>) {
221        let scale = [1.0, 1.5, 0.7, 2.0];
222        let (nprobe, narray) = (12usize, 4usize);
223        let mut m = Array2::zeros((nprobe, narray));
224        let mut w = Array2::zeros((nprobe, narray));
225        for g0 in 0..nprobe {
226            for j0 in 0..narray {
227                let (gi, ji) = (g0 as i64, j0 as i64);
228                m[[g0, j0]] = 3.0
229                    + (gi % 4) as f64 * 0.5
230                    + ((gi * 5 + ji * 3) % 7 - 3) as f64 * 0.2 * scale[j0];
231                w[[g0, j0]] = 0.5 + ((gi * 2 + ji * 5) % 6) as f64 * 0.2;
232            }
233        }
234        (m, w)
235    }
236
237    fn design4() -> Array2<f64> {
238        Array2::from_shape_vec((4, 2), vec![1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0]).unwrap()
239    }
240
241    const LAYOUT: PrinterLayout = PrinterLayout {
242        ngrid_r: 1,
243        ngrid_c: 2,
244        nspot_r: 2,
245        nspot_c: 3,
246    };
247
248    /// Assert that each of the 12 rows equals its block's expected weights:
249    /// rows 0..6 -> `blocks[0]`, rows 6..12 -> `blocks[1]`.
250    fn assert_blocks(out: &Array2<f64>, blocks: &[[f64; 4]; 2], label: &str) {
251        for i in 0..12 {
252            let exp = blocks[i / 6];
253            for a in 0..4 {
254                assert!(
255                    rclose(out[[i, a]], exp[a]),
256                    "{label}[{i},{a}]: {} vs {}",
257                    out[[i, a]],
258                    exp[a]
259                );
260            }
261        }
262    }
263
264    #[test]
265    fn printtip_weights_no_weights_clean() {
266        let (m, _w) = fixture();
267        let out = printtip_weights(&m, &design4(), None, LAYOUT).unwrap();
268        let blocks = [
269            [
270                1.0279718633477748,
271                1.0279718633477748,
272                0.97278927143328675,
273                0.97278927143328664,
274            ],
275            [
276                1.0117348683445897,
277                1.0117348683445895,
278                0.98840124155867992,
279                0.98840124155868003,
280            ],
281        ];
282        assert_blocks(&out, &blocks, "ptw A");
283    }
284
285    #[test]
286    fn printtip_weights_no_weights_na_branch() {
287        let (m, _w) = fixture();
288        let mut mna = m.clone();
289        mna[[7, 1]] = f64::NAN; // M[8,2] in 1-based -> sum(obs)=3 branch in block 2.
290        let out = printtip_weights(&mna, &design4(), None, LAYOUT).unwrap();
291        let blocks = [
292            [
293                1.0279718633477748,
294                1.0279718633477748,
295                0.97278927143328675,
296                0.97278927143328664,
297            ],
298            [
299                0.99170352734697842,
300                0.99170352734697809,
301                1.0085121081848278,
302                1.0082196729118387,
303            ],
304        ];
305        assert_blocks(&out, &blocks, "ptw B");
306    }
307
308    #[test]
309    fn printtip_weights_with_spot_weights() {
310        let (m, w) = fixture();
311        let out = printtip_weights(&m, &design4(), Some(&w), LAYOUT).unwrap();
312        let blocks = [
313            [
314                1.032303811727709,
315                1.013658601320103,
316                0.96283233132495361,
317                0.99254474410404492,
318            ],
319            [
320                1.0258412521778075,
321                0.99993246962846849,
322                0.99494360599235865,
323                0.97982993672697472,
324            ],
325        ];
326        assert_blocks(&out, &blocks, "ptw C");
327    }
328}