mnist/
mnist.rs

1#[macro_use]
2extern crate alumina as al;
3extern crate ndarray;
4
5use std::path::Path;
6use std::sync::Arc;
7use std::cell::Cell;
8use al::graph::{GraphDef, Result};
9use al::id::NodeTag;
10use al::ops::nn::linear::Linear;
11use al::ops::shape::avg_pool::AvgPool;
12use al::ops::nn::bias::Bias;
13use al::ops::nn::conv::Conv;
14use al::ops::activ::tanh::Tanh;
15use al::ops::activ::spline::Spline;
16use al::ops::activ::softmax::Softmax;
17use al::ops::loss::cross_entropy::CrossEntropy;
18use al::ops::loss::prediction::Prediction;
19use al::opt::{Opt, UnboxedCallbacks, CallbackSignal, max_steps, every_n_steps};
20
21#[allow(unused_imports)]
22use al::opt::sgd::Sgd;
23use al::opt::adam::Adam;
24
25use al::data::mnist::Mnist;
26use al::data::*;
27use ndarray::ArrayD;
28
29
30fn main(){
31	learn_mnist().unwrap();
32}
33
34fn learn_mnist() -> Result<()> {
35	let g = mnist_tanh_800(1.0e-3)?;
36	//let g = mnist_lenet(1.0e-3)?;
37
38	let batch_size = 16;
39
40	let data = Mnist::training("D:/ML/Mnist");
41	let epoch = data.length();
42	let mut data_stream = data
43		.shuffle_random()
44		.batch(batch_size)
45		.buffered(32);
46
47	let mut solver = Adam::new(&g)?
48		.rate(1e-2)
49		.beta1(0.9)
50		.beta2(0.995);
51
52	let mut params = g.initialise_nodes(solver.parameters())?;
53	
54	let mut validation = validation(&g)?;
55	solver.add_boxed_callback(every_n_steps(epoch/batch_size, Box::new(move |data| {
56		validation(data.params);
57		CallbackSignal::Continue
58	})));
59	let mut i = 0;
60	solver.add_boxed_callback(every_n_steps(epoch/batch_size, Box::new(move |_| {i += 1; println!("epoch:{}", i); CallbackSignal::Continue})));
61	solver.add_boxed_callback(max_steps(10 * epoch/batch_size));
62	
63	
64	let mut avg_err1 = Arc::new(Cell::new(2.3));
65	let mut avg_err2 = avg_err1.clone();
66	solver.add_callback(move |data| {
67		let new_avg_err = 0.95 * avg_err1.get() + 0.05 * data.err/batch_size as f32;
68		avg_err1.set(new_avg_err);
69		CallbackSignal::Continue
70	});
71	solver.add_boxed_callback(every_n_steps(100, Box::new(move |_data| {
72		println!("err: {}", avg_err2.get());
73		CallbackSignal::Continue
74	})));
75
76	params = solver.optimise_from(&mut data_stream, params).unwrap();
77
78	Ok(())
79}
80
81fn validation(g: &GraphDef) -> Result<Box<FnMut(&[ArrayD<f32>])>>{
82	let data = Mnist::testing(Path::new("D:/ML/Mnist")); // I mean, who doesnt validate on the test set!
83	let epoch = data.length();
84	let batch_size = 100;
85	let mut data_stream = data
86		.sequential()
87		.batch(batch_size)
88		.buffered(32);
89
90	let inputs: Vec<_> = [g.node_id("input"), g.node_id("labels")].iter()
91		.chain(&g.node_ids(NodeTag::Parameter))
92		.map(|node_id| node_id.value_id()).collect();
93	let prediction_loss = g.node_id("prediction_loss");
94	let mut subgraph = g.subgraph(&inputs, &[prediction_loss.value_id()])?;
95	
96	Ok(Box::new(move |parameters: &[ArrayD<f32>]|{
97		
98		let mut err = 0.0;
99		for _ in 0..epoch/batch_size {
100			let mut inputs = data_stream.next();
101			inputs.extend(parameters.iter().cloned());
102			let storage = subgraph.execute(inputs).expect("Could not execute validation");
103			let err_vec = storage.get(&prediction_loss.value_id()).unwrap();
104			err += err_vec.scalar_sum();
105		}
106
107		println!("Validation error is: {}%", 100.0*err/epoch as f32);
108	}))
109}
110
111/// A common mnist network with two hidden layers of 800 units and tanh activation functions
112#[allow(unused)]
113fn mnist_tanh_800(regularise: f32) -> Result<GraphDef> {
114	let mut g = GraphDef::new();
115
116	let input = g.new_node(shape![Unknown, 28, 28, 1], "input", tag![])?;
117	let labels = g.new_node(shape![Unknown, 10], "labels", tag![])?;
118
119	let layer1 = g.new_node(shape![Unknown, 800], "layer1", tag![])?;
120	let layer1_activ = g.new_node(shape![Unknown, 800], "layer1_activ", tag![])?;
121
122	let layer2 = g.new_node(shape![Unknown, 800], "layer2", tag![])?;
123	let layer2_activ = g.new_node(shape![Unknown, 800], "layer2_activ", tag![])?;
124
125	let prediction = g.new_node(shape![Unknown, 10], "prediction", tag![])?;
126	let softmax = g.new_node(shape![Unknown, 10], "softmax", tag![])?;
127
128	let prediction_loss = g.new_node(shape![Unknown], "prediction_loss", tag![])?;
129
130	g.new_op(Linear::new(&input, &layer1).init(Linear::msra(1.0)), tag![])?;
131	g.new_op(Bias::new(&layer1), tag![])?;
132	g.new_op(Tanh::new(&layer1, &layer1_activ), tag![])?;
133
134	g.new_op(Linear::new(&layer1_activ, &layer2).init(Linear::msra(1.0)), tag![])?;
135	g.new_op(Bias::new(&layer2), tag![])?;
136	g.new_op(Tanh::new(&layer2, &layer2_activ), tag![])?;
137
138	g.new_op(Linear::new(&layer2_activ, &prediction).init(Linear::msra(1.0)), tag![])?;
139	g.new_op(Softmax::new(&prediction, &softmax), tag![])?;
140	g.new_op(CrossEntropy::new(&softmax, &labels), tag![])?;
141
142	g.new_op(Prediction::new(&prediction, &labels, &prediction_loss).axes(&[-1]), tag![])?;
143
144	Ok(g)
145}
146
147/// Based on LeNet variant as descripted at http://luizgh.github.io/libraries/2015/12/08/getting-started-with-lasagne/
148/// Activation used is the non-traditional Spline
149#[allow(unused)]
150fn mnist_lenet(regularise: f32) -> Result<GraphDef> {
151	let mut g = GraphDef::new();
152
153	let input = g.new_node(shape![Unknown, 28, 28, 1], "input", tag![])?;
154	let labels = g.new_node(shape![Unknown, 10], "labels", tag![])?;
155
156	let c1 = 6;
157	let layer1 = g.new_node(shape![Unknown, Unknown, Unknown, c1], "layer1", tag![])?;
158	let layer1_activ = g.new_node(shape![Unknown, Unknown, Unknown, c1], "layer1_activ", tag![])?;
159	let layer1_pool = g.new_node(shape![Unknown, Unknown, Unknown, c1], "layer1_pool", tag![])?;
160
161	let c2 = 10;
162	let layer2 = g.new_node(shape![Unknown, Unknown, Unknown, c2], "layer2", tag![])?;
163	let layer2_activ = g.new_node(shape![Unknown, Unknown, Unknown, c2], "layer2_activ", tag![])?;
164	let layer2_pool = g.new_node(shape![Unknown, 7, 7, c2], "layer2_pool", tag![])?;
165
166	let c3 = 32;
167	let layer3 = g.new_node(shape![Unknown, c3], "layer3", tag![])?;
168	let layer3_activ = g.new_node(shape![Unknown, c3], "layer3_activ", tag![])?;
169
170	let prediction = g.new_node(shape![Unknown, 10], "prediction", tag![])?;
171	let softmax = g.new_node(shape![Unknown, 10], "softmax", tag![])?;
172
173	let prediction_loss = g.new_node(shape![Unknown], "prediction_loss", tag![])?;
174
175	g.new_op(Conv::new(&input, &layer1, &[5, 5]).init(Conv::msra(1.0)), tag![])?;
176	g.new_op(Bias::new(&layer1), tag![])?;
177	g.new_op(Spline::new(&layer1, &layer1_activ).init(Spline::swan()), tag![])?;
178	g.new_op(AvgPool::new(&layer1_activ, &layer1_pool, &[1, 2, 2, 1]), tag![])?;
179
180	g.new_op(Conv::new(&layer1_pool, &layer2, &[5, 5]).init(Conv::msra(1.0)), tag![])?;
181	g.new_op(Bias::new(&layer2), tag![])?;
182	g.new_op(Spline::new(&layer2, &layer2_activ).init(Spline::swan()), tag![])?;
183	g.new_op(AvgPool::new(&layer2_activ, &layer2_pool, &[1, 2, 2, 1]), tag![])?;
184
185	g.new_op(Linear::new(&layer2_pool, &layer3).init(Linear::msra(1.0)), tag![])?;
186	g.new_op(Bias::new(&layer3), tag![])?;
187	g.new_op(Spline::new(&layer3, &layer3_activ).init(Spline::swan()), tag![])?;
188
189	g.new_op(Linear::new(&layer3_activ, &prediction).init(Linear::msra(1.0)), tag![])?;
190	g.new_op(Softmax::new(&prediction, &softmax), tag![])?;
191	g.new_op(CrossEntropy::new(&softmax, &labels), tag![])?;
192
193	g.new_op(Prediction::new(&prediction, &labels, &prediction_loss).axes(&[-1]), tag![])?;
194
195	Ok(g)
196}