1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
pub mod sgd;
pub mod cain;
use graph::*;
use supplier::Supplier;
pub enum CallbackSignal{
Stop,
Continue,
}
pub struct CallbackData<'a>{
pub err: f32,
pub step_count: u64,
pub eval_count: u64,
pub graph: &'a Graph,
pub params: &'a [f32],
}
pub fn print_step_data() -> Box<FnMut(CallbackData)->CallbackSignal>{
Box::new(move |data|{
println!("err:{}\tstep_count:{}\teval_count{}", data.err, data.step_count, data.eval_count);
CallbackSignal::Continue
})
}
pub fn max_evals(max: u64) -> Box<FnMut(CallbackData)->CallbackSignal>{
Box::new(move |data|{
if data.eval_count < max {
CallbackSignal::Continue
} else {
CallbackSignal::Stop
}
})
}
pub fn max_steps(max: u64) -> Box<FnMut(CallbackData)->CallbackSignal>{
Box::new(move |data|{
if data.step_count < max {
CallbackSignal::Continue
} else {
CallbackSignal::Stop
}
})
}
pub fn min_err(min: f32) -> Box<FnMut(CallbackData)->CallbackSignal>{
Box::new(move |data|{
if data.err > min {
CallbackSignal::Continue
} else {
CallbackSignal::Stop
}
})
}
pub trait Optimiser<'a> {
fn get_graph(& mut self) -> & mut Graph;
fn add_boxed_step_callback(&mut self, func: Box<FnMut(CallbackData)->CallbackSignal>);
fn add_step_callback<F: 'static + FnMut(CallbackData)->CallbackSignal>(&mut self, func: F){
self.add_boxed_step_callback(Box::new(func));
}
fn optimise(&mut self, training_set: &mut Supplier) -> Vec<f32>{
let params = self.get_graph().init_params();
self.optimise_from(training_set, params)
}
fn optimise_from(&mut self, training_set: &mut Supplier, params: Vec<f32>) -> Vec<f32>;
fn step(&mut self, training_set: &mut Supplier, params: Vec<f32>) -> (f32, Vec<f32>);
}