neural_network_rs/
main.rs

1pub mod dataset;
2pub mod neural_network;
3pub mod plotter;
4
5#[allow(unused_imports)]
6use crate::{
7    dataset::example_datasets::{CIRCLE, RGB_DONUT, RGB_TEST, XOR},
8    neural_network::{
9        activation_function::{linear::LINEAR, relu::RELU, sigmoid::SIGMOID},
10        cost_function::quadratic_cost::QUADRATIC_COST,
11        optimizer::{adam_optimizer::ADAM, rmsprop_optimizer::RMS_PROP, sgd_optimzer::SGD},
12        Network, Summary,
13    },
14    plotter::{graph_plotter::plot_graph, png_plotter::plot_png},
15};
16
17#[allow(dead_code)]
18fn main() {
19    //Define Network Shape
20    let network_shape = [
21        (&RELU, 2),
22        (&RELU, 32),
23        (&RELU, 32),
24        (&RELU, 32),
25        (&RELU, 3),
26    ];
27
28    //Define Optimizer
29    let mut optimizer = ADAM::default();
30
31    //Create Network
32    let mut network = Network::new(&network_shape, &mut optimizer, &QUADRATIC_COST);
33
34    //Define Dataset
35    let dataset = &RGB_DONUT;
36
37    //Train
38    let cost_history = network.train_and_log(dataset, 128, 512, 10000);
39
40    //Prepare Plot-data
41    let (dim, unit_square_prediction) = network.predict_unit_square(512);
42    let name = String::from(dataset.name) + "_" + &network.summerize();
43
44    //Plot
45    plot_png(&name, dim, &unit_square_prediction, png::ColorType::Rgb).unwrap();
46    plot_graph(&name, &cost_history).unwrap();
47}