ellalgo_rs/
ell.rs

1// mod lib;
2use crate::cutting_plane::{CutStatus, SearchSpace, SearchSpaceQ, UpdateByCutChoice};
3use crate::ell_calc::EllCalc;
4// #[macro_use]
5// extern crate ndarray;
6// use ndarray::prelude::*;
7use ndarray::Array1;
8use ndarray::Array2;
9
10/// The code defines a struct called "Ell" that represents an ellipsoid search space in the Ellipsoid
11/// method.
12///
13///   Ell = {x | (x - xc)^T mq^-1 (x - xc) \le \kappa}
14///
15/// Properties:
16///
17/// * `no_defer_trick`: A boolean flag indicating whether the defer trick should be used. The defer
18///          trick is a technique used in the Ellipsoid method to improve efficiency by deferring the update of
19///          the ellipsoid until a certain condition is met.
20/// * `mq`: A matrix representing the shape of the ellipsoid. It is a 2-dimensional array of f64 values.
21/// * `xc`: The `xc` property represents the center of the ellipsoid search space. It is a 1-dimensional
22///          array of floating-point numbers.
23/// * `kappa`: A scalar value that determines the size of the ellipsoid. A larger value of kappa results
24///          in a larger ellipsoid.
25/// * `ndim`: The `ndim` property represents the number of dimensions of the ellipsoid search space.
26/// * `helper`: The `helper` property is an instance of the `EllCalc` struct, which is used to perform
27///          calculations related to the ellipsoid search space. It provides methods for calculating the distance
28///          constant (`dc`), the center constant (`cc`), and the quadratic constant (`q`) used in the ell
29/// * `tsq`: The `tsq` property represents the squared Mahalanobis distance threshold of the ellipsoid.
30///          It is used to determine whether a point is inside or outside the ellipsoid.
31#[derive(Debug, Clone)]
32pub struct Ell {
33    pub no_defer_trick: bool,
34    pub mq: Array2<f64>,
35    pub xc: Array1<f64>,
36    pub kappa: f64,
37    ndim: usize,
38    helper: EllCalc,
39    pub tsq: f64,
40}
41
42impl Ell {
43    /// The function `new_with_matrix` constructs a new `Ell` object with the given parameters.
44    ///
45    /// Arguments:
46    ///
47    /// * `kappa`: The `kappa` parameter is a floating-point number that represents the curvature of the
48    ///            ellipse. It determines the shape of the ellipse, with higher values resulting in a more elongated
49    ///            shape and lower values resulting in a more circular shape.
50    /// * `mq`: The `mq` parameter is of type `Array2<f64>`, which represents a 2-dimensional array of `f64`
51    ///            (floating-point) values. It is used to store the matrix `mq` in the `Ell` object.
52    /// * `xc`: The parameter `xc` represents the center of the ellipsoid in n-dimensional space. It is an
53    ///            array of length `ndim`, where each element represents the coordinate of the center along a specific
54    ///            dimension.
55    ///
56    /// Returns:
57    ///
58    /// an instance of the `Ell` struct.
59    pub fn new_with_matrix(kappa: f64, mq: Array2<f64>, xc: Array1<f64>) -> Ell {
60        let ndim = xc.len();
61        let helper = EllCalc::new(ndim);
62
63        Ell {
64            kappa,
65            mq,
66            xc,
67            ndim,
68            helper,
69            no_defer_trick: false,
70            tsq: 0.0,
71        }
72    }
73
74    /// Creates a new [`Ell`].
75    ///
76    /// The function `new` creates a new `Ell` object with the given values.
77    ///
78    /// Arguments:
79    ///
80    /// * `val`: An array of f64 values representing the diagonal elements of a matrix.
81    /// * `xc`: `xc` is an `Array1<f64>` which represents the center of the ellipse. It contains the x and y
82    ///         coordinates of the center point.
83    ///
84    /// Returns:
85    ///
86    /// The function `new` returns an instance of the [`Ell`] struct.
87    ///
88    /// # Examples
89    ///
90    /// ```
91    /// use ellalgo_rs::ell::Ell;
92    /// use ndarray::arr1;
93    /// let val = arr1(&[1.0, 1.0]);
94    /// let xc = arr1(&[0.0, 0.0]);
95    /// let ellip = Ell::new(val, xc);
96    /// assert_eq!(ellip.kappa, 1.0);
97    /// assert_eq!(ellip.mq.shape(), &[2, 2]);
98    /// ```
99    pub fn new(val: Array1<f64>, xc: Array1<f64>) -> Ell {
100        Ell::new_with_matrix(1.0, Array2::from_diag(&val), xc)
101    }
102
103    /// The function `new_with_scalar` constructs a new [`Ell`] object with a scalar value and a vector.
104    ///
105    /// Arguments:
106    ///
107    /// * `val`: The `val` parameter is a scalar value of type `f64`. It represents the value of the scalar
108    ///          component of the `Ell` object.
109    /// * `xc`: The parameter `xc` is an array of type `Array1<f64>`. It represents the center coordinates
110    ///          of the ellipse.
111    ///
112    /// Returns:
113    ///
114    /// an instance of the [`Ell`] struct.
115    ///
116    /// # Examples
117    ///
118    /// ```
119    /// use ellalgo_rs::ell::Ell;
120    /// use ndarray::arr1;
121    /// let val = 1.0;
122    /// let xc = arr1(&[0.0, 0.0]);
123    /// let ellip = Ell::new_with_scalar(val, xc);
124    /// assert_eq!(ellip.kappa, 1.0);
125    /// assert_eq!(ellip.mq.shape(), &[2, 2]);
126    /// assert_eq!(ellip.xc.shape(), &[2]);
127    /// assert_eq!(ellip.xc[0], 0.0);
128    /// assert_eq!(ellip.xc[1], 0.0);
129    /// assert_eq!(ellip.tsq, 0.0);
130    /// ```
131    pub fn new_with_scalar(val: f64, xc: Array1<f64>) -> Ell {
132        Ell::new_with_matrix(val, Array2::eye(xc.len()), xc)
133    }
134
135    /// Update ellipsoid core function using the cut
136    ///
137    ///  $grad^T * (x - xc) + beta <= 0$
138    ///
139    /// The `update_core` function in Rust updates the ellipsoid core based on a given gradient and beta
140    /// value using a cut strategy.
141    ///
142    /// Arguments:
143    ///
144    /// * `grad`: A reference to an Array1<f64> representing the gradient vector.
145    /// * `beta`: The `beta` parameter is a value that is used in the inequality constraint of the ellipsoid
146    ///            core function. It represents the threshold for the constraint, and the function checks if the dot
147    ///            product of the gradient and the difference between `x` and `xc` plus `beta` is less than or
148    /// * `cut_strategy`: The `cut_strategy` parameter is a closure that takes two arguments: `beta` and
149    ///            `tsq`. It returns a tuple containing a `CutStatus` and a tuple `(rho, sigma, delta)`. The
150    ///            `cut_strategy` function is used to determine the values of `rho`, `
151    ///
152    /// Returns:
153    ///
154    /// a value of type `CutStatus`.
155    fn update_core<T, F>(&mut self, grad: &Array1<f64>, beta: &T, cut_strategy: F) -> CutStatus
156    where
157        T: UpdateByCutChoice<Self, ArrayType = Array1<f64>>,
158        F: FnOnce(&T, f64) -> (CutStatus, (f64, f64, f64)),
159    {
160        let grad_t = self.mq.dot(grad);
161        let omega = grad.dot(&grad_t);
162
163        self.tsq = self.kappa * omega;
164        // let status = self.helper.calc_bias_cut(*beta);
165        let (status, (rho, sigma, delta)) = cut_strategy(beta, self.tsq);
166        if status != CutStatus::Success {
167            return status;
168        }
169
170        self.xc -= &((rho / omega) * &grad_t); // n
171
172        // n*(n+1)/2 + n
173        // self.mq -= (self.sigma / omega) * xt::linalg::outer(grad_t, grad_t);
174        let r = sigma / omega;
175        for i in 0..self.ndim {
176            let r_grad_t = r * grad_t[i];
177            for j in 0..i {
178                self.mq[[i, j]] -= r_grad_t * grad_t[j];
179                self.mq[[j, i]] = self.mq[[i, j]];
180            }
181            self.mq[[i, i]] -= r_grad_t * grad_t[i];
182        }
183
184        self.kappa *= delta;
185
186        if self.no_defer_trick {
187            self.mq *= self.kappa;
188            self.kappa = 1.0;
189        }
190        status
191    }
192}
193
194/// The `impl SearchSpace for Ell` block is implementing the `SearchSpace` trait for the `Ell` struct.
195impl SearchSpace for Ell {
196    type ArrayType = Array1<f64>;
197
198    /// The function `xc` returns a copy of the `xc` array.
199    #[inline]
200    fn xc(&self) -> Self::ArrayType {
201        self.xc.clone()
202    }
203
204    /// The `tsq` function returns the value of the `tsq` field of the struct.
205    ///
206    /// Returns:
207    ///
208    /// The method `tsq` is returning a value of type `f64`.
209    #[inline]
210    fn tsq(&self) -> f64 {
211        self.tsq
212    }
213
214    /// The `update_bias_cut` function updates the decision variable based on the given cut.
215    ///
216    /// Arguments:
217    ///
218    /// * `cut`: A tuple containing two elements:
219    ///
220    /// Returns:
221    ///
222    /// The `update_bias_cut` function returns a value of type `CutStatus`.
223    fn update_bias_cut<T>(&mut self, cut: &(Self::ArrayType, T)) -> CutStatus
224    where
225        T: UpdateByCutChoice<Self, ArrayType = Self::ArrayType>,
226    {
227        let (grad, beta) = cut;
228        beta.update_bias_cut_by(self, grad)
229    }
230
231    /// The `update_central_cut` function updates the cut choices using the gradient and beta values.
232    ///
233    /// Arguments:
234    ///
235    /// * `cut`: The `cut` parameter is a tuple containing two elements. The first element is of type
236    ///          `Self::ArrayType`, and the second element is of type `T`.
237    ///
238    /// Returns:
239    ///
240    /// The function `update_central_cut` returns a value of type `CutStatus`.
241    fn update_central_cut<T>(&mut self, cut: &(Self::ArrayType, T)) -> CutStatus
242    where
243        T: UpdateByCutChoice<Self, ArrayType = Self::ArrayType>,
244    {
245        let (grad, beta) = cut;
246        beta.update_central_cut_by(self, grad)
247    }
248
249    fn set_xc(&mut self, x: Self::ArrayType) {
250        self.xc = x;
251    }
252}
253
254impl SearchSpaceQ for Ell {
255    type ArrayType = Array1<f64>;
256
257    /// The function `xc` returns a copy of the `xc` array.
258    #[inline]
259    fn xc(&self) -> Self::ArrayType {
260        self.xc.clone()
261    }
262
263    /// The `tsq` function returns the value of the `tsq` field of the struct.
264    ///
265    /// Returns:
266    ///
267    /// The method `tsq` is returning a value of type `f64`.
268    #[inline]
269    fn tsq(&self) -> f64 {
270        self.tsq
271    }
272
273    /// The `update_q` function updates the decision variable based on the given cut.
274    ///
275    /// Arguments:
276    ///
277    /// * `cut`: A tuple containing two elements:
278    ///
279    /// Returns:
280    ///
281    /// The `update_bias_cut` function returns a value of type `CutStatus`.
282    fn update_q<T>(&mut self, cut: &(Self::ArrayType, T)) -> CutStatus
283    where
284        T: UpdateByCutChoice<Self, ArrayType = Self::ArrayType>,
285    {
286        let (grad, beta) = cut;
287        beta.update_q_by(self, grad)
288    }
289}
290
291impl UpdateByCutChoice<Ell> for f64 {
292    type ArrayType = Array1<f64>;
293
294    fn update_bias_cut_by(&self, ellip: &mut Ell, grad: &Self::ArrayType) -> CutStatus {
295        let beta = self;
296        let helper = ellip.helper.clone();
297        ellip.update_core(grad, beta, |beta, tsq| helper.calc_bias_cut(*beta, tsq))
298    }
299
300    fn update_central_cut_by(&self, ellip: &mut Ell, grad: &Self::ArrayType) -> CutStatus {
301        let beta = self;
302        let helper = ellip.helper.clone();
303        ellip.update_core(grad, beta, |_beta, tsq| helper.calc_central_cut(tsq))
304    }
305
306    fn update_q_by(&self, ellip: &mut Ell, grad: &Self::ArrayType) -> CutStatus {
307        let beta = self;
308        let helper = ellip.helper.clone();
309        ellip.update_core(grad, beta, |beta, tsq| helper.calc_bias_cut_q(*beta, tsq))
310    }
311}
312
313impl UpdateByCutChoice<Ell> for (f64, Option<f64>) {
314    type ArrayType = Array1<f64>;
315
316    fn update_bias_cut_by(&self, ellip: &mut Ell, grad: &Self::ArrayType) -> CutStatus {
317        let beta = self;
318        let helper = ellip.helper.clone();
319        ellip.update_core(grad, beta, |beta, tsq| {
320            helper.calc_single_or_parallel_bias_cut(beta, tsq)
321        })
322    }
323
324    fn update_central_cut_by(&self, ellip: &mut Ell, grad: &Self::ArrayType) -> CutStatus {
325        let beta = self;
326        let helper = ellip.helper.clone();
327        ellip.update_core(grad, beta, |beta, tsq| {
328            helper.calc_single_or_parallel_central_cut(beta, tsq)
329        })
330    }
331
332    fn update_q_by(&self, ellip: &mut Ell, grad: &Self::ArrayType) -> CutStatus {
333        let beta = self;
334        let helper = ellip.helper.clone();
335        ellip.update_core(grad, beta, |beta, tsq| {
336            helper.calc_single_or_parallel_q(beta, tsq)
337        })
338    }
339}
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344    use approx_eq::assert_approx_eq;
345
346    #[test]
347    fn test_construct() {
348        let ellip = Ell::new_with_scalar(0.01, Array1::zeros(4));
349        assert!(!ellip.no_defer_trick);
350        assert_approx_eq!(ellip.kappa, 0.01);
351        assert_eq!(ellip.mq, Array2::eye(4));
352        assert_eq!(ellip.xc, Array1::zeros(4));
353        assert_approx_eq!(ellip.tsq, 0.0);
354    }
355
356    #[test]
357    fn test_update_central_cut() {
358        let mut ellip = Ell::new_with_scalar(0.01, Array1::zeros(4));
359        let cut = (0.5 * Array1::ones(4), 0.0);
360        let status = ellip.update_central_cut(&cut);
361        assert_eq!(status, CutStatus::Success);
362        assert_eq!(ellip.xc, -0.01 * Array1::ones(4));
363        assert_eq!(ellip.mq, Array2::eye(4) - 0.1 * Array2::ones((4, 4)));
364        assert_approx_eq!(ellip.kappa, 0.16 / 15.0);
365        assert_approx_eq!(ellip.tsq, 0.01);
366    }
367
368    #[test]
369    fn test_update_bias_cut() {
370        let mut ellip = Ell::new_with_scalar(0.01, Array1::zeros(4));
371        let cut = (0.5 * Array1::ones(4), 0.05);
372        let status = ellip.update_bias_cut(&cut);
373        assert_eq!(status, CutStatus::Success);
374        assert_approx_eq!(ellip.xc[0], -0.03);
375        assert_approx_eq!(ellip.mq[(0, 0)], 0.8);
376        assert_approx_eq!(ellip.kappa, 0.008);
377        assert_approx_eq!(ellip.tsq, 0.01);
378    }
379
380    #[test]
381    fn test_update_parallel_central_cut() {
382        let mut ellip = Ell::new_with_scalar(0.01, Array1::zeros(4));
383        let cut = (0.5 * Array1::ones(4), (0.0, Some(0.05)));
384        let status = ellip.update_central_cut(&cut);
385        assert_eq!(status, CutStatus::Success);
386        assert_eq!(ellip.xc, -0.01 * Array1::ones(4));
387        assert_eq!(ellip.mq, Array2::eye(4) - 0.2 * Array2::ones((4, 4)));
388        assert_approx_eq!(ellip.kappa, 0.012);
389        assert_approx_eq!(ellip.tsq, 0.01);
390    }
391
392    #[test]
393    fn test_update_parallel() {
394        let mut ellip = Ell::new_with_scalar(0.01, Array1::zeros(4));
395        let cut = (0.5 * Array1::ones(4), (0.01, Some(0.04)));
396        let status = ellip.update_bias_cut(&cut);
397        assert_eq!(status, CutStatus::Success);
398        assert_approx_eq!(ellip.xc[0], -0.0116);
399        assert_approx_eq!(ellip.mq[(0, 0)], 1.0 - 0.232);
400        assert_approx_eq!(ellip.kappa, 0.01232);
401        assert_approx_eq!(ellip.tsq, 0.01);
402    }
403
404    #[test]
405    fn test_update_parallel_no_effect() {
406        let mut ellip = Ell::new_with_scalar(0.01, Array1::zeros(4));
407        let cut = (0.5 * Array1::ones(4), (-0.04, Some(0.0625)));
408        let status = ellip.update_bias_cut(&cut);
409        assert_eq!(status, CutStatus::Success);
410        assert_eq!(ellip.xc, Array1::zeros(4));
411        assert_eq!(ellip.mq, Array2::eye(4));
412        assert_approx_eq!(ellip.kappa, 0.01);
413    }
414
415    #[test]
416    fn test_update_q_no_effect() {
417        let mut ellip = Ell::new_with_scalar(0.01, Array1::zeros(4));
418        let cut = (0.5 * Array1::ones(4), (-0.04, Some(0.0625)));
419        let status = ellip.update_q(&cut);
420        assert_eq!(status, CutStatus::NoEffect);
421        assert_eq!(ellip.xc, Array1::zeros(4));
422        assert_eq!(ellip.mq, Array2::eye(4));
423        assert_approx_eq!(ellip.kappa, 0.01);
424    }
425}