Skip to main content

iris/
iris.rs

1use csv::ReaderBuilder;
2use env_logger::{Builder, Target};
3use log::{error, info};
4use runn::{
5    adam::Adam,
6    cross_entropy::CrossEntropy,
7    csv::CSV,
8    dense_layer::Dense,
9    helper,
10    matrix::{DMat, DenseMatrix},
11    network::network_model::{Network, NetworkBuilder},
12    network_io::JSON,
13    network_search::NetworkSearchBuilder,
14    numbers::{Numbers, SequentialNumbers},
15    relu::ReLU,
16    softmax::Softmax,
17};
18use std::error::Error;
19use std::fs::File;
20use std::{env, fs};
21
22const EXP_NAME: &str = "iris";
23
24/// This example demonstrates how to train and validate a neural network on the Iris dataset.
25/// The Iris dataset is a classic dataset used for classification tasks.
26/// The code includes functions to load the dataset, build the neural network,
27/// train the network, validate its performance, and perform a hyperparameter search.
28///
29/// to run the example:
30/// ```bash
31/// cargo run --example iris
32/// ```
33/// to run the hyperparameter search:
34/// ```bash
35/// cargo run --example iris -- -search
36/// ```
37/// The hyperparameter search will create a CSV file with the results in the `iris` directory.
38/// The training and validation results will be logged in the `iris` directory as well.
39fn main() {
40    initialize_logger(EXP_NAME);
41
42    let args: Vec<String> = env::args().collect();
43    if args.contains(&"-search".to_string()) {
44        search();
45    } else {
46        train_and_validate();
47    }
48}
49
50fn train_and_validate() {
51    let network_file = format!("{}_network", EXP_NAME);
52
53    let (training_inputs, training_targets) = iris_inputs_outputs("train", 7, 4).unwrap();
54    let mut network = iris_network(training_inputs.cols(), training_targets.cols());
55
56    let training_result = network.train(&training_inputs, &training_targets);
57    match training_result {
58        Ok(_) => {
59            info!("Training successfully completed");
60            network
61                .save(
62                    JSON::default()
63                        .directory(EXP_NAME)
64                        .file_name(&network_file)
65                        .build()
66                        .unwrap(),
67                )
68                .unwrap();
69            let net_results = network.predict(&training_inputs, &training_targets).unwrap();
70            info!(
71                "{}",
72                helper::pretty_compare_matrices(
73                    &training_inputs,
74                    &training_targets,
75                    &net_results.predictions,
76                    helper::CompareMode::Classification
77                )
78            );
79            info!("Training: {}", net_results.display_metrics());
80        }
81        Err(e) => {
82            eprintln!("Training failed: {}", e);
83        }
84    }
85
86    network = Network::load(
87        JSON::default()
88            .directory(EXP_NAME)
89            .file_name(&network_file)
90            .build()
91            .unwrap(),
92    )
93    .unwrap();
94    let (validation_inputs, validation_targets) = iris_inputs_outputs("test", 7, 4).unwrap();
95    let net_results = network.predict(&validation_inputs, &validation_targets).unwrap();
96    log::info!(
97        "{}",
98        helper::pretty_compare_matrices(
99            &validation_inputs,
100            &validation_targets,
101            &net_results.predictions,
102            helper::CompareMode::Classification
103        )
104    );
105    info!("Validation: {}", net_results.display_metrics());
106}
107
108fn iris_network(inp_size: usize, targ_size: usize) -> Network {
109    let network = NetworkBuilder::new(inp_size, targ_size)
110        .layer(Dense::default().size(12).activation(ReLU::build()).build())
111        .layer(Dense::default().size(12).activation(ReLU::build()).build())
112        .layer(Dense::default().size(targ_size).activation(Softmax::build()).build())
113        .loss_function(CrossEntropy::default().epsilon(1e-8).build())
114        .optimizer(Adam::default().beta1(0.99).beta2(0.999).learning_rate(0.0035).build())
115        .batch_size(9)
116        .batch_group_size(2)
117        .parallelize(2)
118        .epochs(3000)
119        .seed(55)
120        .build();
121
122    match network {
123        Ok(net) => net,
124        Err(e) => {
125            eprintln!("Failed to build network: {}", e);
126            std::process::exit(1);
127        }
128    }
129}
130
131fn search() {
132    let (training_inputs, training_targets) = iris_inputs_outputs("train", 7, 4).unwrap();
133    let (validation_inputs, validation_targets) = iris_inputs_outputs("test", 7, 4).unwrap();
134
135    let network = iris_network(training_inputs.cols(), training_targets.cols());
136
137    let network_search = NetworkSearchBuilder::new()
138        .network(network)
139        .parallelize(4)
140        .learning_rates(
141            SequentialNumbers::new()
142                .lower_limit(0.0025)
143                .upper_limit(0.0035)
144                .increment(0.0005)
145                .floats(),
146        )
147        .batch_sizes(
148            SequentialNumbers::new()
149                .lower_limit(7.0)
150                .upper_limit(10.0)
151                .increment(1.0)
152                .ints(),
153        )
154        .hidden_layer(
155            SequentialNumbers::new()
156                .lower_limit(12.0)
157                .upper_limit(20.0)
158                .increment(4.0)
159                .ints(),
160            ReLU::build(),
161        )
162        .hidden_layer(
163            SequentialNumbers::new()
164                .lower_limit(12.0)
165                .upper_limit(20.0)
166                .increment(4.0)
167                .ints(),
168            ReLU::build(),
169        )
170        .export(
171            CSV::default()
172                .directory(EXP_NAME)
173                .file_name(&format!("{}_search", EXP_NAME))
174                .build(),
175        )
176        .build();
177
178    let mut network_search = match network_search {
179        Ok(ns) => ns,
180        Err(e) => {
181            error!("Failed to build network_search: {}", e);
182            std::process::exit(1);
183        }
184    };
185
186    let search_res = network_search
187        .search(&training_inputs, &training_targets, &validation_inputs, &validation_targets)
188        .unwrap();
189
190    info!("Num Results: {}", search_res.len());
191}
192
193pub fn iris_inputs_outputs(
194    name: &str, fields_count: usize, input_count: usize,
195) -> Result<(DMat, DMat), Box<dyn Error>> {
196    let target_count = fields_count - input_count;
197
198    let file_path = format!("./examples/iris/{}.csv", name);
199    let file = File::open(&file_path)?;
200    let mut reader = ReaderBuilder::new().has_headers(true).from_reader(file);
201
202    let mut inputs_data = Vec::new();
203    let mut labels_data = Vec::new();
204
205    for result in reader.records() {
206        let record = result?;
207        for (i, value) in record.iter().enumerate() {
208            let parsed_val: f32 = value.parse()?;
209            if i >= fields_count - target_count {
210                labels_data.push(parsed_val);
211            } else {
212                inputs_data.push(parsed_val);
213            }
214        }
215    }
216
217    let data_length = inputs_data.len() / input_count;
218
219    let inputs = DenseMatrix::new(data_length, input_count)
220        .data(&inputs_data)
221        .build()
222        .unwrap();
223    let labels = DenseMatrix::new(data_length, target_count)
224        .data(&labels_data)
225        .build()
226        .unwrap();
227
228    Ok((inputs, labels))
229}
230
231/// Initializes the logger for the application.
232/// The LOG environment variable is used to define the log level (e.g., info, debug, warn, error).
233/// If the LOG variable is not set, it defaults to info.
234fn initialize_logger(name: &str) {
235    // Check if the directory exists, and attempt to create it if it doesn't
236    if !std::path::Path::new(name).exists() {
237        let _res = fs::create_dir_all(name).map_err(|e| {
238            eprintln!("Failed to create log directory: {}", e);
239        });
240    }
241
242    // Attempt to create a log file
243    let log_file = match File::create(format!("./{}/{}.log", name, name)) {
244        Ok(file) => file,
245        Err(e) => {
246            eprintln!("Failed to create log file: {}", e);
247            return;
248        }
249    };
250
251    // Check if the "LOG" environment variable is set
252    let log_level = env::var("LOG").unwrap_or_else(|_| "info".to_string()); // Default to "info"
253
254    // Initialize the logger with the specified log level
255    Builder::new()
256        .target(Target::Pipe(Box::new(log_file)))
257        .parse_filters(&log_level) // Use the log level from the environment variable
258        .init();
259}