Trait NNBackPropagation

Source
pub trait NNBackPropagation {
    // Required methods
    fn forward(&mut self, input: &NdArray, required_grad: bool) -> NdArray;
    fn forward_as_borrow(&self, input: &NdArray) -> NdArray;
    fn backward(&mut self, bp_grad: NdArray) -> NdArray;

    // Provided methods
    fn step(
        &mut self,
        _reduction: usize,
        _lr: f32,
        _gradient_clip_by_norm: Option<NormType>,
    ) { ... }
    fn weight_mut_borrow(&mut self) -> &mut [f32] { ... }
}

Required Methods§

Source

fn forward(&mut self, input: &NdArray, required_grad: bool) -> NdArray

nn forward

  • require_grad: if it is set to true, it will save the necessary grad graph, and it is also the reason for &mut self
Source

fn forward_as_borrow(&self, input: &NdArray) -> NdArray

forward but without graph and it only requires immutable reference of self

Source

fn backward(&mut self, bp_grad: NdArray) -> NdArray

calculate the gradidents and save them to grad_w or grad_b

Provided Methods§

Source

fn step( &mut self, _reduction: usize, _lr: f32, _gradient_clip_by_norm: Option<NormType>, )

update the weights and bias with the grad_w and grad_b, respectively

  • reduction:
    • len(batch): average over the batch
    • 1: sum the batch
  • lr: learning rate
  • gradient_clip_by_norm: gradient clippling by NormType, default is None
Examples found in repository?
examples/06_mlp.rs (line 44)
9fn main() -> std::io::Result<()> {
10
11    let path = ".data/TianchiCarPriceRegression/train_5w.csv";
12    let dataset = Dataset::<f32>::from_name(path, DatasetName::CarPriceRegressionDataset, None);
13    let mut res = dataset.split_dataset(vec![0.8, 0.2], 0);
14    let (train_dataset, test_dataset) = (res.remove(0), res.remove(0));
15
16    let blocks = vec![
17        NNmodule::Linear(train_dataset.feature_len(), 512, Some(Penalty::RidgeL2(1e-1))), 
18        NNmodule::Tanh,
19        NNmodule::Linear(512, 128, Some(Penalty::RidgeL2(1e-1))), 
20        NNmodule::Tanh,
21        NNmodule::Linear(128, 128, Some(Penalty::RidgeL2(1e-1))), 
22        NNmodule::Relu,
23        NNmodule::Linear(128, 16, Some(Penalty::RidgeL2(1e-1))), 
24        NNmodule::Tanh,
25        NNmodule::Linear(16, 1, None)
26    ];
27    let mut model = NeuralNetwork::new(blocks);
28    model.weight_init(None);
29    let mut criterion = MeanSquaredError::new();
30
31    let mut train_dataloader = Dataloader::new(train_dataset, 64, true, Some(0));
32    let mut test_dataloader = Dataloader::new(test_dataset, 128, false, None);
33
34    const EPOCH: usize = 10;
35    let mut error_records = vec![];
36
37    for ep in 0..EPOCH {
38        let mut losses = vec![];
39        let start = Instant::now();
40        for (feature, label) in train_dataloader.iter_mut() {
41            let logits = model.forward(&feature, true);
42            let grad = criterion.forward(logits, &label);
43            model.backward(grad);
44            model.step(label.len(), 1e-1, Some(NormType::L2(1.0)));
45            losses.push(criterion.avg_loss);
46        }
47        let train_time = Instant::now() - start;
48        let start = Instant::now();
49        let mean_abs_error = evaluate_regression(&mut test_dataloader, &model);
50        let test_time = Instant::now() - start;
51        error_records.push(mean_abs_error);
52        let width = ">".repeat(ep * 50 / EPOCH);
53        print!("\r{width:-<50}\t{:.3}\t{mean_abs_error:.3}\t", losses.iter().sum::<f32>() / losses.len() as f32);
54        println!("\ntime cost train {:.3} test {:.3}", train_time.as_secs_f64(), test_time.as_secs_f64());
55        stdout().flush()?;
56    }
57    let (best_ep, best_error) = error_records.iter().enumerate().fold((0, f32::MAX), |s, (i, e)| {
58        if *e < s.1 {
59            (i, *e)
60        } else {
61            s
62        }
63    });
64    println!("\n{error_records:?}\nbest ep {best_ep} best mean abs error {best_error:.5}");
65
66    Ok(())
67}
Source

fn weight_mut_borrow(&mut self) -> &mut [f32]

mutablly borrow the raw data of the weights

Implementors§