proximal_optimize/lib.rs
1//! A port of the proximal gradient method from `https://github.com/pmelchior/proxmin`.
2//!
3#![allow(unused)]
4#![allow(non_snake_case)]
5#![allow(non_upper_case_globals)]
6
7mod examples;
8pub(crate) mod misc;
9pub mod utils;
10
11use crate::{
12 misc::*,
13 utils::NesterovStepper,
14};
15
16#[derive(Copy, Clone, Debug)]
17pub enum ProximalOptimizerErr {
18 /// Caused when the length of parameter vectors does not match the number
19 /// of parameters specified when the optimizer was created.
20 ParameterLengthMismatch,
21 /// The objective function's value at the start position could not be compared
22 /// with itself (perhaps a `NaN` condition?)
23 StartUnorderable,
24 /// The candidate solution is no better than the starting position, and
25 /// may actually be worse.
26 SolutionNoBetter,
27}
28
29/// Proximal Gradient Method (PGM), ported
30/// from `https://github.com/pmelchior/proxmin`
31///
32/// Proximal Gradient Method
33///
34/// Adapted from Combettes 2009, Algorithm 3.4.
35/// The accelerated version is Algorithm 3.6 with modifications
36/// from Xu & Yin (2015).
37///
38/// Args:
39/// - start_x: starting position
40/// - prox_f: proxed function f (the forward-backward step)
41/// - step_f: step size, < 1/L with L being the Lipschitz constant of grad f
42/// - accelerated: If Nesterov acceleration should be used
43/// - relax: (over)relaxation parameter, 0 < relax < 1.5
44/// - e_rel: relative error of X
45/// - max_iter: maximum iteration, irrespective of residual error
46/// - traceback: utils.Traceback to hold variable histories
47///
48/// Returns: A 3-tuple containing:
49/// - The optimized value for X
50/// - converged: whether the optimizer has converged within e_rel
51/// - error: X^it - X^it-1
52fn pgm<F>(start_x: &[f64],
53 prox_f: F,
54 step_f: &[f64],
55 accelerated: bool,
56 relax: Option<f64>,
57 e_rel: f64,
58 max_iter: usize)
59 -> Result<(Vec<f64>, bool, Vec<f64>), ProximalOptimizerErr>
60 where F: Fn(&[f64], &[f64]) -> Vec<f64>
61{
62 let mut stepper = NesterovStepper::new(accelerated);
63
64 if let Some(relax_val) = relax {
65 assert!(relax_val > 0.0);
66 assert!(relax_val < 1.5);
67 }
68
69 let mut X = Vec::from(&start_x[..]);
70 let mut X_ = vec![0.0; start_x.len()];
71
72 let mut it: usize = 0;
73 let mut converged: bool = false;
74 while it < max_iter {
75 let _X;
76
77 // use Nesterov acceleration (if omega > 0), automatically incremented
78 let omega = stepper.omega();
79 log::trace!("Omega: {}", &omega);
80 if omega > 0.0 {
81 // In Python: _X = X + omega*(X - X_)
82 let tmp1 = vec_sub(&X[..], &X_[..])?;
83 let tmp2 = vec_mul_scalar(&tmp1[..], omega);
84 _X = vec_add(&X[..], &tmp2[..])?;
85 } else {
86 _X = X.clone();
87 }
88
89 log::trace!("_X: {:?}", &_X);
90
91 X_ = X.clone();
92
93 X = prox_f(&_X[..], step_f);
94
95 log::trace!("X: {:?}", &X);
96
97 if let Some(relax_val) = relax {
98 // In Python: X += (relax-1)*(X - X_)
99 let tmp1 = relax_val - 1.0;
100 let tmp2 = vec_sub(&X[..], &X_[..])?;
101 let tmp3 = vec_mul_scalar(&tmp2[..], tmp1);
102 }
103
104 // test for fixed point convergence
105 // In Python: converged = utils.l2sq(X - X_) <= e_rel**2*utils.l2sq(X)
106 let tmp1 = vec_sub(&X[..], &X_[..])?;
107 let left = utils::l2sq(&tmp1[..]);
108 let right = utils::l2sq(&X[..]) * e_rel * e_rel;
109 converged = left <= right;
110 if converged {
111 break;
112 }
113 it += 1;
114 }
115
116 log::info!("Completed {} iterations", it + 1);
117 if !converged {
118 log::warn!("Solution did not converge");
119 }
120
121 let error = vec_sub(&X[..], &X_[..])?;
122
123 return Ok((X, converged, error));
124}