ellalgo_rs/ell.rs
1// mod lib;
2use crate::cutting_plane::{CutStatus, SearchSpace, SearchSpaceQ, UpdateByCutChoice};
3use crate::ell_calc::EllCalc;
4// #[macro_use]
5// extern crate ndarray;
6// use ndarray::prelude::*;
7use ndarray::Array1;
8use ndarray::Array2;
9
10/// The code defines a struct called "Ell" that represents an ellipsoid search space in the Ellipsoid
11/// method.
12///
13/// Ell = {x | (x - xc)^T mq^-1 (x - xc) \le \kappa}
14///
15/// Properties:
16///
17/// * `no_defer_trick`: A boolean flag indicating whether the defer trick should be used. The defer
18/// trick is a technique used in the Ellipsoid method to improve efficiency by deferring the update of
19/// the ellipsoid until a certain condition is met.
20/// * `mq`: A matrix representing the shape of the ellipsoid. It is a 2-dimensional array of f64 values.
21/// * `xc`: The `xc` property represents the center of the ellipsoid search space. It is a 1-dimensional
22/// array of floating-point numbers.
23/// * `kappa`: A scalar value that determines the size of the ellipsoid. A larger value of kappa results
24/// in a larger ellipsoid.
25/// * `ndim`: The `ndim` property represents the number of dimensions of the ellipsoid search space.
26/// * `helper`: The `helper` property is an instance of the `EllCalc` struct, which is used to perform
27/// calculations related to the ellipsoid search space. It provides methods for calculating the distance
28/// constant (`dc`), the center constant (`cc`), and the quadratic constant (`q`) used in the ell
29/// * `tsq`: The `tsq` property represents the squared Mahalanobis distance threshold of the ellipsoid.
30/// It is used to determine whether a point is inside or outside the ellipsoid.
31#[derive(Debug, Clone)]
32pub struct Ell {
33 pub no_defer_trick: bool,
34 pub mq: Array2<f64>,
35 pub xc: Array1<f64>,
36 pub kappa: f64,
37 ndim: usize,
38 helper: EllCalc,
39 pub tsq: f64,
40}
41
42impl Ell {
43 /// The function `new_with_matrix` constructs a new `Ell` object with the given parameters.
44 ///
45 /// Arguments:
46 ///
47 /// * `kappa`: The `kappa` parameter is a floating-point number that represents the curvature of the
48 /// ellipse. It determines the shape of the ellipse, with higher values resulting in a more elongated
49 /// shape and lower values resulting in a more circular shape.
50 /// * `mq`: The `mq` parameter is of type `Array2<f64>`, which represents a 2-dimensional array of `f64`
51 /// (floating-point) values. It is used to store the matrix `mq` in the `Ell` object.
52 /// * `xc`: The parameter `xc` represents the center of the ellipsoid in n-dimensional space. It is an
53 /// array of length `ndim`, where each element represents the coordinate of the center along a specific
54 /// dimension.
55 ///
56 /// Returns:
57 ///
58 /// an instance of the `Ell` struct.
59 pub fn new_with_matrix(kappa: f64, mq: Array2<f64>, xc: Array1<f64>) -> Ell {
60 let ndim = xc.len();
61 let helper = EllCalc::new(ndim);
62
63 Ell {
64 kappa,
65 mq,
66 xc,
67 ndim,
68 helper,
69 no_defer_trick: false,
70 tsq: 0.0,
71 }
72 }
73
74 /// Creates a new [`Ell`].
75 ///
76 /// The function `new` creates a new `Ell` object with the given values.
77 ///
78 /// Arguments:
79 ///
80 /// * `val`: An array of f64 values representing the diagonal elements of a matrix.
81 /// * `xc`: `xc` is an `Array1<f64>` which represents the center of the ellipse. It contains the x and y
82 /// coordinates of the center point.
83 ///
84 /// Returns:
85 ///
86 /// The function `new` returns an instance of the [`Ell`] struct.
87 ///
88 /// # Examples
89 ///
90 /// ```
91 /// use ellalgo_rs::ell::Ell;
92 /// use ndarray::arr1;
93 /// let val = arr1(&[1.0, 1.0]);
94 /// let xc = arr1(&[0.0, 0.0]);
95 /// let ellip = Ell::new(val, xc);
96 /// assert_eq!(ellip.kappa, 1.0);
97 /// assert_eq!(ellip.mq.shape(), &[2, 2]);
98 /// ```
99 pub fn new(val: Array1<f64>, xc: Array1<f64>) -> Ell {
100 Ell::new_with_matrix(1.0, Array2::from_diag(&val), xc)
101 }
102
103 /// The function `new_with_scalar` constructs a new [`Ell`] object with a scalar value and a vector.
104 ///
105 /// Arguments:
106 ///
107 /// * `val`: The `val` parameter is a scalar value of type `f64`. It represents the value of the scalar
108 /// component of the `Ell` object.
109 /// * `xc`: The parameter `xc` is an array of type `Array1<f64>`. It represents the center coordinates
110 /// of the ellipse.
111 ///
112 /// Returns:
113 ///
114 /// an instance of the [`Ell`] struct.
115 ///
116 /// # Examples
117 ///
118 /// ```
119 /// use ellalgo_rs::ell::Ell;
120 /// use ndarray::arr1;
121 /// let val = 1.0;
122 /// let xc = arr1(&[0.0, 0.0]);
123 /// let ellip = Ell::new_with_scalar(val, xc);
124 /// assert_eq!(ellip.kappa, 1.0);
125 /// assert_eq!(ellip.mq.shape(), &[2, 2]);
126 /// assert_eq!(ellip.xc.shape(), &[2]);
127 /// assert_eq!(ellip.xc[0], 0.0);
128 /// assert_eq!(ellip.xc[1], 0.0);
129 /// assert_eq!(ellip.tsq, 0.0);
130 /// ```
131 pub fn new_with_scalar(val: f64, xc: Array1<f64>) -> Ell {
132 Ell::new_with_matrix(val, Array2::eye(xc.len()), xc)
133 }
134
135 /// Update ellipsoid core function using the cut
136 ///
137 /// $grad^T * (x - xc) + beta <= 0$
138 ///
139 /// The `update_core` function in Rust updates the ellipsoid core based on a given gradient and beta
140 /// value using a cut strategy.
141 ///
142 /// Arguments:
143 ///
144 /// * `grad`: A reference to an Array1<f64> representing the gradient vector.
145 /// * `beta`: The `beta` parameter is a value that is used in the inequality constraint of the ellipsoid
146 /// core function. It represents the threshold for the constraint, and the function checks if the dot
147 /// product of the gradient and the difference between `x` and `xc` plus `beta` is less than or
148 /// * `cut_strategy`: The `cut_strategy` parameter is a closure that takes two arguments: `beta` and
149 /// `tsq`. It returns a tuple containing a `CutStatus` and a tuple `(rho, sigma, delta)`. The
150 /// `cut_strategy` function is used to determine the values of `rho`, `
151 ///
152 /// Returns:
153 ///
154 /// a value of type `CutStatus`.
155 fn update_core<T, F>(&mut self, grad: &Array1<f64>, beta: &T, cut_strategy: F) -> CutStatus
156 where
157 T: UpdateByCutChoice<Self, ArrayType = Array1<f64>>,
158 F: FnOnce(&T, f64) -> (CutStatus, (f64, f64, f64)),
159 {
160 let grad_t = self.mq.dot(grad);
161 let omega = grad.dot(&grad_t);
162
163 self.tsq = self.kappa * omega;
164 // let status = self.helper.calc_bias_cut(*beta);
165 let (status, (rho, sigma, delta)) = cut_strategy(beta, self.tsq);
166 if status != CutStatus::Success {
167 return status;
168 }
169
170 self.xc -= &((rho / omega) * &grad_t); // n
171
172 // n*(n+1)/2 + n
173 // self.mq -= (self.sigma / omega) * xt::linalg::outer(grad_t, grad_t);
174 let r = sigma / omega;
175 for i in 0..self.ndim {
176 let r_grad_t = r * grad_t[i];
177 for j in 0..i {
178 self.mq[[i, j]] -= r_grad_t * grad_t[j];
179 self.mq[[j, i]] = self.mq[[i, j]];
180 }
181 self.mq[[i, i]] -= r_grad_t * grad_t[i];
182 }
183
184 self.kappa *= delta;
185
186 if self.no_defer_trick {
187 self.mq *= self.kappa;
188 self.kappa = 1.0;
189 }
190 status
191 }
192}
193
194/// The `impl SearchSpace for Ell` block is implementing the `SearchSpace` trait for the `Ell` struct.
195impl SearchSpace for Ell {
196 type ArrayType = Array1<f64>;
197
198 /// The function `xc` returns a copy of the `xc` array.
199 #[inline]
200 fn xc(&self) -> Self::ArrayType {
201 self.xc.clone()
202 }
203
204 /// The `tsq` function returns the value of the `tsq` field of the struct.
205 ///
206 /// Returns:
207 ///
208 /// The method `tsq` is returning a value of type `f64`.
209 #[inline]
210 fn tsq(&self) -> f64 {
211 self.tsq
212 }
213
214 /// The `update_bias_cut` function updates the decision variable based on the given cut.
215 ///
216 /// Arguments:
217 ///
218 /// * `cut`: A tuple containing two elements:
219 ///
220 /// Returns:
221 ///
222 /// The `update_bias_cut` function returns a value of type `CutStatus`.
223 fn update_bias_cut<T>(&mut self, cut: &(Self::ArrayType, T)) -> CutStatus
224 where
225 T: UpdateByCutChoice<Self, ArrayType = Self::ArrayType>,
226 {
227 let (grad, beta) = cut;
228 beta.update_bias_cut_by(self, grad)
229 }
230
231 /// The `update_central_cut` function updates the cut choices using the gradient and beta values.
232 ///
233 /// Arguments:
234 ///
235 /// * `cut`: The `cut` parameter is a tuple containing two elements. The first element is of type
236 /// `Self::ArrayType`, and the second element is of type `T`.
237 ///
238 /// Returns:
239 ///
240 /// The function `update_central_cut` returns a value of type `CutStatus`.
241 fn update_central_cut<T>(&mut self, cut: &(Self::ArrayType, T)) -> CutStatus
242 where
243 T: UpdateByCutChoice<Self, ArrayType = Self::ArrayType>,
244 {
245 let (grad, beta) = cut;
246 beta.update_central_cut_by(self, grad)
247 }
248
249 fn set_xc(&mut self, x: Self::ArrayType) {
250 self.xc = x;
251 }
252}
253
254impl SearchSpaceQ for Ell {
255 type ArrayType = Array1<f64>;
256
257 /// The function `xc` returns a copy of the `xc` array.
258 #[inline]
259 fn xc(&self) -> Self::ArrayType {
260 self.xc.clone()
261 }
262
263 /// The `tsq` function returns the value of the `tsq` field of the struct.
264 ///
265 /// Returns:
266 ///
267 /// The method `tsq` is returning a value of type `f64`.
268 #[inline]
269 fn tsq(&self) -> f64 {
270 self.tsq
271 }
272
273 /// The `update_q` function updates the decision variable based on the given cut.
274 ///
275 /// Arguments:
276 ///
277 /// * `cut`: A tuple containing two elements:
278 ///
279 /// Returns:
280 ///
281 /// The `update_bias_cut` function returns a value of type `CutStatus`.
282 fn update_q<T>(&mut self, cut: &(Self::ArrayType, T)) -> CutStatus
283 where
284 T: UpdateByCutChoice<Self, ArrayType = Self::ArrayType>,
285 {
286 let (grad, beta) = cut;
287 beta.update_q_by(self, grad)
288 }
289}
290
291impl UpdateByCutChoice<Ell> for f64 {
292 type ArrayType = Array1<f64>;
293
294 fn update_bias_cut_by(&self, ellip: &mut Ell, grad: &Self::ArrayType) -> CutStatus {
295 let beta = self;
296 let helper = ellip.helper.clone();
297 ellip.update_core(grad, beta, |beta, tsq| helper.calc_bias_cut(*beta, tsq))
298 }
299
300 fn update_central_cut_by(&self, ellip: &mut Ell, grad: &Self::ArrayType) -> CutStatus {
301 let beta = self;
302 let helper = ellip.helper.clone();
303 ellip.update_core(grad, beta, |_beta, tsq| helper.calc_central_cut(tsq))
304 }
305
306 fn update_q_by(&self, ellip: &mut Ell, grad: &Self::ArrayType) -> CutStatus {
307 let beta = self;
308 let helper = ellip.helper.clone();
309 ellip.update_core(grad, beta, |beta, tsq| helper.calc_bias_cut_q(*beta, tsq))
310 }
311}
312
313impl UpdateByCutChoice<Ell> for (f64, Option<f64>) {
314 type ArrayType = Array1<f64>;
315
316 fn update_bias_cut_by(&self, ellip: &mut Ell, grad: &Self::ArrayType) -> CutStatus {
317 let beta = self;
318 let helper = ellip.helper.clone();
319 ellip.update_core(grad, beta, |beta, tsq| {
320 helper.calc_single_or_parallel_bias_cut(beta, tsq)
321 })
322 }
323
324 fn update_central_cut_by(&self, ellip: &mut Ell, grad: &Self::ArrayType) -> CutStatus {
325 let beta = self;
326 let helper = ellip.helper.clone();
327 ellip.update_core(grad, beta, |beta, tsq| {
328 helper.calc_single_or_parallel_central_cut(beta, tsq)
329 })
330 }
331
332 fn update_q_by(&self, ellip: &mut Ell, grad: &Self::ArrayType) -> CutStatus {
333 let beta = self;
334 let helper = ellip.helper.clone();
335 ellip.update_core(grad, beta, |beta, tsq| {
336 helper.calc_single_or_parallel_q(beta, tsq)
337 })
338 }
339}
340
341#[cfg(test)]
342mod tests {
343 use super::*;
344 use approx_eq::assert_approx_eq;
345
346 #[test]
347 fn test_construct() {
348 let ellip = Ell::new_with_scalar(0.01, Array1::zeros(4));
349 assert!(!ellip.no_defer_trick);
350 assert_approx_eq!(ellip.kappa, 0.01);
351 assert_eq!(ellip.mq, Array2::eye(4));
352 assert_eq!(ellip.xc, Array1::zeros(4));
353 assert_approx_eq!(ellip.tsq, 0.0);
354 }
355
356 #[test]
357 fn test_update_central_cut() {
358 let mut ellip = Ell::new_with_scalar(0.01, Array1::zeros(4));
359 let cut = (0.5 * Array1::ones(4), 0.0);
360 let status = ellip.update_central_cut(&cut);
361 assert_eq!(status, CutStatus::Success);
362 assert_eq!(ellip.xc, -0.01 * Array1::ones(4));
363 assert_eq!(ellip.mq, Array2::eye(4) - 0.1 * Array2::ones((4, 4)));
364 assert_approx_eq!(ellip.kappa, 0.16 / 15.0);
365 assert_approx_eq!(ellip.tsq, 0.01);
366 }
367
368 #[test]
369 fn test_update_bias_cut() {
370 let mut ellip = Ell::new_with_scalar(0.01, Array1::zeros(4));
371 let cut = (0.5 * Array1::ones(4), 0.05);
372 let status = ellip.update_bias_cut(&cut);
373 assert_eq!(status, CutStatus::Success);
374 assert_approx_eq!(ellip.xc[0], -0.03);
375 assert_approx_eq!(ellip.mq[(0, 0)], 0.8);
376 assert_approx_eq!(ellip.kappa, 0.008);
377 assert_approx_eq!(ellip.tsq, 0.01);
378 }
379
380 #[test]
381 fn test_update_parallel_central_cut() {
382 let mut ellip = Ell::new_with_scalar(0.01, Array1::zeros(4));
383 let cut = (0.5 * Array1::ones(4), (0.0, Some(0.05)));
384 let status = ellip.update_central_cut(&cut);
385 assert_eq!(status, CutStatus::Success);
386 assert_eq!(ellip.xc, -0.01 * Array1::ones(4));
387 assert_eq!(ellip.mq, Array2::eye(4) - 0.2 * Array2::ones((4, 4)));
388 assert_approx_eq!(ellip.kappa, 0.012);
389 assert_approx_eq!(ellip.tsq, 0.01);
390 }
391
392 #[test]
393 fn test_update_parallel() {
394 let mut ellip = Ell::new_with_scalar(0.01, Array1::zeros(4));
395 let cut = (0.5 * Array1::ones(4), (0.01, Some(0.04)));
396 let status = ellip.update_bias_cut(&cut);
397 assert_eq!(status, CutStatus::Success);
398 assert_approx_eq!(ellip.xc[0], -0.0116);
399 assert_approx_eq!(ellip.mq[(0, 0)], 1.0 - 0.232);
400 assert_approx_eq!(ellip.kappa, 0.01232);
401 assert_approx_eq!(ellip.tsq, 0.01);
402 }
403
404 #[test]
405 fn test_update_parallel_no_effect() {
406 let mut ellip = Ell::new_with_scalar(0.01, Array1::zeros(4));
407 let cut = (0.5 * Array1::ones(4), (-0.04, Some(0.0625)));
408 let status = ellip.update_bias_cut(&cut);
409 assert_eq!(status, CutStatus::Success);
410 assert_eq!(ellip.xc, Array1::zeros(4));
411 assert_eq!(ellip.mq, Array2::eye(4));
412 assert_approx_eq!(ellip.kappa, 0.01);
413 }
414
415 #[test]
416 fn test_update_q_no_effect() {
417 let mut ellip = Ell::new_with_scalar(0.01, Array1::zeros(4));
418 let cut = (0.5 * Array1::ones(4), (-0.04, Some(0.0625)));
419 let status = ellip.update_q(&cut);
420 assert_eq!(status, CutStatus::NoEffect);
421 assert_eq!(ellip.xc, Array1::zeros(4));
422 assert_eq!(ellip.mq, Array2::eye(4));
423 assert_approx_eq!(ellip.kappa, 0.01);
424 }
425}