1use ndarray::{Array1, Array2, Axis};
19
20use crate::fit::MArrayLM;
21use crate::linalg::{qr_econ, solve_upper};
22
23pub(crate) fn contr_sum(n: usize) -> Array2<f64> {
26 let mut z = Array2::<f64>::zeros((n, n - 1));
27 for j in 0..(n - 1) {
28 z[[j, j]] = 1.0;
29 z[[n - 1, j]] = -1.0;
30 }
31 z
32}
33
34pub(crate) fn solve_linear(a: &Array2<f64>, b: &Array1<f64>) -> Array1<f64> {
38 let k = a.nrows();
39 let mut m = a.clone();
40 let mut rhs = b.clone();
41 for col in 0..k {
42 let mut piv = col;
43 let mut best = m[[col, col]].abs();
44 for r in (col + 1)..k {
45 let v = m[[r, col]].abs();
46 if v > best {
47 best = v;
48 piv = r;
49 }
50 }
51 if piv != col {
52 for c in 0..k {
53 let tmp = m[[col, c]];
54 m[[col, c]] = m[[piv, c]];
55 m[[piv, c]] = tmp;
56 }
57 rhs.swap(col, piv);
58 }
59 let d = m[[col, col]];
60 for r in (col + 1)..k {
61 let f = m[[r, col]] / d;
62 if f != 0.0 {
63 for c in col..k {
64 let v = m[[col, c]];
65 m[[r, c]] -= f * v;
66 }
67 rhs[r] -= f * rhs[col];
68 }
69 }
70 }
71 let mut x = Array1::<f64>::zeros(k);
72 for i in (0..k).rev() {
73 let mut sum = rhs[i];
74 for j in (i + 1)..k {
75 sum -= m[[i, j]] * x[j];
76 }
77 x[i] = sum / m[[i, i]];
78 }
79 x
80}
81
82pub fn array_weights(
94 exprs: &Array2<f64>,
95 design: &Array2<f64>,
96 var_design: Option<&Array2<f64>>,
97 prior_n: f64,
98 maxiter: usize,
99 tol: f64,
100) -> Array1<f64> {
101 let narrays = exprs.ncols();
102 let ngenes_all = exprs.nrows();
103 let p = design.ncols();
104
105 let mut w = Array1::<f64>::ones(narrays);
106 if ngenes_all < 2 || narrays < p + 2 {
108 return w;
109 }
110
111 let z2 = match var_design {
112 Some(v) => v.to_owned(),
113 None => contr_sum(narrays),
114 };
115 let ngam = z2.ncols();
116 let nz = ngam + 1;
117
118 let mut zmat = Array2::<f64>::ones((narrays, nz));
120 for j in 0..ngam {
121 for i in 0..narrays {
122 zmat[[i, j + 1]] = z2[[i, j]];
123 }
124 }
125 let z2tz2 = z2.t().dot(&z2);
126 let dfres = (narrays - p) as f64;
127
128 let (q0, _r0) = qr_econ(design);
130 let mut kept: Vec<usize> = Vec::with_capacity(ngenes_all);
131 for g in 0..ngenes_all {
132 let yg = exprs.row(g).to_owned();
133 let cg = q0.t().dot(&yg);
134 let s2 = (yg.dot(&yg) - cg.dot(&cg)) / dfres;
135 if s2 >= 1e-15 {
136 kept.push(g);
137 }
138 }
139 if kept.len() < 2 {
140 return w;
141 }
142 let y = exprs.select(Axis(0), &kept);
143 let ngenes = y.nrows();
144 let ngenes_f = ngenes as f64;
145
146 let mut gam = Array1::<f64>::zeros(ngam);
147 let mut convcrit_last = f64::INFINITY;
148 let p2 = p * (p + 1) / 2;
149
150 for _iter in 1..=maxiter {
151 let sw: Array1<f64> = w.mapv(f64::sqrt);
152
153 let mut xw = design.clone();
155 for i in 0..narrays {
156 for j in 0..p {
157 xw[[i, j]] *= sw[i];
158 }
159 }
160 let (qe, r) = qr_econ(&xw);
161
162 let mut resid = Array2::<f64>::zeros((narrays, ngenes));
164 let mut s2 = Array1::<f64>::zeros(ngenes);
165 for gi in 0..ngenes {
166 let yg = y.row(gi).to_owned();
167 let ywg = &yg * &sw;
168 let cg = qe.t().dot(&ywg);
169 let beta = solve_upper(&r, &cg);
170 let fitted = design.dot(&beta);
171 let ss = ywg.dot(&ywg) - cg.dot(&cg);
172 s2[gi] = ss / dfres;
173 for i in 0..narrays {
174 resid[[i, gi]] = yg[i] - fitted[i];
175 }
176 }
177
178 let mut q2 = Array2::<f64>::zeros((narrays, p2));
183 let mut h = Array1::<f64>::zeros(narrays);
184 for i in 0..narrays {
185 let mut col = 0usize;
186 for k in 0..p {
187 for a in 0..(p - k) {
188 q2[[i, col]] = qe[[i, a]] * qe[[i, a + k]];
189 col += 1;
190 }
191 }
192 for c in p..p2 {
193 q2[[i, c]] *= std::f64::consts::SQRT_2;
194 }
195 let mut lev = 0.0;
196 for c in 0..p {
197 lev += q2[[i, c]];
198 }
199 h[i] = lev;
200 }
201
202 let mut info = Array2::<f64>::zeros((nz, nz));
204 for a in 0..nz {
205 for b in 0..nz {
206 let mut acc = 0.0;
207 for i in 0..narrays {
208 acc += zmat[[i, a]] * (1.0 - 2.0 * h[i]) * zmat[[i, b]];
209 }
210 info[[a, b]] = acc;
211 }
212 }
213 let q2tz = q2.t().dot(&zmat);
214 let gram = q2tz.t().dot(&q2tz);
215 info = &info + &gram;
216
217 let i00 = info[[0, 0]];
219 let mut info2 = Array2::<f64>::zeros((ngam, ngam));
220 for a in 0..ngam {
221 for b in 0..ngam {
222 info2[[a, b]] = info[[a + 1, b + 1]] - info[[a + 1, 0]] * info[[0, b + 1]] / i00;
223 }
224 }
225
226 let mut zvec = Array1::<f64>::zeros(narrays);
228 for i in 0..narrays {
229 let mut acc = 0.0;
230 for gi in 0..ngenes {
231 acc += w[i] * resid[[i, gi]] * resid[[i, gi]] / s2[gi];
232 }
233 zvec[i] = acc / ngenes_f - (1.0 - h[i]);
234 }
235
236 for a in 0..ngam {
238 for b in 0..ngam {
239 info2[[a, b]] = ngenes_f * info2[[a, b]] + prior_n * z2tz2[[a, b]];
240 }
241 }
242 for i in 0..narrays {
243 zvec[i] = ngenes_f * zvec[i] + prior_n * (w[i] - 1.0);
244 }
245
246 let dl = z2.t().dot(&zvec);
248 let gamstep = solve_linear(&info2, &dl);
249 let convcrit = dl.dot(&gamstep) / (ngam as f64) / (ngenes_f + prior_n);
250 if convcrit.is_nan() || convcrit >= convcrit_last {
251 break;
252 }
253 convcrit_last = convcrit;
254
255 gam = &gam + &gamstep;
256 w = z2.dot(&gam).mapv(|x| (-x).exp());
257
258 if convcrit < tol {
259 break;
260 }
261 }
262
263 w
264}
265
266pub fn array_weights_prwts_reml(
290 exprs: &Array2<f64>,
291 design: &Array2<f64>,
292 weights: &Array2<f64>,
293 var_design: Option<&Array2<f64>>,
294 prior_n: f64,
295 maxiter: usize,
296 tol: f64,
297) -> Array1<f64> {
298 let narrays = exprs.ncols();
299 let ngenes = exprs.nrows();
300 let p = design.ncols();
301
302 let z2 = match var_design {
303 Some(v) => v.to_owned(),
304 None => contr_sum(narrays),
305 };
306 let ngam = z2.ncols();
307 let nz = ngam + 1;
308
309 let mut zmat = Array2::<f64>::ones((narrays, nz));
311 for j in 0..ngam {
312 for i in 0..narrays {
313 zmat[[i, j + 1]] = z2[[i, j]];
314 }
315 }
316 let z2tz2 = z2.t().dot(&z2);
317 let dfres = (narrays - p) as f64;
318 let denom = ngenes as f64 + prior_n;
319 let p2 = p * (p + 1) / 2;
320
321 let mut gam = Array1::<f64>::zeros(ngam);
322 let mut w = Array1::<f64>::ones(narrays);
323
324 for _iter in 1..=maxiter {
325 let mut info2 = z2tz2.mapv(|x| x * prior_n);
327 let mut zvec = w.mapv(|wi| prior_n * (wi - 1.0));
328
329 for g in 0..ngenes {
330 let cw: Vec<f64> = (0..narrays).map(|i| w[i] * weights[[g, i]]).collect();
332 let sw: Vec<f64> = cw.iter().map(|&v| v.sqrt()).collect();
333 let mut xw = design.clone();
334 for i in 0..narrays {
335 for j in 0..p {
336 xw[[i, j]] *= sw[i];
337 }
338 }
339 let yg = exprs.row(g);
340 let yw: Array1<f64> = (0..narrays).map(|i| yg[i] * sw[i]).collect();
341 let (qe, r) = qr_econ(&xw);
342 let cg = qe.t().dot(&yw);
343 let beta = solve_upper(&r, &cg);
344 let fitted = design.dot(&beta);
345 let resid: Vec<f64> = (0..narrays).map(|i| yg[i] - fitted[i]).collect();
346 let s2 = (yw.dot(&yw) - cg.dot(&cg)) / dfres;
347
348 let mut q2 = Array2::<f64>::zeros((narrays, p2));
351 let mut h = vec![0.0f64; narrays];
352 for i in 0..narrays {
353 let mut col = 0usize;
354 for k in 0..p {
355 for a in 0..(p - k) {
356 q2[[i, col]] = qe[[i, a]] * qe[[i, a + k]];
357 col += 1;
358 }
359 }
360 for c in p..p2 {
361 q2[[i, c]] *= std::f64::consts::SQRT_2;
362 }
363 let mut lev = 0.0;
364 for c in 0..p {
365 lev += q2[[i, c]];
366 }
367 h[i] = lev;
368 }
369
370 let mut info = Array2::<f64>::zeros((nz, nz));
372 for a in 0..nz {
373 for b in 0..nz {
374 let mut acc = 0.0;
375 for i in 0..narrays {
376 acc += zmat[[i, a]] * (1.0 - 2.0 * h[i]) * zmat[[i, b]];
377 }
378 info[[a, b]] = acc;
379 }
380 }
381 let q2tz = q2.t().dot(&zmat);
382 let gram = q2tz.t().dot(&q2tz);
383 info = &info + &gram;
384
385 let i00 = info[[0, 0]];
387 for a in 0..ngam {
388 for b in 0..ngam {
389 info2[[a, b]] +=
390 info[[a + 1, b + 1]] - info[[a + 1, 0]] * info[[0, b + 1]] / i00;
391 }
392 }
393
394 if s2 > 1e-15 {
396 for i in 0..narrays {
397 zvec[i] += cw[i] * resid[i] * resid[i] / s2 - (1.0 - h[i]);
398 }
399 }
400 }
401
402 info2.mapv_inplace(|x| x / denom);
404 zvec.mapv_inplace(|x| x / denom);
405
406 let dl = z2.t().dot(&zvec);
408 let gamstep = solve_linear(&info2, &dl);
409 gam = &gam + &gamstep;
410 w = z2.dot(&gam).mapv(|x| (-x).exp());
411
412 let convcrit = dl.dot(&gamstep) / denom / (ngam as f64);
413 if convcrit.is_nan() || convcrit < tol {
414 break;
415 }
416 }
417
418 w
419}
420
421pub(crate) fn wfit_resid_lev_s2(
426 x: &Array2<f64>,
427 y: &[f64],
428 w: &[f64],
429) -> (Vec<f64>, Vec<f64>, f64) {
430 let n = x.nrows();
431 let p = x.ncols();
432 let sw: Vec<f64> = w.iter().map(|&v| v.sqrt()).collect();
433 let mut xw = x.clone();
434 for i in 0..n {
435 for j in 0..p {
436 xw[[i, j]] *= sw[i];
437 }
438 }
439 let yw: Array1<f64> = (0..n).map(|i| y[i] * sw[i]).collect();
440 let (qe, r) = qr_econ(&xw);
441 let cg = qe.t().dot(&yw);
442 let beta = solve_upper(&r, &cg);
443 let fitted = x.dot(&beta);
444 let resid: Vec<f64> = (0..n).map(|i| y[i] - fitted[i]).collect();
445 let lev: Vec<f64> = (0..n)
446 .map(|i| (0..p).map(|k| qe[[i, k]] * qe[[i, k]]).sum())
447 .collect();
448 let rss = yw.dot(&yw) - cg.dot(&cg);
449 let s2 = rss / (n - p) as f64;
450 (resid, lev, s2)
451}
452
453pub fn array_weights_gene_by_gene(
474 exprs: &Array2<f64>,
475 design: &Array2<f64>,
476 weights: Option<&Array2<f64>>,
477 var_design: Option<&Array2<f64>>,
478 prior_n: f64,
479) -> Array1<f64> {
480 let ngenes = exprs.nrows();
481 let narrays = exprs.ncols();
482 let nparams = design.ncols();
483
484 let z2 = match var_design {
485 Some(v) => v.to_owned(),
486 None => contr_sum(narrays),
487 };
488 let ngam = z2.ncols();
489 let nz = ngam + 1;
490
491 let mut zmat = Array2::<f64>::ones((narrays, nz));
493 for j in 0..ngam {
494 for i in 0..narrays {
495 zmat[[i, j + 1]] = z2[[i, j]];
496 }
497 }
498
499 let mut gam = Array1::<f64>::zeros(ngam);
500 let mut aw = Array1::<f64>::ones(narrays);
501 let mut info2 = z2.t().dot(&z2);
503 info2.mapv_inplace(|x| x * prior_n);
504
505 for i in 0..ngenes {
506 let mut w: Vec<f64> = aw.to_vec();
508 if let Some(wt) = weights {
509 for (j, wj) in w.iter_mut().enumerate() {
510 *wj *= wt[[i, j]];
511 }
512 }
513 let yrow: Vec<f64> = exprs.row(i).to_vec();
514
515 let mut d = vec![0.0f64; narrays];
516 let mut h1 = vec![0.0f64; narrays];
517 let s2;
518 if yrow.iter().any(|v| v.is_nan()) {
519 let obs: Vec<usize> = (0..narrays).filter(|&j| yrow[j].is_finite()).collect();
520 let nobs = obs.len();
521 if nobs <= 2 || nobs < nparams + 2 {
523 continue;
524 }
525 let mut xsub = Array2::<f64>::zeros((nobs, nparams));
526 let mut ysub = vec![0.0f64; nobs];
527 let mut wsub = vec![0.0f64; nobs];
528 for (r, &j) in obs.iter().enumerate() {
529 for c in 0..nparams {
530 xsub[[r, c]] = design[[j, c]];
531 }
532 ysub[r] = yrow[j];
533 wsub[r] = w[j];
534 }
535 let (resid, lev, s2v) = wfit_resid_lev_s2(&xsub, &ysub, &wsub);
536 s2 = s2v;
537 for (r, &j) in obs.iter().enumerate() {
538 d[j] = wsub[r] * resid[r] * resid[r];
539 h1[j] = 1.0 - lev[r];
540 }
541 } else {
542 let (resid, lev, s2v) = wfit_resid_lev_s2(design, &yrow, &w);
543 s2 = s2v;
544 for j in 0..narrays {
545 d[j] = w[j] * resid[j] * resid[j];
546 h1[j] = 1.0 - lev[j];
547 }
548 }
549 if s2 < 1e-15 {
550 continue;
551 }
552
553 let mut info = Array2::<f64>::zeros((nz, nz));
556 for a in 0..nz {
557 for b in 0..nz {
558 let mut acc = 0.0;
559 for j in 0..narrays {
560 acc += zmat[[j, a]] * h1[j] * zmat[[j, b]];
561 }
562 info[[a, b]] = acc;
563 }
564 }
565 let i00 = info[[0, 0]];
566 for a in 0..ngam {
567 for b in 0..ngam {
568 info2[[a, b]] += info[[a + 1, b + 1]] - info[[a + 1, 0]] * info[[0, b + 1]] / i00;
569 }
570 }
571
572 let z: Array1<f64> = (0..narrays).map(|j| d[j] / s2 - h1[j]).collect();
574 let dl = z2.t().dot(&z);
575 let step = solve_linear(&info2, &dl);
576 gam = &gam + &step;
577 aw = z2.dot(&gam).mapv(|x| (-x).exp());
578 }
579
580 aw
581}
582
583pub fn array_weights_quick(y: &Array2<f64>, fit: &MArrayLM) -> Array1<f64> {
596 let design = fit
597 .design
598 .as_ref()
599 .expect("arrayWeightsQuick requires a design in the fit");
600 let narrays = design.nrows();
601 let ngenes = y.nrows();
602
603 let fitted = fit.coefficients.dot(&design.t());
604 let (q, _r) = qr_econ(design);
606 let h: Vec<f64> = (0..narrays)
607 .map(|j| q.row(j).iter().map(|&v| v * v).sum::<f64>())
608 .collect();
609
610 let mut w = Array1::<f64>::zeros(narrays);
611 for j in 0..narrays {
612 let denom_j = 1.0 - h[j];
613 let mut sum = 0.0;
614 let mut cnt = 0usize;
615 for i in 0..ngenes {
616 let e = y[[i, j]] - fitted[[i, j]];
617 let s2 = fit.sigma[i] * fit.sigma[i];
618 let ratio = e * e / (s2 * denom_j);
619 if !ratio.is_nan() {
620 sum += ratio;
621 cnt += 1;
622 }
623 }
624 w[j] = cnt as f64 / sum;
625 }
626 w
627}
628
629#[cfg(test)]
630#[allow(clippy::excessive_precision)]
631mod tests {
632 use super::*;
633 use ndarray::array;
634
635 fn gbg_fixture() -> (Array2<f64>, Array2<f64>, Array2<f64>) {
639 let ngenes = 12usize;
640 let narrays = 6usize;
641 let grp = [0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
642 let scale = [0.5, 0.7, 1.0, 1.3, 1.6, 2.0];
643 let mut e = Array2::<f64>::zeros((ngenes, narrays));
644 let mut wt = Array2::<f64>::zeros((ngenes, narrays));
645 for g in 0..ngenes {
646 let gi = g as i64;
647 for j in 0..narrays {
648 let ji = j as i64;
649 let noise = (((gi * 7 + ji * 5) % 11) - 5) as f64 * 0.1 * scale[j];
650 e[[g, j]] = 5.0 + (gi % 5) as f64 * 0.5 + grp[j] * 0.3 + noise;
651 wt[[g, j]] = 0.5 + ((gi * 3 + ji * 2) % 7) as f64 * 0.15;
652 }
653 }
654 let mut design = Array2::<f64>::zeros((narrays, 2));
655 for j in 0..narrays {
656 design[[j, 0]] = 1.0;
657 design[[j, 1]] = grp[j];
658 }
659 (e, design, wt)
660 }
661
662 #[test]
663 fn array_weights_gene_by_gene_matches_r() {
664 let (e, design, wt) = gbg_fixture();
665
666 let aw = array_weights_gene_by_gene(&e, &design, None, None, 10.0);
667 let want = [
668 1.5261488797077414,
669 1.200903015330324,
670 1.2231111922725117,
671 1.2216065418856867,
672 0.42436321655995812,
673 0.86051835803628207,
674 ];
675 for (g, x) in aw.iter().zip(want.iter()) {
676 assert!((g - x).abs() < 1e-7, "no-weights: got {g}, want {x}");
677 }
678
679 let aww = array_weights_gene_by_gene(&e, &design, Some(&wt), None, 10.0);
680 let want_w = [
681 1.5273497496573842,
682 1.1661798617674437,
683 1.2204452025989252,
684 1.2146290021966843,
685 0.42565954641887849,
686 0.8897574953263433,
687 ];
688 for (g, x) in aww.iter().zip(want_w.iter()) {
689 assert!((g - x).abs() < 1e-7, "with-weights: got {g}, want {x}");
690 }
691
692 let mut ena = e.clone();
694 ena[[2, 4]] = f64::NAN;
695 ena[[6, 0]] = f64::NAN;
696 ena[[9, 5]] = f64::NAN;
697 let awn = array_weights_gene_by_gene(&ena, &design, None, None, 10.0);
698 let want_na = [
699 1.4797107434846923,
700 1.1254112099887159,
701 1.1542363030943652,
702 1.2058806596955525,
703 0.47074236183939411,
704 0.91649393378073563,
705 ];
706 for (g, x) in awn.iter().zip(want_na.iter()) {
707 assert!((g - x).abs() < 1e-7, "NA: got {g}, want {x}");
708 }
709 }
710
711 #[test]
715 fn array_weights_reml_matches_r() {
716 let exprs = array![
717 [4.871, 4.629, 4.697, 5.807, 4.798, 5.195],
718 [6.356, 6.349, 6.764, 4.125, 3.125, 4.752],
719 [4.298, 4.659, 4.508, 5.936, 4.075, 7.367],
720 [8.896, 9.420, 8.915, 9.165, 9.466, 8.598],
721 [6.563, 6.610, 6.813, 6.123, 6.155, 7.309],
722 [4.443, 4.283, 3.851, 5.435, 5.304, 5.784],
723 [7.247, 7.184, 7.620, 6.533, 7.878, 6.820],
724 [7.456, 7.644, 8.368, 9.096, 7.422, 10.245],
725 [7.229, 6.945, 6.986, 8.178, 7.445, 10.159],
726 [5.378, 5.177, 4.919, 7.692, 6.023, 7.432],
727 [8.748, 9.133, 9.280, 9.431, 10.394, 11.954],
728 [6.697, 7.010, 6.719, 4.293, 3.114, 5.796],
729 ];
730 let design = array![
732 [1.0, 0.0],
733 [1.0, 0.0],
734 [1.0, 0.0],
735 [1.0, 1.0],
736 [1.0, 1.0],
737 [1.0, 1.0],
738 ];
739 let want = [
740 1.611164881845,
741 1.659781018122,
742 1.455189349487,
743 1.160250145839,
744 0.462418787784,
745 0.478963710208,
746 ];
747
748 let w = array_weights(&exprs, &design, None, 10.0, 50, 1e-5);
749 assert_eq!(w.len(), want.len());
750 for (got, exp) in w.iter().zip(want.iter()) {
751 assert!(
752 (got - exp).abs() < 1e-6,
753 "array weight mismatch: got {got}, want {exp}"
754 );
755 }
756 }
757
758 #[test]
759 fn array_weights_prwts_reml_matches_r() {
760 let (e, design, wt) = gbg_fixture();
761
762 let w = array_weights_prwts_reml(&e, &design, &wt, None, 10.0, 50, 1e-5);
764 let want = [
765 1.5867284550700409,
766 1.1019127664080206,
767 1.2065279057943283,
768 1.4157105291721905,
769 0.29591187727091722,
770 1.1315557212822311,
771 ];
772 assert_eq!(w.len(), want.len());
773 for (got, exp) in w.iter().zip(want.iter()) {
774 assert!((got - exp).abs() < 1e-6, "p=2: got {got}, want {exp}");
775 }
776
777 let narrays = e.ncols();
779 let design1 = Array2::<f64>::ones((narrays, 1));
780 let w1 = array_weights_prwts_reml(&e, &design1, &wt, None, 10.0, 50, 1e-5);
781 let want1 = [
782 1.3749226879021179,
783 1.4906361441733864,
784 1.0601401145509048,
785 1.0761314591200781,
786 0.62056380371792619,
787 0.68918376453280328,
788 ];
789 for (got, exp) in w1.iter().zip(want1.iter()) {
790 assert!((got - exp).abs() < 1e-6, "p=1: got {got}, want {exp}");
791 }
792 }
793
794 #[test]
795 fn array_weights_quick_matches_r() {
796 let y = array![
797 [-0.59, 0.01, 0.79, 0.36],
798 [0.03, -0.19, -0.23, 0.65],
799 [-1.52, -0.77, 1.67, 1.81],
800 [-1.36, -0.22, 0.50, 1.22],
801 [1.18, -0.98, 0.16, 0.91],
802 ];
803 let design = array![[1.0, 0.0], [1.0, 0.0], [1.0, 1.0], [1.0, 1.0]];
804 let fit = crate::fit::lmfit(
805 &y,
806 &design,
807 (0..5).map(|i| i.to_string()).collect(),
808 vec!["Int".into(), "grp".into()],
809 )
810 .unwrap();
811 let w = array_weights_quick(&y, &fit);
812 let want = [
813 0.759166824278653,
814 0.759166824278652,
815 1.46462963844169,
816 1.46462963844169,
817 ];
818 assert_eq!(w.len(), want.len());
819 for (got, exp) in w.iter().zip(want.iter()) {
820 assert!((got - exp).abs() < 1e-9, "got {got}, want {exp}");
821 }
822 }
823
824 #[test]
825 fn contr_sum_shape() {
826 let c = contr_sum(4);
827 assert_eq!(c.dim(), (4, 3));
828 assert_eq!(c[[0, 0]], 1.0);
829 assert_eq!(c[[1, 1]], 1.0);
830 assert_eq!(c[[3, 0]], -1.0);
831 assert_eq!(c[[3, 2]], -1.0);
832 assert_eq!(c[[0, 1]], 0.0);
833 for j in 0..3 {
835 let s: f64 = (0..4).map(|i| c[[i, j]]).sum();
836 assert!(s.abs() < 1e-15);
837 }
838 }
839}