alumina/opt/
mod.rs

1pub mod sgd;
2pub mod adam;
3
4use graph::{GraphDef, Subgraph, Result};
5use id::{NodeID, DataID};
6use data::DataStream;
7use ndarray::ArrayD;
8
9pub enum CallbackSignal{
10	Stop,
11	Continue,
12}
13
14pub struct CallbackData<'a>{
15	pub err: f32,
16	pub step: usize,
17	pub change_norm: f32,
18	pub params: &'a [ArrayD<f32>],
19	pub stream: &'a DataStream,
20}
21
22pub trait Opt {
23
24	/// Borrows subgraph
25	fn subgraph(&self) -> &Subgraph;
26
27	/// This is the list of inputs to the subgraph which will be fed from the supplied `DataStream`.
28	fn inputs(&self) -> &[DataID];
29
30	/// This is the list of value `NodeID`s which correspond to nodes marked `Parameter`,
31	/// excluding `NodeID`s which overlap with DataIDs included in `inputs()`
32	fn parameters(&self) -> &[NodeID];
33
34	/// Returns the error, step number, l2 norm of param change, and the new parameters
35	fn step(&mut self, inputs: Vec<ArrayD<f32>>, parameters: Vec<ArrayD<f32>>) -> Result<(f32, usize, f32, Vec<ArrayD<f32>>)>;
36
37	fn callbacks(&mut self) -> &mut [Box<FnMut(&CallbackData)->CallbackSignal>];
38
39	fn add_boxed_callback(&mut self, func: Box<FnMut(&CallbackData)->CallbackSignal>);
40
41	fn optimise(&mut self, training_stream: &mut DataStream, graph: &GraphDef) -> Result<Vec<ArrayD<f32>>>{
42		let params = graph.initialise_nodes(self.parameters())?;
43		self.optimise_from(training_stream, params)
44	}
45
46	fn optimise_from(&mut self, training_stream: &mut DataStream, mut params: Vec<ArrayD<f32>>) -> Result<Vec<ArrayD<f32>>>{
47		let mut stop = false;
48		while !stop {
49			let (err, step, change_norm, new_params) = self.step(training_stream.next(), params)?;
50			params = new_params;
51
52			let data = CallbackData{err: err, step: step, change_norm: change_norm, params: &params, stream: training_stream};
53			for func in self.callbacks().iter_mut(){
54				stop = stop | matches!(func(&data), CallbackSignal::Stop);
55			}
56		}
57		Ok(params)
58	}
59}
60
61pub trait UnboxedCallbacks: Opt {
62	fn add_callback<F: 'static + FnMut(&CallbackData)->CallbackSignal>(&mut self, func: F){
63		self.add_boxed_callback(Box::new(func));
64	}
65}
66
67impl<O: Opt> UnboxedCallbacks for O {}
68
69pub fn print_step_data() -> Box<FnMut(&CallbackData)->CallbackSignal>{
70	let mut step = 0;
71	Box::new(move |data|{
72		println!("step:{}\terr:{}", step, data.err);
73		step += 1;
74		CallbackSignal::Continue
75	})
76}
77
78pub fn max_steps(max: usize) -> Box<FnMut(&CallbackData)->CallbackSignal>{
79	let mut step = 0;
80	Box::new(move |_data|{
81		if step < max {
82			step += 1;
83			CallbackSignal::Continue
84		} else {
85			CallbackSignal::Stop
86		}
87	})
88}
89
90pub fn min_err(min: f32) -> Box<FnMut(&CallbackData)->CallbackSignal>{
91	Box::new(move |data|{
92		if data.err > min {
93			CallbackSignal::Continue
94		} else {
95			CallbackSignal::Stop
96		}
97	})
98}
99
100pub fn every_n_steps(n: usize, mut func: Box<FnMut(&CallbackData)->CallbackSignal>) -> Box<FnMut(&CallbackData)->CallbackSignal>{
101	let mut step = 0;
102	Box::new(move |data|{
103		if step % n == 0 {
104			step += 1;
105			func(data)
106		} else {
107			step += 1;
108			CallbackSignal::Continue
109		}
110	})
111}