1#[macro_use]
2extern crate alumina as al;
3extern crate ndarray;
4
5use std::path::Path;
6use std::sync::Arc;
7use std::cell::Cell;
8use al::graph::{GraphDef, Result};
9use al::id::NodeTag;
10use al::ops::nn::linear::Linear;
11use al::ops::shape::avg_pool::AvgPool;
12use al::ops::nn::bias::Bias;
13use al::ops::nn::conv::Conv;
14use al::ops::activ::tanh::Tanh;
15use al::ops::activ::spline::Spline;
16use al::ops::activ::softmax::Softmax;
17use al::ops::loss::cross_entropy::CrossEntropy;
18use al::ops::loss::prediction::Prediction;
19use al::opt::{Opt, UnboxedCallbacks, CallbackSignal, max_steps, every_n_steps};
20
21#[allow(unused_imports)]
22use al::opt::sgd::Sgd;
23use al::opt::adam::Adam;
24
25use al::data::mnist::Mnist;
26use al::data::*;
27use ndarray::ArrayD;
28
29
30fn main(){
31 learn_mnist().unwrap();
32}
33
34fn learn_mnist() -> Result<()> {
35 let g = mnist_tanh_800(1.0e-3)?;
36 let batch_size = 16;
39
40 let data = Mnist::training("D:/ML/Mnist");
41 let epoch = data.length();
42 let mut data_stream = data
43 .shuffle_random()
44 .batch(batch_size)
45 .buffered(32);
46
47 let mut solver = Adam::new(&g)?
48 .rate(1e-2)
49 .beta1(0.9)
50 .beta2(0.995);
51
52 let mut params = g.initialise_nodes(solver.parameters())?;
53
54 let mut validation = validation(&g)?;
55 solver.add_boxed_callback(every_n_steps(epoch/batch_size, Box::new(move |data| {
56 validation(data.params);
57 CallbackSignal::Continue
58 })));
59 let mut i = 0;
60 solver.add_boxed_callback(every_n_steps(epoch/batch_size, Box::new(move |_| {i += 1; println!("epoch:{}", i); CallbackSignal::Continue})));
61 solver.add_boxed_callback(max_steps(10 * epoch/batch_size));
62
63
64 let mut avg_err1 = Arc::new(Cell::new(2.3));
65 let mut avg_err2 = avg_err1.clone();
66 solver.add_callback(move |data| {
67 let new_avg_err = 0.95 * avg_err1.get() + 0.05 * data.err/batch_size as f32;
68 avg_err1.set(new_avg_err);
69 CallbackSignal::Continue
70 });
71 solver.add_boxed_callback(every_n_steps(100, Box::new(move |_data| {
72 println!("err: {}", avg_err2.get());
73 CallbackSignal::Continue
74 })));
75
76 params = solver.optimise_from(&mut data_stream, params).unwrap();
77
78 Ok(())
79}
80
81fn validation(g: &GraphDef) -> Result<Box<FnMut(&[ArrayD<f32>])>>{
82 let data = Mnist::testing(Path::new("D:/ML/Mnist")); let epoch = data.length();
84 let batch_size = 100;
85 let mut data_stream = data
86 .sequential()
87 .batch(batch_size)
88 .buffered(32);
89
90 let inputs: Vec<_> = [g.node_id("input"), g.node_id("labels")].iter()
91 .chain(&g.node_ids(NodeTag::Parameter))
92 .map(|node_id| node_id.value_id()).collect();
93 let prediction_loss = g.node_id("prediction_loss");
94 let mut subgraph = g.subgraph(&inputs, &[prediction_loss.value_id()])?;
95
96 Ok(Box::new(move |parameters: &[ArrayD<f32>]|{
97
98 let mut err = 0.0;
99 for _ in 0..epoch/batch_size {
100 let mut inputs = data_stream.next();
101 inputs.extend(parameters.iter().cloned());
102 let storage = subgraph.execute(inputs).expect("Could not execute validation");
103 let err_vec = storage.get(&prediction_loss.value_id()).unwrap();
104 err += err_vec.scalar_sum();
105 }
106
107 println!("Validation error is: {}%", 100.0*err/epoch as f32);
108 }))
109}
110
111#[allow(unused)]
113fn mnist_tanh_800(regularise: f32) -> Result<GraphDef> {
114 let mut g = GraphDef::new();
115
116 let input = g.new_node(shape![Unknown, 28, 28, 1], "input", tag![])?;
117 let labels = g.new_node(shape![Unknown, 10], "labels", tag![])?;
118
119 let layer1 = g.new_node(shape![Unknown, 800], "layer1", tag![])?;
120 let layer1_activ = g.new_node(shape![Unknown, 800], "layer1_activ", tag![])?;
121
122 let layer2 = g.new_node(shape![Unknown, 800], "layer2", tag![])?;
123 let layer2_activ = g.new_node(shape![Unknown, 800], "layer2_activ", tag![])?;
124
125 let prediction = g.new_node(shape![Unknown, 10], "prediction", tag![])?;
126 let softmax = g.new_node(shape![Unknown, 10], "softmax", tag![])?;
127
128 let prediction_loss = g.new_node(shape![Unknown], "prediction_loss", tag![])?;
129
130 g.new_op(Linear::new(&input, &layer1).init(Linear::msra(1.0)), tag![])?;
131 g.new_op(Bias::new(&layer1), tag![])?;
132 g.new_op(Tanh::new(&layer1, &layer1_activ), tag![])?;
133
134 g.new_op(Linear::new(&layer1_activ, &layer2).init(Linear::msra(1.0)), tag![])?;
135 g.new_op(Bias::new(&layer2), tag![])?;
136 g.new_op(Tanh::new(&layer2, &layer2_activ), tag![])?;
137
138 g.new_op(Linear::new(&layer2_activ, &prediction).init(Linear::msra(1.0)), tag![])?;
139 g.new_op(Softmax::new(&prediction, &softmax), tag![])?;
140 g.new_op(CrossEntropy::new(&softmax, &labels), tag![])?;
141
142 g.new_op(Prediction::new(&prediction, &labels, &prediction_loss).axes(&[-1]), tag![])?;
143
144 Ok(g)
145}
146
147#[allow(unused)]
150fn mnist_lenet(regularise: f32) -> Result<GraphDef> {
151 let mut g = GraphDef::new();
152
153 let input = g.new_node(shape![Unknown, 28, 28, 1], "input", tag![])?;
154 let labels = g.new_node(shape![Unknown, 10], "labels", tag![])?;
155
156 let c1 = 6;
157 let layer1 = g.new_node(shape![Unknown, Unknown, Unknown, c1], "layer1", tag![])?;
158 let layer1_activ = g.new_node(shape![Unknown, Unknown, Unknown, c1], "layer1_activ", tag![])?;
159 let layer1_pool = g.new_node(shape![Unknown, Unknown, Unknown, c1], "layer1_pool", tag![])?;
160
161 let c2 = 10;
162 let layer2 = g.new_node(shape![Unknown, Unknown, Unknown, c2], "layer2", tag![])?;
163 let layer2_activ = g.new_node(shape![Unknown, Unknown, Unknown, c2], "layer2_activ", tag![])?;
164 let layer2_pool = g.new_node(shape![Unknown, 7, 7, c2], "layer2_pool", tag![])?;
165
166 let c3 = 32;
167 let layer3 = g.new_node(shape![Unknown, c3], "layer3", tag![])?;
168 let layer3_activ = g.new_node(shape![Unknown, c3], "layer3_activ", tag![])?;
169
170 let prediction = g.new_node(shape![Unknown, 10], "prediction", tag![])?;
171 let softmax = g.new_node(shape![Unknown, 10], "softmax", tag![])?;
172
173 let prediction_loss = g.new_node(shape![Unknown], "prediction_loss", tag![])?;
174
175 g.new_op(Conv::new(&input, &layer1, &[5, 5]).init(Conv::msra(1.0)), tag![])?;
176 g.new_op(Bias::new(&layer1), tag![])?;
177 g.new_op(Spline::new(&layer1, &layer1_activ).init(Spline::swan()), tag![])?;
178 g.new_op(AvgPool::new(&layer1_activ, &layer1_pool, &[1, 2, 2, 1]), tag![])?;
179
180 g.new_op(Conv::new(&layer1_pool, &layer2, &[5, 5]).init(Conv::msra(1.0)), tag![])?;
181 g.new_op(Bias::new(&layer2), tag![])?;
182 g.new_op(Spline::new(&layer2, &layer2_activ).init(Spline::swan()), tag![])?;
183 g.new_op(AvgPool::new(&layer2_activ, &layer2_pool, &[1, 2, 2, 1]), tag![])?;
184
185 g.new_op(Linear::new(&layer2_pool, &layer3).init(Linear::msra(1.0)), tag![])?;
186 g.new_op(Bias::new(&layer3), tag![])?;
187 g.new_op(Spline::new(&layer3, &layer3_activ).init(Spline::swan()), tag![])?;
188
189 g.new_op(Linear::new(&layer3_activ, &prediction).init(Linear::msra(1.0)), tag![])?;
190 g.new_op(Softmax::new(&prediction, &softmax), tag![])?;
191 g.new_op(CrossEntropy::new(&softmax, &labels), tag![])?;
192
193 g.new_op(Prediction::new(&prediction, &labels, &prediction_loss).axes(&[-1]), tag![])?;
194
195 Ok(g)
196}