Function evaluate_regression

Source
pub fn evaluate_regression<'a, T: EvaluateArgTrait<'a, f32>>(
    dataset: T,
    model: &impl Model<f32>,
) -> f32
Expand description

evaluate regression dataset

  • data: &Dataset or &mut Dataloader<f32, &Dataset>
  • return: mean absolute error
Examples found in repository?
examples/03_linear_regression.rs (line 27)
9fn main() -> std::io::Result<()> {
10    let path = ".data/TianchiCarPriceRegression/train_5w.csv";
11        let dataset = Dataset::<f32>::from_name(path, DatasetName::CarPriceRegressionDataset, None);
12        let mut res = dataset.split_dataset(vec![0.8, 0.2], 0);
13        let (train_dataset, test_dataset) = (res.remove(0), res.remove(0));
14
15        let mut model = LinearRegression::new(train_dataset.feature_len(), Some(Penalty::RidgeL2(0.1)), |_| {});
16
17        let mut train_dataloader = Dataloader::new(train_dataset, 64, true, None);
18
19        const EPOCH: usize = 10;
20        let mut error_records = vec![];
21        for ep in 0..EPOCH {
22            let mut losses = vec![];
23            for (feature, label) in train_dataloader.iter_mut() {
24                let loss = model.one_step(&feature, &label, 1e-2, Some(NormType::L2(1.0)));
25                losses.push(loss);
26            }
27            let mean_abs_error = evaluate_regression(&test_dataset, &model);
28            error_records.push(mean_abs_error);
29            let width = ">".repeat(ep * 50 / EPOCH);
30            print!("\r{width:-<50}\t{:.3}\t{mean_abs_error:.3}", losses.iter().sum::<f32>() / losses.len() as f32);
31            stdout().flush()?;
32        }
33        let (best_ep, best_error) = error_records.iter().enumerate().fold((0, f32::MAX), |s, (i, e)| {
34            if *e < s.1 {
35                (i, *e)
36            } else {
37                s
38            }
39        });
40        println!("\n{error_records:?}\nbest ep {best_ep} best mean abs error {best_error:.5}");
41
42        Ok(())
43}
More examples
Hide additional examples
examples/06_mlp.rs (line 49)
9fn main() -> std::io::Result<()> {
10
11    let path = ".data/TianchiCarPriceRegression/train_5w.csv";
12    let dataset = Dataset::<f32>::from_name(path, DatasetName::CarPriceRegressionDataset, None);
13    let mut res = dataset.split_dataset(vec![0.8, 0.2], 0);
14    let (train_dataset, test_dataset) = (res.remove(0), res.remove(0));
15
16    let blocks = vec![
17        NNmodule::Linear(train_dataset.feature_len(), 512, Some(Penalty::RidgeL2(1e-1))), 
18        NNmodule::Tanh,
19        NNmodule::Linear(512, 128, Some(Penalty::RidgeL2(1e-1))), 
20        NNmodule::Tanh,
21        NNmodule::Linear(128, 128, Some(Penalty::RidgeL2(1e-1))), 
22        NNmodule::Relu,
23        NNmodule::Linear(128, 16, Some(Penalty::RidgeL2(1e-1))), 
24        NNmodule::Tanh,
25        NNmodule::Linear(16, 1, None)
26    ];
27    let mut model = NeuralNetwork::new(blocks);
28    model.weight_init(None);
29    let mut criterion = MeanSquaredError::new();
30
31    let mut train_dataloader = Dataloader::new(train_dataset, 64, true, Some(0));
32    let mut test_dataloader = Dataloader::new(test_dataset, 128, false, None);
33
34    const EPOCH: usize = 10;
35    let mut error_records = vec![];
36
37    for ep in 0..EPOCH {
38        let mut losses = vec![];
39        let start = Instant::now();
40        for (feature, label) in train_dataloader.iter_mut() {
41            let logits = model.forward(&feature, true);
42            let grad = criterion.forward(logits, &label);
43            model.backward(grad);
44            model.step(label.len(), 1e-1, Some(NormType::L2(1.0)));
45            losses.push(criterion.avg_loss);
46        }
47        let train_time = Instant::now() - start;
48        let start = Instant::now();
49        let mean_abs_error = evaluate_regression(&mut test_dataloader, &model);
50        let test_time = Instant::now() - start;
51        error_records.push(mean_abs_error);
52        let width = ">".repeat(ep * 50 / EPOCH);
53        print!("\r{width:-<50}\t{:.3}\t{mean_abs_error:.3}\t", losses.iter().sum::<f32>() / losses.len() as f32);
54        println!("\ntime cost train {:.3} test {:.3}", train_time.as_secs_f64(), test_time.as_secs_f64());
55        stdout().flush()?;
56    }
57    let (best_ep, best_error) = error_records.iter().enumerate().fold((0, f32::MAX), |s, (i, e)| {
58        if *e < s.1 {
59            (i, *e)
60        } else {
61            s
62        }
63    });
64    println!("\n{error_records:?}\nbest ep {best_ep} best mean abs error {best_error:.5}");
65
66    Ok(())
67}