pub fn loss(
train: &Vec<f32>,
validation: &Vec<f32>,
accuracy: &Vec<f32>,
title: &str,
path: &str,
)Expand description
Plots a simple line plot of the given data.
ยงArguments
train- The training loss.validation- The validation loss.accuracy- The validation accuracy.title- The title of the plot.path- The path to save the plot.
Examples found in repository?
examples/bike/plain.rs (lines 80-86)
43fn main() {
44 // Load the ftir dataset
45 let (x, y) = data("./examples/datasets/bike/hour.csv");
46
47 let split = (x.len() as f32 * 0.8) as usize;
48 let x = x.split_at(split);
49 let y = y.split_at(split);
50
51 let x_train: Vec<&tensor::Tensor> = x.0.iter().collect();
52 let y_train: Vec<&tensor::Tensor> = y.0.iter().collect();
53 let x_test: Vec<&tensor::Tensor> = x.1.iter().collect();
54 let y_test: Vec<&tensor::Tensor> = y.1.iter().collect();
55
56 // Create the network
57 let mut network = network::Network::new(tensor::Shape::Single(12));
58
59 network.dense(24, activation::Activation::ReLU, false, None);
60 network.dense(24, activation::Activation::ReLU, false, None);
61 network.dense(24, activation::Activation::ReLU, false, None);
62
63 network.dense(1, activation::Activation::Linear, false, None);
64 network.set_objective(objective::Objective::RMSE, None);
65
66 network.set_optimizer(optimizer::Adam::create(0.01, 0.9, 0.999, 1e-4, None));
67
68 println!("{}", network);
69
70 // Train the network
71
72 let (train_loss, val_loss, val_acc) = network.learn(
73 &x_train,
74 &y_train,
75 Some((&x_test, &y_test, 25)),
76 64,
77 600,
78 Some(100),
79 );
80 plot::loss(
81 &train_loss,
82 &val_loss,
83 &val_acc,
84 &"PLAIN : BIKE",
85 &"./output/bike/plain.png",
86 );
87
88 // Use the network
89 let prediction = network.predict(x_test.get(0).unwrap());
90 println!(
91 "Prediction. Target: {}. Output: {}.",
92 y_test[0].data, prediction.data
93 );
94}More examples
examples/bike/looping.rs (lines 83-89)
44fn main() {
45 // Load the ftir dataset
46 let (x, y) = data("./examples/datasets/bike/hour.csv");
47
48 let split = (x.len() as f32 * 0.8) as usize;
49 let x = x.split_at(split);
50 let y = y.split_at(split);
51
52 let x_train: Vec<&tensor::Tensor> = x.0.iter().collect();
53 let y_train: Vec<&tensor::Tensor> = y.0.iter().collect();
54 let x_test: Vec<&tensor::Tensor> = x.1.iter().collect();
55 let y_test: Vec<&tensor::Tensor> = y.1.iter().collect();
56
57 // Create the network
58 let mut network = network::Network::new(tensor::Shape::Single(12));
59
60 network.dense(24, activation::Activation::ReLU, false, None);
61 network.dense(24, activation::Activation::ReLU, false, None);
62 network.dense(24, activation::Activation::ReLU, false, None);
63
64 network.dense(1, activation::Activation::Linear, false, None);
65 network.set_objective(objective::Objective::RMSE, None);
66
67 network.loopback(2, 1, 2, Arc::new(|_loops| 1.0), false);
68
69 network.set_optimizer(optimizer::Adam::create(0.01, 0.9, 0.999, 1e-4, None));
70
71 println!("{}", network);
72
73 // Train the network
74
75 let (train_loss, val_loss, val_acc) = network.learn(
76 &x_train,
77 &y_train,
78 Some((&x_test, &y_test, 25)),
79 64,
80 600,
81 Some(100),
82 );
83 plot::loss(
84 &train_loss,
85 &val_loss,
86 &val_acc,
87 &"LOOP : BIKE",
88 &"./output/bike/loop.png",
89 );
90
91 // Use the network
92 let prediction = network.predict(x_test.get(0).unwrap());
93 println!(
94 "Prediction. Target: {}. Output: {}.",
95 y_test[0].data, prediction.data
96 );
97}examples/bike/feedback.rs (lines 88-94)
43fn main() {
44 // Load the ftir dataset
45 let (x, y) = data("./examples/datasets/bike/hour.csv");
46
47 let split = (x.len() as f32 * 0.8) as usize;
48 let x = x.split_at(split);
49 let y = y.split_at(split);
50
51 let x_train: Vec<&tensor::Tensor> = x.0.iter().collect();
52 let y_train: Vec<&tensor::Tensor> = y.0.iter().collect();
53 let x_test: Vec<&tensor::Tensor> = x.1.iter().collect();
54 let y_test: Vec<&tensor::Tensor> = y.1.iter().collect();
55
56 // Create the network
57 let mut network = network::Network::new(tensor::Shape::Single(12));
58
59 network.dense(24, activation::Activation::ReLU, false, None);
60 network.feedback(
61 vec![
62 feedback::Layer::Dense(24, activation::Activation::ReLU, false, None),
63 feedback::Layer::Dense(24, activation::Activation::ReLU, false, None),
64 ],
65 3,
66 true,
67 true,
68 feedback::Accumulation::Mean,
69 );
70
71 network.dense(1, activation::Activation::Linear, false, None);
72 network.set_objective(objective::Objective::RMSE, None);
73
74 network.set_optimizer(optimizer::Adam::create(0.01, 0.9, 0.999, 1e-4, None));
75
76 println!("{}", network);
77
78 // Train the network
79
80 let (train_loss, val_loss, val_acc) = network.learn(
81 &x_train,
82 &y_train,
83 Some((&x_test, &y_test, 25)),
84 64,
85 600,
86 Some(100),
87 );
88 plot::loss(
89 &train_loss,
90 &val_loss,
91 &val_acc,
92 &"FEEDBACK : BIKE",
93 &"./output/bike/feedback.png",
94 );
95
96 // Use the network
97 let prediction = network.predict(x_test.get(0).unwrap());
98 println!(
99 "Prediction. Target: {}. Output: {}.",
100 y_test[0].data, prediction.data
101 );
102}examples/mnist-fashion/skip.rs (lines 126-132)
54fn main() {
55 let x_train = load_mnist("./examples/datasets/mnist-fashion/train-images-idx3-ubyte").unwrap();
56 let y_train = load_labels(
57 "./examples/datasets/mnist-fashion/train-labels-idx1-ubyte",
58 10,
59 )
60 .unwrap();
61 let x_test = load_mnist("./examples/datasets/mnist-fashion/t10k-images-idx3-ubyte").unwrap();
62 let y_test = load_labels(
63 "./examples/datasets/mnist-fashion/t10k-labels-idx1-ubyte",
64 10,
65 )
66 .unwrap();
67 println!(
68 "Train: {} images, Test: {} images",
69 x_train.len(),
70 x_test.len()
71 );
72
73 let x_train: Vec<&tensor::Tensor> = x_train.iter().collect();
74 let y_train: Vec<&tensor::Tensor> = y_train.iter().collect();
75 let x_test: Vec<&tensor::Tensor> = x_test.iter().collect();
76 let y_test: Vec<&tensor::Tensor> = y_test.iter().collect();
77
78 let mut network = network::Network::new(tensor::Shape::Triple(1, 14, 14));
79
80 network.convolution(
81 1,
82 (3, 3),
83 (1, 1),
84 (1, 1),
85 (1, 1),
86 activation::Activation::ReLU,
87 None,
88 );
89 network.convolution(
90 1,
91 (3, 3),
92 (1, 1),
93 (1, 1),
94 (1, 1),
95 activation::Activation::ReLU,
96 None,
97 );
98 network.convolution(
99 1,
100 (3, 3),
101 (1, 1),
102 (1, 1),
103 (1, 1),
104 activation::Activation::ReLU,
105 None,
106 );
107 network.maxpool((2, 2), (2, 2));
108 network.dense(10, activation::Activation::Softmax, true, None);
109
110 network.connect(0, 2);
111
112 network.set_optimizer(optimizer::Adam::create(0.001, 0.9, 0.999, 1e-8, None));
113 network.set_objective(objective::Objective::CrossEntropy, None);
114
115 println!("{}", network);
116
117 // Train the network
118 let (train_loss, val_loss, val_acc) = network.learn(
119 &x_train,
120 &y_train,
121 Some((&x_test, &y_test, 10)),
122 32,
123 25,
124 Some(5),
125 );
126 plot::loss(
127 &train_loss,
128 &val_loss,
129 &val_acc,
130 "SKIP : Fashion-MNIST",
131 "./output/mnist-fashion/skip.png",
132 );
133
134 // Validate the network
135 let (val_loss, val_acc) = network.validate(&x_test, &y_test, 1e-6);
136 println!(
137 "Final validation accuracy: {:.2} % and loss: {:.5}",
138 val_acc * 100.0,
139 val_loss
140 );
141
142 // Use the network
143 let prediction = network.predict(x_test.get(0).unwrap());
144 println!(
145 "Prediction on input: Target: {}. Output: {}.",
146 y_test[0].argmax(),
147 prediction.argmax()
148 );
149
150 // let x = x_test.get(5).unwrap();
151 // let y = y_test.get(5).unwrap();
152 // Plot the pre- and post-activation heatmaps for each (image) layer.
153 // let (pre, post, _) = network.forward(x);
154 // for (i, (i_pre, i_post)) in pre.iter().zip(post.iter()).enumerate() {
155 // let pre_title = format!("layer_{}_pre", i);
156 // let post_title = format!("layer_{}_post", i);
157 // let pre_file = format!("layer_{}_pre.png", i);
158 // let post_file = format!("layer_{}_post.png", i);
159 // plot::heatmap(&i_pre, &pre_title, &pre_file);
160 // plot::heatmap(&i_post, &post_title, &post_file);
161 // }
162}examples/mnist/plain.rs (lines 119-125)
54fn main() {
55 let x_train = load_mnist("./examples/datasets/mnist/train-images-idx3-ubyte").unwrap();
56 let y_train = load_labels("./examples/datasets/mnist/train-labels-idx1-ubyte", 10).unwrap();
57 let x_test = load_mnist("./examples/datasets/mnist/t10k-images-idx3-ubyte").unwrap();
58 let y_test = load_labels("./examples/datasets/mnist/t10k-labels-idx1-ubyte", 10).unwrap();
59 println!(
60 "Train: {} images, Test: {} images",
61 x_train.len(),
62 x_test.len()
63 );
64
65 let x_train: Vec<&tensor::Tensor> = x_train.iter().collect();
66 let y_train: Vec<&tensor::Tensor> = y_train.iter().collect();
67 let x_test: Vec<&tensor::Tensor> = x_test.iter().collect();
68 let y_test: Vec<&tensor::Tensor> = y_test.iter().collect();
69
70 let mut network = network::Network::new(tensor::Shape::Triple(1, 14, 14));
71
72 network.convolution(
73 1,
74 (3, 3),
75 (1, 1),
76 (1, 1),
77 (1, 1),
78 activation::Activation::ReLU,
79 None,
80 );
81 network.convolution(
82 1,
83 (3, 3),
84 (1, 1),
85 (1, 1),
86 (1, 1),
87 activation::Activation::ReLU,
88 None,
89 );
90 network.convolution(
91 1,
92 (3, 3),
93 (1, 1),
94 (1, 1),
95 (1, 1),
96 activation::Activation::ReLU,
97 None,
98 );
99 network.maxpool((2, 2), (2, 2));
100 network.dense(10, activation::Activation::Softmax, true, None);
101
102 network.set_optimizer(optimizer::Adam::create(0.001, 0.9, 0.999, 1e-8, None));
103 network.set_objective(
104 objective::Objective::CrossEntropy, // Objective function
105 None, // Gradient clipping
106 );
107
108 println!("{}", network);
109
110 // Train the network
111 let (train_loss, val_loss, val_acc) = network.learn(
112 &x_train,
113 &y_train,
114 Some((&x_test, &y_test, 10)),
115 32,
116 25,
117 Some(5),
118 );
119 plot::loss(
120 &train_loss,
121 &val_loss,
122 &val_acc,
123 "PLAIN : MNIST",
124 "./output/mnist/plain.png",
125 );
126
127 // Validate the network
128 let (val_loss, val_acc) = network.validate(&x_test, &y_test, 1e-6);
129 println!(
130 "Final validation accuracy: {:.2} % and loss: {:.5}",
131 val_acc * 100.0,
132 val_loss
133 );
134
135 // Use the network
136 let prediction = network.predict(x_test.get(0).unwrap());
137 println!(
138 "Prediction on input: Target: {}. Output: {}.",
139 y_test[0].argmax(),
140 prediction.argmax()
141 );
142
143 // let x = x_test.get(5).unwrap();
144 // let y = y_test.get(5).unwrap();
145 // plot::heatmap(
146 // &x,
147 // &format!("Target: {}", y.argmax()),
148 // "./output/mnist/input.png",
149 // );
150
151 // Plot the pre- and post-activation heatmaps for each (image) layer.
152 // let (pre, post, _) = network.forward(x);
153 // for (i, (i_pre, i_post)) in pre.iter().zip(post.iter()).enumerate() {
154 // let pre_title = format!("layer_{}_pre", i);
155 // let post_title = format!("layer_{}_post", i);
156 // let pre_file = format!("layer_{}_pre.png", i);
157 // let post_file = format!("layer_{}_post.png", i);
158 // plot::heatmap(&i_pre, &pre_title, &pre_file);
159 // plot::heatmap(&i_post, &post_title, &post_file);
160 // }
161}examples/mnist/skip.rs (lines 121-127)
54fn main() {
55 let x_train = load_mnist("./examples/datasets/mnist/train-images-idx3-ubyte").unwrap();
56 let y_train = load_labels("./examples/datasets/mnist/train-labels-idx1-ubyte", 10).unwrap();
57 let x_test = load_mnist("./examples/datasets/mnist/t10k-images-idx3-ubyte").unwrap();
58 let y_test = load_labels("./examples/datasets/mnist/t10k-labels-idx1-ubyte", 10).unwrap();
59 println!(
60 "Train: {} images, Test: {} images",
61 x_train.len(),
62 x_test.len()
63 );
64
65 let x_train: Vec<&tensor::Tensor> = x_train.iter().collect();
66 let y_train: Vec<&tensor::Tensor> = y_train.iter().collect();
67 let x_test: Vec<&tensor::Tensor> = x_test.iter().collect();
68 let y_test: Vec<&tensor::Tensor> = y_test.iter().collect();
69
70 let mut network = network::Network::new(tensor::Shape::Triple(1, 14, 14));
71
72 network.convolution(
73 1,
74 (3, 3),
75 (1, 1),
76 (1, 1),
77 (1, 1),
78 activation::Activation::ReLU,
79 None,
80 );
81 network.convolution(
82 1,
83 (3, 3),
84 (1, 1),
85 (1, 1),
86 (1, 1),
87 activation::Activation::ReLU,
88 None,
89 );
90 network.convolution(
91 1,
92 (3, 3),
93 (1, 1),
94 (1, 1),
95 (1, 1),
96 activation::Activation::ReLU,
97 None,
98 );
99 network.maxpool((2, 2), (2, 2));
100 network.dense(10, activation::Activation::Softmax, true, None);
101
102 network.connect(0, 3);
103
104 network.set_optimizer(optimizer::Adam::create(0.001, 0.9, 0.999, 1e-8, None));
105 network.set_objective(
106 objective::Objective::CrossEntropy, // Objective function
107 None, // Gradient clipping
108 );
109
110 println!("{}", network);
111
112 // Train the network
113 let (train_loss, val_loss, val_acc) = network.learn(
114 &x_train,
115 &y_train,
116 Some((&x_test, &y_test, 10)),
117 32,
118 25,
119 Some(5),
120 );
121 plot::loss(
122 &train_loss,
123 &val_loss,
124 &val_acc,
125 "SKIP : MNIST",
126 "./output/mnist/skip.png",
127 );
128
129 // Validate the network
130 let (val_loss, val_acc) = network.validate(&x_test, &y_test, 1e-6);
131 println!(
132 "Final validation accuracy: {:.2} % and loss: {:.5}",
133 val_acc * 100.0,
134 val_loss
135 );
136
137 // Use the network
138 let prediction = network.predict(x_test.get(0).unwrap());
139 println!(
140 "Prediction on input: Target: {}. Output: {}.",
141 y_test[0].argmax(),
142 prediction.argmax()
143 );
144
145 // let x = x_test.get(5).unwrap();
146 // let y = y_test.get(5).unwrap();
147 // plot::heatmap(
148 // &x,
149 // &format!("Target: {}", y.argmax()),
150 // "./output/mnist/input.png",
151 // );
152
153 // Plot the pre- and post-activation heatmaps for each (image) layer.
154 // let (pre, post, _) = network.forward(x);
155 // for (i, (i_pre, i_post)) in pre.iter().zip(post.iter()).enumerate() {
156 // let pre_title = format!("layer_{}_pre", i);
157 // let post_title = format!("layer_{}_post", i);
158 // let pre_file = format!("layer_{}_pre.png", i);
159 // let post_file = format!("layer_{}_post.png", i);
160 // plot::heatmap(&i_pre, &pre_title, &pre_file);
161 // plot::heatmap(&i_post, &post_title, &post_file);
162 // }
163}Additional examples can be found in:
- examples/mnist/deconvolution.rs
- examples/mnist-fashion/feedback.rs
- examples/mnist-fashion/plain.rs
- examples/mnist/feedback.rs
- examples/cifar10/plain.rs
- examples/mnist-fashion/looping.rs
- examples/cifar10/looping.rs
- examples/mnist/looping.rs
- examples/ftir/mlp/plain.rs
- examples/cifar10/feedback.rs
- examples/ftir/cnn/plain.rs
- examples/ftir/mlp/skip.rs
- examples/ftir/mlp/looping.rs
- examples/ftir/cnn/skip.rs
- examples/ftir/cnn/looping.rs
- examples/ftir/cnn/feedback.rs
- examples/ftir/mlp/feedback.rs