1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
// This example demonstrates building a simple multi layer perceptron
// in a functional style and reusing the resulting computation graph.

// The graph gets created by running all model operations once.
// It is then repeatedly fed with new inputs and recomputed.

// This approach offers the best performance as memory for intermediate
// operations doesn't need to be reallocated for repeated executions of the graph.
// It also makes it easy to save & load the model.

// Note that control statements used during the graph's contruction will
// be *baked in* and don't get reevaluated in subsequent runs.
// If you need that kind of flexibility in your model, please check the
// 'perceptron_eager' example!

use microtensor::{prelude::*, Tensor, Variable};

fn dense_layer(input: &Variable<f32>, size: usize) -> Variable<f32> {
  let weights = (Tensor::randn(&[input.shape()[-1], size]) / size as f32).trained();
  let bias = Tensor::zeros(&[size]).trained();
  input.mm(&weights) + bias
}

fn perceptron(input: &Variable<f32>) -> Variable<f32> {
  let output = dense_layer(input, 16).relu();
  dense_layer(&output, 10).sigmoid()
}

fn main() {
  // Define model by performing all computations on a placeholder once
  let image_input = Tensor::zeros(&[32, 28 * 28]).tracked();
  let output = perceptron(&image_input);

  // Define the loss to me minimized
  let label_input = Tensor::zeros(&[32, 10]).tracked();
  let loss = (&label_input - &output).sqr().mean(0);

  // Train with some labeled samples
  let learning_rate = 0.01;
  for _ in 0..100 {
    // Insert real training data here
    let images = Tensor::ones(&[32, 28 * 28]).tracked();
    let labels = (Tensor::rand(&[32]) * 10.0).cast::<u8>().one_hot(10);

    // Feed existing computation graph with new inputs
    image_input.feed(&images);
    label_input.feed(&labels);

    // Recompute output and loss
    loss.forward();

    // Compute gradients
    loss.backward();

    // Minimize loss by updating model parameters
    for mut param in loss.parameters() {
      param -= param.grad().unwrap() * learning_rate
    }

    // Reset gradients
    loss.reset();
  }
}