darjeeling/
generation.rs

1use ascii_converter::decimals_to_string;
2use rand::{Rng, seq::SliceRandom, thread_rng}; 
3use serde::{Serialize, Deserialize};
4use std::{fs, path::Path};
5use crate::{
6    categorize::CatNetwork,
7    node::Node, 
8    activation::ActivationFunction, 
9    DEBUG, 
10    error::DarjeelingError,
11    input::Input, 
12    types::{Types, Types::Boolean},
13    dbg_println
14};
15use rayon::prelude::*;
16
17/// The generation Neural Network struct
18#[derive(Debug, Serialize, Deserialize)]
19pub struct GenNetwork {
20    node_array: Vec<Vec<Node>>,
21    sensor: Option<usize>,
22    answer: Option<usize>,
23    parameters: Option<u128>,
24    activation_function: ActivationFunction
25}
26#[warn(clippy::unwrap_in_result)]
27impl GenNetwork {
28    
29    /// Constructor function for the neural network
30    /// Fills a Neural Network's node_array with empty nodes. 
31    /// Initializes random starting link and bias weights between -.5 and .5
32    /// 
33    /// ## Params
34    /// - Inputs: The number of sensors in the input layer
35    /// - Hidden: The number of hidden nodes in each layer
36    /// - Answer: The number of answer nodes, or possible categories
37    /// - Hidden Layers: The number of different hidden layers
38    /// 
39    /// ## Examples
40    /// ``` rust
41    /// use darjeeling::{
42    ///     activation::ActivationFunction,
43    ///     generation::GenNetwork
44    /// };
45    /// 
46    /// let inputs: i32 = 10;
47    /// let hidden: i32 = 40;
48    /// let answer: i32 = 2;
49    /// let hidden_layers: i32 = 1;
50    /// let mut net: GenNetwork = GenNetwork::new(inputs, hidden, answer, hidden_layers, ActivationFunction::Sigmoid);
51    /// ```
52    pub fn new(input_num: i32, hidden_num: i32, answer_num: i32, hidden_layers: i32, activation_function: ActivationFunction) -> GenNetwork {
53        let mut net: GenNetwork = GenNetwork { node_array: vec![], sensor: Some(0), answer: Some(hidden_layers as usize + 1), parameters: None, activation_function};
54        let mut rng = rand::thread_rng();
55        net.node_array.push(vec![]);    
56        for _i in 0..input_num {
57            net.node_array[net.sensor.unwrap()].push(Node::new(&vec![], None));
58        }
59
60        for i in 1..hidden_layers + 1 {
61            let mut hidden_vec:Vec<Node> = vec![];
62            let hidden_links = net.node_array[(i - 1) as usize].len();
63            dbg_println!("Hidden Links: {:?}", hidden_links);
64            for _j in 0..hidden_num{
65                hidden_vec.push(Node { link_weights: vec![], link_vals: vec![], links: hidden_links, err_sig: None, correct_answer: None, cached_output: None, category: None, b_weight: None });
66            }
67            net.node_array.push(hidden_vec);
68        }
69
70        net.node_array.push(vec![]);
71        let answer_links = net.node_array[hidden_layers as usize].len();
72        println!("Answer Links: {:?}", answer_links);
73        for _i in 0..answer_num {
74            net.node_array[net.answer.unwrap()].push(Node { link_weights: vec![], link_vals: vec![], links: answer_links, err_sig: None, correct_answer: None, cached_output: Some(0.0), category: None, b_weight: None });
75        }
76        
77        net.node_array
78            .iter_mut()
79            .for_each(|layer| {
80                layer
81                    .iter_mut()
82                    .for_each(|mut node| {
83                        node.b_weight = Some(rng.gen_range(-0.5..0.5));
84                        dbg_println!("Made it to pushing link weights");
85                        (0..node.links)
86                            .into_iter()
87                            .for_each(|_| {
88                                node.link_weights.push(rng.gen_range(-0.5..0.5));
89                                node.link_vals.push(None);
90                            })
91                    })
92            });
93        let mut params = 0;
94        (0..net.node_array.len())
95            .into_iter()
96            .for_each(|i| {
97                (0..net.node_array[i].len())
98                    .into_iter()
99                    .for_each(|j| {
100                        params += 1 + net.node_array[i][j].links as u128;
101                    })
102            });
103        net.parameters = Some(params);
104        net
105    }
106
107    /// Trains a neural model to generate new data formatted as inputs, based on the given data
108    /// 
109    /// ## Params
110    /// - Data: List of inputs to be trained on
111    /// - Learning Rate: The modifier that is applied to link weights as they're adjusted.
112    /// Try fiddling with this one, but -1.5 - 1.5 is recommended to start.
113    /// - Name: The model name
114    /// - Max Cycles: The maximum number of epochs the training will run for.
115    /// - Distinguising Learning Rate: The learning rate for the distinguishing model.
116    /// - Distinguishing Hidden Neurons: The number of hidden neurons in each layer of the distinguishing model.
117    /// - Distinguising Hidden Layers: The number of hidden layers in the distinguishing model.
118    /// - Distinguishing Activation: The activation function of the distinguishing model.
119    /// - Distinguishing Target Error Percent: The error percentange at which the distinguishing models will stop training.
120    /// 
121    /// ## Returns
122    /// The falable name of the model that this neural network trained
123    /// 
124    /// ## Err
125    /// ### WriteModelFailed
126    /// There was a problem when saving the model to a file
127    /// 
128    /// ### ModelNameAlreadyExists
129    /// The random model name chosen already exists
130    /// Change the name or retrain
131    /// 
132    /// ### RemoveModelFailed
133    /// Everytime a new distinguishing model is written to the project folder, the previous one has to be removed.
134    /// This removal failed,
135    /// 
136    /// ### DistinguishingModel 
137    /// The distinguishing model training failed.
138    /// 
139    /// ### UnknownError
140    /// Not sure what happened, but something failed
141    /// 
142    /// Make an issue on the [darjeeling](https://github.com/Ewie21/darjeeling) github page
143    /// Or contact me at elocolburn@comcast.net
144    /// 
145    /// ## TODO: Refactor to pass around the neural net, not the model name
146    /// 
147    /// ## Examples
148    /// ```ignore
149    /// use darjeeling::{
150    ///     generation::GenNetwork,
151    ///     activation::ActivationFunction,
152    ///     input::Input, 
153    ///     // This file may not be avaliable
154    ///     // Everything found here will be hyper-specific to your project.
155    ///     tests::{categories_str_format, file}
156    /// };
157    /// 
158    /// // A file with data
159    /// // To make sure the networked is properly trained, make sure it follows some sort of pattern
160    /// // This is just sample data, for accurate results, around 3800 datapoints is needed
161    /// // 1 2 3 4 5 6 7 8
162    /// // 3 2 5 4 7 6 1 8
163    /// // 0 2 5 4 3 6 1 8
164    /// // 7 2 3 4 9 6 1 8
165    /// // You also need to write the file input function
166    /// // Automatic file reading and formatting function coming soon
167    /// let mut data: Vec<Input> = file();
168    /// let mut net = GenNetwork::new(2, 2, 2, 1, ActivationFunction::Sigmoid);
169    /// let model_name: String = net.learn(&mut data, 0.5, "gen", 100, 0.5, 10, 1, ActivationFunction::Sigmoid, 99.0).unwrap();
170    /// let new_data: Vec<Input> = net.test(data).unwrap();
171    /// ```
172    pub fn learn( // Frankly this whole function is disgusting and needs to be burned; I concure from the future
173        &mut self, 
174        data: &mut Vec<Input>, 
175        learning_rate: f32, 
176        name: &str, max_cycles: i32, 
177        distinguising_learning_rate: f32, distinguising_hidden_neurons: i32,
178        distinguising_hidden_layers: i32, distinguising_activation: ActivationFunction,
179        distinguishing_target_err_percent: f32
180    ) -> Result<String, DarjeelingError> {
181        let mut epochs: f32 = 0.0;
182        let distinguishing_model: *mut CatNetwork = &mut CatNetwork::new(self.node_array[self.answer.unwrap()].len() as i32, distinguising_hidden_neurons, 2, distinguising_hidden_layers, distinguising_activation);
183        let mut outputs: Vec<Input> = vec![];
184        for _i in 0..max_cycles {
185            let mse: f32;
186            data.shuffle(&mut thread_rng());
187            for line in 0..data.len() {
188                dbg_println!("Training Checkpoint One Passed");
189                self.push_downstream(data, line as i32);
190                let mut output = vec![];
191                for i in 0..self.node_array[self.answer.unwrap()].len() {
192                    output.push(self.node_array[self.answer.unwrap()][i].output(&self.activation_function));
193                }
194                outputs.push(Input::new(output, Some(Boolean(false)))); // false indicates not real data
195                data[line].answer = Some(Boolean(true));
196                outputs.push(data[line].clone());
197            }
198            // Do we train a new one from scratch or do we continue training the old one
199            // We still need to figure out how to accurately deal with distinguishing error affecting the generative model
200            let mut new_model: CatNetwork = CatNetwork::new(self.node_array[self.answer.unwrap()].len() as i32, distinguising_hidden_neurons, 2, distinguising_hidden_layers, distinguising_activation);
201            match new_model.learn(
202                data,
203                vec![Boolean(true), Boolean(false)],
204                distinguising_learning_rate,
205                &("distinguishing".to_owned() + &name), distinguishing_target_err_percent, false) 
206                {
207                    Ok((_name, _err_percent, errmse)) => mse = errmse,
208                    Err(error) => return Err(DarjeelingError::DisinguishingModelError(error.to_string()))
209                };
210
211            unsafe { distinguishing_model.write(new_model) };
212
213            self.backpropogate(learning_rate, mse);
214            epochs += 1.0;
215            println!("Epoch: {:?}", epochs);
216        }
217        #[allow(unused_mut)]
218        let mut model_name: String;
219        match self.write_model(&name) {
220            Ok(m_name) => {
221                model_name = m_name;
222            },
223            Err(error) => return Err(error)
224        }
225        Ok(model_name)
226    }
227
228    pub fn test(&mut self, data: &mut Vec<Input>) -> Result<Vec<Input>, DarjeelingError> {
229        data.shuffle(&mut thread_rng());
230        let mut outputs: Vec<Input> = vec![];
231        for i in 0..data.len() {
232            self.push_downstream(data, i as i32);
233            let mut output = vec![];
234            for i in 0..self.node_array[self.answer.unwrap()].len() {
235                output.push(self.node_array[self.answer.unwrap()][i].output(&self.activation_function));
236            }
237            outputs.push(Input::new(output, None)); // false indicates not real data
238        }
239        Ok(outputs)
240    }
241
242    /// Passes in data to the sensors, pushs data 'downstream' through the network
243    fn push_downstream(&mut self, data: &mut Vec<Input>, line: i32) {
244
245        // Passes in data for input layer
246        for i in 0..self.node_array[self.sensor.unwrap()].len() {
247            let input  = data[line as usize].inputs[i];
248
249            self.node_array[self.sensor.unwrap()][i].cached_output = Some(input);
250        }
251
252        // Feed-forward values for hidden and output layers
253        for layer in 1..self.node_array.len() {
254
255            for node in 0..self.node_array[layer].len() {
256
257                for prev_node in 0..self.node_array[layer-1].len() {
258                    
259                    // self.node_array[layer][node].link_vals.push(self.node_array[layer-1][prev_node].cached_output.unwrap());
260                    self.node_array[layer][node].link_vals[prev_node] = Some(self.node_array[layer-1][prev_node].cached_output.unwrap());
261                    // I think this line needs to be un-commented
262                    self.node_array[layer][node].output(&self.activation_function);
263                    if DEBUG { if layer == self.answer.unwrap() { println!("Ran output on answer {:?}", self.node_array[layer][node].cached_output) } }
264                }
265                self.node_array[layer][node].output(&self.activation_function);
266            }
267        }
268    }
269
270    /// Finds the index and the brightest node in an array and returns it
271    fn largest_node(&self) -> usize {
272        let mut largest_node = 0;
273        (0..self.node_array[self.answer.unwrap()].len())
274            .into_iter()
275            .for_each(|node| {
276                if self.node_array[self.answer.unwrap()][node].cached_output > self.node_array[self.answer.unwrap()][largest_node].cached_output {
277                    largest_node = node;
278                }
279            });
280        largest_node
281    }
282    /// Goes back through the network adjusting the weights of the all the neurons based on their error signal
283    fn backpropogate(&mut self, learning_rate: f32, mse: f32) {
284        let hidden_layers = (self.node_array.len() - 2) as i32;
285        (self.node_array[self.answer.unwrap()])
286            .par_iter_mut()
287            .for_each(|answer_node| {
288                println!("Node: {:?}", answer_node);
289                answer_node.compute_answer_err_sig_gen(mse, &self.activation_function);
290                dbg_println!("Error: {:?}", answer_node.err_sig.unwrap());
291            });
292        self.adjust_hidden_weights(learning_rate, hidden_layers);
293        // Adjusts weights for answer neurons
294        (self.node_array[self.answer.unwrap()])
295            .par_iter_mut()
296            .for_each(|node| {
297                node.adjust_weights(learning_rate);
298            });
299    }
300
301    #[allow(non_snake_case)]
302    /// Adjusts the weights of all the hidden neurons in a network
303    fn adjust_hidden_weights(&mut self, learning_rate: f32, hidden_layers: i32) {
304        // HIDDEN represents the layer, while hidden represents the node of the layer
305        (1..(hidden_layers + 1) as usize)
306        .into_iter()
307        .for_each(|HIDDEN| {
308            (0..self.node_array[HIDDEN].len())
309            .into_iter()
310            .for_each(|hidden| {
311                self.node_array[HIDDEN][hidden].err_sig = Some(0.0);
312                (0..self.node_array[HIDDEN + 1 ].len())
313                .into_iter()
314                .for_each(|next_layer| {
315                    let next_weight = self.node_array[HIDDEN + 1][next_layer].link_weights[hidden];
316                    self.node_array[HIDDEN + 1][next_layer].err_sig = match self.node_array[HIDDEN + 1][next_layer].err_sig.is_none() {
317                        true => {
318                            Some(0.0)
319                        }, 
320                        false => {
321                            self.node_array[HIDDEN + 1][next_layer].err_sig
322                        }
323                    };
324                    // This changes based on the activation function
325                    self.node_array[HIDDEN][hidden].err_sig = Some(self.node_array[HIDDEN][hidden].err_sig.unwrap() + (self.node_array[HIDDEN + 1][next_layer].err_sig.unwrap() * next_weight));
326                    
327                    let hidden_result = self.node_array[HIDDEN][hidden].cached_output.unwrap();
328                    let multiplied_value = self.node_array[HIDDEN][hidden].err_sig.unwrap() * (hidden_result) * (1.0 - hidden_result);
329                    self.node_array[HIDDEN][hidden].err_sig = Some(multiplied_value);
330                    dbg_println!("next err sig {:?}\nnext weight {:?}\nnew hidden errsig multiply: {:?}\n\nLayer: {:?}\nNode: {:?}\n", self.node_array[HIDDEN + 1][next_layer].err_sig.unwrap(), next_weight, multiplied_value, HIDDEN, hidden);
331                    self.node_array[HIDDEN][hidden].adjust_weights(learning_rate);
332                });
333                    
334            });   
335        });
336    }
337
338    /// Not needed for now
339    /// Analyses the chosen answer node's result.
340    /// Also increments sum and count
341    /// Err if string requested and float exceeds u8 limit (fix by parsing the floats and slicing them)
342    fn self_analysis<'b>(&'b self, epochs: &mut Option<f32>, sum: &'b mut f32, count: &'b mut f32, data: &mut Vec<Input>, line: usize, expected_type: Types) -> Result<Vec<Types>, DarjeelingError> {
343        // println!("answer {}", self.answer.unwrap());
344        // println!("largest index {}", self.largest_node());
345        // println!("{:?}", self);
346        let brightest_node: &Node = &self.node_array[self.answer.unwrap()][self.largest_node()];
347        let brightness: f32 = brightest_node.cached_output.unwrap();
348
349        if !(epochs.is_none()) { // This lets us use the same function for testing and training 
350            if epochs.unwrap() % 10.0 == 0.0 && epochs.unwrap() != 0.0 {
351                println!("\n-------------------------\n");
352                println!("Epoch: {:?}", epochs);
353                println!("Category: {:?} \nBrightness: {:?}", brightest_node.category.as_ref().unwrap(), brightness);
354                if DEBUG {
355                    let dimest_node: &Node = &self.node_array[self.answer.unwrap()][self.node_array[self.answer.unwrap()].len()-1-self.largest_node()];
356                    println!("Chosen category: {:?} \nDimest Brightness: {:?}", dimest_node.category.as_ref().unwrap(), dimest_node.cached_output.unwrap());
357                }
358            }
359        }
360
361        dbg_println!("Category: {:?} \nBrightness: {:?}", brightest_node.category.as_ref().unwrap(), brightness);
362        if brightest_node.category.as_ref().unwrap().eq(&data[line].answer.as_ref().unwrap()) { 
363            dbg_println!("Correct Answer Chosen\nSum++");
364            *sum += 1.0;
365        }
366        *count += 1.0;
367        let mut ret: Vec<Types> = vec![];
368        match expected_type {
369            Types::Integer(_) => {
370                let _ = &self.node_array[self.answer.unwrap()]
371                .iter()
372                .for_each(|node| {
373                    let int = node.cached_output.unwrap() as i32;
374                    ret.push(Types::Integer(int));
375                });
376            }
377            Types::Boolean(_) => {
378                let _ = &self.node_array[self.answer.unwrap()]
379                .iter()
380                .for_each(|node| {
381                    let bool = node.cached_output.unwrap() > 0.0;
382                    ret.push(Types::Boolean(bool));
383                });
384            }
385            Types::Float(_) => {
386                let _ = &self.node_array[self.answer.unwrap()]
387                .iter()
388                .for_each(|node| {
389                    ret.push(Types::Float(node.cached_output.unwrap()));  
390                });
391            }
392            Types::String(_) => {
393                for node in &self.node_array[self.answer.unwrap()] {
394                    let inputs = vec![(node.cached_output.unwrap() as u8)];
395                    let buff = match decimals_to_string(&inputs) {
396                        Ok(val) => val,
397                        Err(err) => return Err(DarjeelingError::SelfAnalysisStringConversion(err))
398                    };
399                    ret.push(Types::String(buff));
400                }
401            }
402        };
403        Ok(ret)
404    }
405
406    /// Serializes a trained model as a .darj file so it can be used later
407    /// 
408    /// ## Returns
409    /// The name of the model
410    /// 
411    /// ## Error
412    /// ### WriteModelFailed:
413    /// Writing to the file failed
414    /// 
415    /// Wraps the models name
416    /// ### UnknownError:
417    /// Something else went wrong
418    /// 
419    /// Wraps error
420    pub fn write_model(&mut self, name: &str) -> Result<String, DarjeelingError> {
421        let mut rng = rand::thread_rng();
422        let file_num: u32 = rng.gen();
423        let model_name: String = format!("model_{}_{}.darj", name, file_num);
424
425        match Path::new(&model_name).try_exists() {
426
427            Ok(false) => {
428                let _file: fs::File = fs::File::create(&model_name).unwrap();
429                let mut serialized = "".to_string();
430                println!("write, length: {}", self.node_array.len());
431                for i in 0..self.node_array.len() {
432                    if i != 0 {
433                        let _ = serialized.push_str("lb\n");
434                    }
435                    for j in 0..self.node_array[i].len() {
436                        for k in 0..self.node_array[i][j].link_weights.len() {
437                            print!("{}", self.node_array[i][j].link_weights[k]);
438                            if k == self.node_array[i][j].link_weights.len() - 1 {
439                                let _ = serialized.push_str(format!("{}", self.node_array[i][j].link_weights[k]).as_str());
440                            } else {
441                                let _ = serialized.push_str(format!("{},", self.node_array[i][j].link_weights[k]).as_str());
442                            }                        
443                        }
444                        let _ = serialized.push_str(format!(";{}", self.node_array[i][j].b_weight.unwrap().to_string()).as_str()); 
445                        let _ = serialized.push_str("\n");
446                    }
447                }
448                serialized.push_str("lb\n");                    
449                serialized.push_str(format!("{}", self.activation_function).as_str());
450                // println!("Serialized: {:?}", serialized);
451                match fs::write(&model_name, serialized) {
452                    Ok(()) => {
453                        println!("Model {:?} Saved", file_num);
454                        Ok(model_name)
455                    },
456                    Err(_error) => {
457                        Err(DarjeelingError::WriteModelFailed(model_name))
458                    }
459                }
460            },
461            Ok(true) => {
462                return self.write_model(name);
463            },
464            Err(error) => Err(DarjeelingError::UnknownError(error.to_string()))
465        }
466    }
467
468    /// Reads a serizalized Neural Network
469    /// 
470    /// ## Params
471    /// - Model Name: The name(or more helpfully the path) of the model to be read
472    /// 
473    /// ## Returns
474    /// A neural network read from a serialized .darj file
475    /// 
476    /// ## Err
477    /// If the file cannnot be read, or if the file does not contain a valid serialized Neural Network
478    pub fn read_model(model_name: String) -> Result<GenNetwork, DarjeelingError> {
479
480        println!("Loading model");
481        
482        // Err if the file reading fails
483        let serialized_net: String = match fs::read_to_string(&model_name) {
484            
485            Ok(serizalized_net) => serizalized_net,
486            Err(error) => return Err(DarjeelingError::ReadModelFailed(model_name.clone() + ";" +  &error.to_string()))
487        };
488 
489        let mut node_array: Vec<Vec<Node>> = vec![];
490        let mut layer: Vec<Node> = vec![];
491        let mut activation: Option<ActivationFunction> = None;
492        for i in serialized_net.lines() {
493            match i {
494                "sigmoid" => activation = Some(ActivationFunction::Sigmoid),
495
496                "linear" => activation = Some(ActivationFunction::Linear),
497
498                "tanh" => activation = Some(ActivationFunction::Tanh),
499 
500                // "step" => activation = Some(ActivationFunction::Step),
501
502                _ => {
503                
504                    if i.trim() == "lb" {
505                        node_array.push(layer.clone());
506                        // println!("pushed layer {:?}", layer.clone());
507                        layer = vec![];
508                        continue;
509                    }
510                    #[allow(unused_mut)]
511                    let mut node: Option<Node>;
512                    if node_array.len() == 0 {
513                        let b_weight: Vec<&str> = i.split(";").collect();
514                        // println!("b_weight: {:?}", b_weight);
515                        node = Some(Node::new(&vec![], Some(b_weight[1].parse().unwrap())));
516                    } else {
517                        let node_data: Vec<&str> = i.trim().split(";").collect();
518                        let str_weight_array: Vec<&str> = node_data[0].split(",").collect();
519                        let mut weight_array: Vec<f32> = vec![];
520                        let b_weight: &str = node_data[1];
521                        // println!("node_data: {:?}", node_data);
522                        // println!("array {:?}", str_weight_array);
523                        for weight in 0..str_weight_array.len() {
524                            // println!("testing here {:?}", str_weight_array[weight]);
525                            let val: f32 = str_weight_array[weight].parse().unwrap();
526                            weight_array.push(val);
527                        }
528                        // print!("{}", b_weight);
529                        node = Some(Node::new(&weight_array, Some(b_weight.parse().unwrap())));
530                    }
531                    
532                    layer.push(node.expect("Both cases provide a Some value for node"));
533                    // println!("layer: {:?}", layer.clone())
534                }
535            }
536            
537        }
538        // println!("node array size {}", node_array.len());
539        let sensor: Option<usize> = Some(0);
540        let answer: Option<usize> = Some(node_array.len() - 1);
541        
542        let net = GenNetwork {
543            node_array,
544            sensor,
545            answer,
546            parameters: None,
547            activation_function: activation.unwrap()
548        };
549        // println!("node array {:?}", net.node_array);
550
551        Ok(net)
552    }
553
554    pub fn add_hidden_layer_with_size(&mut self, size: usize) {
555        let mut rng = rand::thread_rng();
556        let a = self.answer.expect("initialized network");
557        self.node_array.push(self.node_array[a].clone());
558        self.node_array[a] = vec![];
559        let links = self.node_array[a - 1].len();
560        (0..size).into_iter().for_each(|i| {
561            self.node_array[a].push(Node::new(&vec![], Some(rng.gen_range(-0.5..0.5))));
562            self.node_array[a][i].links = links;
563            (0..self.node_array[a][i].links).into_iter().for_each(|_| {
564                self.node_array[a][i].link_weights.push(rng.gen_range(-0.5..0.5));
565                self.node_array[a][i].link_vals.push(None);
566            })
567        });
568        self.answer = Some(a + 1);
569    }
570}
571