Skip to main content

math_audio_optimisation/
stack_linear_penalty.rs

1use ndarray::{Array1, Array2};
2
3use crate::LinearPenalty;
4
5pub(crate) fn stack_linear_penalty(dst: &mut LinearPenalty, src: &LinearPenalty) {
6    // Vertically stack A, lb, ub; pick max weight to enforce strongest among merged
7    let a_dst = dst.a.clone();
8    let a_src = src.a.clone();
9    let rows = a_dst.nrows() + a_src.nrows();
10    let cols = a_dst.ncols();
11    assert_eq!(
12        cols,
13        a_src.ncols(),
14        "LinearPenalty A width mismatch while stacking"
15    );
16    let mut a_new = Array2::<f64>::zeros((rows, cols));
17    // copy
18    for i in 0..a_dst.nrows() {
19        for j in 0..cols {
20            a_new[(i, j)] = a_dst[(i, j)];
21        }
22    }
23    for i in 0..a_src.nrows() {
24        for j in 0..cols {
25            a_new[(a_dst.nrows() + i, j)] = a_src[(i, j)];
26        }
27    }
28    let mut lb_new = Array1::<f64>::zeros(rows);
29    let mut ub_new = Array1::<f64>::zeros(rows);
30    for i in 0..a_dst.nrows() {
31        lb_new[i] = dst.lb[i];
32        ub_new[i] = dst.ub[i];
33    }
34    for i in 0..a_src.nrows() {
35        lb_new[a_dst.nrows() + i] = src.lb[i];
36        ub_new[a_dst.nrows() + i] = src.ub[i];
37    }
38    dst.a = a_new;
39    dst.lb = lb_new;
40    dst.ub = ub_new;
41    dst.weight = dst.weight.max(src.weight);
42}
43
44#[cfg(test)]
45mod tests {
46    use super::*;
47    use ndarray::array;
48
49    #[test]
50    fn stacks_rows_and_uses_stronger_weight() {
51        let mut dst = LinearPenalty {
52            a: array![[1.0, 0.0]],
53            lb: array![0.0],
54            ub: array![1.0],
55            weight: 10.0,
56        };
57        let src = LinearPenalty {
58            a: array![[0.0, 1.0], [1.0, 1.0]],
59            lb: array![-1.0, 0.5],
60            ub: array![2.0, 3.0],
61            weight: 20.0,
62        };
63
64        stack_linear_penalty(&mut dst, &src);
65
66        assert_eq!(dst.a, array![[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]);
67        assert_eq!(dst.lb, array![0.0, -1.0, 0.5]);
68        assert_eq!(dst.ub, array![1.0, 2.0, 3.0]);
69        assert_eq!(dst.weight, 20.0);
70    }
71
72    #[test]
73    #[should_panic(expected = "LinearPenalty A width mismatch while stacking")]
74    fn rejects_width_mismatch() {
75        let mut dst = LinearPenalty {
76            a: array![[1.0, 0.0]],
77            lb: array![0.0],
78            ub: array![1.0],
79            weight: 10.0,
80        };
81        let src = LinearPenalty {
82            a: array![[1.0, 0.0, 1.0]],
83            lb: array![0.0],
84            ub: array![1.0],
85            weight: 10.0,
86        };
87
88        stack_linear_penalty(&mut dst, &src);
89    }
90}