1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
use std;
use ndarray::{Array1, Array2};
use ndarray_linalg::Inverse;
use errors::*;
use prelude::*;
use problem::ArgminProblem;
use result::ArgminResult;
use termination::TerminationReason;
pub struct Newton<'a> {
gamma: f64,
max_iters: u64,
state: Option<NewtonState<'a>>,
}
struct NewtonState<'a> {
problem: &'a ArgminProblem<'a, Array1<f64>, f64, Array2<f64>>,
param: Array1<f64>,
iter: u64,
}
impl<'a> NewtonState<'a> {
pub fn new(
problem: &'a ArgminProblem<'a, Array1<f64>, f64, Array2<f64>>,
param: Array1<f64>,
) -> Self {
NewtonState {
problem: problem,
param: param,
iter: 0_u64,
}
}
}
impl<'a> Newton<'a> {
pub fn new() -> Self {
Newton {
gamma: 1.0,
max_iters: std::u64::MAX,
state: None,
}
}
pub fn max_iters(&mut self, max_iters: u64) -> &mut Self {
self.max_iters = max_iters;
self
}
}
impl<'a> ArgminSolver<'a> for Newton<'a> {
type Parameter = Array1<f64>;
type CostValue = f64;
type Hessian = Array2<f64>;
type StartingPoints = Self::Parameter;
type ProblemDefinition = &'a ArgminProblem<'a, Self::Parameter, Self::CostValue, Self::Hessian>;
fn init(
&mut self,
problem: Self::ProblemDefinition,
init_param: &Self::StartingPoints,
) -> Result<()> {
self.state = Some(NewtonState::new(problem, init_param.clone()));
Ok(())
}
fn next_iter(&mut self) -> Result<ArgminResult<Self::Parameter, Self::CostValue>> {
let mut state = self.state.take().unwrap();
let g = (state.problem.gradient.unwrap())(&state.param);
let h_inv = (state.problem.hessian.unwrap())(&state.param).inv()?;
state.param = state.param - self.gamma * h_inv.dot(&g);
state.iter += 1;
let mut out = ArgminResult::new(state.param.clone(), std::f64::NAN, state.iter);
self.state = Some(state);
out.set_termination_reason(self.terminate());
Ok(out)
}
make_terminate!(self,
self.state.as_ref().unwrap().iter >= self.max_iters, TerminationReason::MaxItersReached;
);
make_run!(
Self::ProblemDefinition,
Self::StartingPoints,
Self::Parameter,
Self::CostValue
);
}
impl<'a> Default for Newton<'a> {
fn default() -> Self {
Self::new()
}
}