pub trait NNBackPropagation {
// Required methods
fn forward(&mut self, input: &NdArray, required_grad: bool) -> NdArray;
fn forward_as_borrow(&self, input: &NdArray) -> NdArray;
fn backward(&mut self, bp_grad: NdArray) -> NdArray;
// Provided methods
fn step(
&mut self,
_reduction: usize,
_lr: f32,
_gradient_clip_by_norm: Option<NormType>,
) { ... }
fn weight_mut_borrow(&mut self) -> &mut [f32] { ... }
}
Required Methods§
Sourcefn forward(&mut self, input: &NdArray, required_grad: bool) -> NdArray
fn forward(&mut self, input: &NdArray, required_grad: bool) -> NdArray
nn forward
- require_grad: if it is set to true, it will save the necessary grad graph, and it is also the reason for &mut self
Sourcefn forward_as_borrow(&self, input: &NdArray) -> NdArray
fn forward_as_borrow(&self, input: &NdArray) -> NdArray
forward but without graph and it only requires immutable reference of self
Provided Methods§
Sourcefn step(
&mut self,
_reduction: usize,
_lr: f32,
_gradient_clip_by_norm: Option<NormType>,
)
fn step( &mut self, _reduction: usize, _lr: f32, _gradient_clip_by_norm: Option<NormType>, )
update the weights and bias with the grad_w and grad_b, respectively
- reduction:
- len(batch): average over the batch
- 1: sum the batch
- lr: learning rate
- gradient_clip_by_norm: gradient clippling by NormType, default is None
Examples found in repository?
examples/06_mlp.rs (line 44)
9fn main() -> std::io::Result<()> {
10
11 let path = ".data/TianchiCarPriceRegression/train_5w.csv";
12 let dataset = Dataset::<f32>::from_name(path, DatasetName::CarPriceRegressionDataset, None);
13 let mut res = dataset.split_dataset(vec![0.8, 0.2], 0);
14 let (train_dataset, test_dataset) = (res.remove(0), res.remove(0));
15
16 let blocks = vec![
17 NNmodule::Linear(train_dataset.feature_len(), 512, Some(Penalty::RidgeL2(1e-1))),
18 NNmodule::Tanh,
19 NNmodule::Linear(512, 128, Some(Penalty::RidgeL2(1e-1))),
20 NNmodule::Tanh,
21 NNmodule::Linear(128, 128, Some(Penalty::RidgeL2(1e-1))),
22 NNmodule::Relu,
23 NNmodule::Linear(128, 16, Some(Penalty::RidgeL2(1e-1))),
24 NNmodule::Tanh,
25 NNmodule::Linear(16, 1, None)
26 ];
27 let mut model = NeuralNetwork::new(blocks);
28 model.weight_init(None);
29 let mut criterion = MeanSquaredError::new();
30
31 let mut train_dataloader = Dataloader::new(train_dataset, 64, true, Some(0));
32 let mut test_dataloader = Dataloader::new(test_dataset, 128, false, None);
33
34 const EPOCH: usize = 10;
35 let mut error_records = vec![];
36
37 for ep in 0..EPOCH {
38 let mut losses = vec![];
39 let start = Instant::now();
40 for (feature, label) in train_dataloader.iter_mut() {
41 let logits = model.forward(&feature, true);
42 let grad = criterion.forward(logits, &label);
43 model.backward(grad);
44 model.step(label.len(), 1e-1, Some(NormType::L2(1.0)));
45 losses.push(criterion.avg_loss);
46 }
47 let train_time = Instant::now() - start;
48 let start = Instant::now();
49 let mean_abs_error = evaluate_regression(&mut test_dataloader, &model);
50 let test_time = Instant::now() - start;
51 error_records.push(mean_abs_error);
52 let width = ">".repeat(ep * 50 / EPOCH);
53 print!("\r{width:-<50}\t{:.3}\t{mean_abs_error:.3}\t", losses.iter().sum::<f32>() / losses.len() as f32);
54 println!("\ntime cost train {:.3} test {:.3}", train_time.as_secs_f64(), test_time.as_secs_f64());
55 stdout().flush()?;
56 }
57 let (best_ep, best_error) = error_records.iter().enumerate().fold((0, f32::MAX), |s, (i, e)| {
58 if *e < s.1 {
59 (i, *e)
60 } else {
61 s
62 }
63 });
64 println!("\n{error_records:?}\nbest ep {best_ep} best mean abs error {best_error:.5}");
65
66 Ok(())
67}
Sourcefn weight_mut_borrow(&mut self) -> &mut [f32]
fn weight_mut_borrow(&mut self) -> &mut [f32]
mutablly borrow the raw data of the weights