1use crate::{
2 checkpoint::Checkpoint, layers::Layer, matrix::Matrix, model::Model, optimizer::Optimizer,
3};
4
5const DEFAULT_LAMBDA: f64 = 0.001;
6const DEFAULT_OPTIMIZER: Optimizer = Optimizer::SGD {
7 learning_step: 0.01,
8};
9const DEFAULT_PRINT_FREQUENCY: usize = 100;
10const DEFAULT_SILENT_MODE: bool = false;
11const DEFAULT_DEBUG: bool = false;
12
13#[derive(Clone)]
14pub struct ModelBuilder {
15 layers: Vec<Layer>,
16 user_defined_lambda: Option<f64>,
17 user_defined_optimizer: Option<Optimizer>,
18 checkpoint: Option<Checkpoint>,
19 user_defined_print_frequency: Option<usize>,
20 user_defined_debug: Option<bool>,
21 user_defined_silent_mode: Option<bool>,
22}
23
24impl ModelBuilder {
25 pub fn new() -> ModelBuilder {
26 ModelBuilder {
27 layers: vec![],
28 user_defined_debug: None,
29 user_defined_silent_mode: None,
30 user_defined_print_frequency: None,
31 user_defined_optimizer: None,
32 user_defined_lambda: None,
33 checkpoint: None,
34 }
35 }
36
37 pub fn add_layer(mut self, layer: Layer) -> ModelBuilder {
38 self.layers.push(layer);
39 self
40 }
41
42 pub fn optimizer(mut self, optimizer: Optimizer) -> ModelBuilder {
43 self.user_defined_optimizer = Some(optimizer);
44 self
45 }
46
47 pub fn l2_reg(mut self, lambda: f64) -> ModelBuilder {
48 self.user_defined_lambda = Some(lambda);
49 self
50 }
51
52 pub fn checkpoint(mut self, checkpoint: Checkpoint) -> ModelBuilder {
53 self.checkpoint = Some(checkpoint);
54 self
55 }
56
57 pub fn verbose(mut self, print_frequency: usize, silent_mode: bool) -> ModelBuilder {
58 self.user_defined_print_frequency = Some(print_frequency);
59 self.user_defined_silent_mode = Some(silent_mode);
60 self
61 }
62
63 pub fn debug(mut self, debug: bool) -> ModelBuilder {
64 self.user_defined_debug = Some(debug);
65 self
66 }
67
68 pub fn build(self) -> Model {
69 assert_ne!(
70 self.layers.len(),
71 0,
72 "Error : No layers have been added to the model"
73 );
74
75 let optimizer: Optimizer = match &self.user_defined_optimizer {
76 Some(optimizer) => optimizer.clone(),
77 None => DEFAULT_OPTIMIZER,
78 };
79
80 let lambda: f64 = match self.user_defined_lambda {
81 Some(lambda) => lambda,
82 None => DEFAULT_LAMBDA,
83 };
84
85 Model::init(self.layers.clone(), optimizer, lambda)
86 }
87
88 pub fn build_and_train(
89 self,
90 data: &Matrix,
91 labels: &Matrix,
92 batch_size: u32,
93 epochs: u32,
94 validation_dataset_size: usize,
95 ) {
96 let print_frequency: usize = match &self.user_defined_print_frequency {
97 Some(v) => *v,
98 None => DEFAULT_PRINT_FREQUENCY,
99 };
100
101 let silent_mode: bool = match self.user_defined_silent_mode {
102 Some(v) => v,
103 None => DEFAULT_SILENT_MODE,
104 };
105
106 let debug: bool = match self.user_defined_debug {
107 Some(v) => v,
108 None => DEFAULT_DEBUG,
109 };
110
111 let checkpoint = self.checkpoint.clone();
112
113 let mut model: Model = self.build();
114 model.train(
115 data,
116 labels,
117 batch_size,
118 epochs,
119 validation_dataset_size,
120 checkpoint,
121 print_frequency,
122 debug,
123 silent_mode,
124 );
125 }
126}