1use crate::cutting_plane::{CutStatus, SearchSpace, SearchSpaceQ, UpdateByCutChoice};
3use crate::ell_calc::EllCalc;
4use ndarray::Array1;
8use ndarray::Array2;
9use ndarray::Axis;
10
11#[derive(Debug, Clone)]
33pub struct Ell {
34 pub no_defer_trick: bool,
35 pub mq: Array2<f64>,
36 pub xc: Array1<f64>,
37 pub kappa: f64,
38 helper: EllCalc,
39 pub tsq: f64,
40}
41
42impl Ell {
43 pub fn new_with_matrix(kappa: f64, mq: Array2<f64>, xc: Array1<f64>) -> Ell {
60 let helper = EllCalc::new(xc.len());
61
62 Ell {
63 kappa,
64 mq,
65 xc,
66 helper,
67 no_defer_trick: false,
68 tsq: 0.0,
69 }
70 }
71
72 pub fn new(val: Array1<f64>, xc: Array1<f64>) -> Ell {
98 Ell::new_with_matrix(1.0, Array2::from_diag(&val), xc)
99 }
100
101 pub fn new_with_scalar(val: f64, xc: Array1<f64>) -> Ell {
130 Ell::new_with_matrix(val, Array2::eye(xc.len()), xc)
131 }
132
133 fn update_core<T, F>(&mut self, grad: &Array1<f64>, beta: &T, cut_strategy: F) -> CutStatus
154 where
155 T: UpdateByCutChoice<Self, ArrayType = Array1<f64>>,
156 F: FnOnce(&T, f64) -> (CutStatus, (f64, f64, f64)),
157 {
158 let grad_t = self.mq.dot(grad);
159 let omega = grad.dot(&grad_t);
160
161 self.tsq = self.kappa * omega;
162 let (status, (rho, sigma, delta)) = cut_strategy(beta, self.tsq);
164 if status != CutStatus::Success {
165 return status;
166 }
167
168 self.xc -= &((rho / omega) * &grad_t); let r = sigma / omega;
172 let grad_t_view = grad_t.view();
173 self.mq.scaled_add(
174 -r,
175 &(&grad_t_view.insert_axis(Axis(1)) * &grad_t_view.insert_axis(Axis(0))),
176 );
177
178 self.kappa *= delta;
179
180 if self.no_defer_trick {
181 self.mq *= self.kappa;
182 self.kappa = 1.0;
183 }
184 status
185 }
186}
187
188impl SearchSpace for Ell {
190 type ArrayType = Array1<f64>;
191
192 #[inline]
194 fn xc(&self) -> Self::ArrayType {
195 self.xc.clone()
196 }
197
198 #[inline]
204 fn tsq(&self) -> f64 {
205 self.tsq
206 }
207
208 fn update_bias_cut<T>(&mut self, cut: &(Self::ArrayType, T)) -> CutStatus
218 where
219 T: UpdateByCutChoice<Self, ArrayType = Self::ArrayType>,
220 {
221 let (grad, beta) = cut;
222 beta.update_bias_cut_by(self, grad)
223 }
224
225 fn update_central_cut<T>(&mut self, cut: &(Self::ArrayType, T)) -> CutStatus
236 where
237 T: UpdateByCutChoice<Self, ArrayType = Self::ArrayType>,
238 {
239 let (grad, beta) = cut;
240 beta.update_central_cut_by(self, grad)
241 }
242
243 fn set_xc(&mut self, x: Self::ArrayType) {
244 self.xc = x;
245 }
246}
247
248impl SearchSpaceQ for Ell {
249 type ArrayType = Array1<f64>;
250
251 #[inline]
253 fn xc(&self) -> Self::ArrayType {
254 self.xc.clone()
255 }
256
257 #[inline]
263 fn tsq(&self) -> f64 {
264 self.tsq
265 }
266
267 fn update_q<T>(&mut self, cut: &(Self::ArrayType, T)) -> CutStatus
277 where
278 T: UpdateByCutChoice<Self, ArrayType = Self::ArrayType>,
279 {
280 let (grad, beta) = cut;
281 beta.update_q_by(self, grad)
282 }
283}
284
285trait CutType {
286 fn call_bias_cut(&self, helper: &EllCalc, tsq: f64) -> (CutStatus, (f64, f64, f64));
287 fn call_central_cut(&self, helper: &EllCalc, tsq: f64) -> (CutStatus, (f64, f64, f64));
288 fn call_q_cut(&self, helper: &EllCalc, tsq: f64) -> (CutStatus, (f64, f64, f64));
289}
290
291impl CutType for f64 {
292 fn call_bias_cut(&self, helper: &EllCalc, tsq: f64) -> (CutStatus, (f64, f64, f64)) {
293 helper.calc_bias_cut(*self, tsq)
294 }
295
296 fn call_central_cut(&self, helper: &EllCalc, tsq: f64) -> (CutStatus, (f64, f64, f64)) {
297 helper.calc_central_cut(tsq)
298 }
299
300 fn call_q_cut(&self, helper: &EllCalc, tsq: f64) -> (CutStatus, (f64, f64, f64)) {
301 helper.calc_bias_cut_q(*self, tsq)
302 }
303}
304
305impl CutType for (f64, Option<f64>) {
306 fn call_bias_cut(&self, helper: &EllCalc, tsq: f64) -> (CutStatus, (f64, f64, f64)) {
307 helper.calc_single_or_parallel_bias_cut(self, tsq)
308 }
309
310 fn call_central_cut(&self, helper: &EllCalc, tsq: f64) -> (CutStatus, (f64, f64, f64)) {
311 helper.calc_single_or_parallel_central_cut(self, tsq)
312 }
313
314 fn call_q_cut(&self, helper: &EllCalc, tsq: f64) -> (CutStatus, (f64, f64, f64)) {
315 helper.calc_single_or_parallel_q(self, tsq)
316 }
317}
318
319impl<T: CutType> UpdateByCutChoice<Ell> for T {
320 type ArrayType = Array1<f64>;
321
322 fn update_bias_cut_by(&self, ellip: &mut Ell, grad: &Self::ArrayType) -> CutStatus {
323 let helper = ellip.helper.clone();
324 ellip.update_core(grad, self, |beta, tsq| beta.call_bias_cut(&helper, tsq))
325 }
326
327 fn update_central_cut_by(&self, ellip: &mut Ell, grad: &Self::ArrayType) -> CutStatus {
328 let helper = ellip.helper.clone();
329 ellip.update_core(grad, self, |beta, tsq| beta.call_central_cut(&helper, tsq))
330 }
331
332 fn update_q_by(&self, ellip: &mut Ell, grad: &Self::ArrayType) -> CutStatus {
333 let helper = ellip.helper.clone();
334 ellip.update_core(grad, self, |beta, tsq| beta.call_q_cut(&helper, tsq))
335 }
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341 use approx_eq::assert_approx_eq;
342
343 #[test]
344 fn test_construct() {
345 let ellip = Ell::new_with_scalar(0.01, Array1::zeros(4));
346 assert!(!ellip.no_defer_trick);
347 assert_approx_eq!(ellip.kappa, 0.01);
348 assert_eq!(ellip.mq, Array2::eye(4));
349 assert_eq!(ellip.xc, Array1::zeros(4));
350 assert_approx_eq!(ellip.tsq, 0.0);
351 }
352
353 #[test]
354 fn test_update_central_cut() {
355 let mut ellip = Ell::new_with_scalar(0.01, Array1::zeros(4));
356 let cut = (0.5 * Array1::ones(4), 0.0);
357 let status = ellip.update_central_cut(&cut);
358 assert_eq!(status, CutStatus::Success);
359 assert_eq!(ellip.xc, -0.01 * Array1::ones(4));
360 assert_eq!(ellip.mq, Array2::eye(4) - 0.1 * Array2::ones((4, 4)));
361 assert_approx_eq!(ellip.kappa, 0.16 / 15.0);
362 assert_approx_eq!(ellip.tsq, 0.01);
363 }
364
365 #[test]
366 fn test_update_bias_cut() {
367 let mut ellip = Ell::new_with_scalar(0.01, Array1::zeros(4));
368 let cut = (0.5 * Array1::ones(4), 0.05);
369 let status = ellip.update_bias_cut(&cut);
370 assert_eq!(status, CutStatus::Success);
371 assert_approx_eq!(ellip.xc[0], -0.03);
372 assert_approx_eq!(ellip.mq[(0, 0)], 0.8);
373 assert_approx_eq!(ellip.kappa, 0.008);
374 assert_approx_eq!(ellip.tsq, 0.01);
375 }
376
377 #[test]
378 fn test_update_parallel_central_cut() {
379 let mut ellip = Ell::new_with_scalar(0.01, Array1::zeros(4));
380 let cut = (0.5 * Array1::ones(4), (0.0, Some(0.05)));
381 let status = ellip.update_central_cut(&cut);
382 assert_eq!(status, CutStatus::Success);
383 assert_eq!(ellip.xc, -0.01 * Array1::ones(4));
384 assert_eq!(ellip.mq, Array2::eye(4) - 0.2 * Array2::ones((4, 4)));
385 assert_approx_eq!(ellip.kappa, 0.012);
386 assert_approx_eq!(ellip.tsq, 0.01);
387 }
388
389 #[test]
390 fn test_update_parallel() {
391 let mut ellip = Ell::new_with_scalar(0.01, Array1::zeros(4));
392 let cut = (0.5 * Array1::ones(4), (0.01, Some(0.04)));
393 let status = ellip.update_bias_cut(&cut);
394 assert_eq!(status, CutStatus::Success);
395 assert_approx_eq!(ellip.xc[0], -0.0116);
396 assert_approx_eq!(ellip.mq[(0, 0)], 1.0 - 0.232);
397 assert_approx_eq!(ellip.kappa, 0.01232);
398 assert_approx_eq!(ellip.tsq, 0.01);
399 }
400
401 #[test]
402 fn test_update_parallel_no_effect() {
403 let mut ellip = Ell::new_with_scalar(0.01, Array1::zeros(4));
404 let cut = (0.5 * Array1::ones(4), (-0.04, Some(0.0625)));
405 let status = ellip.update_bias_cut(&cut);
406 assert_eq!(status, CutStatus::Success);
407 assert_eq!(ellip.xc, Array1::zeros(4));
408 assert_eq!(ellip.mq, Array2::eye(4));
409 assert_approx_eq!(ellip.kappa, 0.01);
410 }
411
412 #[test]
413 fn test_update_q_no_effect() {
414 let mut ellip = Ell::new_with_scalar(0.01, Array1::zeros(4));
415 let cut = (0.5 * Array1::ones(4), (-0.04, Some(0.0625)));
416 let status = ellip.update_q(&cut);
417 assert_eq!(status, CutStatus::NoEffect);
418 assert_eq!(ellip.xc, Array1::zeros(4));
419 assert_eq!(ellip.mq, Array2::eye(4));
420 assert_approx_eq!(ellip.kappa, 0.01);
421 }
422
423 #[test]
424 fn test_update_q() {
425 let mut ellip = Ell::new_with_scalar(0.01, Array1::zeros(4));
426 let cut = (0.5 * Array1::ones(4), (0.01, Some(0.04)));
427 let status = ellip.update_q(&cut);
428 assert_eq!(status, CutStatus::Success);
429 assert_approx_eq!(ellip.xc[0], -0.0116);
430 assert_approx_eq!(ellip.mq[(0, 0)], 1.0 - 0.232);
431 assert_approx_eq!(ellip.kappa, 0.01232);
432 assert_approx_eq!(ellip.tsq, 0.01);
433 }
434
435 #[test]
436 fn test_update_central_cut_mq() {
437 let mut ellip = Ell::new_with_scalar(0.01, Array1::zeros(4));
438 let cut = (0.5 * Array1::ones(4), 0.0);
439 let _ = ellip.update_central_cut(&cut);
440 let mq_expected: Array2<f64> = Array2::eye(4) - 0.1 * Array2::ones((4, 4));
441 for i in 0..4 {
442 for j in 0..4 {
443 assert_approx_eq!(ellip.mq[[i, j]], mq_expected[[i, j]]);
444 }
445 }
446 }
447
448 #[test]
449 fn test_no_defer_trick() {
450 let mut ellip = Ell::new_with_scalar(0.01, Array1::zeros(4));
451 ellip.no_defer_trick = true;
452 let cut = (0.5 * Array1::ones(4), 0.0);
453 let _ = ellip.update_central_cut(&cut);
454 assert_approx_eq!(ellip.kappa, 1.0);
455 let mq_expected: Array2<f64> =
456 (Array2::eye(4) - 0.1 * Array2::ones((4, 4))) * (0.16 / 15.0);
457 for i in 0..4 {
458 for j in 0..4 {
459 assert_approx_eq!(ellip.mq[[i, j]], mq_expected[[i, j]]);
460 }
461 }
462 }
463}