Skip to main content

energy_efficiency/
energy_efficiency.rs

1use csv::ReaderBuilder;
2use env_logger::{Builder, Target};
3use log::{error, info};
4use runn::{
5    adam::Adam,
6    csv::CSV,
7    dense_layer::Dense,
8    flexible::{Flexible, MonitorMetric},
9    helper,
10    linear::Linear,
11    matrix::{DMat, DenseMatrix},
12    mean_squared_error::MeanSquared,
13    min_max::MinMax,
14    network::network_model::{Network, NetworkBuilder},
15    network_io::JSON,
16    network_search::NetworkSearchBuilder,
17    numbers::{Numbers, SequentialNumbers},
18    relu::ReLU,
19};
20use std::error::Error;
21use std::fs::File;
22use std::{env, fs};
23
24const EXP_NAME: &str = "energy_efficiency";
25
26/// This example demonstrates how to train and validate a neural network on the Energy Efficiency dataset.
27/// The Energy Efficiency dataset is a regression dataset used for predicting energy efficiency.
28/// The code includes functions to load the dataset, build the neural network,
29/// train the network, validate its performance, and perform a hyperparameter search.
30///
31/// to run the example:
32/// ```bash
33/// cargo run --example energy_efficiency
34/// ```
35/// to run the hyperparameter search:
36/// ```bash
37/// cargo run --example energy_efficiency -- -search
38/// ```
39/// The hyperparameter search will create a CSV file with the results in the `energy_efficiency` directory.
40/// The training and validation results will be logged in the `energy_efficiency` directory.
41fn main() {
42    initialize_logger(EXP_NAME);
43
44    let (training_inputs, training_targets, validation_inputs, validation_targets) =
45        energy_efficiency_inputs_targets("ENB2012_data", 10, 8).unwrap();
46
47    let args: Vec<String> = env::args().collect();
48    if args.contains(&"-search".to_string()) {
49        search(&training_inputs, &training_targets, &validation_inputs, &validation_targets);
50    } else {
51        train_and_validate(&training_inputs, &training_targets, &validation_inputs, &validation_targets);
52    }
53}
54
55fn train_and_validate(
56    training_inputs: &DMat, training_targets: &DMat, validation_inputs: &DMat, validation_targets: &DMat,
57) {
58    let network_file = format!("{}_network", EXP_NAME);
59
60    let mut network = energy_efficiency_network(training_inputs.cols(), training_targets.cols());
61
62    let training_result = network.train(training_inputs, training_targets);
63    match training_result {
64        Ok(_) => {
65            info!("Training successfully completed");
66            network
67                .save(
68                    JSON::default()
69                        .directory(EXP_NAME)
70                        .file_name(&network_file)
71                        .build()
72                        .unwrap(),
73                )
74                .unwrap();
75            let net_results = network.predict(training_inputs, training_targets).unwrap();
76            info!(
77                "{}",
78                helper::pretty_compare_matrices(
79                    training_inputs,
80                    training_targets,
81                    &net_results.predictions,
82                    helper::CompareMode::Regression
83                )
84            );
85            info!("Training: {}", net_results.display_metrics());
86        }
87        Err(e) => {
88            error!("Training failed: {}", e);
89        }
90    }
91
92    network = Network::load(
93        JSON::default()
94            .directory(EXP_NAME)
95            .file_name(&network_file)
96            .build()
97            .unwrap(),
98    )
99    .unwrap();
100    let net_results = network.predict(validation_inputs, validation_targets).unwrap();
101    log::info!(
102        "{}",
103        helper::pretty_compare_matrices(
104            validation_inputs,
105            validation_targets,
106            &net_results.predictions,
107            helper::CompareMode::Regression
108        )
109    );
110    info!("Validation: {}", net_results.display_metrics());
111}
112
113fn energy_efficiency_network(inp_size: usize, targ_size: usize) -> Network {
114    let network = NetworkBuilder::new(inp_size, targ_size)
115        .layer(Dense::default().size(18).activation(ReLU::build()).build())
116        .layer(Dense::default().size(14).activation(ReLU::build()).build())
117        .layer(Dense::default().size(targ_size).activation(Linear::build()).build())
118        .optimizer(Adam::default().beta1(0.99).beta2(0.999).learning_rate(0.0030).build())
119        .loss_function(MeanSquared.build())
120        .early_stopper(
121            Flexible::default()
122                .monitor_metric(MonitorMetric::Loss)
123                .patience(500)
124                .min_delta(0.1)
125                .smoothing_factor(0.5)
126                .build(),
127        )
128        .batch_size(7)
129        .batch_group_size(2)
130        .parallelize(2)
131        .normalize_input(MinMax::default())
132        .epochs(500)
133        .seed(55)
134        .build();
135
136    match network {
137        Ok(net) => net,
138        Err(e) => {
139            eprintln!("Failed to build network: {}", e);
140            std::process::exit(1);
141        }
142    }
143}
144
145fn search(training_inputs: &DMat, training_targets: &DMat, validation_inputs: &DMat, validation_targets: &DMat) {
146    let start_time = std::time::Instant::now();
147    info!("Energy Efficieny network search started");
148    let network = energy_efficiency_network(training_inputs.cols(), training_targets.cols());
149
150    let network_search = NetworkSearchBuilder::new()
151        .network(network)
152        .parallelize(4)
153        .normalize_input(MinMax::default())
154        .learning_rates(
155            SequentialNumbers::new()
156                .lower_limit(0.0025)
157                .upper_limit(0.0035)
158                .increment(0.0005)
159                .floats(),
160        )
161        .batch_sizes(
162            SequentialNumbers::new()
163                .lower_limit(6.0)
164                .upper_limit(9.0)
165                .increment(1.0)
166                .ints(),
167        )
168        .hidden_layer(
169            SequentialNumbers::new()
170                .lower_limit(14.0)
171                .upper_limit(18.0)
172                .increment(2.0)
173                .ints(),
174            ReLU::build(),
175        )
176        .hidden_layer(
177            SequentialNumbers::new()
178                .lower_limit(14.0)
179                .upper_limit(18.0)
180                .increment(2.0)
181                .ints(),
182            ReLU::build(),
183        )
184        .export(
185            CSV::default()
186                .directory(EXP_NAME)
187                .file_name(&format!("{}_search", EXP_NAME))
188                .build(),
189        )
190        .build();
191    let mut network_search = match network_search {
192        Ok(ns) => ns,
193        Err(e) => {
194            error!("Failed to build network_search: {}", e);
195            std::process::exit(1);
196        }
197    };
198
199    let search_res = network_search
200        .search(training_inputs, training_targets, validation_inputs, validation_targets)
201        .unwrap();
202
203    info!("Energy Efficieny network search finished in {} seconds", start_time.elapsed().as_secs());
204    info!("Num Results: {}", search_res.len());
205}
206
207pub fn energy_efficiency_inputs_targets(
208    name: &str, fields_count: usize, input_count: usize,
209) -> Result<(DMat, DMat, DMat, DMat), Box<dyn Error>> {
210    let target_count = fields_count - input_count;
211
212    let file_path = format!("./examples/energy_efficiency/{}.csv", name);
213    let file = File::open(&file_path)?;
214    let mut reader = ReaderBuilder::new().has_headers(true).from_reader(file);
215
216    let mut inputs_data = Vec::new();
217    let mut targets_data = Vec::new();
218
219    for (index, result) in reader.records().enumerate() {
220        let record = result?;
221        // Skip if record is empty
222        if record.is_empty() {
223            error!("Skipping empty record at line {}", index + 2); // +2 because of header + 0-indexed
224            continue;
225        }
226        // If record has wrong number of fields, print detailed info
227        if record.len() != fields_count {
228            error!(
229                "Bad record at line {}: expected {} fields, but got {} fields",
230                index + 2,
231                fields_count,
232                record.len()
233            );
234            error!("Record content: {:?}", record);
235            return Err("Unexpected number of fields".to_string().into());
236        }
237        for (i, value) in record.iter().enumerate() {
238            let parsed_val: f32 = value.parse()?;
239            if i >= fields_count - target_count {
240                targets_data.push(parsed_val);
241            } else {
242                inputs_data.push(parsed_val);
243            }
244        }
245    }
246
247    let all_inputs = DenseMatrix::new(inputs_data.len() / input_count, input_count)
248        .data(&inputs_data)
249        .build()
250        .unwrap();
251    let all_targets = DenseMatrix::new(targets_data.len() / target_count, target_count)
252        .data(&targets_data)
253        .build()
254        .unwrap();
255
256    let (training_inputs, training_targets, validation_inputs, validation_targets) =
257        helper::random_split(&all_inputs, &all_targets, 0.2, 55);
258
259    Ok((training_inputs, training_targets, validation_inputs, validation_targets))
260}
261
262/// Initializes the logger for the application.
263/// The LOG environment variable is used to define the log level (e.g., info, debug, warn, error).
264/// If the LOG variable is not set, it defaults to info.
265fn initialize_logger(name: &str) {
266    // Check if the directory exists, and attempt to create it if it doesn't
267    if !std::path::Path::new(name).exists() {
268        let _res = fs::create_dir_all(name).map_err(|e| {
269            eprintln!("Failed to create log directory: {}", e);
270        });
271    }
272
273    // Attempt to create a log file
274    let log_file = match File::create(format!("./{}/{}.log", name, name)) {
275        Ok(file) => file,
276        Err(e) => {
277            eprintln!("Failed to create log file: {}", e);
278            return;
279        }
280    };
281
282    // Check if the "LOG" environment variable is set
283    let log_level = env::var("LOG").unwrap_or_else(|_| "info".to_string()); // Default to "info"
284
285    // Initialize the logger with the specified log level
286    Builder::new()
287        .target(Target::Pipe(Box::new(log_file)))
288        .parse_filters(&log_level) // Use the log level from the environment variable
289        .init();
290}