Skip to main content

gam_problem/
linear_constraints.rs

1use ndarray::{Array1, Array2};
2
3#[derive(Clone, Debug)]
4pub struct LinearInequalityConstraints {
5    pub a: Array2<f64>,
6    pub b: Array1<f64>,
7}
8
9impl LinearInequalityConstraints {
10    /// Construct with the equal-row-count invariant enforced. The dimensions
11    /// `a.nrows() == b.len()` are required by every downstream KKT / active-set
12    /// routine; routing every construction site through this constructor
13    /// eliminates a class of "rows out of sync" bugs at the type boundary.
14    #[inline]
15    pub fn new(a: Array2<f64>, b: Array1<f64>) -> Result<Self, String> {
16        if a.nrows() != b.len() {
17            return Err(format!(
18                "LinearInequalityConstraints: row count mismatch (A has {} rows, b has length {})",
19                a.nrows(),
20                b.len(),
21            ));
22        }
23        Ok(Self { a, b })
24    }
25
26    /// Build the per-coordinate `β_i >= lower_bounds[i]` inequality system.
27    /// Non-finite entries are treated as "no bound" and skipped; returns
28    /// `None` when every entry is non-finite so callers can short-circuit
29    /// the no-constraint case without allocating the empty A/b pair.
30    pub fn from_per_coordinate_lower_bounds(lower_bounds: &Array1<f64>) -> Option<Self> {
31        let active_rows: Vec<usize> = (0..lower_bounds.len())
32            .filter(|&i| lower_bounds[i].is_finite())
33            .collect();
34        if active_rows.is_empty() {
35            return None;
36        }
37        let p = lower_bounds.len();
38        let mut a = Array2::<f64>::zeros((active_rows.len(), p));
39        let mut b = Array1::<f64>::zeros(active_rows.len());
40        for (r, &idx) in active_rows.iter().enumerate() {
41            a[[r, idx]] = 1.0;
42            b[r] = lower_bounds[idx];
43        }
44        Some(Self { a, b })
45    }
46}
47
48#[cfg(test)]
49mod tests {
50    use super::*;
51    use ndarray::{array, Array1, Array2};
52
53    #[test]
54    fn new_ok_when_rows_match_b_len() {
55        let a = Array2::<f64>::eye(3);
56        let b = Array1::<f64>::zeros(3);
57        assert!(LinearInequalityConstraints::new(a, b).is_ok());
58    }
59
60    #[test]
61    fn new_err_on_row_count_mismatch() {
62        let a = Array2::<f64>::eye(3);
63        let b = Array1::<f64>::zeros(2);
64        assert!(LinearInequalityConstraints::new(a, b).is_err());
65    }
66
67    #[test]
68    fn from_lower_bounds_none_when_all_non_finite() {
69        let bounds = array![f64::NAN, f64::INFINITY, f64::NEG_INFINITY];
70        assert!(LinearInequalityConstraints::from_per_coordinate_lower_bounds(&bounds).is_none());
71    }
72
73    #[test]
74    fn from_lower_bounds_selects_finite_entries() {
75        // bounds = [NaN, 1.0, NaN] → one active constraint: β₁ ≥ 1.0
76        let bounds = array![f64::NAN, 1.0_f64, f64::NAN];
77        let c = LinearInequalityConstraints::from_per_coordinate_lower_bounds(&bounds).unwrap();
78        assert_eq!(c.a.nrows(), 1);
79        assert_eq!(c.a.ncols(), 3);
80        assert_eq!(c.a[[0, 1]], 1.0);
81        assert_eq!(c.b[0], 1.0);
82    }
83
84    #[test]
85    fn from_lower_bounds_multiple_active_rows() {
86        let bounds = array![0.5_f64, f64::NAN, -1.0];
87        let c = LinearInequalityConstraints::from_per_coordinate_lower_bounds(&bounds).unwrap();
88        assert_eq!(c.a.nrows(), 2);
89        // First row: col 0 active with bound 0.5
90        assert_eq!(c.a[[0, 0]], 1.0);
91        assert_eq!(c.b[0], 0.5);
92        // Second row: col 2 active with bound -1.0
93        assert_eq!(c.a[[1, 2]], 1.0);
94        assert_eq!(c.b[1], -1.0);
95    }
96}