ellalgo_rs/oracles/
lmi_oracle.rs

1use super::ldlt_mgr::LDLTMgr;
2use crate::cutting_plane::OracleFeas;
3use ndarray::{Array1, Array2};
4
5pub type Arr = Array1<f64>;
6pub type Cut = (Arr, f64);
7
8/// The `LMIOracle` struct represents an oracle for a Linear Matrix Inequality (LMI) constraint.
9/// It contains the necessary data to evaluate the LMI constraint, including the matrix `mat_f`,
10/// the matrix `mat_f0`, and an `LDLTMgr` instance for managing the Cholesky decomposition.
11/// This oracle can be used to check the feasibility of a given point with respect to the LMI constraint.
12pub struct LMIOracle {
13    mat_f: Vec<Array2<f64>>,
14    mat_f0: Array2<f64>,
15    ldlt_mgr: LDLTMgr,
16}
17
18impl LMIOracle {
19    /// This function initializes a new LMIOracle struct with given matrices and an LDLTMgr instance.
20    ///
21    /// Arguments:
22    ///
23    /// * `mat_f`: The `mat_f` parameter is a vector of 2D arrays of type `f64`.
24    /// * `mat_b`: The `mat_b` parameter is an `Array2<f64>` type, which represents a 2-dimensional array
25    ///            of f64 (floating point numbers).
26    ///
27    /// Returns:
28    ///
29    /// An instance of the `LMIOracle` struct is being returned.
30    pub fn new(mat_f: Vec<Array2<f64>>, mat_b: Array2<f64>) -> Self {
31        let ldlt_mgr = LDLTMgr::new(mat_b.nrows());
32        LMIOracle {
33            mat_f,
34            mat_f0: mat_b,
35            ldlt_mgr,
36        }
37    }
38}
39
40impl OracleFeas<Arr> for LMIOracle {
41    type CutChoice = f64; // single cut
42
43    /// The function assesses the feasibility of a solution by calculating the difference between
44    /// elements of matrices based on input arrays.
45    ///
46    /// Arguments:
47    ///
48    /// * `mat_f0`: `mat_f0` is a reference to a 2D array of `f64` values.
49    /// * `mat_f`: The `mat_f` parameter in the `get_elem` function is a slice of `Array2<f64>` types.
50    ///            It represents an array of 2D matrices. Each element in the slice is a 2D matrix of f64 values.
51    /// * `xc`: The `xc` parameter in the `assess_feas` function is a reference to an `Array1<f64>`, which
52    ///         represents a one-dimensional array of floating-point numbers. This array is used as input to the
53    ///         function for some calculations related to feasibility assessment.
54    fn assess_feas(&mut self, xc: &Array1<f64>) -> Option<Cut> {
55        fn get_elem(
56            mat_f0: &Array2<f64>,
57            mat_f: &[Array2<f64>],
58            xc: &Array1<f64>,
59            i: usize,
60            j: usize,
61        ) -> f64 {
62            mat_f0[(i, j)]
63                - mat_f
64                    .iter()
65                    .zip(xc.iter())
66                    .map(|(mat_fk, xk)| mat_fk[(i, j)] * xk)
67                    .sum::<f64>()
68        }
69
70        let get_elem = |i: usize, j: usize| get_elem(&self.mat_f0, &self.mat_f, xc, i, j);
71
72        if self.ldlt_mgr.factor(get_elem) {
73            None
74        } else {
75            let ep = self.ldlt_mgr.witness();
76            let g = self
77                .mat_f
78                .iter()
79                .map(|mat_fk| self.ldlt_mgr.sym_quad(mat_fk))
80                .collect();
81            Some((g, ep))
82        }
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89    // use super::{ProfitOracle, ProfitOracleQ, ProfitRbOracle};
90    use crate::cutting_plane::{cutting_plane_optim, Options, OracleOptim};
91    use crate::ell::Ell;
92    use ndarray::{array, Array2, ShapeError};
93
94    struct MyOracle {
95        c: Array1<f64>,
96        lmi1: LMIOracle,
97        lmi2: LMIOracle,
98    }
99
100    impl OracleOptim<Arr> for MyOracle {
101        type CutChoice = f64; // single cut
102
103        fn assess_optim(&mut self, xc: &Arr, gamma: &mut f64) -> ((Arr, f64), bool) {
104            if let Some(cut) = self.lmi1.assess_feas(xc) {
105                return (cut, false);
106            }
107
108            if let Some(cut) = self.lmi2.assess_feas(xc) {
109                return (cut, false);
110            }
111
112            let f0 = self.c.dot(xc);
113            let fj = f0 - *gamma;
114            if fj > 0.0 {
115                return ((self.c.clone(), fj), false);
116            }
117
118            *gamma = f0;
119            ((self.c.clone(), 0.0), true)
120        }
121    }
122
123    fn run_lmi(oracle1: LMIOracle, oracle2: LMIOracle) -> usize {
124        let xinit = Arr::zeros(3);
125        let mut ellip = Ell::new_with_scalar(10.0, xinit);
126        let mut omega = MyOracle {
127            c: array![1.0, -1.0, 1.0],
128            lmi1: oracle1,
129            lmi2: oracle2,
130        };
131        let mut gamma = f64::INFINITY;
132        let (xbest, num_iters) =
133            cutting_plane_optim(&mut omega, &mut ellip, &mut gamma, &Options::default());
134        assert!(xbest.is_some());
135        num_iters
136    }
137
138    #[test]
139    fn test_lmi() -> Result<(), ShapeError> {
140        let f1 = vec![
141            Array2::from_shape_vec((2, 2), vec![-7.0, -11.0, -11.0, 3.0])?,
142            Array2::from_shape_vec((2, 2), vec![7.0, -18.0, -18.0, 8.0])?,
143            Array2::from_shape_vec((2, 2), vec![-2.0, -8.0, -8.0, 1.0])?,
144        ];
145        let b1 = Array2::from_shape_vec((2, 2), vec![33.0, -9.0, -9.0, 26.0])?;
146        let f2 = vec![
147            Array2::from_shape_vec(
148                (3, 3),
149                vec![-21.0, -11.0, 0.0, -11.0, 10.0, 8.0, 0.0, 8.0, 5.0],
150            )?,
151            Array2::from_shape_vec(
152                (3, 3),
153                vec![0.0, 10.0, 16.0, 10.0, -10.0, -10.0, 16.0, -10.0, 3.0],
154            )?,
155            Array2::from_shape_vec(
156                (3, 3),
157                vec![-5.0, 2.0, -17.0, 2.0, -6.0, 8.0, -17.0, 8.0, 6.0],
158            )?,
159        ];
160        let b2 = Array2::from_shape_vec(
161            (3, 3),
162            vec![14.0, 9.0, 40.0, 9.0, 91.0, 10.0, 40.0, 10.0, 15.0],
163        )?;
164
165        let oracle1 = LMIOracle::new(f1, b1);
166        let oracle2 = LMIOracle::new(f2, b2);
167        let result = run_lmi(oracle1, oracle2);
168        assert_eq!(result, 281);
169        Ok(())
170    }
171}