1use ascii_converter::decimals_to_string;
2use rand::{Rng, seq::SliceRandom, thread_rng};
3use serde::{Serialize, Deserialize};
4use std::{fs, path::Path};
5use crate::{
6 categorize::CatNetwork,
7 node::Node,
8 activation::ActivationFunction,
9 DEBUG,
10 error::DarjeelingError,
11 input::Input,
12 types::{Types, Types::Boolean},
13 dbg_println
14};
15use rayon::prelude::*;
16
17#[derive(Debug, Serialize, Deserialize)]
19pub struct GenNetwork {
20 node_array: Vec<Vec<Node>>,
21 sensor: Option<usize>,
22 answer: Option<usize>,
23 parameters: Option<u128>,
24 activation_function: ActivationFunction
25}
26#[warn(clippy::unwrap_in_result)]
27impl GenNetwork {
28
29 pub fn new(input_num: i32, hidden_num: i32, answer_num: i32, hidden_layers: i32, activation_function: ActivationFunction) -> GenNetwork {
53 let mut net: GenNetwork = GenNetwork { node_array: vec![], sensor: Some(0), answer: Some(hidden_layers as usize + 1), parameters: None, activation_function};
54 let mut rng = rand::thread_rng();
55 net.node_array.push(vec![]);
56 for _i in 0..input_num {
57 net.node_array[net.sensor.unwrap()].push(Node::new(&vec![], None));
58 }
59
60 for i in 1..hidden_layers + 1 {
61 let mut hidden_vec:Vec<Node> = vec![];
62 let hidden_links = net.node_array[(i - 1) as usize].len();
63 dbg_println!("Hidden Links: {:?}", hidden_links);
64 for _j in 0..hidden_num{
65 hidden_vec.push(Node { link_weights: vec![], link_vals: vec![], links: hidden_links, err_sig: None, correct_answer: None, cached_output: None, category: None, b_weight: None });
66 }
67 net.node_array.push(hidden_vec);
68 }
69
70 net.node_array.push(vec![]);
71 let answer_links = net.node_array[hidden_layers as usize].len();
72 println!("Answer Links: {:?}", answer_links);
73 for _i in 0..answer_num {
74 net.node_array[net.answer.unwrap()].push(Node { link_weights: vec![], link_vals: vec![], links: answer_links, err_sig: None, correct_answer: None, cached_output: Some(0.0), category: None, b_weight: None });
75 }
76
77 net.node_array
78 .iter_mut()
79 .for_each(|layer| {
80 layer
81 .iter_mut()
82 .for_each(|mut node| {
83 node.b_weight = Some(rng.gen_range(-0.5..0.5));
84 dbg_println!("Made it to pushing link weights");
85 (0..node.links)
86 .into_iter()
87 .for_each(|_| {
88 node.link_weights.push(rng.gen_range(-0.5..0.5));
89 node.link_vals.push(None);
90 })
91 })
92 });
93 let mut params = 0;
94 (0..net.node_array.len())
95 .into_iter()
96 .for_each(|i| {
97 (0..net.node_array[i].len())
98 .into_iter()
99 .for_each(|j| {
100 params += 1 + net.node_array[i][j].links as u128;
101 })
102 });
103 net.parameters = Some(params);
104 net
105 }
106
107 pub fn learn( &mut self,
174 data: &mut Vec<Input>,
175 learning_rate: f32,
176 name: &str, max_cycles: i32,
177 distinguising_learning_rate: f32, distinguising_hidden_neurons: i32,
178 distinguising_hidden_layers: i32, distinguising_activation: ActivationFunction,
179 distinguishing_target_err_percent: f32
180 ) -> Result<String, DarjeelingError> {
181 let mut epochs: f32 = 0.0;
182 let distinguishing_model: *mut CatNetwork = &mut CatNetwork::new(self.node_array[self.answer.unwrap()].len() as i32, distinguising_hidden_neurons, 2, distinguising_hidden_layers, distinguising_activation);
183 let mut outputs: Vec<Input> = vec![];
184 for _i in 0..max_cycles {
185 let mse: f32;
186 data.shuffle(&mut thread_rng());
187 for line in 0..data.len() {
188 dbg_println!("Training Checkpoint One Passed");
189 self.push_downstream(data, line as i32);
190 let mut output = vec![];
191 for i in 0..self.node_array[self.answer.unwrap()].len() {
192 output.push(self.node_array[self.answer.unwrap()][i].output(&self.activation_function));
193 }
194 outputs.push(Input::new(output, Some(Boolean(false)))); data[line].answer = Some(Boolean(true));
196 outputs.push(data[line].clone());
197 }
198 let mut new_model: CatNetwork = CatNetwork::new(self.node_array[self.answer.unwrap()].len() as i32, distinguising_hidden_neurons, 2, distinguising_hidden_layers, distinguising_activation);
201 match new_model.learn(
202 data,
203 vec![Boolean(true), Boolean(false)],
204 distinguising_learning_rate,
205 &("distinguishing".to_owned() + &name), distinguishing_target_err_percent, false)
206 {
207 Ok((_name, _err_percent, errmse)) => mse = errmse,
208 Err(error) => return Err(DarjeelingError::DisinguishingModelError(error.to_string()))
209 };
210
211 unsafe { distinguishing_model.write(new_model) };
212
213 self.backpropogate(learning_rate, mse);
214 epochs += 1.0;
215 println!("Epoch: {:?}", epochs);
216 }
217 #[allow(unused_mut)]
218 let mut model_name: String;
219 match self.write_model(&name) {
220 Ok(m_name) => {
221 model_name = m_name;
222 },
223 Err(error) => return Err(error)
224 }
225 Ok(model_name)
226 }
227
228 pub fn test(&mut self, data: &mut Vec<Input>) -> Result<Vec<Input>, DarjeelingError> {
229 data.shuffle(&mut thread_rng());
230 let mut outputs: Vec<Input> = vec![];
231 for i in 0..data.len() {
232 self.push_downstream(data, i as i32);
233 let mut output = vec![];
234 for i in 0..self.node_array[self.answer.unwrap()].len() {
235 output.push(self.node_array[self.answer.unwrap()][i].output(&self.activation_function));
236 }
237 outputs.push(Input::new(output, None)); }
239 Ok(outputs)
240 }
241
242 fn push_downstream(&mut self, data: &mut Vec<Input>, line: i32) {
244
245 for i in 0..self.node_array[self.sensor.unwrap()].len() {
247 let input = data[line as usize].inputs[i];
248
249 self.node_array[self.sensor.unwrap()][i].cached_output = Some(input);
250 }
251
252 for layer in 1..self.node_array.len() {
254
255 for node in 0..self.node_array[layer].len() {
256
257 for prev_node in 0..self.node_array[layer-1].len() {
258
259 self.node_array[layer][node].link_vals[prev_node] = Some(self.node_array[layer-1][prev_node].cached_output.unwrap());
261 self.node_array[layer][node].output(&self.activation_function);
263 if DEBUG { if layer == self.answer.unwrap() { println!("Ran output on answer {:?}", self.node_array[layer][node].cached_output) } }
264 }
265 self.node_array[layer][node].output(&self.activation_function);
266 }
267 }
268 }
269
270 fn largest_node(&self) -> usize {
272 let mut largest_node = 0;
273 (0..self.node_array[self.answer.unwrap()].len())
274 .into_iter()
275 .for_each(|node| {
276 if self.node_array[self.answer.unwrap()][node].cached_output > self.node_array[self.answer.unwrap()][largest_node].cached_output {
277 largest_node = node;
278 }
279 });
280 largest_node
281 }
282 fn backpropogate(&mut self, learning_rate: f32, mse: f32) {
284 let hidden_layers = (self.node_array.len() - 2) as i32;
285 (self.node_array[self.answer.unwrap()])
286 .par_iter_mut()
287 .for_each(|answer_node| {
288 println!("Node: {:?}", answer_node);
289 answer_node.compute_answer_err_sig_gen(mse, &self.activation_function);
290 dbg_println!("Error: {:?}", answer_node.err_sig.unwrap());
291 });
292 self.adjust_hidden_weights(learning_rate, hidden_layers);
293 (self.node_array[self.answer.unwrap()])
295 .par_iter_mut()
296 .for_each(|node| {
297 node.adjust_weights(learning_rate);
298 });
299 }
300
301 #[allow(non_snake_case)]
302 fn adjust_hidden_weights(&mut self, learning_rate: f32, hidden_layers: i32) {
304 (1..(hidden_layers + 1) as usize)
306 .into_iter()
307 .for_each(|HIDDEN| {
308 (0..self.node_array[HIDDEN].len())
309 .into_iter()
310 .for_each(|hidden| {
311 self.node_array[HIDDEN][hidden].err_sig = Some(0.0);
312 (0..self.node_array[HIDDEN + 1 ].len())
313 .into_iter()
314 .for_each(|next_layer| {
315 let next_weight = self.node_array[HIDDEN + 1][next_layer].link_weights[hidden];
316 self.node_array[HIDDEN + 1][next_layer].err_sig = match self.node_array[HIDDEN + 1][next_layer].err_sig.is_none() {
317 true => {
318 Some(0.0)
319 },
320 false => {
321 self.node_array[HIDDEN + 1][next_layer].err_sig
322 }
323 };
324 self.node_array[HIDDEN][hidden].err_sig = Some(self.node_array[HIDDEN][hidden].err_sig.unwrap() + (self.node_array[HIDDEN + 1][next_layer].err_sig.unwrap() * next_weight));
326
327 let hidden_result = self.node_array[HIDDEN][hidden].cached_output.unwrap();
328 let multiplied_value = self.node_array[HIDDEN][hidden].err_sig.unwrap() * (hidden_result) * (1.0 - hidden_result);
329 self.node_array[HIDDEN][hidden].err_sig = Some(multiplied_value);
330 dbg_println!("next err sig {:?}\nnext weight {:?}\nnew hidden errsig multiply: {:?}\n\nLayer: {:?}\nNode: {:?}\n", self.node_array[HIDDEN + 1][next_layer].err_sig.unwrap(), next_weight, multiplied_value, HIDDEN, hidden);
331 self.node_array[HIDDEN][hidden].adjust_weights(learning_rate);
332 });
333
334 });
335 });
336 }
337
338 fn self_analysis<'b>(&'b self, epochs: &mut Option<f32>, sum: &'b mut f32, count: &'b mut f32, data: &mut Vec<Input>, line: usize, expected_type: Types) -> Result<Vec<Types>, DarjeelingError> {
343 let brightest_node: &Node = &self.node_array[self.answer.unwrap()][self.largest_node()];
347 let brightness: f32 = brightest_node.cached_output.unwrap();
348
349 if !(epochs.is_none()) { if epochs.unwrap() % 10.0 == 0.0 && epochs.unwrap() != 0.0 {
351 println!("\n-------------------------\n");
352 println!("Epoch: {:?}", epochs);
353 println!("Category: {:?} \nBrightness: {:?}", brightest_node.category.as_ref().unwrap(), brightness);
354 if DEBUG {
355 let dimest_node: &Node = &self.node_array[self.answer.unwrap()][self.node_array[self.answer.unwrap()].len()-1-self.largest_node()];
356 println!("Chosen category: {:?} \nDimest Brightness: {:?}", dimest_node.category.as_ref().unwrap(), dimest_node.cached_output.unwrap());
357 }
358 }
359 }
360
361 dbg_println!("Category: {:?} \nBrightness: {:?}", brightest_node.category.as_ref().unwrap(), brightness);
362 if brightest_node.category.as_ref().unwrap().eq(&data[line].answer.as_ref().unwrap()) {
363 dbg_println!("Correct Answer Chosen\nSum++");
364 *sum += 1.0;
365 }
366 *count += 1.0;
367 let mut ret: Vec<Types> = vec![];
368 match expected_type {
369 Types::Integer(_) => {
370 let _ = &self.node_array[self.answer.unwrap()]
371 .iter()
372 .for_each(|node| {
373 let int = node.cached_output.unwrap() as i32;
374 ret.push(Types::Integer(int));
375 });
376 }
377 Types::Boolean(_) => {
378 let _ = &self.node_array[self.answer.unwrap()]
379 .iter()
380 .for_each(|node| {
381 let bool = node.cached_output.unwrap() > 0.0;
382 ret.push(Types::Boolean(bool));
383 });
384 }
385 Types::Float(_) => {
386 let _ = &self.node_array[self.answer.unwrap()]
387 .iter()
388 .for_each(|node| {
389 ret.push(Types::Float(node.cached_output.unwrap()));
390 });
391 }
392 Types::String(_) => {
393 for node in &self.node_array[self.answer.unwrap()] {
394 let inputs = vec![(node.cached_output.unwrap() as u8)];
395 let buff = match decimals_to_string(&inputs) {
396 Ok(val) => val,
397 Err(err) => return Err(DarjeelingError::SelfAnalysisStringConversion(err))
398 };
399 ret.push(Types::String(buff));
400 }
401 }
402 };
403 Ok(ret)
404 }
405
406 pub fn write_model(&mut self, name: &str) -> Result<String, DarjeelingError> {
421 let mut rng = rand::thread_rng();
422 let file_num: u32 = rng.gen();
423 let model_name: String = format!("model_{}_{}.darj", name, file_num);
424
425 match Path::new(&model_name).try_exists() {
426
427 Ok(false) => {
428 let _file: fs::File = fs::File::create(&model_name).unwrap();
429 let mut serialized = "".to_string();
430 println!("write, length: {}", self.node_array.len());
431 for i in 0..self.node_array.len() {
432 if i != 0 {
433 let _ = serialized.push_str("lb\n");
434 }
435 for j in 0..self.node_array[i].len() {
436 for k in 0..self.node_array[i][j].link_weights.len() {
437 print!("{}", self.node_array[i][j].link_weights[k]);
438 if k == self.node_array[i][j].link_weights.len() - 1 {
439 let _ = serialized.push_str(format!("{}", self.node_array[i][j].link_weights[k]).as_str());
440 } else {
441 let _ = serialized.push_str(format!("{},", self.node_array[i][j].link_weights[k]).as_str());
442 }
443 }
444 let _ = serialized.push_str(format!(";{}", self.node_array[i][j].b_weight.unwrap().to_string()).as_str());
445 let _ = serialized.push_str("\n");
446 }
447 }
448 serialized.push_str("lb\n");
449 serialized.push_str(format!("{}", self.activation_function).as_str());
450 match fs::write(&model_name, serialized) {
452 Ok(()) => {
453 println!("Model {:?} Saved", file_num);
454 Ok(model_name)
455 },
456 Err(_error) => {
457 Err(DarjeelingError::WriteModelFailed(model_name))
458 }
459 }
460 },
461 Ok(true) => {
462 return self.write_model(name);
463 },
464 Err(error) => Err(DarjeelingError::UnknownError(error.to_string()))
465 }
466 }
467
468 pub fn read_model(model_name: String) -> Result<GenNetwork, DarjeelingError> {
479
480 println!("Loading model");
481
482 let serialized_net: String = match fs::read_to_string(&model_name) {
484
485 Ok(serizalized_net) => serizalized_net,
486 Err(error) => return Err(DarjeelingError::ReadModelFailed(model_name.clone() + ";" + &error.to_string()))
487 };
488
489 let mut node_array: Vec<Vec<Node>> = vec![];
490 let mut layer: Vec<Node> = vec![];
491 let mut activation: Option<ActivationFunction> = None;
492 for i in serialized_net.lines() {
493 match i {
494 "sigmoid" => activation = Some(ActivationFunction::Sigmoid),
495
496 "linear" => activation = Some(ActivationFunction::Linear),
497
498 "tanh" => activation = Some(ActivationFunction::Tanh),
499
500 _ => {
503
504 if i.trim() == "lb" {
505 node_array.push(layer.clone());
506 layer = vec![];
508 continue;
509 }
510 #[allow(unused_mut)]
511 let mut node: Option<Node>;
512 if node_array.len() == 0 {
513 let b_weight: Vec<&str> = i.split(";").collect();
514 node = Some(Node::new(&vec![], Some(b_weight[1].parse().unwrap())));
516 } else {
517 let node_data: Vec<&str> = i.trim().split(";").collect();
518 let str_weight_array: Vec<&str> = node_data[0].split(",").collect();
519 let mut weight_array: Vec<f32> = vec![];
520 let b_weight: &str = node_data[1];
521 for weight in 0..str_weight_array.len() {
524 let val: f32 = str_weight_array[weight].parse().unwrap();
526 weight_array.push(val);
527 }
528 node = Some(Node::new(&weight_array, Some(b_weight.parse().unwrap())));
530 }
531
532 layer.push(node.expect("Both cases provide a Some value for node"));
533 }
535 }
536
537 }
538 let sensor: Option<usize> = Some(0);
540 let answer: Option<usize> = Some(node_array.len() - 1);
541
542 let net = GenNetwork {
543 node_array,
544 sensor,
545 answer,
546 parameters: None,
547 activation_function: activation.unwrap()
548 };
549 Ok(net)
552 }
553
554 pub fn add_hidden_layer_with_size(&mut self, size: usize) {
555 let mut rng = rand::thread_rng();
556 let a = self.answer.expect("initialized network");
557 self.node_array.push(self.node_array[a].clone());
558 self.node_array[a] = vec![];
559 let links = self.node_array[a - 1].len();
560 (0..size).into_iter().for_each(|i| {
561 self.node_array[a].push(Node::new(&vec![], Some(rng.gen_range(-0.5..0.5))));
562 self.node_array[a][i].links = links;
563 (0..self.node_array[a][i].links).into_iter().for_each(|_| {
564 self.node_array[a][i].link_weights.push(rng.gen_range(-0.5..0.5));
565 self.node_array[a][i].link_vals.push(None);
566 })
567 });
568 self.answer = Some(a + 1);
569 }
570}
571