loss

Function loss 

Source
pub fn loss(
    train: &Vec<f32>,
    validation: &Vec<f32>,
    accuracy: &Vec<f32>,
    title: &str,
    path: &str,
)
Expand description

Plots a simple line plot of the given data.

ยงArguments

  • train - The training loss.
  • validation - The validation loss.
  • accuracy - The validation accuracy.
  • title - The title of the plot.
  • path - The path to save the plot.
Examples found in repository?
examples/bike/plain.rs (lines 80-86)
43fn main() {
44    // Load the ftir dataset
45    let (x, y) = data("./examples/datasets/bike/hour.csv");
46
47    let split = (x.len() as f32 * 0.8) as usize;
48    let x = x.split_at(split);
49    let y = y.split_at(split);
50
51    let x_train: Vec<&tensor::Tensor> = x.0.iter().collect();
52    let y_train: Vec<&tensor::Tensor> = y.0.iter().collect();
53    let x_test: Vec<&tensor::Tensor> = x.1.iter().collect();
54    let y_test: Vec<&tensor::Tensor> = y.1.iter().collect();
55
56    // Create the network
57    let mut network = network::Network::new(tensor::Shape::Single(12));
58
59    network.dense(24, activation::Activation::ReLU, false, None);
60    network.dense(24, activation::Activation::ReLU, false, None);
61    network.dense(24, activation::Activation::ReLU, false, None);
62
63    network.dense(1, activation::Activation::Linear, false, None);
64    network.set_objective(objective::Objective::RMSE, None);
65
66    network.set_optimizer(optimizer::Adam::create(0.01, 0.9, 0.999, 1e-4, None));
67
68    println!("{}", network);
69
70    // Train the network
71
72    let (train_loss, val_loss, val_acc) = network.learn(
73        &x_train,
74        &y_train,
75        Some((&x_test, &y_test, 25)),
76        64,
77        600,
78        Some(100),
79    );
80    plot::loss(
81        &train_loss,
82        &val_loss,
83        &val_acc,
84        &"PLAIN : BIKE",
85        &"./output/bike/plain.png",
86    );
87
88    // Use the network
89    let prediction = network.predict(x_test.get(0).unwrap());
90    println!(
91        "Prediction. Target: {}. Output: {}.",
92        y_test[0].data, prediction.data
93    );
94}
More examples
Hide additional examples
examples/bike/looping.rs (lines 83-89)
44fn main() {
45    // Load the ftir dataset
46    let (x, y) = data("./examples/datasets/bike/hour.csv");
47
48    let split = (x.len() as f32 * 0.8) as usize;
49    let x = x.split_at(split);
50    let y = y.split_at(split);
51
52    let x_train: Vec<&tensor::Tensor> = x.0.iter().collect();
53    let y_train: Vec<&tensor::Tensor> = y.0.iter().collect();
54    let x_test: Vec<&tensor::Tensor> = x.1.iter().collect();
55    let y_test: Vec<&tensor::Tensor> = y.1.iter().collect();
56
57    // Create the network
58    let mut network = network::Network::new(tensor::Shape::Single(12));
59
60    network.dense(24, activation::Activation::ReLU, false, None);
61    network.dense(24, activation::Activation::ReLU, false, None);
62    network.dense(24, activation::Activation::ReLU, false, None);
63
64    network.dense(1, activation::Activation::Linear, false, None);
65    network.set_objective(objective::Objective::RMSE, None);
66
67    network.loopback(2, 1, 2, Arc::new(|_loops| 1.0), false);
68
69    network.set_optimizer(optimizer::Adam::create(0.01, 0.9, 0.999, 1e-4, None));
70
71    println!("{}", network);
72
73    // Train the network
74
75    let (train_loss, val_loss, val_acc) = network.learn(
76        &x_train,
77        &y_train,
78        Some((&x_test, &y_test, 25)),
79        64,
80        600,
81        Some(100),
82    );
83    plot::loss(
84        &train_loss,
85        &val_loss,
86        &val_acc,
87        &"LOOP : BIKE",
88        &"./output/bike/loop.png",
89    );
90
91    // Use the network
92    let prediction = network.predict(x_test.get(0).unwrap());
93    println!(
94        "Prediction. Target: {}. Output: {}.",
95        y_test[0].data, prediction.data
96    );
97}
examples/bike/feedback.rs (lines 88-94)
43fn main() {
44    // Load the ftir dataset
45    let (x, y) = data("./examples/datasets/bike/hour.csv");
46
47    let split = (x.len() as f32 * 0.8) as usize;
48    let x = x.split_at(split);
49    let y = y.split_at(split);
50
51    let x_train: Vec<&tensor::Tensor> = x.0.iter().collect();
52    let y_train: Vec<&tensor::Tensor> = y.0.iter().collect();
53    let x_test: Vec<&tensor::Tensor> = x.1.iter().collect();
54    let y_test: Vec<&tensor::Tensor> = y.1.iter().collect();
55
56    // Create the network
57    let mut network = network::Network::new(tensor::Shape::Single(12));
58
59    network.dense(24, activation::Activation::ReLU, false, None);
60    network.feedback(
61        vec![
62            feedback::Layer::Dense(24, activation::Activation::ReLU, false, None),
63            feedback::Layer::Dense(24, activation::Activation::ReLU, false, None),
64        ],
65        3,
66        true,
67        true,
68        feedback::Accumulation::Mean,
69    );
70
71    network.dense(1, activation::Activation::Linear, false, None);
72    network.set_objective(objective::Objective::RMSE, None);
73
74    network.set_optimizer(optimizer::Adam::create(0.01, 0.9, 0.999, 1e-4, None));
75
76    println!("{}", network);
77
78    // Train the network
79
80    let (train_loss, val_loss, val_acc) = network.learn(
81        &x_train,
82        &y_train,
83        Some((&x_test, &y_test, 25)),
84        64,
85        600,
86        Some(100),
87    );
88    plot::loss(
89        &train_loss,
90        &val_loss,
91        &val_acc,
92        &"FEEDBACK : BIKE",
93        &"./output/bike/feedback.png",
94    );
95
96    // Use the network
97    let prediction = network.predict(x_test.get(0).unwrap());
98    println!(
99        "Prediction. Target: {}. Output: {}.",
100        y_test[0].data, prediction.data
101    );
102}
examples/mnist-fashion/skip.rs (lines 126-132)
54fn main() {
55    let x_train = load_mnist("./examples/datasets/mnist-fashion/train-images-idx3-ubyte").unwrap();
56    let y_train = load_labels(
57        "./examples/datasets/mnist-fashion/train-labels-idx1-ubyte",
58        10,
59    )
60    .unwrap();
61    let x_test = load_mnist("./examples/datasets/mnist-fashion/t10k-images-idx3-ubyte").unwrap();
62    let y_test = load_labels(
63        "./examples/datasets/mnist-fashion/t10k-labels-idx1-ubyte",
64        10,
65    )
66    .unwrap();
67    println!(
68        "Train: {} images, Test: {} images",
69        x_train.len(),
70        x_test.len()
71    );
72
73    let x_train: Vec<&tensor::Tensor> = x_train.iter().collect();
74    let y_train: Vec<&tensor::Tensor> = y_train.iter().collect();
75    let x_test: Vec<&tensor::Tensor> = x_test.iter().collect();
76    let y_test: Vec<&tensor::Tensor> = y_test.iter().collect();
77
78    let mut network = network::Network::new(tensor::Shape::Triple(1, 14, 14));
79
80    network.convolution(
81        1,
82        (3, 3),
83        (1, 1),
84        (1, 1),
85        (1, 1),
86        activation::Activation::ReLU,
87        None,
88    );
89    network.convolution(
90        1,
91        (3, 3),
92        (1, 1),
93        (1, 1),
94        (1, 1),
95        activation::Activation::ReLU,
96        None,
97    );
98    network.convolution(
99        1,
100        (3, 3),
101        (1, 1),
102        (1, 1),
103        (1, 1),
104        activation::Activation::ReLU,
105        None,
106    );
107    network.maxpool((2, 2), (2, 2));
108    network.dense(10, activation::Activation::Softmax, true, None);
109
110    network.connect(0, 2);
111
112    network.set_optimizer(optimizer::Adam::create(0.001, 0.9, 0.999, 1e-8, None));
113    network.set_objective(objective::Objective::CrossEntropy, None);
114
115    println!("{}", network);
116
117    // Train the network
118    let (train_loss, val_loss, val_acc) = network.learn(
119        &x_train,
120        &y_train,
121        Some((&x_test, &y_test, 10)),
122        32,
123        25,
124        Some(5),
125    );
126    plot::loss(
127        &train_loss,
128        &val_loss,
129        &val_acc,
130        "SKIP : Fashion-MNIST",
131        "./output/mnist-fashion/skip.png",
132    );
133
134    // Validate the network
135    let (val_loss, val_acc) = network.validate(&x_test, &y_test, 1e-6);
136    println!(
137        "Final validation accuracy: {:.2} % and loss: {:.5}",
138        val_acc * 100.0,
139        val_loss
140    );
141
142    // Use the network
143    let prediction = network.predict(x_test.get(0).unwrap());
144    println!(
145        "Prediction on input: Target: {}. Output: {}.",
146        y_test[0].argmax(),
147        prediction.argmax()
148    );
149
150    // let x = x_test.get(5).unwrap();
151    // let y = y_test.get(5).unwrap();
152    // Plot the pre- and post-activation heatmaps for each (image) layer.
153    // let (pre, post, _) = network.forward(x);
154    // for (i, (i_pre, i_post)) in pre.iter().zip(post.iter()).enumerate() {
155    //     let pre_title = format!("layer_{}_pre", i);
156    //     let post_title = format!("layer_{}_post", i);
157    //     let pre_file = format!("layer_{}_pre.png", i);
158    //     let post_file = format!("layer_{}_post.png", i);
159    //     plot::heatmap(&i_pre, &pre_title, &pre_file);
160    //     plot::heatmap(&i_post, &post_title, &post_file);
161    // }
162}
examples/mnist/plain.rs (lines 119-125)
54fn main() {
55    let x_train = load_mnist("./examples/datasets/mnist/train-images-idx3-ubyte").unwrap();
56    let y_train = load_labels("./examples/datasets/mnist/train-labels-idx1-ubyte", 10).unwrap();
57    let x_test = load_mnist("./examples/datasets/mnist/t10k-images-idx3-ubyte").unwrap();
58    let y_test = load_labels("./examples/datasets/mnist/t10k-labels-idx1-ubyte", 10).unwrap();
59    println!(
60        "Train: {} images, Test: {} images",
61        x_train.len(),
62        x_test.len()
63    );
64
65    let x_train: Vec<&tensor::Tensor> = x_train.iter().collect();
66    let y_train: Vec<&tensor::Tensor> = y_train.iter().collect();
67    let x_test: Vec<&tensor::Tensor> = x_test.iter().collect();
68    let y_test: Vec<&tensor::Tensor> = y_test.iter().collect();
69
70    let mut network = network::Network::new(tensor::Shape::Triple(1, 14, 14));
71
72    network.convolution(
73        1,
74        (3, 3),
75        (1, 1),
76        (1, 1),
77        (1, 1),
78        activation::Activation::ReLU,
79        None,
80    );
81    network.convolution(
82        1,
83        (3, 3),
84        (1, 1),
85        (1, 1),
86        (1, 1),
87        activation::Activation::ReLU,
88        None,
89    );
90    network.convolution(
91        1,
92        (3, 3),
93        (1, 1),
94        (1, 1),
95        (1, 1),
96        activation::Activation::ReLU,
97        None,
98    );
99    network.maxpool((2, 2), (2, 2));
100    network.dense(10, activation::Activation::Softmax, true, None);
101
102    network.set_optimizer(optimizer::Adam::create(0.001, 0.9, 0.999, 1e-8, None));
103    network.set_objective(
104        objective::Objective::CrossEntropy, // Objective function
105        None,                               // Gradient clipping
106    );
107
108    println!("{}", network);
109
110    // Train the network
111    let (train_loss, val_loss, val_acc) = network.learn(
112        &x_train,
113        &y_train,
114        Some((&x_test, &y_test, 10)),
115        32,
116        25,
117        Some(5),
118    );
119    plot::loss(
120        &train_loss,
121        &val_loss,
122        &val_acc,
123        "PLAIN : MNIST",
124        "./output/mnist/plain.png",
125    );
126
127    // Validate the network
128    let (val_loss, val_acc) = network.validate(&x_test, &y_test, 1e-6);
129    println!(
130        "Final validation accuracy: {:.2} % and loss: {:.5}",
131        val_acc * 100.0,
132        val_loss
133    );
134
135    // Use the network
136    let prediction = network.predict(x_test.get(0).unwrap());
137    println!(
138        "Prediction on input: Target: {}. Output: {}.",
139        y_test[0].argmax(),
140        prediction.argmax()
141    );
142
143    // let x = x_test.get(5).unwrap();
144    // let y = y_test.get(5).unwrap();
145    // plot::heatmap(
146    //     &x,
147    //     &format!("Target: {}", y.argmax()),
148    //     "./output/mnist/input.png",
149    // );
150
151    // Plot the pre- and post-activation heatmaps for each (image) layer.
152    // let (pre, post, _) = network.forward(x);
153    // for (i, (i_pre, i_post)) in pre.iter().zip(post.iter()).enumerate() {
154    //     let pre_title = format!("layer_{}_pre", i);
155    //     let post_title = format!("layer_{}_post", i);
156    //     let pre_file = format!("layer_{}_pre.png", i);
157    //     let post_file = format!("layer_{}_post.png", i);
158    //     plot::heatmap(&i_pre, &pre_title, &pre_file);
159    //     plot::heatmap(&i_post, &post_title, &post_file);
160    // }
161}
examples/mnist/skip.rs (lines 121-127)
54fn main() {
55    let x_train = load_mnist("./examples/datasets/mnist/train-images-idx3-ubyte").unwrap();
56    let y_train = load_labels("./examples/datasets/mnist/train-labels-idx1-ubyte", 10).unwrap();
57    let x_test = load_mnist("./examples/datasets/mnist/t10k-images-idx3-ubyte").unwrap();
58    let y_test = load_labels("./examples/datasets/mnist/t10k-labels-idx1-ubyte", 10).unwrap();
59    println!(
60        "Train: {} images, Test: {} images",
61        x_train.len(),
62        x_test.len()
63    );
64
65    let x_train: Vec<&tensor::Tensor> = x_train.iter().collect();
66    let y_train: Vec<&tensor::Tensor> = y_train.iter().collect();
67    let x_test: Vec<&tensor::Tensor> = x_test.iter().collect();
68    let y_test: Vec<&tensor::Tensor> = y_test.iter().collect();
69
70    let mut network = network::Network::new(tensor::Shape::Triple(1, 14, 14));
71
72    network.convolution(
73        1,
74        (3, 3),
75        (1, 1),
76        (1, 1),
77        (1, 1),
78        activation::Activation::ReLU,
79        None,
80    );
81    network.convolution(
82        1,
83        (3, 3),
84        (1, 1),
85        (1, 1),
86        (1, 1),
87        activation::Activation::ReLU,
88        None,
89    );
90    network.convolution(
91        1,
92        (3, 3),
93        (1, 1),
94        (1, 1),
95        (1, 1),
96        activation::Activation::ReLU,
97        None,
98    );
99    network.maxpool((2, 2), (2, 2));
100    network.dense(10, activation::Activation::Softmax, true, None);
101
102    network.connect(0, 3);
103
104    network.set_optimizer(optimizer::Adam::create(0.001, 0.9, 0.999, 1e-8, None));
105    network.set_objective(
106        objective::Objective::CrossEntropy, // Objective function
107        None,                               // Gradient clipping
108    );
109
110    println!("{}", network);
111
112    // Train the network
113    let (train_loss, val_loss, val_acc) = network.learn(
114        &x_train,
115        &y_train,
116        Some((&x_test, &y_test, 10)),
117        32,
118        25,
119        Some(5),
120    );
121    plot::loss(
122        &train_loss,
123        &val_loss,
124        &val_acc,
125        "SKIP : MNIST",
126        "./output/mnist/skip.png",
127    );
128
129    // Validate the network
130    let (val_loss, val_acc) = network.validate(&x_test, &y_test, 1e-6);
131    println!(
132        "Final validation accuracy: {:.2} % and loss: {:.5}",
133        val_acc * 100.0,
134        val_loss
135    );
136
137    // Use the network
138    let prediction = network.predict(x_test.get(0).unwrap());
139    println!(
140        "Prediction on input: Target: {}. Output: {}.",
141        y_test[0].argmax(),
142        prediction.argmax()
143    );
144
145    // let x = x_test.get(5).unwrap();
146    // let y = y_test.get(5).unwrap();
147    // plot::heatmap(
148    //     &x,
149    //     &format!("Target: {}", y.argmax()),
150    //     "./output/mnist/input.png",
151    // );
152
153    // Plot the pre- and post-activation heatmaps for each (image) layer.
154    // let (pre, post, _) = network.forward(x);
155    // for (i, (i_pre, i_post)) in pre.iter().zip(post.iter()).enumerate() {
156    //     let pre_title = format!("layer_{}_pre", i);
157    //     let post_title = format!("layer_{}_post", i);
158    //     let pre_file = format!("layer_{}_pre.png", i);
159    //     let post_file = format!("layer_{}_post.png", i);
160    //     plot::heatmap(&i_pre, &pre_title, &pre_file);
161    //     plot::heatmap(&i_post, &post_title, &post_file);
162    // }
163}