1use std::collections::HashMap;
2use std::vec::Vec;
3
4use ndarray_stats::QuantileExt;
5use serde::Serialize;
6use serde_yaml;
7
8use log::{debug, error, info, warn};
9
10use std::cell::RefCell;
11use std::sync::Arc;
12
13use std::fs::File;
14use std::io::{ErrorKind, Write};
15use std::time::Instant;
16
17use ndarray::{Axis, Zip};
18
19use std::fs::OpenOptions;
20
21use crossbeam::channel;
22use std::thread;
23
24use crate::dataloader::*;
25use crate::err::CustomError;
26use crate::util::*;
27use crate::cpu_params::*;
28
29use crate::models::Model;
30
31pub enum CallbackReturnAction {
32 None,
33 Stop,
34 StopAndSave,
35}
36
37enum DataloaderMsg {
38 Batch(MiniBatch),
39 DoNext,
40 Pos(usize),
41 Stop,
42}
43
44pub struct Orchestra<T>
46where
47 T: Model + Serialize + Clone,
48{
49 pub train_dl: Option<Box<dyn DataLoader + Send>>,
50 pub test_dl: Option<Box<dyn DataLoader + Send>>,
51 test_err_accum: f64,
52 test_batch_size: usize,
53 is_write_test_err: bool,
54 train_model: Option<T>,
55 test_model: Option<T>,
56 snap_iter: usize,
57 test_iter: usize,
58 cur_iter_err: f32,
59 cur_iter_acc: f64,
60 learn_rate_decay: f32,
61 decay_step: usize,
62 show_accuracy: bool,
63 save_on_finish: bool,
64 pub name: String,
65 callbacks: Vec<Box<dyn FnMut(usize, f32, f64) -> CallbackReturnAction>>,
67}
68
69impl<T> Orchestra<T>
70where
71 T: Model + Serialize + Clone,
72{
73 pub fn new(model: T) -> Self {
74 let mut test_model = model.clone();
75 test_model.set_batch_size_for_tests(model.batch_size());
78
79 Orchestra {
80 train_dl: None,
81 test_dl: None,
82 test_err_accum: 0.0,
83 test_batch_size: 10,
84 is_write_test_err: true,
85 train_model: Some(model),
86 test_model: Some(test_model),
87 snap_iter: 0,
88 test_iter: 100,
89 cur_iter_acc: 0.0,
90 cur_iter_err: 0.0,
91 learn_rate_decay: 1.0,
92 decay_step: 0,
93 show_accuracy: true,
94 save_on_finish: true,
95 name: "network".to_owned(),
96 callbacks: Vec::new(),
97 }
98 }
99
100 pub fn new_for_eval(model: T) -> Self {
101 let tbs = model.batch_size();
102 let test_net = Orchestra {
103 train_dl: None,
104 test_dl: None,
105 test_err_accum: 0.0,
106 test_batch_size: 1,
107 is_write_test_err: true,
108 train_model: None,
109 test_model: Some(model),
110 snap_iter: 0,
111 test_iter: 100,
112 cur_iter_err: 0.0,
113 cur_iter_acc: 0.0,
114 learn_rate_decay: 1.0,
115 decay_step: 0,
116 show_accuracy: true,
117 save_on_finish: true,
118 name: "network".to_owned(),
119 callbacks: Vec::new(),
120 };
121 test_net.test_batch_size(tbs)
122 }
123
124 pub fn test_batch_size(mut self, batch_size: usize) -> Self {
125 if let Some(test_model) = self.test_model.as_mut() {
126 test_model.set_batch_size_for_tests(batch_size);
127 }
128
129 self.test_batch_size = batch_size;
130
131 self
132 }
133
134 pub fn set_test_batch_size(&mut self, batch_size: usize) {
135 self.test_batch_size = batch_size;
136
137 if let Some(test_model) = self.test_model.as_mut() {
138 test_model.set_batch_size(self.test_batch_size);
139 }
140 }
141
142 pub fn test_dataloader(mut self, test_dl: Box<dyn DataLoader + Send>) -> Self {
143 self.test_dl = Some(test_dl);
144 let s = self.test_batch_size(1);
145 s
146 }
147
148 pub fn add_callback(&mut self, c: Box<dyn FnMut(usize, f32, f64) -> CallbackReturnAction>) {
149 self.callbacks.push(c);
150 }
151
152 pub fn write_err_to_file(mut self, state: bool) -> Self {
153 self.is_write_test_err = state;
154 self
155 }
156
157 pub fn set_write_err_to_file(&mut self, state: bool) {
158 self.is_write_test_err = state;
159 }
160
161 pub fn train_batch_size(&self) -> Option<usize> {
162 if let Some(train_model) = &self.train_model {
163 return Some(train_model.batch_size());
164 } else {
165 return None;
166 }
167 }
168
169 pub fn set_train_batch_size(&mut self, batch_size: usize) {
170 if let Some(train_model) = &mut self.train_model {
171 train_model.set_batch_size(batch_size);
172 } else {
173 warn!("Attempting to set batch size to non-existing train model");
174 }
175 }
176
177 pub fn save_network_cfg(&mut self, _path: &str) -> std::io::Result<()> {
178 todo!() }
180
181 pub fn snap_iter(mut self, snap_each_iter: usize) -> Self {
182 self.snap_iter = snap_each_iter;
183 self
184 }
185
186 pub fn set_snap_iter(&mut self, snap_each_iter: usize) {
187 self.snap_iter = snap_each_iter;
188 }
189
190 pub fn test_iter(mut self, test_iter: usize) -> Self {
191 self.test_iter = test_iter;
192 self
193 }
194
195 pub fn set_test_iter(&mut self, test_iter: usize) {
196 self.test_iter = test_iter;
197 }
198
199 pub fn save_model_state(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
200 if let Some(train_model) = self.train_model.as_ref() {
201 return train_model.save_state(path);
202 }
203 Err(Box::new(CustomError::Other))
204 }
205
206 pub fn set_train_dataset(&mut self, data: Box<dyn DataLoader + Send>) {
207 self.train_dl = Some(data)
208 }
209
210 pub fn set_test_dataset(&mut self, data: Box<dyn DataLoader + Send>) {
211 self.test_dl = Some(data)
212 }
213
214 pub fn set_learn_rate_decay(&mut self, decay: f32) {
215 self.learn_rate_decay = decay
216 }
217
218 pub fn set_learn_rate_decay_step(&mut self, step: usize) {
219 self.decay_step = step;
220 }
221
222 fn infer_train_error(&mut self, divider_iter: f64) -> f64 {
223 let err = self.test_err_accum / divider_iter; self.test_err_accum = 0.0;
225 return err;
226 }
227
228 fn test_net(&mut self) -> f64 {
230 let test_dl = self.test_dl.as_mut().expect("Test dataset isn't set");
231 let mut err = 0.0;
232
233 let test_batch = test_dl.next_batch(self.test_batch_size);
234 self.test_model
235 .as_mut()
236 .unwrap()
237 .feedforward(test_batch.input);
238
239 let lr = self.test_model.as_ref().unwrap().output_params();
240 let out = lr.get_2d_buf_t(TypeBuffer::Output);
241 let out = out.borrow();
242
243 let mut accuracy_cnt = 0.0;
244
245 Zip::from(out.rows())
246 .and(test_batch.output.rows())
247 .for_each(|out_r, exp_r| {
248 let mut local_err = 0.0;
249
250 if out_r.argmax() == exp_r.argmax() {
251 accuracy_cnt += 1.0;
252 }
253
254 for i in 0..out_r.shape()[0] {
255 local_err += (exp_r[i] - out_r[i]).powf(2.0);
256 }
257
258 err += (local_err / out_r.shape()[0] as f32).sqrt();
259 });
260
261 accuracy_cnt = accuracy_cnt / self.test_batch_size as f32;
262
263 if self.show_accuracy {
264 info!("Validation accuracy : {}", accuracy_cnt);
265 }
266
267 err as f64 / self.test_batch_size as f64
268 }
269
270 pub fn eval_one(
271 &mut self,
272 data: DataVec,
273 ) -> Result<Arc<RefCell<Array2D>>, Box<dyn std::error::Error>> {
274 if self.test_batch_size != 1 {
275 error!("Invalid batch size {} for test model", self.test_batch_size);
276 return Err(Box::new(CustomError::Other));
277 }
278
279 if let Some(test_model) = self.test_model.as_mut() {
280 let data_len = data.len();
281 let cvt = data.into_shape((1, data_len)).unwrap();
282 test_model.feedforward(cvt);
283
284 let last_lp = test_model.output_params();
285
286 return Ok(last_lp.get_2d_buf_t(TypeBuffer::Output));
287 }
288
289 return Err(Box::new(CustomError::Other));
290 }
291
292 pub fn eval(
293 &mut self,
294 train_data: Array2D,
295 ) -> Result<Arc<RefCell<Array2D>>, Box<dyn std::error::Error>> {
296 if let Some(test_model) = self.test_model.as_mut() {
297 test_model.feedforward(train_data);
298
299 let last_lp = test_model.output_params();
300
301 return Ok(last_lp.get_2d_buf_t(TypeBuffer::Output));
302 }
303
304 if let Some(train_model) = self.train_model.as_mut() {
305 train_model.feedforward(train_data);
306
307 let last_lp = train_model.output_params();
308
309 return Ok(last_lp.get_2d_buf_t(TypeBuffer::Output));
310 }
311
312 error!("Error evaluation !!!");
313
314 return Err(Box::new(CustomError::Other));
315 }
316
317 fn calc_avg_err(last_layer_lr: &CpuParams) -> f32 {
318 let err = last_layer_lr.get_2d_buf_t(TypeBuffer::NeuGrad);
319 let err = err.borrow();
320
321 let sq_sum = err.fold(0.0, |mut sq_sum, el| {
322 sq_sum += el.powf(2.0);
323 return sq_sum;
324 });
325
326 let test_err = (sq_sum / err.nrows() as f32).sqrt();
327 return test_err;
328 }
329
330 fn calc_accuracy(metrics: Option<&Metrics>) -> f64 {
331 if let Some(metrics) = metrics {
332 return metrics["accuracy"];
333 } else {
334 return 0.0; }
336 }
337
338 pub fn set_save_on_finish_flag(&mut self, state: bool) {
339 self.save_on_finish = state;
340 }
341
342 fn perform_learn_rate_decay(&mut self) {
343 let optim = self.train_model.as_mut().unwrap().optimizer_mut();
344 let optim_prm = optim.cfg();
345
346 if let Some(lr) = optim_prm.get("learning_rate") {
347 if let Variant::Float(lr) = lr {
348 let decayed_lr = lr * self.learn_rate_decay;
349
350 info!(
351 "Perfoming learning rate decay : from {}, to {}",
352 lr, decayed_lr
353 );
354
355 let mut m = HashMap::new();
356 m.insert("learning_rate".to_owned(), Variant::Float(decayed_lr));
357 optim.set_cfg(&m);
358 }
359 } else {
360 warn!("Coudln't perform learning rate decay due to cfg entry miss");
361 }
362 }
363
364 fn perform_step(&mut self, mb: MiniBatch) {
365 if let Some(train_model) = self.train_model.as_mut() {
366 train_model.feedforward(mb.input);
367 train_model.backpropagate(mb.output);
368
369 train_model.optimize();
370
371 let lr = train_model.output_params();
373
374 self.cur_iter_err = Self::calc_avg_err(&lr);
375 self.cur_iter_acc = Self::calc_accuracy(train_model.last_layer_metrics());
376
377 self.test_err_accum += self.cur_iter_err as f64;
378 }
379 }
380
381 pub fn train_for_n_times(&mut self, times: usize) -> Result<(), Box<dyn std::error::Error>> {
382 self.train_for_error_or_iter(0.0, times)
383 }
384
385 fn create_empty_error_file(&self) -> Result<File, Box<dyn std::error::Error>> {
386 let file = OpenOptions::new()
387 .write(true)
388 .create(true)
389 .open("err.log")?;
390 Ok(file)
391 }
392
393 fn append_error(&self, f: &mut File, err: f64) -> Result<(), Box<dyn std::error::Error>> {
394 write!(f, "{:.6}\n", err)?;
395 Ok(())
396 }
397
398 pub fn update_test_model(&mut self) {
399 let test_mdl = self.train_model.as_ref().unwrap().clone();
400 self.test_model = Some(test_mdl);
401 }
402
403 pub fn train_for_error(&mut self, err: f64) -> Result<(), Box<dyn std::error::Error>> {
404 self.train_for_error_or_iter(err, 0)
405 }
406
407 pub fn train_epochs_or_error(
408 &mut self,
409 epochs: usize,
410 err: f64,
411 ) -> Result<(), Box<dyn std::error::Error>> {
412 let train_batch_size = self.train_model.as_ref().unwrap().batch_size();
413 let iter = epochs
414 * self
415 .train_dl
416 .as_ref()
417 .unwrap()
418 .len()
419 .expect("Train dataset has not length!")
420 / train_batch_size;
421 self.train_for_error_or_iter(err, iter)
422 }
423
424 pub fn train_for_error_or_iter(
429 &mut self,
430 err: f64,
431 max_iter: usize,
432 ) -> Result<(), Box<dyn std::error::Error>> {
433 let mut iter_num = 0;
434
435 let mut err_file = None;
436
437 if self.is_write_test_err {
438 err_file = Some(self.create_empty_error_file()?);
439 }
440
441 let mut bench_time = Instant::now();
442 let mut test_err = 0.0;
443
444 let mut flag_stop = false;
445 let mut flag_save = false;
446
447 let mut epoch_cnt = 1;
448 let ds_len = self.train_dl.as_ref().unwrap().len().unwrap();
449 let train_batch_size = self.train_model.as_ref().unwrap().batch_size();
450
451 let ten_perc_metric = ds_len as f64 * 0.1; let mut prev_pos = 0;
453 let mut ten_perc_num = 0;
454
455 let mut accuracy_sum = 0.0;
456
457 let (tx_thr, rx_cur) = channel::bounded(2);
458 let (tx_cur, rx_thr) = channel::bounded(2);
459
460 let mut train_dl_to_thr = std::mem::replace(&mut self.train_dl, None); let thread_join = thread::spawn(move || {
463 loop {
464 let ds_pos = train_dl_to_thr.as_ref().unwrap().pos().unwrap();
465
466 let batch = train_dl_to_thr
467 .as_mut()
468 .unwrap()
469 .next_batch(train_batch_size);
470
471 tx_thr
472 .send(DataloaderMsg::Pos(ds_pos))
473 .expect("Failed to send dataset position from thread");
474 tx_thr.send(DataloaderMsg::Batch(batch)).unwrap();
475
476 let resp = rx_thr.recv().unwrap();
477
478 match resp {
479 DataloaderMsg::Stop => break,
480 _ => continue,
481 };
482 }
483 return train_dl_to_thr;
484 });
485
486 loop {
487 {
489 let ds_pos = match rx_cur.recv().unwrap() {
490 DataloaderMsg::Pos(p) => p,
491 _ => panic!("Invalid message"), };
493
494 accuracy_sum += self.cur_iter_acc;
495
496 if train_batch_size * 10 < ds_len && (ds_pos >= (ten_perc_num + 1) as usize * ten_perc_metric as usize
498 || prev_pos > ds_pos)
499 {
500 info!(
501 "Done {}% of {} epoch, error : {:.5}",
502 (ten_perc_num + 1) * 10,
503 epoch_cnt,
504 self.test_err_accum
505 / (ten_perc_metric * (ten_perc_num + 1) as f64
506 / train_batch_size as f64),
507 );
508
509 if accuracy_sum != 0.0 {
510 info!(
511 "Accuracy : {:.4}",
512 accuracy_sum
513 / (ten_perc_metric * (ten_perc_num + 1) as f64
514 / train_batch_size as f64)
515 );
516 }
517
518 ten_perc_num += 1;
519
520 if ten_perc_num > 9 {
521 test_err = self.infer_train_error((ds_len / train_batch_size) as f64); if test_err < err {
524 info!("Reached satisfying error value");
525 break;
526 }
527
528 let elapsed = bench_time.elapsed();
529
530 info!(
531 "Epoch {} for {:.3} seconds, error : {:.6}",
532 epoch_cnt,
533 elapsed.as_secs_f64(),
534 test_err,
535 );
536
537 epoch_cnt += 1;
538 ten_perc_num = 0;
539 prev_pos = 0;
540 accuracy_sum = 0.0;
541
542 bench_time = Instant::now();
543 } else {
544 prev_pos = ds_pos;
545 }
546 } else {
547 test_err = self.cur_iter_err as f64;
548 }
549 }
550
551 if self.test_dl.is_some() {
553 if iter_num % self.test_iter == 0 && iter_num != 0 {
554 info!("Testing net on {} iteration", iter_num);
555
556 let val_test_err = self.test_net();
557
558 info!("Validation error value : {:.5}", val_test_err);
559
560 if val_test_err < err {
561 info!("Reached satisfying error value on validation dataset!");
562 break;
563 }
564 }
565
566 if self.test_iter != 0 && iter_num % self.test_iter == 0 && iter_num != 0 {
567 if self.is_write_test_err {
568 self.append_error(err_file.as_mut().unwrap(), test_err)?;
569 }
570 }
571 }
572
573 if max_iter != 0 && iter_num >= max_iter {
574 info!("Reached max iteration");
575 break;
576 }
577
578 if let DataloaderMsg::Batch(minibatch) = rx_cur.recv().unwrap() {
579 tx_cur.send(DataloaderMsg::DoNext).unwrap();
580 self.perform_step(minibatch);
581 } else {
582 todo!("Handle");
583 }
584
585 if iter_num != 0 && self.decay_step != 0 && iter_num % self.decay_step == 0 {
586 self.perform_learn_rate_decay();
587 }
588
589 if self.snap_iter != 0 && iter_num % self.snap_iter == 0 && iter_num != 0 {
590 let filename = format!("{}_{}.state", self.name, iter_num);
591 self.save_model_state(&filename)?;
592 }
593
594 for it_cb in self.callbacks.iter_mut() {
595 let out = it_cb(iter_num, self.cur_iter_err, self.cur_iter_acc);
596
597 match out {
598 CallbackReturnAction::None => (),
599 CallbackReturnAction::Stop => {
600 info!("Stopping training loop on {} iteration...", iter_num);
601 info!("Last test error : {}", test_err);
602 flag_stop = true;
603 }
604 CallbackReturnAction::StopAndSave => {
605 info!("Stopping training loop on {} iteration...", iter_num);
606 info!("Last test error : {}", test_err);
607 flag_save = true;
608 flag_stop = true;
609 }
610 }
611 }
612
613 if flag_stop {
614 break;
615 }
616
617 iter_num += 1;
618 }
619
620 if flag_save {
621 let filename = format!("{}_{}_int.state", self.name, iter_num);
622 info!("Saving net to file {}", filename);
623 self.save_model_state(&filename)?;
624 }
625
626 tx_cur.send(DataloaderMsg::Stop).unwrap();
627 let train_dl = thread_join
628 .join()
629 .expect("Failed to join dataloader thread");
630 self.train_dl = train_dl;
631
632 if flag_stop {
633 return Ok(());
634 }
635
636 info!("Training finished !");
637 info!("Trained for error : {}", test_err);
638 info!("Iterations : {}", iter_num);
639
640 if self.save_on_finish {
641 let filename = format!("{}_{}_final.state", self.name, iter_num);
642 self.save_model_state(&filename)?;
643 }
644
645 Ok(())
646 }
647
648 pub fn train_model(&self) -> Option<&T> {
649 self.train_model.as_ref()
650 }
651
652 pub fn test_model(&self) -> Option<&T> {
653 self.test_model.as_ref()
654 }
655
656 pub fn train_model_mut(&mut self) -> Option<&mut T> {
657 self.train_model.as_mut()
658 }
659
660 pub fn test_model_mut(&mut self) -> Option<&mut T> {
661 self.test_model.as_mut()
662 }
663}
664
665pub fn save_model_cfg<S: Model + Serialize>(solver: &S, path: &str) -> std::io::Result<()> {
667 let yaml_str_result = serde_yaml::to_string(solver);
668
669 let mut output = File::create(path)?;
670
671 match yaml_str_result {
672 Ok(yaml_str) => {
673 output.write_all(yaml_str.as_bytes())?;
674 }
675 Err(x) => {
676 error!("Error (serde-yaml) serializing net layers !!!");
677 return Err(std::io::Error::new(ErrorKind::Other, x));
678 }
679 }
680
681 Ok(())
682}