neural_network_rs/
main.rs1pub mod dataset;
2pub mod neural_network;
3pub mod plotter;
4
5#[allow(unused_imports)]
6use crate::{
7 dataset::example_datasets::{CIRCLE, RGB_DONUT, RGB_TEST, XOR},
8 neural_network::{
9 activation_function::{linear::LINEAR, relu::RELU, sigmoid::SIGMOID},
10 cost_function::quadratic_cost::QUADRATIC_COST,
11 optimizer::{adam_optimizer::ADAM, rmsprop_optimizer::RMS_PROP, sgd_optimzer::SGD},
12 Network, Summary,
13 },
14 plotter::{graph_plotter::plot_graph, png_plotter::plot_png},
15};
16
17#[allow(dead_code)]
18fn main() {
19 let network_shape = [
21 (&RELU, 2),
22 (&RELU, 32),
23 (&RELU, 32),
24 (&RELU, 32),
25 (&RELU, 3),
26 ];
27
28 let mut optimizer = ADAM::default();
30
31 let mut network = Network::new(&network_shape, &mut optimizer, &QUADRATIC_COST);
33
34 let dataset = &RGB_DONUT;
36
37 let cost_history = network.train_and_log(dataset, 128, 512, 10000);
39
40 let (dim, unit_square_prediction) = network.predict_unit_square(512);
42 let name = String::from(dataset.name) + "_" + &network.summerize();
43
44 plot_png(&name, dim, &unit_square_prediction, png::ColorType::Rgb).unwrap();
46 plot_graph(&name, &cost_history).unwrap();
47}