1use csv::ReaderBuilder;
2use env_logger::{Builder, Target};
3use log::{error, info};
4use runn::{
5 adam::Adam,
6 csv::CSV,
7 dense_layer::Dense,
8 flexible::{Flexible, MonitorMetric},
9 helper,
10 linear::Linear,
11 matrix::{DMat, DenseMatrix},
12 mean_squared_error::MeanSquared,
13 min_max::MinMax,
14 network::network_model::{Network, NetworkBuilder},
15 network_io::JSON,
16 network_search::NetworkSearchBuilder,
17 numbers::{Numbers, SequentialNumbers},
18 relu::ReLU,
19};
20use std::error::Error;
21use std::fs::File;
22use std::{env, fs};
23
24const EXP_NAME: &str = "energy_efficiency";
25
26fn main() {
42 initialize_logger(EXP_NAME);
43
44 let (training_inputs, training_targets, validation_inputs, validation_targets) =
45 energy_efficiency_inputs_targets("ENB2012_data", 10, 8).unwrap();
46
47 let args: Vec<String> = env::args().collect();
48 if args.contains(&"-search".to_string()) {
49 search(&training_inputs, &training_targets, &validation_inputs, &validation_targets);
50 } else {
51 train_and_validate(&training_inputs, &training_targets, &validation_inputs, &validation_targets);
52 }
53}
54
55fn train_and_validate(
56 training_inputs: &DMat, training_targets: &DMat, validation_inputs: &DMat, validation_targets: &DMat,
57) {
58 let network_file = format!("{}_network", EXP_NAME);
59
60 let mut network = energy_efficiency_network(training_inputs.cols(), training_targets.cols());
61
62 let training_result = network.train(training_inputs, training_targets);
63 match training_result {
64 Ok(_) => {
65 info!("Training successfully completed");
66 network
67 .save(
68 JSON::default()
69 .directory(EXP_NAME)
70 .file_name(&network_file)
71 .build()
72 .unwrap(),
73 )
74 .unwrap();
75 let net_results = network.predict(training_inputs, training_targets).unwrap();
76 info!(
77 "{}",
78 helper::pretty_compare_matrices(
79 training_inputs,
80 training_targets,
81 &net_results.predictions,
82 helper::CompareMode::Regression
83 )
84 );
85 info!("Training: {}", net_results.display_metrics());
86 }
87 Err(e) => {
88 error!("Training failed: {}", e);
89 }
90 }
91
92 network = Network::load(
93 JSON::default()
94 .directory(EXP_NAME)
95 .file_name(&network_file)
96 .build()
97 .unwrap(),
98 )
99 .unwrap();
100 let net_results = network.predict(validation_inputs, validation_targets).unwrap();
101 log::info!(
102 "{}",
103 helper::pretty_compare_matrices(
104 validation_inputs,
105 validation_targets,
106 &net_results.predictions,
107 helper::CompareMode::Regression
108 )
109 );
110 info!("Validation: {}", net_results.display_metrics());
111}
112
113fn energy_efficiency_network(inp_size: usize, targ_size: usize) -> Network {
114 let network = NetworkBuilder::new(inp_size, targ_size)
115 .layer(Dense::default().size(18).activation(ReLU::build()).build())
116 .layer(Dense::default().size(14).activation(ReLU::build()).build())
117 .layer(Dense::default().size(targ_size).activation(Linear::build()).build())
118 .optimizer(Adam::default().beta1(0.99).beta2(0.999).learning_rate(0.0030).build())
119 .loss_function(MeanSquared.build())
120 .early_stopper(
121 Flexible::default()
122 .monitor_metric(MonitorMetric::Loss)
123 .patience(500)
124 .min_delta(0.1)
125 .smoothing_factor(0.5)
126 .build(),
127 )
128 .batch_size(7)
129 .batch_group_size(2)
130 .parallelize(2)
131 .normalize_input(MinMax::default())
132 .epochs(500)
133 .seed(55)
134 .build();
135
136 match network {
137 Ok(net) => net,
138 Err(e) => {
139 eprintln!("Failed to build network: {}", e);
140 std::process::exit(1);
141 }
142 }
143}
144
145fn search(training_inputs: &DMat, training_targets: &DMat, validation_inputs: &DMat, validation_targets: &DMat) {
146 let start_time = std::time::Instant::now();
147 info!("Energy Efficieny network search started");
148 let network = energy_efficiency_network(training_inputs.cols(), training_targets.cols());
149
150 let network_search = NetworkSearchBuilder::new()
151 .network(network)
152 .parallelize(4)
153 .normalize_input(MinMax::default())
154 .learning_rates(
155 SequentialNumbers::new()
156 .lower_limit(0.0025)
157 .upper_limit(0.0035)
158 .increment(0.0005)
159 .floats(),
160 )
161 .batch_sizes(
162 SequentialNumbers::new()
163 .lower_limit(6.0)
164 .upper_limit(9.0)
165 .increment(1.0)
166 .ints(),
167 )
168 .hidden_layer(
169 SequentialNumbers::new()
170 .lower_limit(14.0)
171 .upper_limit(18.0)
172 .increment(2.0)
173 .ints(),
174 ReLU::build(),
175 )
176 .hidden_layer(
177 SequentialNumbers::new()
178 .lower_limit(14.0)
179 .upper_limit(18.0)
180 .increment(2.0)
181 .ints(),
182 ReLU::build(),
183 )
184 .export(
185 CSV::default()
186 .directory(EXP_NAME)
187 .file_name(&format!("{}_search", EXP_NAME))
188 .build(),
189 )
190 .build();
191 let mut network_search = match network_search {
192 Ok(ns) => ns,
193 Err(e) => {
194 error!("Failed to build network_search: {}", e);
195 std::process::exit(1);
196 }
197 };
198
199 let search_res = network_search
200 .search(training_inputs, training_targets, validation_inputs, validation_targets)
201 .unwrap();
202
203 info!("Energy Efficieny network search finished in {} seconds", start_time.elapsed().as_secs());
204 info!("Num Results: {}", search_res.len());
205}
206
207pub fn energy_efficiency_inputs_targets(
208 name: &str, fields_count: usize, input_count: usize,
209) -> Result<(DMat, DMat, DMat, DMat), Box<dyn Error>> {
210 let target_count = fields_count - input_count;
211
212 let file_path = format!("./examples/energy_efficiency/{}.csv", name);
213 let file = File::open(&file_path)?;
214 let mut reader = ReaderBuilder::new().has_headers(true).from_reader(file);
215
216 let mut inputs_data = Vec::new();
217 let mut targets_data = Vec::new();
218
219 for (index, result) in reader.records().enumerate() {
220 let record = result?;
221 if record.is_empty() {
223 error!("Skipping empty record at line {}", index + 2); continue;
225 }
226 if record.len() != fields_count {
228 error!(
229 "Bad record at line {}: expected {} fields, but got {} fields",
230 index + 2,
231 fields_count,
232 record.len()
233 );
234 error!("Record content: {:?}", record);
235 return Err("Unexpected number of fields".to_string().into());
236 }
237 for (i, value) in record.iter().enumerate() {
238 let parsed_val: f32 = value.parse()?;
239 if i >= fields_count - target_count {
240 targets_data.push(parsed_val);
241 } else {
242 inputs_data.push(parsed_val);
243 }
244 }
245 }
246
247 let all_inputs = DenseMatrix::new(inputs_data.len() / input_count, input_count)
248 .data(&inputs_data)
249 .build()
250 .unwrap();
251 let all_targets = DenseMatrix::new(targets_data.len() / target_count, target_count)
252 .data(&targets_data)
253 .build()
254 .unwrap();
255
256 let (training_inputs, training_targets, validation_inputs, validation_targets) =
257 helper::random_split(&all_inputs, &all_targets, 0.2, 55);
258
259 Ok((training_inputs, training_targets, validation_inputs, validation_targets))
260}
261
262fn initialize_logger(name: &str) {
266 if !std::path::Path::new(name).exists() {
268 let _res = fs::create_dir_all(name).map_err(|e| {
269 eprintln!("Failed to create log directory: {}", e);
270 });
271 }
272
273 let log_file = match File::create(format!("./{}/{}.log", name, name)) {
275 Ok(file) => file,
276 Err(e) => {
277 eprintln!("Failed to create log file: {}", e);
278 return;
279 }
280 };
281
282 let log_level = env::var("LOG").unwrap_or_else(|_| "info".to_string()); Builder::new()
287 .target(Target::Pipe(Box::new(log_file)))
288 .parse_filters(&log_level) .init();
290}