auto-diff 0.5.9

A neural network library in Rust.
Documentation
use tensor_rs::tensor::{PaddingMode};
use auto_diff::op::{Linear, OpCall, Conv2d};
use auto_diff::optim::{SGD, MiniBatch};
use auto_diff::Var;
use rand::prelude::*;
use ::rand::prelude::StdRng;
extern crate openblas_src;

//use tensorboard_rs::summary_writer::SummaryWriter;

mod mnist;
use mnist::{load_images, load_labels};

fn main() {
    let train_img = load_images("examples/data/mnist/train-images-idx3-ubyte");
    let test_img = load_images("examples/data/mnist/t10k-images-idx3-ubyte");
    let train_label = load_labels("examples/data/mnist/train-labels-idx1-ubyte");
    let test_label = load_labels("examples/data/mnist/t10k-labels-idx1-ubyte");

    let train_size = train_img.size();
    let n = train_size[0];
    let h = train_size[1];
    let w = train_size[2];
    let train_data = train_img.reshape(&vec![n, 1, h, w]).unwrap();

    let test_size = test_img.size();
    let n = test_size[0];
    let h = test_size[1];
    let w = test_size[2];
    let test_data = test_img.reshape(&vec![n, 1, h, w]).unwrap();

    train_data.reset_net();
    train_label.reset_net();
    test_data.reset_net();
    test_label.reset_net();

    let patch_size = 16;
    //let class_size = 10;

    // build the model
//    let mut m = Module::new();
//    let mut rng = RNG::new();
//    rng.set_seed(123);
//
//    // 28 - (3x3) - 28 - (3x3,2) - 14 - (view) - 196 - (linear, 98.0) - 98 - (linear, 10) - 10
//    
//    let op1 = Conv2d::new(1, 32, (3,3), (1,1), (1,1), (1,1), true, PaddingMode::Zeros);
//    rng.normal_(op1.get_values()[0], 0., 1.);
//    rng.normal_(op1.get_values()[1], 0., 1.);
//    let conv1 = Op::new(Box::new(op1));
//
//    let op2 = Conv2d::new(32, 64, (3,3), (2,2), (1,1), (1,1), true, PaddingMode::Zeros);
//    rng.normal_(op2.get_values()[0], 0., 1.);
//    rng.normal_(op2.get_values()[1], 0., 1.);
//    let conv2 = Op::new(Box::new(op2));
//
//    let view = Op::new(Box::new(View::new(&[patch_size, 14*14*64])));
//
//    let op3 = Linear::new(Some(14*14*64), Some(14*14), true);
//    rng.normal_(op3.weight(), 0., 1.);
//    rng.normal_(op3.bias(), 0., 1.);
//    let linear3 = Op::new(Box::new(op3));
//
//    let op4 = Linear::new(Some(14*14), Some(10), true);
//    rng.normal_(op4.weight(), 0., 1.);
//    rng.normal_(op4.bias(), 0., 1.);
//    let linear4 = Op::new(Box::new(op4));
//
//    let mut acts = Vec::new();
//    for i in 0..3 {
//        let act1 = Op::new(Box::new(ReLU::new()));
//        acts.push(act1);
//    }
//
//    let input = m.var();
//    let output = input
//        .to(&conv1)
//        .to(&acts[0])
//        .to(&conv2)
//        .to(&acts[1])
//        .to(&view)
//        .to(&linear3)
//        .to(&acts[2])
//        .to(&linear4)
//        ;
//    let label = m.var();
//
//    let loss = crossentropyloss(&output, &label);
//
//    let rng = RNG::new();
//    let minibatch = MiniBatch::new(rng, patch_size);
//
//    let mut lr = 0.01;
//    let mut opt = SGD::new(lr);
//    
    //    let mut writer = SummaryWriter::new(&("./logdir".to_string()));


    let mut rng = StdRng::seed_from_u64(671);

    let mut op1 = Conv2d::new(1, 32, (3,3), (1,1), (1,1), (1,1), true, PaddingMode::Zeros);
    op1.set_weight(Var::normal(&mut rng, &op1.weight().size(), 0., 1.));
    op1.set_bias(Var::normal(&mut rng, &op1.bias().size(), 0., 1.));

    let mut op2 = Conv2d::new(32, 64, (3,3), (2,2), (1,1), (1,1), true, PaddingMode::Zeros);
    op2.set_weight(Var::normal(&mut rng, &op2.weight().size(), 0., 1.));
    op2.set_bias(Var::normal(&mut rng, &op2.bias().size(), 0., 1.));

    let mut op3 = Linear::new(Some(14*14*64), Some(14*14), true);
    op3.set_weight(Var::normal(&mut rng, &[14*14*64, 14*14], 0., 1.));
    op3.set_bias(Var::normal(&mut rng, &[14*14, ], 0., 1.));

    let mut op4 = Linear::new(Some(14*14), Some(10), true);
    op4.set_weight(Var::normal(&mut rng, &[14*14, 10], 0., 1.));
    op4.set_bias(Var::normal(&mut rng, &[10, ], 0., 1.));

//    //println!("{}, {}", &train_data, &train_label);
    let rng = StdRng::seed_from_u64(671);
    let mut minibatch = MiniBatch::new(rng, 16);

    //    let mut writer = SummaryWriter::new(&("./logdir".to_string()));
    let (input, label) = minibatch.next(&train_data, &train_label).unwrap();       println!("here0");

    let output1 = op1.call(&[&input]).unwrap().pop().unwrap(); println!("here");
    let output1_1 = output1.relu().unwrap();  println!("here2");
    let output2 = op2.call(&[&output1_1]).unwrap().pop().unwrap();  println!("here3");
    let output2_1 = output2.relu().unwrap().view(&[patch_size, 14*14*64]).unwrap();  println!("her4");
    let output3 = op3.call(&[&output2_1]).unwrap().pop().unwrap();  println!("here5");
    let output3_1 = output3.relu().unwrap(); println!("her6");
    let output = op4.call(&[&output3_1]).unwrap().pop().unwrap();  println!("here7");

    let loss = output.cross_entropy_loss(&label).unwrap();  println!("here8");
    
    let lr = 0.1;
    let mut opt = SGD::new(lr);

    println!("{:?}", loss);
    
//    
//
    for i in 1..900 {
        println!("index: {}", i);

        //let (mdata, mlabel) = minibatch.next(&train_data, &train_label).unwrap();
        let (input_next, label_next) = minibatch.next(&train_data, &train_label).unwrap();        
        input.set(&input_next);
        label.set(&label_next);
        println!("load data done");

        loss.rerun().unwrap(); println!("rerun");
        loss.bp().unwrap();    println!("bp");
        loss.step(&mut opt).unwrap();  println!("step");
        
        if i % 10 == 0 {
        
            let (input_next, label_next) = minibatch.next(&test_data, &test_label).unwrap();        
            input.set(&input_next);
            label.set(&label_next);
            loss.rerun().unwrap();
        
            println!("test loss: {:?}", loss);
        
            //let loss_value = loss.get().get_scale_f32();
        
            let tsum = output.clone().argmax(Some(&[1]), false).unwrap().eq_elem(&test_label).unwrap().mean(None, false);
            //let accuracy = tsum.get_scale_f32();
            //println!("{}, loss: {}, accuracy: {}", i, loss_value, accuracy);
            println!("test error: {:?}", tsum);
        
            //writer.add_scalar(&"cnn/run1/accuracy".to_string(), accuracy, i);
            //writer.flush();
        }
        
        //println!("{}, loss: {}", i, loss.get().get_scale_f32());
        //writer.add_scalar(&"cnn/run1/test_loss".to_string(), loss.get().get_scale_f32(), i);
        //writer.flush();
        //
        //if i != 0 && i % 300 == 0 {
        //    lr = lr / 3.;
        //    opt = SGD::new(lr);
        //}
    }
}