Function normalize_dataset

Source
pub fn normalize_dataset<T: TaskLabelType + Copy>(
    dataset: &mut Dataset<T>,
    scaler: ScalerType,
)
Examples found in repository?
examples/02_logistic_regression.rs (line 14)
10fn main() -> std::io::Result<()> {
11    let path = ".data/MobilePhonePricePredict/train.csv";
12        let mut dataset: Dataset<usize> = Dataset::<usize>::from_name(path, DatasetName::MobilePhonePricePredictDataset, None);
13
14        normalize_dataset(&mut dataset, ScalerType::Standard);
15
16        let mut res = dataset.split_dataset(vec![0.8, 0.2], 0);
17        let (train_dataset, test_dataset) = (res.remove(0), res.remove(0));
18        
19        let mut rng = RandGenerator::new(0);
20
21        // initing weights is also important, or the gradients may not be good
22        let mut model = LogisticRegression::new(train_dataset.feature_len(), train_dataset.class_num(), Some(Penalty::RidgeL2(1e-1)),|item| item.iter_mut().for_each(move |i| {*i = rng.gen_f32()}));
23
24        let mut train_dataloader = Dataloader::new(train_dataset, 8, true, None);
25
26        const EPOCH: usize = 3000;
27        let mut best_acc = vec![];
28        for ep in 0..EPOCH {
29            let mut losses = vec![];
30            let lr = match ep {
31                i if i < 200 => 1e-3,
32                _ => 2e-3,
33            };
34            for (feature, label) in train_dataloader.iter_mut() {
35                let loss = model.one_step(&feature, &label, lr, Some(NormType::L2(1.0)));
36                losses.push(loss);
37            }
38            let (_, acc) = evaluate(&test_dataset, &model);
39            best_acc.push(acc);
40            let width = ">".repeat(ep * 100 / EPOCH);
41            print!("\r{width:-<100}\t{:.3}\t{acc:.3}", losses.iter().sum::<f32>() / losses.len() as f32);
42            stdout().flush()?;
43        }
44        let acc = best_acc.iter().fold(0.0, |s, i| f32::max(s, *i));
45        let best_ep = argmax(&NdArray::new(best_acc), 0);
46        println!("\nbest acc = {acc} ep {}", best_ep[0][0]);
47        assert!(acc > 0.9); // gradient clip greatly helps it
48
49        Ok(())
50}