Skip to main content

wine/
wine.rs

1use csv::ReaderBuilder;
2use env_logger::{Builder, Target};
3use log::{error, info};
4use runn::{
5    adam::Adam,
6    cross_entropy::CrossEntropy,
7    csv::CSV,
8    dense_layer::Dense,
9    helper,
10    matrix::{DMat, DenseMatrix},
11    min_max::MinMax,
12    network::network_model::{Network, NetworkBuilder},
13    network_io::JSON,
14    network_search::NetworkSearchBuilder,
15    numbers::{Numbers, SequentialNumbers},
16    relu::ReLU,
17    softmax::Softmax,
18};
19use std::error::Error;
20use std::fs::File;
21use std::{env, fs};
22
23const EXP_NAME: &str = "wine";
24
25/// This example demonstrates how to train and validate a neural network on the Wine dataset.
26/// The Wine dataset is a classic dataset used for classification tasks.
27/// The code includes functions to load the dataset, build the neural network,
28/// train the network, validate its performance, and perform a hyperparameter search.
29///
30/// to run the example:
31/// ```bash
32/// cargo run --example wine
33/// ```
34/// to run the hyperparameter search:
35/// ```bash
36/// cargo run --example wine -- -search
37/// ```
38/// The hyperparameter search will create a CSV file with the results in the `wine` directory.
39/// The training and validation results will be logged in the `wine` directory.
40fn main() {
41    initialize_logger(EXP_NAME);
42
43    let (training_inputs, training_targets, validation_inputs, validation_targets) =
44        wine_inputs_targets("wine.data", 14, 13).unwrap();
45
46    let training_targets = helper::one_hot_encode(&training_targets);
47    let validation_targets = helper::one_hot_encode(&validation_targets);
48
49    let args: Vec<String> = env::args().collect();
50    if args.contains(&"-search".to_string()) {
51        search(&training_inputs, &training_targets, &validation_inputs, &validation_targets);
52    } else {
53        train_and_validate(&training_inputs, &training_targets, &validation_inputs, &validation_targets);
54    }
55}
56
57fn train_and_validate(
58    training_inputs: &DMat, training_targets: &DMat, validation_inputs: &DMat, validation_targets: &DMat,
59) {
60    let network_file = format!("{}_network", EXP_NAME);
61
62    let mut network = one_hot_encode_network(training_inputs.cols(), training_targets.cols());
63
64    let training_result = network.train(training_inputs, training_targets);
65    match training_result {
66        Ok(_) => {
67            info!("Training successfully completed");
68            network
69                .save(
70                    JSON::default()
71                        .directory(EXP_NAME)
72                        .file_name(&network_file)
73                        .build()
74                        .unwrap(),
75                )
76                .unwrap();
77            let net_results = network.predict(training_inputs, training_targets).unwrap();
78            info!(
79                "{}",
80                helper::pretty_compare_matrices(
81                    training_inputs,
82                    training_targets,
83                    &net_results.predictions,
84                    helper::CompareMode::Classification
85                )
86            );
87            info!("Training: {}", net_results.display_metrics());
88        }
89        Err(e) => {
90            eprintln!("Training failed: {}", e);
91        }
92    }
93
94    network = Network::load(
95        JSON::default()
96            .directory(EXP_NAME)
97            .file_name(&network_file)
98            .build()
99            .unwrap(),
100    )
101    .unwrap();
102    let net_results = network.predict(validation_inputs, validation_targets).unwrap();
103    info!(
104        "{}",
105        helper::pretty_compare_matrices(
106            validation_inputs,
107            validation_targets,
108            &net_results.predictions,
109            helper::CompareMode::Classification
110        )
111    );
112    info!("Validation: {}", net_results.display_metrics());
113}
114
115fn one_hot_encode_network(inp_size: usize, targ_size: usize) -> Network {
116    let network = NetworkBuilder::new(inp_size, targ_size)
117        .layer(Dense::default().size(7).activation(ReLU::build()).build())
118        .layer(Dense::default().size(5).activation(ReLU::build()).build())
119        .layer(Dense::default().size(targ_size).activation(Softmax::build()).build())
120        .optimizer(Adam::default().beta1(0.99).beta2(0.999).learning_rate(0.0035).build())
121        .loss_function(CrossEntropy::default().epsilon(1e-8).build())
122        .batch_size(4)
123        .normalize_input(MinMax::default())
124        .epochs(500)
125        .seed(55)
126        .build();
127
128    match network {
129        Ok(net) => net,
130        Err(e) => {
131            error!("Failed to build network: {}", e);
132            std::process::exit(1);
133        }
134    }
135}
136
137fn search(training_inputs: &DMat, training_targets: &DMat, validation_inputs: &DMat, validation_targets: &DMat) {
138    let network = one_hot_encode_network(training_inputs.cols(), training_targets.cols());
139
140    let network_search = NetworkSearchBuilder::new()
141        .network(network)
142        .parallelize(4)
143        .normalize_input(MinMax::default())
144        .learning_rates(
145            SequentialNumbers::new()
146                .lower_limit(0.0015)
147                .upper_limit(0.0035)
148                .increment(0.0005)
149                .floats(),
150        )
151        .batch_sizes(
152            SequentialNumbers::new()
153                .lower_limit(4.0)
154                .upper_limit(9.0)
155                .increment(1.0)
156                .ints(),
157        )
158        .hidden_layer(
159            SequentialNumbers::new()
160                .lower_limit(4.0)
161                .upper_limit(10.0)
162                .increment(1.0)
163                .ints(),
164            ReLU::build(),
165        )
166        .hidden_layer(
167            SequentialNumbers::new()
168                .lower_limit(4.0)
169                .upper_limit(10.0)
170                .increment(1.0)
171                .ints(),
172            ReLU::build(),
173        )
174        .export(
175            CSV::default()
176                .directory(EXP_NAME)
177                .file_name(&format!("{}_search", EXP_NAME))
178                .build(),
179        )
180        .build();
181
182    let mut network_search = match network_search {
183        Ok(ns) => ns,
184        Err(e) => {
185            error!("Failed to build network_search: {}", e);
186            std::process::exit(1);
187        }
188    };
189
190    let search_res = network_search
191        .search(training_inputs, training_targets, validation_inputs, validation_targets)
192        .unwrap();
193
194    info!("Search Result Count: {}", search_res.len());
195}
196
197pub fn wine_inputs_targets(
198    name: &str, fields_count: usize, input_count: usize,
199) -> Result<(DMat, DMat, DMat, DMat), Box<dyn Error>> {
200    let target_count = fields_count - input_count;
201
202    let file_path = format!("./examples/wine/{}", name);
203    let file = File::open(&file_path)?;
204    let mut reader = ReaderBuilder::new().has_headers(true).from_reader(file);
205
206    let mut inputs_data = Vec::new();
207    let mut targets_data = Vec::new();
208
209    for (index, result) in reader.records().enumerate() {
210        let record = result?;
211        // Skip if record is empty
212        if record.is_empty() {
213            println!("Skipping empty record at line {}", index + 2); // +2 because of header + 0-indexed
214            continue;
215        }
216        // If record has wrong number of fields, print detailed info
217        if record.len() != fields_count {
218            println!(
219                "Bad record at line {}: expected {} fields, but got {} fields.",
220                index + 2,
221                fields_count,
222                record.len()
223            );
224            println!("Record content: {:?}", record);
225            return Err("Unexpected number of fields".to_string().into());
226        }
227        for (i, value) in record.iter().enumerate() {
228            let parsed_val: f32 = value.parse()?;
229            if i < fields_count - input_count {
230                targets_data.push(parsed_val);
231            } else {
232                inputs_data.push(parsed_val);
233            }
234        }
235    }
236
237    let all_inputs = DenseMatrix::new(inputs_data.len() / input_count, input_count)
238        .data(&inputs_data)
239        .build()
240        .unwrap();
241    let all_targets = DenseMatrix::new(targets_data.len() / target_count, target_count)
242        .data(&targets_data)
243        .build()
244        .unwrap();
245
246    let (training_inputs, training_targets, validation_inputs, validation_targets) =
247        helper::stratified_split(&all_inputs, &all_targets, 0.2, 11);
248
249    Ok((training_inputs, training_targets, validation_inputs, validation_targets))
250}
251
252/// Initializes the logger for the application.
253/// The LOG environment variable is used to define the log level (e.g., info, debug, warn, error).
254/// If the LOG variable is not set, it defaults to info.
255fn initialize_logger(name: &str) {
256    // Check if the directory exists, and attempt to create it if it doesn't
257    if !std::path::Path::new(name).exists() {
258        let _res = fs::create_dir_all(name).map_err(|e| {
259            eprintln!("Failed to create log directory: {}", e);
260        });
261    }
262
263    // Attempt to create a log file
264    let log_file = match File::create(format!("./{}/{}.log", name, name)) {
265        Ok(file) => file,
266        Err(e) => {
267            eprintln!("Failed to create log file: {}", e);
268            return;
269        }
270    };
271
272    // Check if the "LOG" environment variable is set
273    let log_level = env::var("LOG").unwrap_or_else(|_| "info".to_string()); // Default to "info"
274
275    // Initialize the logger with the specified log level
276    Builder::new()
277        .target(Target::Pipe(Box::new(log_file)))
278        .parse_filters(&log_level) // Use the log level from the environment variable
279        .init();
280}