1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#![allow(unused)]
#![allow(non_snake_case)]
#![allow(non_upper_case_globals)]
mod examples;
pub(crate) mod misc;
pub mod utils;
use crate::{
misc::*,
utils::NesterovStepper,
};
#[derive(Copy, Clone, Debug)]
pub enum ProximalOptimizerErr {
ParameterLengthMismatch,
StartUnorderable,
SolutionNoBetter,
}
fn pgm<F>(start_x: &[f64],
prox_f: F,
step_f: &[f64],
accelerated: bool,
relax: Option<f64>,
e_rel: f64,
max_iter: usize)
-> Result<(Vec<f64>, bool, Vec<f64>), ProximalOptimizerErr>
where F: Fn(&[f64], &[f64]) -> Vec<f64>
{
let mut stepper = NesterovStepper::new(accelerated);
if let Some(relax_val) = relax {
assert!(relax_val > 0.0);
assert!(relax_val < 1.5);
}
let mut X = Vec::from(&start_x[..]);
let mut X_ = vec![0.0; start_x.len()];
let mut it: usize = 0;
let mut converged: bool = false;
while it < max_iter {
let _X;
let omega = stepper.omega();
log::trace!("Omega: {}", &omega);
if omega > 0.0 {
let tmp1 = vec_sub(&X[..], &X_[..])?;
let tmp2 = vec_mul_scalar(&tmp1[..], omega);
_X = vec_add(&X[..], &tmp2[..])?;
} else {
_X = X.clone();
}
log::trace!("_X: {:?}", &_X);
X_ = X.clone();
X = prox_f(&_X[..], step_f);
log::trace!("X: {:?}", &X);
if let Some(relax_val) = relax {
let tmp1 = relax_val - 1.0;
let tmp2 = vec_sub(&X[..], &X_[..])?;
let tmp3 = vec_mul_scalar(&tmp2[..], tmp1);
}
let tmp1 = vec_sub(&X[..], &X_[..])?;
let left = utils::l2sq(&tmp1[..]);
let right = utils::l2sq(&X[..]) * e_rel * e_rel;
converged = left <= right;
if converged {
break;
}
it += 1;
}
log::info!("Completed {} iterations", it + 1);
if !converged {
log::warn!("Solution did not converge");
}
let error = vec_sub(&X[..], &X_[..])?;
return Ok((X, converged, error));
}