1pub mod sgd;
2pub mod adam;
3
4use graph::{GraphDef, Subgraph, Result};
5use id::{NodeID, DataID};
6use data::DataStream;
7use ndarray::ArrayD;
8
9pub enum CallbackSignal{
10 Stop,
11 Continue,
12}
13
14pub struct CallbackData<'a>{
15 pub err: f32,
16 pub step: usize,
17 pub change_norm: f32,
18 pub params: &'a [ArrayD<f32>],
19 pub stream: &'a DataStream,
20}
21
22pub trait Opt {
23
24 fn subgraph(&self) -> &Subgraph;
26
27 fn inputs(&self) -> &[DataID];
29
30 fn parameters(&self) -> &[NodeID];
33
34 fn step(&mut self, inputs: Vec<ArrayD<f32>>, parameters: Vec<ArrayD<f32>>) -> Result<(f32, usize, f32, Vec<ArrayD<f32>>)>;
36
37 fn callbacks(&mut self) -> &mut [Box<FnMut(&CallbackData)->CallbackSignal>];
38
39 fn add_boxed_callback(&mut self, func: Box<FnMut(&CallbackData)->CallbackSignal>);
40
41 fn optimise(&mut self, training_stream: &mut DataStream, graph: &GraphDef) -> Result<Vec<ArrayD<f32>>>{
42 let params = graph.initialise_nodes(self.parameters())?;
43 self.optimise_from(training_stream, params)
44 }
45
46 fn optimise_from(&mut self, training_stream: &mut DataStream, mut params: Vec<ArrayD<f32>>) -> Result<Vec<ArrayD<f32>>>{
47 let mut stop = false;
48 while !stop {
49 let (err, step, change_norm, new_params) = self.step(training_stream.next(), params)?;
50 params = new_params;
51
52 let data = CallbackData{err: err, step: step, change_norm: change_norm, params: ¶ms, stream: training_stream};
53 for func in self.callbacks().iter_mut(){
54 stop = stop | matches!(func(&data), CallbackSignal::Stop);
55 }
56 }
57 Ok(params)
58 }
59}
60
61pub trait UnboxedCallbacks: Opt {
62 fn add_callback<F: 'static + FnMut(&CallbackData)->CallbackSignal>(&mut self, func: F){
63 self.add_boxed_callback(Box::new(func));
64 }
65}
66
67impl<O: Opt> UnboxedCallbacks for O {}
68
69pub fn print_step_data() -> Box<FnMut(&CallbackData)->CallbackSignal>{
70 let mut step = 0;
71 Box::new(move |data|{
72 println!("step:{}\terr:{}", step, data.err);
73 step += 1;
74 CallbackSignal::Continue
75 })
76}
77
78pub fn max_steps(max: usize) -> Box<FnMut(&CallbackData)->CallbackSignal>{
79 let mut step = 0;
80 Box::new(move |_data|{
81 if step < max {
82 step += 1;
83 CallbackSignal::Continue
84 } else {
85 CallbackSignal::Stop
86 }
87 })
88}
89
90pub fn min_err(min: f32) -> Box<FnMut(&CallbackData)->CallbackSignal>{
91 Box::new(move |data|{
92 if data.err > min {
93 CallbackSignal::Continue
94 } else {
95 CallbackSignal::Stop
96 }
97 })
98}
99
100pub fn every_n_steps(n: usize, mut func: Box<FnMut(&CallbackData)->CallbackSignal>) -> Box<FnMut(&CallbackData)->CallbackSignal>{
101 let mut step = 0;
102 Box::new(move |data|{
103 if step % n == 0 {
104 step += 1;
105 func(data)
106 } else {
107 step += 1;
108 CallbackSignal::Continue
109 }
110 })
111}