nevermind_neu/
orchestra.rs

1use std::collections::HashMap;
2use std::vec::Vec;
3
4use ndarray_stats::QuantileExt;
5use serde::Serialize;
6use serde_yaml;
7
8use log::{debug, error, info, warn};
9
10use std::cell::RefCell;
11use std::sync::Arc;
12
13use std::fs::File;
14use std::io::{ErrorKind, Write};
15use std::time::Instant;
16
17use ndarray::{Axis, Zip};
18
19use std::fs::OpenOptions;
20
21use crossbeam::channel;
22use std::thread;
23
24use crate::dataloader::*;
25use crate::err::CustomError;
26use crate::util::*;
27use crate::cpu_params::*;
28
29use crate::models::Model;
30
31pub enum CallbackReturnAction {
32    None,
33    Stop,
34    StopAndSave,
35}
36
37enum DataloaderMsg {
38    Batch(MiniBatch),
39    DoNext,
40    Pos(usize),
41    Stop,
42}
43
44/// Neural-Network learning orchestrator
45pub struct Orchestra<T>
46where
47    T: Model + Serialize + Clone,
48{
49    pub train_dl: Option<Box<dyn DataLoader + Send>>,
50    pub test_dl: Option<Box<dyn DataLoader + Send>>,
51    test_err_accum: f64,
52    test_batch_size: usize,
53    is_write_test_err: bool,
54    train_model: Option<T>,
55    test_model: Option<T>,
56    snap_iter: usize,
57    test_iter: usize,
58    cur_iter_err: f32,
59    cur_iter_acc: f64,
60    learn_rate_decay: f32,
61    decay_step: usize,
62    show_accuracy: bool,
63    save_on_finish: bool,
64    pub name: String,
65    // callback fn args : (iteration_number, current iteration loss, accuracy)
66    callbacks: Vec<Box<dyn FnMut(usize, f32, f64) -> CallbackReturnAction>>,
67}
68
69impl<T> Orchestra<T>
70where
71    T: Model + Serialize + Clone,
72{
73    pub fn new(model: T) -> Self {
74        let mut test_model = model.clone();
75        // Creates separate output buffer for test network,
76        // weights buffer is the same as in train model
77        test_model.set_batch_size_for_tests(model.batch_size());
78
79        Orchestra {
80            train_dl: None,
81            test_dl: None,
82            test_err_accum: 0.0,
83            test_batch_size: 10,
84            is_write_test_err: true,
85            train_model: Some(model),
86            test_model: Some(test_model),
87            snap_iter: 0,
88            test_iter: 100,
89            cur_iter_acc: 0.0,
90            cur_iter_err: 0.0,
91            learn_rate_decay: 1.0,
92            decay_step: 0,
93            show_accuracy: true,
94            save_on_finish: true,
95            name: "network".to_owned(),
96            callbacks: Vec::new(),
97        }
98    }
99
100    pub fn new_for_eval(model: T) -> Self {
101        let tbs = model.batch_size();
102        let test_net = Orchestra {
103            train_dl: None,
104            test_dl: None,
105            test_err_accum: 0.0,
106            test_batch_size: 1,
107            is_write_test_err: true,
108            train_model: None,
109            test_model: Some(model),
110            snap_iter: 0,
111            test_iter: 100,
112            cur_iter_err: 0.0,
113            cur_iter_acc: 0.0,
114            learn_rate_decay: 1.0,
115            decay_step: 0,
116            show_accuracy: true,
117            save_on_finish: true,
118            name: "network".to_owned(),
119            callbacks: Vec::new(),
120        };
121        test_net.test_batch_size(tbs)
122    }
123
124    pub fn test_batch_size(mut self, batch_size: usize) -> Self {
125        if let Some(test_model) = self.test_model.as_mut() {
126            test_model.set_batch_size_for_tests(batch_size);
127        }
128
129        self.test_batch_size = batch_size;
130
131        self
132    }
133
134    pub fn set_test_batch_size(&mut self, batch_size: usize) {
135        self.test_batch_size = batch_size;
136
137        if let Some(test_model) = self.test_model.as_mut() {
138            test_model.set_batch_size(self.test_batch_size);
139        }
140    }
141
142    pub fn test_dataloader(mut self, test_dl: Box<dyn DataLoader + Send>) -> Self {
143        self.test_dl = Some(test_dl);
144        let s = self.test_batch_size(1);
145        s
146    }
147
148    pub fn add_callback(&mut self, c: Box<dyn FnMut(usize, f32, f64) -> CallbackReturnAction>) {
149        self.callbacks.push(c);
150    }
151
152    pub fn write_err_to_file(mut self, state: bool) -> Self {
153        self.is_write_test_err = state;
154        self
155    }
156
157    pub fn set_write_err_to_file(&mut self, state: bool) {
158        self.is_write_test_err = state;
159    }
160
161    pub fn train_batch_size(&self) -> Option<usize> {
162        if let Some(train_model) = &self.train_model {
163            return Some(train_model.batch_size());
164        } else {
165            return None;
166        }
167    }
168
169    pub fn set_train_batch_size(&mut self, batch_size: usize) {
170        if let Some(train_model) = &mut self.train_model {
171            train_model.set_batch_size(batch_size);
172        } else {
173            warn!("Attempting to set batch size to non-existing train model");
174        }
175    }
176
177    pub fn save_network_cfg(&mut self, _path: &str) -> std::io::Result<()> {
178        todo!() // TODO : need to save net.cfg with layers_cfg and optimizer_cfg
179    }
180
181    pub fn snap_iter(mut self, snap_each_iter: usize) -> Self {
182        self.snap_iter = snap_each_iter;
183        self
184    }
185
186    pub fn set_snap_iter(&mut self, snap_each_iter: usize) {
187        self.snap_iter = snap_each_iter;
188    }
189
190    pub fn test_iter(mut self, test_iter: usize) -> Self {
191        self.test_iter = test_iter;
192        self
193    }
194
195    pub fn set_test_iter(&mut self, test_iter: usize) {
196        self.test_iter = test_iter;
197    }
198
199    pub fn save_model_state(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
200        if let Some(train_model) = self.train_model.as_ref() {
201            return train_model.save_state(path);
202        }
203        Err(Box::new(CustomError::Other))
204    }
205
206    pub fn set_train_dataset(&mut self, data: Box<dyn DataLoader + Send>) {
207        self.train_dl = Some(data)
208    }
209
210    pub fn set_test_dataset(&mut self, data: Box<dyn DataLoader + Send>) {
211        self.test_dl = Some(data)
212    }
213
214    pub fn set_learn_rate_decay(&mut self, decay: f32) {
215        self.learn_rate_decay = decay
216    }
217
218    pub fn set_learn_rate_decay_step(&mut self, step: usize) {
219        self.decay_step = step;
220    }
221
222    fn infer_train_error(&mut self, divider_iter: f64) -> f64 {
223        let err = self.test_err_accum / divider_iter; // error average
224        self.test_err_accum = 0.0;
225        return err;
226    }
227
228    /// Test net and returns and average error
229    fn test_net(&mut self) -> f64 {
230        let test_dl = self.test_dl.as_mut().expect("Test dataset isn't set");
231        let mut err = 0.0;
232
233        let test_batch = test_dl.next_batch(self.test_batch_size);
234        self.test_model
235            .as_mut()
236            .unwrap()
237            .feedforward(test_batch.input);
238
239        let lr = self.test_model.as_ref().unwrap().output_params();
240        let out = lr.get_2d_buf_t(TypeBuffer::Output);
241        let out = out.borrow();
242
243        let mut accuracy_cnt = 0.0;
244
245        Zip::from(out.rows())
246            .and(test_batch.output.rows())
247            .for_each(|out_r, exp_r| {
248                let mut local_err = 0.0;
249
250                if out_r.argmax() == exp_r.argmax() {
251                    accuracy_cnt += 1.0;
252                }
253
254                for i in 0..out_r.shape()[0] {
255                    local_err += (exp_r[i] - out_r[i]).powf(2.0);
256                }
257
258                err += (local_err / out_r.shape()[0] as f32).sqrt();
259            });
260
261        accuracy_cnt = accuracy_cnt / self.test_batch_size as f32;
262
263        if self.show_accuracy {
264            info!("Validation accuracy : {}", accuracy_cnt);
265        }
266
267        err as f64 / self.test_batch_size as f64
268    }
269
270    pub fn eval_one(
271        &mut self,
272        data: DataVec,
273    ) -> Result<Arc<RefCell<Array2D>>, Box<dyn std::error::Error>> {
274        if self.test_batch_size != 1 {
275            error!("Invalid batch size {} for test model", self.test_batch_size);
276            return Err(Box::new(CustomError::Other));
277        }
278
279        if let Some(test_model) = self.test_model.as_mut() {
280            let data_len = data.len();
281            let cvt = data.into_shape((1, data_len)).unwrap();
282            test_model.feedforward(cvt);
283
284            let last_lp = test_model.output_params();
285
286            return Ok(last_lp.get_2d_buf_t(TypeBuffer::Output));
287        }
288
289        return Err(Box::new(CustomError::Other));
290    }
291
292    pub fn eval(
293        &mut self,
294        train_data: Array2D,
295    ) -> Result<Arc<RefCell<Array2D>>, Box<dyn std::error::Error>> {
296        if let Some(test_model) = self.test_model.as_mut() {
297            test_model.feedforward(train_data);
298
299            let last_lp = test_model.output_params();
300
301            return Ok(last_lp.get_2d_buf_t(TypeBuffer::Output));
302        }
303
304        if let Some(train_model) = self.train_model.as_mut() {
305            train_model.feedforward(train_data);
306
307            let last_lp = train_model.output_params();
308
309            return Ok(last_lp.get_2d_buf_t(TypeBuffer::Output));
310        }
311
312        error!("Error evaluation !!!");
313
314        return Err(Box::new(CustomError::Other));
315    }
316
317    fn calc_avg_err(last_layer_lr: &CpuParams) -> f32 {
318        let err = last_layer_lr.get_2d_buf_t(TypeBuffer::NeuGrad);
319        let err = err.borrow();
320
321        let sq_sum = err.fold(0.0, |mut sq_sum, el| {
322            sq_sum += el.powf(2.0);
323            return sq_sum;
324        });
325
326        let test_err = (sq_sum / err.nrows() as f32).sqrt();
327        return test_err;
328    }
329
330    fn calc_accuracy(metrics: Option<&Metrics>) -> f64 {
331        if let Some(metrics) = metrics {
332            return metrics["accuracy"];
333        } else {
334            return 0.0; // none ?
335        }
336    }
337
338    pub fn set_save_on_finish_flag(&mut self, state: bool) {
339        self.save_on_finish = state;
340    }
341
342    fn perform_learn_rate_decay(&mut self) {
343        let optim = self.train_model.as_mut().unwrap().optimizer_mut();
344        let optim_prm = optim.cfg();
345
346        if let Some(lr) = optim_prm.get("learning_rate") {
347            if let Variant::Float(lr) = lr {
348                let decayed_lr = lr * self.learn_rate_decay;
349
350                info!(
351                    "Perfoming learning rate decay : from {}, to {}",
352                    lr, decayed_lr
353                );
354
355                let mut m = HashMap::new();
356                m.insert("learning_rate".to_owned(), Variant::Float(decayed_lr));
357                optim.set_cfg(&m);
358            }
359        } else {
360            warn!("Coudln't perform learning rate decay due to cfg entry miss");
361        }
362    }
363
364    fn perform_step(&mut self, mb: MiniBatch) {
365        if let Some(train_model) = self.train_model.as_mut() {
366            train_model.feedforward(mb.input);
367            train_model.backpropagate(mb.output);
368
369            train_model.optimize();
370
371            // store current iteration loss and accuracy
372            let lr = train_model.output_params();
373
374            self.cur_iter_err = Self::calc_avg_err(&lr);
375            self.cur_iter_acc = Self::calc_accuracy(train_model.last_layer_metrics());
376
377            self.test_err_accum += self.cur_iter_err as f64;
378        }
379    }
380
381    pub fn train_for_n_times(&mut self, times: usize) -> Result<(), Box<dyn std::error::Error>> {
382        self.train_for_error_or_iter(0.0, times)
383    }
384
385    fn create_empty_error_file(&self) -> Result<File, Box<dyn std::error::Error>> {
386        let file = OpenOptions::new()
387            .write(true)
388            .create(true)
389            .open("err.log")?;
390        Ok(file)
391    }
392
393    fn append_error(&self, f: &mut File, err: f64) -> Result<(), Box<dyn std::error::Error>> {
394        write!(f, "{:.6}\n", err)?;
395        Ok(())
396    }
397
398    pub fn update_test_model(&mut self) {
399        let test_mdl = self.train_model.as_ref().unwrap().clone();
400        self.test_model = Some(test_mdl);
401    }
402
403    pub fn train_for_error(&mut self, err: f64) -> Result<(), Box<dyn std::error::Error>> {
404        self.train_for_error_or_iter(err, 0)
405    }
406
407    pub fn train_epochs_or_error(
408        &mut self,
409        epochs: usize,
410        err: f64,
411    ) -> Result<(), Box<dyn std::error::Error>> {
412        let train_batch_size = self.train_model.as_ref().unwrap().batch_size();
413        let iter = epochs
414            * self
415                .train_dl
416                .as_ref()
417                .unwrap()
418                .len()
419                .expect("Train dataset has not length!")
420            / train_batch_size;
421        self.train_for_error_or_iter(err, iter)
422    }
423
424    /// Trains till error becomes lower than err or
425    /// train iteration more then max_iter.
426    /// If err is 0, it will ignore the error threshold.
427    /// If max_iter is 0, it will ignore max_iter argument.
428    pub fn train_for_error_or_iter(
429        &mut self,
430        err: f64,
431        max_iter: usize,
432    ) -> Result<(), Box<dyn std::error::Error>> {
433        let mut iter_num = 0;
434
435        let mut err_file = None;
436
437        if self.is_write_test_err {
438            err_file = Some(self.create_empty_error_file()?);
439        }
440
441        let mut bench_time = Instant::now();
442        let mut test_err = 0.0;
443
444        let mut flag_stop = false;
445        let mut flag_save = false;
446
447        let mut epoch_cnt = 1;
448        let ds_len = self.train_dl.as_ref().unwrap().len().unwrap();
449        let train_batch_size = self.train_model.as_ref().unwrap().batch_size();
450
451        let ten_perc_metric = ds_len as f64 * 0.1; // for 10% , 20% done displaying
452        let mut prev_pos = 0;
453        let mut ten_perc_num = 0;
454
455        let mut accuracy_sum = 0.0;
456
457        let (tx_thr, rx_cur) = channel::bounded(2);
458        let (tx_cur, rx_thr) = channel::bounded(2);
459
460        let mut train_dl_to_thr = std::mem::replace(&mut self.train_dl, None); // we need to move dataloader to another thread for async batch preparing
461
462        let thread_join = thread::spawn(move || {
463            loop {
464                let ds_pos = train_dl_to_thr.as_ref().unwrap().pos().unwrap();
465
466                let batch = train_dl_to_thr
467                    .as_mut()
468                    .unwrap()
469                    .next_batch(train_batch_size);
470
471                tx_thr
472                    .send(DataloaderMsg::Pos(ds_pos))
473                    .expect("Failed to send dataset position from thread");
474                tx_thr.send(DataloaderMsg::Batch(batch)).unwrap();
475
476                let resp = rx_thr.recv().unwrap();
477
478                match resp {
479                    DataloaderMsg::Stop => break,
480                    _ => continue,
481                };
482            }
483            return train_dl_to_thr;
484        });
485
486        loop {
487            // Error calc for each 10%
488            {
489                let ds_pos = match rx_cur.recv().unwrap() {
490                    DataloaderMsg::Pos(p) => p,
491                    _ => panic!("Invalid message"), // TODO : handle without panic
492                };
493
494                accuracy_sum += self.cur_iter_acc;
495
496                if train_batch_size * 10 < ds_len // for small datasets do not display percentages
497                    && (ds_pos >= (ten_perc_num + 1) as usize * ten_perc_metric as usize
498                        || prev_pos > ds_pos)
499                {
500                    info!(
501                        "Done {}% of {} epoch, error : {:.5}",
502                        (ten_perc_num + 1) * 10,
503                        epoch_cnt,
504                        self.test_err_accum
505                            / (ten_perc_metric * (ten_perc_num + 1) as f64
506                                / train_batch_size as f64),
507                    );
508
509                    if accuracy_sum != 0.0 {
510                        info!(
511                            "Accuracy : {:.4}",
512                            accuracy_sum
513                                / (ten_perc_metric * (ten_perc_num + 1) as f64
514                                    / train_batch_size as f64)
515                        );
516                    }
517
518                    ten_perc_num += 1;
519
520                    if ten_perc_num > 9 {
521                        test_err = self.infer_train_error((ds_len / train_batch_size) as f64); // average error on train dataset
522
523                        if test_err < err {
524                            info!("Reached satisfying error value");
525                            break;
526                        }
527
528                        let elapsed = bench_time.elapsed();
529
530                        info!(
531                            "Epoch {} for {:.3} seconds, error : {:.6}",
532                            epoch_cnt,
533                            elapsed.as_secs_f64(),
534                            test_err,
535                        );
536
537                        epoch_cnt += 1;
538                        ten_perc_num = 0;
539                        prev_pos = 0;
540                        accuracy_sum = 0.0;
541
542                        bench_time = Instant::now();
543                    } else {
544                        prev_pos = ds_pos;
545                    }
546                } else {
547                    test_err = self.cur_iter_err as f64;
548                }
549            }
550
551            // Validation dataset or (testing dataset)
552            if self.test_dl.is_some() {
553                if iter_num % self.test_iter == 0 && iter_num != 0 {
554                    info!("Testing net on {} iteration", iter_num);
555
556                    let val_test_err = self.test_net();
557
558                    info!("Validation error value : {:.5}", val_test_err);
559
560                    if val_test_err < err {
561                        info!("Reached satisfying error value on validation dataset!");
562                        break;
563                    }
564                }
565
566                if self.test_iter != 0 && iter_num % self.test_iter == 0 && iter_num != 0 {
567                    if self.is_write_test_err {
568                        self.append_error(err_file.as_mut().unwrap(), test_err)?;
569                    }
570                }
571            }
572
573            if max_iter != 0 && iter_num >= max_iter {
574                info!("Reached max iteration");
575                break;
576            }
577
578            if let DataloaderMsg::Batch(minibatch) = rx_cur.recv().unwrap() {
579                tx_cur.send(DataloaderMsg::DoNext).unwrap();
580                self.perform_step(minibatch);
581            } else {
582                todo!("Handle");
583            }
584
585            if iter_num != 0 && self.decay_step != 0 && iter_num % self.decay_step == 0 {
586                self.perform_learn_rate_decay();
587            }
588
589            if self.snap_iter != 0 && iter_num % self.snap_iter == 0 && iter_num != 0 {
590                let filename = format!("{}_{}.state", self.name, iter_num);
591                self.save_model_state(&filename)?;
592            }
593
594            for it_cb in self.callbacks.iter_mut() {
595                let out = it_cb(iter_num, self.cur_iter_err, self.cur_iter_acc);
596
597                match out {
598                    CallbackReturnAction::None => (),
599                    CallbackReturnAction::Stop => {
600                        info!("Stopping training loop on {} iteration...", iter_num);
601                        info!("Last test error : {}", test_err);
602                        flag_stop = true;
603                    }
604                    CallbackReturnAction::StopAndSave => {
605                        info!("Stopping training loop on {} iteration...", iter_num);
606                        info!("Last test error : {}", test_err);
607                        flag_save = true;
608                        flag_stop = true;
609                    }
610                }
611            }
612
613            if flag_stop {
614                break;
615            }
616
617            iter_num += 1;
618        }
619
620        if flag_save {
621            let filename = format!("{}_{}_int.state", self.name, iter_num);
622            info!("Saving net to file {}", filename);
623            self.save_model_state(&filename)?;
624        }
625
626        tx_cur.send(DataloaderMsg::Stop).unwrap();
627        let train_dl = thread_join
628            .join()
629            .expect("Failed to join dataloader thread");
630        self.train_dl = train_dl;
631
632        if flag_stop {
633            return Ok(());
634        }
635
636        info!("Training finished !");
637        info!("Trained for error : {}", test_err);
638        info!("Iterations : {}", iter_num);
639
640        if self.save_on_finish {
641            let filename = format!("{}_{}_final.state", self.name, iter_num);
642            self.save_model_state(&filename)?;
643        }
644
645        Ok(())
646    }
647
648    pub fn train_model(&self) -> Option<&T> {
649        self.train_model.as_ref()
650    }
651
652    pub fn test_model(&self) -> Option<&T> {
653        self.test_model.as_ref()
654    }
655
656    pub fn train_model_mut(&mut self) -> Option<&mut T> {
657        self.train_model.as_mut()
658    }
659
660    pub fn test_model_mut(&mut self) -> Option<&mut T> {
661        self.test_model.as_mut()
662    }
663}
664
665/// TODO : rename this method. it make a confusion with save_network_cfg ?
666pub fn save_model_cfg<S: Model + Serialize>(solver: &S, path: &str) -> std::io::Result<()> {
667    let yaml_str_result = serde_yaml::to_string(solver);
668
669    let mut output = File::create(path)?;
670
671    match yaml_str_result {
672        Ok(yaml_str) => {
673            output.write_all(yaml_str.as_bytes())?;
674        }
675        Err(x) => {
676            error!("Error (serde-yaml) serializing net layers !!!");
677            return Err(std::io::Error::new(ErrorKind::Other, x));
678        }
679    }
680
681    Ok(())
682}