1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
use errors::*;
use ndarray::Array1;
#[derive(Debug)]
pub enum TerminationReason {
MaxNumberIterations,
Converged,
Unkown,
}
pub struct BacktrackingLineSearch<'a> {
cost_function: &'a Fn(&Array1<f64>) -> f64,
gradient: &'a Fn(&Array1<f64>) -> Array1<f64>,
alpha: f64,
max_iters: u64,
tau: f64,
c: f64,
}
impl<'a> BacktrackingLineSearch<'a> {
pub fn new(
cost_function: &'a Fn(&Array1<f64>) -> f64,
gradient: &'a Fn(&Array1<f64>) -> Array1<f64>,
) -> Self {
BacktrackingLineSearch {
cost_function: cost_function,
gradient: gradient,
alpha: 1.0,
max_iters: 100,
tau: 0.5,
c: 0.5,
}
}
pub fn alpha(&mut self, alpha: f64) -> &mut Self {
self.alpha = alpha;
self
}
pub fn max_iters(&mut self, max_iters: u64) -> &mut Self {
self.max_iters = max_iters;
self
}
pub fn c(&mut self, c: f64) -> Result<&mut Self> {
if c >= 1.0 || c <= 0.0 {
return Err(ErrorKind::InvalidParameter(
"BacktrackingLineSearch: Parameter `c` must satisfy 0 < c < 1.".into(),
).into());
}
self.c = c;
Ok(self)
}
pub fn tau(&mut self, tau: f64) -> Result<&mut Self> {
if tau >= 1.0 || tau <= 0.0 {
return Err(ErrorKind::InvalidParameter(
"BacktrackingLineSearch: Parameter `tau` must satisfy 0 < tau < 1.".into(),
).into());
}
self.tau = tau;
Ok(self)
}
pub fn run(&self, p: &Array1<f64>, x: &Array1<f64>) -> Result<(f64, u64, TerminationReason)> {
let m: f64 = p.t().dot(&((self.gradient)(x)));
let t = -self.c * m;
let fx = (self.cost_function)(x);
let termination_reason;
let mut idx = 0;
let mut alpha = self.alpha;
loop {
let param = x + &(alpha * p);
if fx - (self.cost_function)(¶m) >= alpha * t {
termination_reason = TerminationReason::Converged;
break;
}
if idx > self.max_iters {
termination_reason = TerminationReason::MaxNumberIterations;
break;
}
idx += 1;
alpha *= self.tau;
}
Ok((alpha, idx, termination_reason))
}
}