proximal_optimize/
lib.rs

1//! A port of the proximal gradient method from `https://github.com/pmelchior/proxmin`.
2//!
3#![allow(unused)]
4#![allow(non_snake_case)]
5#![allow(non_upper_case_globals)]
6
7mod examples;
8pub(crate) mod misc;
9pub mod utils;
10
11use crate::{
12  misc::*,
13  utils::NesterovStepper,
14};
15
16#[derive(Copy, Clone, Debug)]
17pub enum ProximalOptimizerErr {
18  /// Caused when the length of parameter vectors does not match the number
19  /// of parameters specified when the optimizer was created.
20  ParameterLengthMismatch,
21  /// The objective function's value at the start position could not be compared
22  /// with itself (perhaps a `NaN` condition?)
23  StartUnorderable,
24  /// The candidate solution is no better than the starting position, and
25  /// may actually be worse.
26  SolutionNoBetter,
27}
28
29/// Proximal Gradient Method (PGM), ported
30/// from `https://github.com/pmelchior/proxmin`
31///
32///  Proximal Gradient Method
33///
34///  Adapted from Combettes 2009, Algorithm 3.4.
35///  The accelerated version is Algorithm 3.6 with modifications
36///  from Xu & Yin (2015).
37///
38///  Args:
39///  - start_x: starting position
40///  - prox_f: proxed function f (the forward-backward step)
41///  - step_f: step size, < 1/L with L being the Lipschitz constant of grad f
42///  - accelerated: If Nesterov acceleration should be used
43///  - relax: (over)relaxation parameter, 0 < relax < 1.5
44///  - e_rel: relative error of X
45///  - max_iter: maximum iteration, irrespective of residual error
46///  - traceback: utils.Traceback to hold variable histories
47///
48///  Returns: A 3-tuple containing:
49///  - The optimized value for X
50///  - converged: whether the optimizer has converged within e_rel
51///  - error: X^it - X^it-1
52fn pgm<F>(start_x: &[f64],
53          prox_f: F,
54          step_f: &[f64],
55          accelerated: bool,
56          relax: Option<f64>,
57          e_rel: f64,
58          max_iter: usize)
59          -> Result<(Vec<f64>, bool, Vec<f64>), ProximalOptimizerErr>
60  where F: Fn(&[f64], &[f64]) -> Vec<f64>
61{
62  let mut stepper = NesterovStepper::new(accelerated);
63
64  if let Some(relax_val) = relax {
65    assert!(relax_val > 0.0);
66    assert!(relax_val < 1.5);
67  }
68
69  let mut X = Vec::from(&start_x[..]);
70  let mut X_ = vec![0.0; start_x.len()];
71
72  let mut it: usize = 0;
73  let mut converged: bool = false;
74  while it < max_iter {
75    let _X;
76
77    // use Nesterov acceleration (if omega > 0), automatically incremented
78    let omega = stepper.omega();
79    log::trace!("Omega: {}", &omega);
80    if omega > 0.0 {
81      // In Python: _X = X + omega*(X - X_)
82      let tmp1 = vec_sub(&X[..], &X_[..])?;
83      let tmp2 = vec_mul_scalar(&tmp1[..], omega);
84      _X = vec_add(&X[..], &tmp2[..])?;
85    } else {
86      _X = X.clone();
87    }
88
89    log::trace!("_X: {:?}", &_X);
90
91    X_ = X.clone();
92
93    X = prox_f(&_X[..], step_f);
94
95    log::trace!("X: {:?}", &X);
96
97    if let Some(relax_val) = relax {
98      // In Python: X += (relax-1)*(X - X_)
99      let tmp1 = relax_val - 1.0;
100      let tmp2 = vec_sub(&X[..], &X_[..])?;
101      let tmp3 = vec_mul_scalar(&tmp2[..], tmp1);
102    }
103
104    // test for fixed point convergence
105    // In Python: converged = utils.l2sq(X - X_) <= e_rel**2*utils.l2sq(X)
106    let tmp1 = vec_sub(&X[..], &X_[..])?;
107    let left = utils::l2sq(&tmp1[..]);
108    let right = utils::l2sq(&X[..]) * e_rel * e_rel;
109    converged = left <= right;
110    if converged {
111      break;
112    }
113    it += 1;
114  }
115
116  log::info!("Completed {} iterations", it + 1);
117  if !converged {
118    log::warn!("Solution did not converge");
119  }
120
121  let error = vec_sub(&X[..], &X_[..])?;
122
123  return Ok((X, converged, error));
124}