math_audio_optimisation/
stack_linear_penalty.rs1use ndarray::{Array1, Array2};
2
3use crate::LinearPenalty;
4
5pub(crate) fn stack_linear_penalty(dst: &mut LinearPenalty, src: &LinearPenalty) {
6 let a_dst = dst.a.clone();
8 let a_src = src.a.clone();
9 let rows = a_dst.nrows() + a_src.nrows();
10 let cols = a_dst.ncols();
11 assert_eq!(
12 cols,
13 a_src.ncols(),
14 "LinearPenalty A width mismatch while stacking"
15 );
16 let mut a_new = Array2::<f64>::zeros((rows, cols));
17 for i in 0..a_dst.nrows() {
19 for j in 0..cols {
20 a_new[(i, j)] = a_dst[(i, j)];
21 }
22 }
23 for i in 0..a_src.nrows() {
24 for j in 0..cols {
25 a_new[(a_dst.nrows() + i, j)] = a_src[(i, j)];
26 }
27 }
28 let mut lb_new = Array1::<f64>::zeros(rows);
29 let mut ub_new = Array1::<f64>::zeros(rows);
30 for i in 0..a_dst.nrows() {
31 lb_new[i] = dst.lb[i];
32 ub_new[i] = dst.ub[i];
33 }
34 for i in 0..a_src.nrows() {
35 lb_new[a_dst.nrows() + i] = src.lb[i];
36 ub_new[a_dst.nrows() + i] = src.ub[i];
37 }
38 dst.a = a_new;
39 dst.lb = lb_new;
40 dst.ub = ub_new;
41 dst.weight = dst.weight.max(src.weight);
42}
43
44#[cfg(test)]
45mod tests {
46 use super::*;
47 use ndarray::array;
48
49 #[test]
50 fn stacks_rows_and_uses_stronger_weight() {
51 let mut dst = LinearPenalty {
52 a: array![[1.0, 0.0]],
53 lb: array![0.0],
54 ub: array![1.0],
55 weight: 10.0,
56 };
57 let src = LinearPenalty {
58 a: array![[0.0, 1.0], [1.0, 1.0]],
59 lb: array![-1.0, 0.5],
60 ub: array![2.0, 3.0],
61 weight: 20.0,
62 };
63
64 stack_linear_penalty(&mut dst, &src);
65
66 assert_eq!(dst.a, array![[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]);
67 assert_eq!(dst.lb, array![0.0, -1.0, 0.5]);
68 assert_eq!(dst.ub, array![1.0, 2.0, 3.0]);
69 assert_eq!(dst.weight, 20.0);
70 }
71
72 #[test]
73 #[should_panic(expected = "LinearPenalty A width mismatch while stacking")]
74 fn rejects_width_mismatch() {
75 let mut dst = LinearPenalty {
76 a: array![[1.0, 0.0]],
77 lb: array![0.0],
78 ub: array![1.0],
79 weight: 10.0,
80 };
81 let src = LinearPenalty {
82 a: array![[1.0, 0.0, 1.0]],
83 lb: array![0.0],
84 ub: array![1.0],
85 weight: 10.0,
86 };
87
88 stack_linear_penalty(&mut dst, &src);
89 }
90}