ellalgo_rs/
ell.rs

1// mod lib;
2use crate::cutting_plane::{CutStatus, SearchSpace, SearchSpaceQ, UpdateByCutChoice};
3use crate::ell_calc::EllCalc;
4// #[macro_use]
5// extern crate ndarray;
6// use ndarray::prelude::*;
7use ndarray::Array1;
8use ndarray::Array2;
9use ndarray::Axis;
10
11/// The code defines a struct called "Ell" that represents an ellipsoid search space in the Ellipsoid
12/// method.
13///
14///   Ell = {x | (x - xc)^T mq^-1 (x - xc) \le \kappa}
15///
16/// Properties:
17///
18/// * `no_defer_trick`: A boolean flag indicating whether the defer trick should be used. The defer
19///          trick is a technique used in the Ellipsoid method to improve efficiency by deferring the update of
20///          the ellipsoid until a certain condition is met.
21/// * `mq`: A matrix representing the shape of the ellipsoid. It is a 2-dimensional array of f64 values.
22/// * `xc`: The `xc` property represents the center of the ellipsoid search space. It is a 1-dimensional
23///          array of floating-point numbers.
24/// * `kappa`: A scalar value that determines the size of the ellipsoid. A larger value of kappa results
25///          in a larger ellipsoid.
26/// * `ndim`: The `ndim` property represents the number of dimensions of the ellipsoid search space.
27/// * `helper`: The `helper` property is an instance of the `EllCalc` struct, which is used to perform
28///          calculations related to the ellipsoid search space. It provides methods for calculating the distance
29///          constant (`dc`), the center constant (`cc`), and the quadratic constant (`q`) used in the ell
30/// * `tsq`: The `tsq` property represents the squared Mahalanobis distance threshold of the ellipsoid.
31///          It is used to determine whether a point is inside or outside the ellipsoid.
32#[derive(Debug, Clone)]
33pub struct Ell {
34    pub no_defer_trick: bool,
35    pub mq: Array2<f64>,
36    pub xc: Array1<f64>,
37    pub kappa: f64,
38    helper: EllCalc,
39    pub tsq: f64,
40}
41
42impl Ell {
43    /// The function `new_with_matrix` constructs a new `Ell` object with the given parameters.
44    ///
45    /// Arguments:
46    ///
47    /// * `kappa`: The `kappa` parameter is a floating-point number that represents the curvature of the
48    ///            ellipse. It determines the shape of the ellipse, with higher values resulting in a more elongated
49    ///            shape and lower values resulting in a more circular shape.
50    /// * `mq`: The `mq` parameter is of type `Array2<f64>`, which represents a 2-dimensional array of `f64`
51    ///            (floating-point) values. It is used to store the matrix `mq` in the `Ell` object.
52    /// * `xc`: The parameter `xc` represents the center of the ellipsoid in n-dimensional space. It is an
53    ///            array of length `ndim`, where each element represents the coordinate of the center along a specific
54    ///            dimension.
55    ///
56    /// Returns:
57    ///
58    /// an instance of the `Ell` struct.
59    pub fn new_with_matrix(kappa: f64, mq: Array2<f64>, xc: Array1<f64>) -> Ell {
60        let helper = EllCalc::new(xc.len());
61
62        Ell {
63            kappa,
64            mq,
65            xc,
66            helper,
67            no_defer_trick: false,
68            tsq: 0.0,
69        }
70    }
71
72    /// Creates a new [`Ell`].
73    ///
74    /// The function `new` creates a new `Ell` object with the given values.
75    ///
76    /// Arguments:
77    ///
78    /// * `val`: An array of f64 values representing the diagonal elements of a matrix.
79    /// * `xc`: `xc` is an `Array1<f64>` which represents the center of the ellipse. It contains the x and y
80    ///         coordinates of the center point.
81    ///
82    /// Returns:
83    ///
84    /// The function `new` returns an instance of the [`Ell`] struct.
85    ///
86    /// # Examples
87    ///
88    /// ```
89    /// use ellalgo_rs::ell::Ell;
90    /// use ndarray::arr1;
91    /// let val = arr1(&[1.0, 1.0]);
92    /// let xc = arr1(&[0.0, 0.0]);
93    /// let ellip = Ell::new(val, xc);
94    /// assert_eq!(ellip.kappa, 1.0);
95    /// assert_eq!(ellip.mq.shape(), &[2, 2]);
96    /// ```
97    pub fn new(val: Array1<f64>, xc: Array1<f64>) -> Ell {
98        Ell::new_with_matrix(1.0, Array2::from_diag(&val), xc)
99    }
100
101    /// The function `new_with_scalar` constructs a new [`Ell`] object with a scalar value and a vector.
102    ///
103    /// Arguments:
104    ///
105    /// * `val`: The `val` parameter is a scalar value of type `f64`. It represents the value of the scalar
106    ///          component of the `Ell` object.
107    /// * `xc`: The parameter `xc` is an array of type `Array1<f64>`. It represents the center coordinates
108    ///          of the ellipse.
109    ///
110    /// Returns:
111    ///
112    /// an instance of the [`Ell`] struct.
113    ///
114    /// # Examples
115    ///
116    /// ```
117    /// use ellalgo_rs::ell::Ell;
118    /// use ndarray::arr1;
119    /// let val = 1.0;
120    /// let xc = arr1(&[0.0, 0.0]);
121    /// let ellip = Ell::new_with_scalar(val, xc);
122    /// assert_eq!(ellip.kappa, 1.0);
123    /// assert_eq!(ellip.mq.shape(), &[2, 2]);
124    /// assert_eq!(ellip.xc.shape(), &[2]);
125    /// assert_eq!(ellip.xc[0], 0.0);
126    /// assert_eq!(ellip.xc[1], 0.0);
127    /// assert_eq!(ellip.tsq, 0.0);
128    /// ```
129    pub fn new_with_scalar(val: f64, xc: Array1<f64>) -> Ell {
130        Ell::new_with_matrix(val, Array2::eye(xc.len()), xc)
131    }
132
133    /// Update ellipsoid core function using the cut
134    ///
135    ///  $grad^T * (x - xc) + beta <= 0$
136    ///
137    /// The `update_core` function in Rust updates the ellipsoid core based on a given gradient and beta
138    /// value using a cut strategy.
139    ///
140    /// Arguments:
141    ///
142    /// * `grad`: A reference to an Array1<f64> representing the gradient vector.
143    /// * `beta`: The `beta` parameter is a value that is used in the inequality constraint of the ellipsoid
144    ///            core function. It represents the threshold for the constraint, and the function checks if the dot
145    ///            product of the gradient and the difference between `x` and `xc` plus `beta` is less than or
146    /// * `cut_strategy`: The `cut_strategy` parameter is a closure that takes two arguments: `beta` and
147    ///            `tsq`. It returns a tuple containing a `CutStatus` and a tuple `(rho, sigma, delta)`. The
148    ///            `cut_strategy` function is used to determine the values of `rho`, `
149    ///
150    /// Returns:
151    ///
152    /// a value of type `CutStatus`.
153    fn update_core<T, F>(&mut self, grad: &Array1<f64>, beta: &T, cut_strategy: F) -> CutStatus
154    where
155        T: UpdateByCutChoice<Self, ArrayType = Array1<f64>>,
156        F: FnOnce(&T, f64) -> (CutStatus, (f64, f64, f64)),
157    {
158        let grad_t = self.mq.dot(grad);
159        let omega = grad.dot(&grad_t);
160
161        self.tsq = self.kappa * omega;
162        // let status = self.helper.calc_bias_cut(*beta);
163        let (status, (rho, sigma, delta)) = cut_strategy(beta, self.tsq);
164        if status != CutStatus::Success {
165            return status;
166        }
167
168        self.xc -= &((rho / omega) * &grad_t); // n
169
170        // n*(n+1)/2 + n
171        let r = sigma / omega;
172        let grad_t_view = grad_t.view();
173        self.mq.scaled_add(
174            -r,
175            &(&grad_t_view.insert_axis(Axis(1)) * &grad_t_view.insert_axis(Axis(0))),
176        );
177
178        self.kappa *= delta;
179
180        if self.no_defer_trick {
181            self.mq *= self.kappa;
182            self.kappa = 1.0;
183        }
184        status
185    }
186}
187
188/// The `impl SearchSpace for Ell` block is implementing the `SearchSpace` trait for the `Ell` struct.
189impl SearchSpace for Ell {
190    type ArrayType = Array1<f64>;
191
192    /// The function `xc` returns a copy of the `xc` array.
193    #[inline]
194    fn xc(&self) -> Self::ArrayType {
195        self.xc.clone()
196    }
197
198    /// The `tsq` function returns the value of the `tsq` field of the struct.
199    ///
200    /// Returns:
201    ///
202    /// The method `tsq` is returning a value of type `f64`.
203    #[inline]
204    fn tsq(&self) -> f64 {
205        self.tsq
206    }
207
208    /// The `update_bias_cut` function updates the decision variable based on the given cut.
209    ///
210    /// Arguments:
211    ///
212    /// * `cut`: A tuple containing two elements:
213    ///
214    /// Returns:
215    ///
216    /// The `update_bias_cut` function returns a value of type `CutStatus`.
217    fn update_bias_cut<T>(&mut self, cut: &(Self::ArrayType, T)) -> CutStatus
218    where
219        T: UpdateByCutChoice<Self, ArrayType = Self::ArrayType>,
220    {
221        let (grad, beta) = cut;
222        beta.update_bias_cut_by(self, grad)
223    }
224
225    /// The `update_central_cut` function updates the cut choices using the gradient and beta values.
226    ///
227    /// Arguments:
228    ///
229    /// * `cut`: The `cut` parameter is a tuple containing two elements. The first element is of type
230    ///          `Self::ArrayType`, and the second element is of type `T`.
231    ///
232    /// Returns:
233    ///
234    /// The function `update_central_cut` returns a value of type `CutStatus`.
235    fn update_central_cut<T>(&mut self, cut: &(Self::ArrayType, T)) -> CutStatus
236    where
237        T: UpdateByCutChoice<Self, ArrayType = Self::ArrayType>,
238    {
239        let (grad, beta) = cut;
240        beta.update_central_cut_by(self, grad)
241    }
242
243    fn set_xc(&mut self, x: Self::ArrayType) {
244        self.xc = x;
245    }
246}
247
248impl SearchSpaceQ for Ell {
249    type ArrayType = Array1<f64>;
250
251    /// The function `xc` returns a copy of the `xc` array.
252    #[inline]
253    fn xc(&self) -> Self::ArrayType {
254        self.xc.clone()
255    }
256
257    /// The `tsq` function returns the value of the `tsq` field of the struct.
258    ///
259    /// Returns:
260    ///
261    /// The method `tsq` is returning a value of type `f64`.
262    #[inline]
263    fn tsq(&self) -> f64 {
264        self.tsq
265    }
266
267    /// The `update_q` function updates the decision variable based on the given cut.
268    ///
269    /// Arguments:
270    ///
271    /// * `cut`: A tuple containing two elements:
272    ///
273    /// Returns:
274    ///
275    /// The `update_bias_cut` function returns a value of type `CutStatus`.
276    fn update_q<T>(&mut self, cut: &(Self::ArrayType, T)) -> CutStatus
277    where
278        T: UpdateByCutChoice<Self, ArrayType = Self::ArrayType>,
279    {
280        let (grad, beta) = cut;
281        beta.update_q_by(self, grad)
282    }
283}
284
285trait CutType {
286    fn call_bias_cut(&self, helper: &EllCalc, tsq: f64) -> (CutStatus, (f64, f64, f64));
287    fn call_central_cut(&self, helper: &EllCalc, tsq: f64) -> (CutStatus, (f64, f64, f64));
288    fn call_q_cut(&self, helper: &EllCalc, tsq: f64) -> (CutStatus, (f64, f64, f64));
289}
290
291impl CutType for f64 {
292    fn call_bias_cut(&self, helper: &EllCalc, tsq: f64) -> (CutStatus, (f64, f64, f64)) {
293        helper.calc_bias_cut(*self, tsq)
294    }
295
296    fn call_central_cut(&self, helper: &EllCalc, tsq: f64) -> (CutStatus, (f64, f64, f64)) {
297        helper.calc_central_cut(tsq)
298    }
299
300    fn call_q_cut(&self, helper: &EllCalc, tsq: f64) -> (CutStatus, (f64, f64, f64)) {
301        helper.calc_bias_cut_q(*self, tsq)
302    }
303}
304
305impl CutType for (f64, Option<f64>) {
306    fn call_bias_cut(&self, helper: &EllCalc, tsq: f64) -> (CutStatus, (f64, f64, f64)) {
307        helper.calc_single_or_parallel_bias_cut(self, tsq)
308    }
309
310    fn call_central_cut(&self, helper: &EllCalc, tsq: f64) -> (CutStatus, (f64, f64, f64)) {
311        helper.calc_single_or_parallel_central_cut(self, tsq)
312    }
313
314    fn call_q_cut(&self, helper: &EllCalc, tsq: f64) -> (CutStatus, (f64, f64, f64)) {
315        helper.calc_single_or_parallel_q(self, tsq)
316    }
317}
318
319impl<T: CutType> UpdateByCutChoice<Ell> for T {
320    type ArrayType = Array1<f64>;
321
322    fn update_bias_cut_by(&self, ellip: &mut Ell, grad: &Self::ArrayType) -> CutStatus {
323        let helper = ellip.helper.clone();
324        ellip.update_core(grad, self, |beta, tsq| beta.call_bias_cut(&helper, tsq))
325    }
326
327    fn update_central_cut_by(&self, ellip: &mut Ell, grad: &Self::ArrayType) -> CutStatus {
328        let helper = ellip.helper.clone();
329        ellip.update_core(grad, self, |beta, tsq| beta.call_central_cut(&helper, tsq))
330    }
331
332    fn update_q_by(&self, ellip: &mut Ell, grad: &Self::ArrayType) -> CutStatus {
333        let helper = ellip.helper.clone();
334        ellip.update_core(grad, self, |beta, tsq| beta.call_q_cut(&helper, tsq))
335    }
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341    use approx_eq::assert_approx_eq;
342
343    #[test]
344    fn test_construct() {
345        let ellip = Ell::new_with_scalar(0.01, Array1::zeros(4));
346        assert!(!ellip.no_defer_trick);
347        assert_approx_eq!(ellip.kappa, 0.01);
348        assert_eq!(ellip.mq, Array2::eye(4));
349        assert_eq!(ellip.xc, Array1::zeros(4));
350        assert_approx_eq!(ellip.tsq, 0.0);
351    }
352
353    #[test]
354    fn test_update_central_cut() {
355        let mut ellip = Ell::new_with_scalar(0.01, Array1::zeros(4));
356        let cut = (0.5 * Array1::ones(4), 0.0);
357        let status = ellip.update_central_cut(&cut);
358        assert_eq!(status, CutStatus::Success);
359        assert_eq!(ellip.xc, -0.01 * Array1::ones(4));
360        assert_eq!(ellip.mq, Array2::eye(4) - 0.1 * Array2::ones((4, 4)));
361        assert_approx_eq!(ellip.kappa, 0.16 / 15.0);
362        assert_approx_eq!(ellip.tsq, 0.01);
363    }
364
365    #[test]
366    fn test_update_bias_cut() {
367        let mut ellip = Ell::new_with_scalar(0.01, Array1::zeros(4));
368        let cut = (0.5 * Array1::ones(4), 0.05);
369        let status = ellip.update_bias_cut(&cut);
370        assert_eq!(status, CutStatus::Success);
371        assert_approx_eq!(ellip.xc[0], -0.03);
372        assert_approx_eq!(ellip.mq[(0, 0)], 0.8);
373        assert_approx_eq!(ellip.kappa, 0.008);
374        assert_approx_eq!(ellip.tsq, 0.01);
375    }
376
377    #[test]
378    fn test_update_parallel_central_cut() {
379        let mut ellip = Ell::new_with_scalar(0.01, Array1::zeros(4));
380        let cut = (0.5 * Array1::ones(4), (0.0, Some(0.05)));
381        let status = ellip.update_central_cut(&cut);
382        assert_eq!(status, CutStatus::Success);
383        assert_eq!(ellip.xc, -0.01 * Array1::ones(4));
384        assert_eq!(ellip.mq, Array2::eye(4) - 0.2 * Array2::ones((4, 4)));
385        assert_approx_eq!(ellip.kappa, 0.012);
386        assert_approx_eq!(ellip.tsq, 0.01);
387    }
388
389    #[test]
390    fn test_update_parallel() {
391        let mut ellip = Ell::new_with_scalar(0.01, Array1::zeros(4));
392        let cut = (0.5 * Array1::ones(4), (0.01, Some(0.04)));
393        let status = ellip.update_bias_cut(&cut);
394        assert_eq!(status, CutStatus::Success);
395        assert_approx_eq!(ellip.xc[0], -0.0116);
396        assert_approx_eq!(ellip.mq[(0, 0)], 1.0 - 0.232);
397        assert_approx_eq!(ellip.kappa, 0.01232);
398        assert_approx_eq!(ellip.tsq, 0.01);
399    }
400
401    #[test]
402    fn test_update_parallel_no_effect() {
403        let mut ellip = Ell::new_with_scalar(0.01, Array1::zeros(4));
404        let cut = (0.5 * Array1::ones(4), (-0.04, Some(0.0625)));
405        let status = ellip.update_bias_cut(&cut);
406        assert_eq!(status, CutStatus::Success);
407        assert_eq!(ellip.xc, Array1::zeros(4));
408        assert_eq!(ellip.mq, Array2::eye(4));
409        assert_approx_eq!(ellip.kappa, 0.01);
410    }
411
412    #[test]
413    fn test_update_q_no_effect() {
414        let mut ellip = Ell::new_with_scalar(0.01, Array1::zeros(4));
415        let cut = (0.5 * Array1::ones(4), (-0.04, Some(0.0625)));
416        let status = ellip.update_q(&cut);
417        assert_eq!(status, CutStatus::NoEffect);
418        assert_eq!(ellip.xc, Array1::zeros(4));
419        assert_eq!(ellip.mq, Array2::eye(4));
420        assert_approx_eq!(ellip.kappa, 0.01);
421    }
422
423    #[test]
424    fn test_update_q() {
425        let mut ellip = Ell::new_with_scalar(0.01, Array1::zeros(4));
426        let cut = (0.5 * Array1::ones(4), (0.01, Some(0.04)));
427        let status = ellip.update_q(&cut);
428        assert_eq!(status, CutStatus::Success);
429        assert_approx_eq!(ellip.xc[0], -0.0116);
430        assert_approx_eq!(ellip.mq[(0, 0)], 1.0 - 0.232);
431        assert_approx_eq!(ellip.kappa, 0.01232);
432        assert_approx_eq!(ellip.tsq, 0.01);
433    }
434
435    #[test]
436    fn test_update_central_cut_mq() {
437        let mut ellip = Ell::new_with_scalar(0.01, Array1::zeros(4));
438        let cut = (0.5 * Array1::ones(4), 0.0);
439        let _ = ellip.update_central_cut(&cut);
440        let mq_expected: Array2<f64> = Array2::eye(4) - 0.1 * Array2::ones((4, 4));
441        for i in 0..4 {
442            for j in 0..4 {
443                assert_approx_eq!(ellip.mq[[i, j]], mq_expected[[i, j]]);
444            }
445        }
446    }
447
448    #[test]
449    fn test_no_defer_trick() {
450        let mut ellip = Ell::new_with_scalar(0.01, Array1::zeros(4));
451        ellip.no_defer_trick = true;
452        let cut = (0.5 * Array1::ones(4), 0.0);
453        let _ = ellip.update_central_cut(&cut);
454        assert_approx_eq!(ellip.kappa, 1.0);
455        let mq_expected: Array2<f64> =
456            (Array2::eye(4) - 0.1 * Array2::ones((4, 4))) * (0.16 / 15.0);
457        for i in 0..4 {
458            for j in 0..4 {
459                assert_approx_eq!(ellip.mq[[i, j]], mq_expected[[i, j]]);
460            }
461        }
462    }
463}