1use csv::ReaderBuilder;
2use env_logger::{Builder, Target};
3use log::{error, info};
4use runn::{
5 adam::Adam,
6 cross_entropy::CrossEntropy,
7 csv::CSV,
8 dense_layer::Dense,
9 helper,
10 matrix::{DMat, DenseMatrix},
11 min_max::MinMax,
12 network::network_model::{Network, NetworkBuilder},
13 network_io::JSON,
14 network_search::NetworkSearchBuilder,
15 numbers::{Numbers, SequentialNumbers},
16 relu::ReLU,
17 softmax::Softmax,
18};
19use std::error::Error;
20use std::fs::File;
21use std::{env, fs};
22
23const EXP_NAME: &str = "wine";
24
25fn main() {
41 initialize_logger(EXP_NAME);
42
43 let (training_inputs, training_targets, validation_inputs, validation_targets) =
44 wine_inputs_targets("wine.data", 14, 13).unwrap();
45
46 let training_targets = helper::one_hot_encode(&training_targets);
47 let validation_targets = helper::one_hot_encode(&validation_targets);
48
49 let args: Vec<String> = env::args().collect();
50 if args.contains(&"-search".to_string()) {
51 search(&training_inputs, &training_targets, &validation_inputs, &validation_targets);
52 } else {
53 train_and_validate(&training_inputs, &training_targets, &validation_inputs, &validation_targets);
54 }
55}
56
57fn train_and_validate(
58 training_inputs: &DMat, training_targets: &DMat, validation_inputs: &DMat, validation_targets: &DMat,
59) {
60 let network_file = format!("{}_network", EXP_NAME);
61
62 let mut network = one_hot_encode_network(training_inputs.cols(), training_targets.cols());
63
64 let training_result = network.train(training_inputs, training_targets);
65 match training_result {
66 Ok(_) => {
67 info!("Training successfully completed");
68 network
69 .save(
70 JSON::default()
71 .directory(EXP_NAME)
72 .file_name(&network_file)
73 .build()
74 .unwrap(),
75 )
76 .unwrap();
77 let net_results = network.predict(training_inputs, training_targets).unwrap();
78 info!(
79 "{}",
80 helper::pretty_compare_matrices(
81 training_inputs,
82 training_targets,
83 &net_results.predictions,
84 helper::CompareMode::Classification
85 )
86 );
87 info!("Training: {}", net_results.display_metrics());
88 }
89 Err(e) => {
90 eprintln!("Training failed: {}", e);
91 }
92 }
93
94 network = Network::load(
95 JSON::default()
96 .directory(EXP_NAME)
97 .file_name(&network_file)
98 .build()
99 .unwrap(),
100 )
101 .unwrap();
102 let net_results = network.predict(validation_inputs, validation_targets).unwrap();
103 info!(
104 "{}",
105 helper::pretty_compare_matrices(
106 validation_inputs,
107 validation_targets,
108 &net_results.predictions,
109 helper::CompareMode::Classification
110 )
111 );
112 info!("Validation: {}", net_results.display_metrics());
113}
114
115fn one_hot_encode_network(inp_size: usize, targ_size: usize) -> Network {
116 let network = NetworkBuilder::new(inp_size, targ_size)
117 .layer(Dense::default().size(7).activation(ReLU::build()).build())
118 .layer(Dense::default().size(5).activation(ReLU::build()).build())
119 .layer(Dense::default().size(targ_size).activation(Softmax::build()).build())
120 .optimizer(Adam::default().beta1(0.99).beta2(0.999).learning_rate(0.0035).build())
121 .loss_function(CrossEntropy::default().epsilon(1e-8).build())
122 .batch_size(4)
123 .normalize_input(MinMax::default())
124 .epochs(500)
125 .seed(55)
126 .build();
127
128 match network {
129 Ok(net) => net,
130 Err(e) => {
131 error!("Failed to build network: {}", e);
132 std::process::exit(1);
133 }
134 }
135}
136
137fn search(training_inputs: &DMat, training_targets: &DMat, validation_inputs: &DMat, validation_targets: &DMat) {
138 let network = one_hot_encode_network(training_inputs.cols(), training_targets.cols());
139
140 let network_search = NetworkSearchBuilder::new()
141 .network(network)
142 .parallelize(4)
143 .normalize_input(MinMax::default())
144 .learning_rates(
145 SequentialNumbers::new()
146 .lower_limit(0.0015)
147 .upper_limit(0.0035)
148 .increment(0.0005)
149 .floats(),
150 )
151 .batch_sizes(
152 SequentialNumbers::new()
153 .lower_limit(4.0)
154 .upper_limit(9.0)
155 .increment(1.0)
156 .ints(),
157 )
158 .hidden_layer(
159 SequentialNumbers::new()
160 .lower_limit(4.0)
161 .upper_limit(10.0)
162 .increment(1.0)
163 .ints(),
164 ReLU::build(),
165 )
166 .hidden_layer(
167 SequentialNumbers::new()
168 .lower_limit(4.0)
169 .upper_limit(10.0)
170 .increment(1.0)
171 .ints(),
172 ReLU::build(),
173 )
174 .export(
175 CSV::default()
176 .directory(EXP_NAME)
177 .file_name(&format!("{}_search", EXP_NAME))
178 .build(),
179 )
180 .build();
181
182 let mut network_search = match network_search {
183 Ok(ns) => ns,
184 Err(e) => {
185 error!("Failed to build network_search: {}", e);
186 std::process::exit(1);
187 }
188 };
189
190 let search_res = network_search
191 .search(training_inputs, training_targets, validation_inputs, validation_targets)
192 .unwrap();
193
194 info!("Search Result Count: {}", search_res.len());
195}
196
197pub fn wine_inputs_targets(
198 name: &str, fields_count: usize, input_count: usize,
199) -> Result<(DMat, DMat, DMat, DMat), Box<dyn Error>> {
200 let target_count = fields_count - input_count;
201
202 let file_path = format!("./examples/wine/{}", name);
203 let file = File::open(&file_path)?;
204 let mut reader = ReaderBuilder::new().has_headers(true).from_reader(file);
205
206 let mut inputs_data = Vec::new();
207 let mut targets_data = Vec::new();
208
209 for (index, result) in reader.records().enumerate() {
210 let record = result?;
211 if record.is_empty() {
213 println!("Skipping empty record at line {}", index + 2); continue;
215 }
216 if record.len() != fields_count {
218 println!(
219 "Bad record at line {}: expected {} fields, but got {} fields.",
220 index + 2,
221 fields_count,
222 record.len()
223 );
224 println!("Record content: {:?}", record);
225 return Err("Unexpected number of fields".to_string().into());
226 }
227 for (i, value) in record.iter().enumerate() {
228 let parsed_val: f32 = value.parse()?;
229 if i < fields_count - input_count {
230 targets_data.push(parsed_val);
231 } else {
232 inputs_data.push(parsed_val);
233 }
234 }
235 }
236
237 let all_inputs = DenseMatrix::new(inputs_data.len() / input_count, input_count)
238 .data(&inputs_data)
239 .build()
240 .unwrap();
241 let all_targets = DenseMatrix::new(targets_data.len() / target_count, target_count)
242 .data(&targets_data)
243 .build()
244 .unwrap();
245
246 let (training_inputs, training_targets, validation_inputs, validation_targets) =
247 helper::stratified_split(&all_inputs, &all_targets, 0.2, 11);
248
249 Ok((training_inputs, training_targets, validation_inputs, validation_targets))
250}
251
252fn initialize_logger(name: &str) {
256 if !std::path::Path::new(name).exists() {
258 let _res = fs::create_dir_all(name).map_err(|e| {
259 eprintln!("Failed to create log directory: {}", e);
260 });
261 }
262
263 let log_file = match File::create(format!("./{}/{}.log", name, name)) {
265 Ok(file) => file,
266 Err(e) => {
267 eprintln!("Failed to create log file: {}", e);
268 return;
269 }
270 };
271
272 let log_level = env::var("LOG").unwrap_or_else(|_| "info".to_string()); Builder::new()
277 .target(Target::Pipe(Box::new(log_file)))
278 .parse_filters(&log_level) .init();
280}