supervised_classification/
supervised_classification.rs1use train_station::{
13 gradtrack::clear_all_graphs_known,
14 optimizers::{Adam, Optimizer},
15 Tensor,
16};
17
18#[allow(clippy::duplicate_mod)]
19#[path = "../neural_networks/feedforward_network.rs"]
20mod feedforward_network;
21use feedforward_network::{FeedForwardConfig, FeedForwardNetwork};
22
23fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
24 let mut total_sq = 0.0f32;
25 for p in parameters.iter() {
26 if let Some(g) = p.grad_owned() {
27 for &v in g.data() {
28 total_sq += v * v;
29 }
30 }
31 }
32 let norm = total_sq.sqrt();
33 if norm > max_norm {
34 let scale = max_norm / (norm + eps);
35 for p in parameters.iter_mut() {
36 if let Some(g) = p.grad_owned() {
37 p.set_grad(g.mul_scalar(scale));
38 }
39 }
40 }
41}
42
43fn cross_entropy_logits(
45 logits: &Tensor,
46 labels: &[usize],
47 batch: usize,
48 _num_classes: usize,
49) -> Tensor {
50 let max_logits = logits.max_dims(&[1], true);
52 let shifted = logits.sub_tensor(&max_logits);
53 let exp = shifted.exp();
54 let sum_exp = exp.sum_dims(&[1], true);
55 let log_sum_exp = sum_exp.log();
56 let log_softmax = shifted.sub_tensor(&log_sum_exp);
57 let ll = log_softmax.gather(1, labels, &[batch, 1]); ll.mul_scalar(-1.0).mean()
59}
60
61fn accuracy_from_logits(
62 logits: &Tensor,
63 labels: &[usize],
64 batch: usize,
65 num_classes: usize,
66) -> f32 {
67 let row = logits.data();
68 let mut correct = 0usize;
69 for (i, &label) in labels.iter().enumerate().take(batch) {
70 let base = i * num_classes;
71 let mut best_j = 0usize;
72 let mut best_v = row[base];
73 for j in 1..num_classes {
74 let v = row[base + j];
75 if v > best_v {
76 best_v = v;
77 best_j = j;
78 }
79 }
80 if best_j == label {
81 correct += 1;
82 }
83 }
84 correct as f32 / batch as f32
85}
86
87pub fn main() -> Result<(), Box<dyn std::error::Error>> {
88 println!("=== Supervised Classification Example (Cross-Entropy) ===");
89
90 let n = 1200usize;
92 let classes = 3usize;
93 let mut xs: Vec<f32> = Vec::with_capacity(n * 2);
94 let mut ys: Vec<usize> = Vec::with_capacity(n);
95
96 let mut state: u64 = 424242;
98 let mut rand_f32 = || {
99 state = state.wrapping_mul(1664525).wrapping_add(1013904223);
100 (state >> 16) as f32 / (u32::MAX as f32)
101 };
102
103 for _ in 0..n {
104 let x1 = rand_f32() * 4.0 - 2.0;
105 let x2 = rand_f32() * 4.0 - 2.0;
106 let mut c = if x1 + 0.5 * x2 > 0.5 {
108 0
109 } else if x1 - x2 < -0.5 {
110 1
111 } else {
112 2
113 };
114 if rand_f32() < 0.05 {
115 c = (c + 1) % classes;
116 }
117 xs.push(x1);
118 xs.push(x2);
119 ys.push(c);
120 }
121
122 let mut min1 = f32::INFINITY;
124 let mut max1 = f32::NEG_INFINITY;
125 let mut min2 = f32::INFINITY;
126 let mut max2 = f32::NEG_INFINITY;
127 for i in (0..xs.len()).step_by(2) {
128 let a = xs[i];
129 let b = xs[i + 1];
130 if a < min1 {
131 min1 = a;
132 }
133 if a > max1 {
134 max1 = a;
135 }
136 if b < min2 {
137 min2 = b;
138 }
139 if b > max2 {
140 max2 = b;
141 }
142 }
143 let rng1 = (max1 - min1).max(1e-8);
144 let rng2 = (max2 - min2).max(1e-8);
145 for i in (0..xs.len()).step_by(2) {
146 let a = xs[i];
147 let b = xs[i + 1];
148 xs[i] = 2.0 * (a - min1) / rng1 - 1.0;
149 xs[i + 1] = 2.0 * (b - min2) / rng2 - 1.0;
150 }
151
152 let n_train = (n as f32 * 0.8) as usize;
154 let x_train = Tensor::from_slice(&xs[..n_train * 2], vec![n_train, 2]).unwrap();
155 let y_train = ys[..n_train].to_vec();
156 let x_val = Tensor::from_slice(&xs[n_train * 2..], vec![n - n_train, 2]).unwrap();
157 let y_val = ys[n_train..].to_vec();
158
159 let cfg = FeedForwardConfig {
161 input_size: 2,
162 hidden_sizes: vec![64, 64],
163 output_size: classes,
164 use_bias: true,
165 };
166 let mut net = FeedForwardNetwork::new(cfg, Some(303));
167
168 let mut opt = Adam::with_learning_rate(1e-3);
170 for p in net.parameters() {
171 opt.add_parameter(p);
172 }
173
174 let epochs = 300usize;
175 let max_grad_norm = 1.0f32;
176 let mut best_val_acc = 0.0f32;
177 let mut best_val_loss = f32::INFINITY;
178
179 for e in 0..epochs {
180 {
182 let mut params = net.parameters();
183 opt.zero_grad(&mut params);
184 }
185
186 let logits = net.forward(&x_train);
188 let mut loss = cross_entropy_logits(&logits, &y_train, n_train, classes);
189 loss.backward(None);
190
191 {
193 let params = net.parameters();
194 let mut with_grads: Vec<&mut Tensor> = Vec::new();
195 for p in params {
196 if p.grad_owned().is_some() {
197 with_grads.push(p);
198 }
199 }
200 if !with_grads.is_empty() {
201 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
202 opt.step(&mut with_grads);
203 opt.zero_grad(&mut with_grads);
204 }
205 }
206
207 let train_acc = accuracy_from_logits(&logits, &y_train, n_train, classes);
209 let val_logits = net.forward(&x_val);
210 let val_loss = cross_entropy_logits(&val_logits, &y_val, n - n_train, classes).value();
211 let val_acc = accuracy_from_logits(&val_logits, &y_val, n - n_train, classes);
212 if val_acc > best_val_acc {
213 best_val_acc = val_acc;
214 }
215 if val_loss < best_val_loss {
216 best_val_loss = val_loss;
217 }
218
219 if e % 10 == 0 || e + 1 == epochs {
220 println!(
221 "epoch {:4} | loss={:.4} acc={:.3} | val_loss={:.4} val_acc={:.3} | best_val_acc={:.3}",
222 e, loss.value(), train_acc, val_loss, val_acc, best_val_acc
223 );
224 }
225
226 clear_all_graphs_known();
227 }
228
229 let samples = Tensor::from_slice(&[-1.0, -1.0, 0.0, 0.0, 1.0, 1.0], vec![3, 2]).unwrap();
231 let sm = net.forward(&samples).softmax(1);
232 println!("sample class probs: {:?}", sm.data());
233
234 println!("=== Supervised classification finished ===");
235 Ok(())
236}