brique/
model_builder.rs

1use crate::{
2    checkpoint::Checkpoint, layers::Layer, matrix::Matrix, model::Model, optimizer::Optimizer,
3};
4
5const DEFAULT_LAMBDA: f64 = 0.001;
6const DEFAULT_OPTIMIZER: Optimizer = Optimizer::SGD {
7    learning_step: 0.01,
8};
9const DEFAULT_PRINT_FREQUENCY: usize = 100;
10const DEFAULT_SILENT_MODE: bool = false;
11const DEFAULT_DEBUG: bool = false;
12
13#[derive(Clone)]
14pub struct ModelBuilder {
15    layers: Vec<Layer>,
16    user_defined_lambda: Option<f64>,
17    user_defined_optimizer: Option<Optimizer>,
18    checkpoint: Option<Checkpoint>,
19    user_defined_print_frequency: Option<usize>,
20    user_defined_debug: Option<bool>,
21    user_defined_silent_mode: Option<bool>,
22}
23
24impl ModelBuilder {
25    pub fn new() -> ModelBuilder {
26        ModelBuilder {
27            layers: vec![],
28            user_defined_debug: None,
29            user_defined_silent_mode: None,
30            user_defined_print_frequency: None,
31            user_defined_optimizer: None,
32            user_defined_lambda: None,
33            checkpoint: None,
34        }
35    }
36
37    pub fn add_layer(mut self, layer: Layer) -> ModelBuilder {
38        self.layers.push(layer);
39        self
40    }
41
42    pub fn optimizer(mut self, optimizer: Optimizer) -> ModelBuilder {
43        self.user_defined_optimizer = Some(optimizer);
44        self
45    }
46
47    pub fn l2_reg(mut self, lambda: f64) -> ModelBuilder {
48        self.user_defined_lambda = Some(lambda);
49        self
50    }
51
52    pub fn checkpoint(mut self, checkpoint: Checkpoint) -> ModelBuilder {
53        self.checkpoint = Some(checkpoint);
54        self
55    }
56
57    pub fn verbose(mut self, print_frequency: usize, silent_mode: bool) -> ModelBuilder {
58        self.user_defined_print_frequency = Some(print_frequency);
59        self.user_defined_silent_mode = Some(silent_mode);
60        self
61    }
62
63    pub fn debug(mut self, debug: bool) -> ModelBuilder {
64        self.user_defined_debug = Some(debug);
65        self
66    }
67
68    pub fn build(self) -> Model {
69        assert_ne!(
70            self.layers.len(),
71            0,
72            "Error : No layers have been added to the model"
73        );
74
75        let optimizer: Optimizer = match &self.user_defined_optimizer {
76            Some(optimizer) => optimizer.clone(),
77            None => DEFAULT_OPTIMIZER,
78        };
79
80        let lambda: f64 = match self.user_defined_lambda {
81            Some(lambda) => lambda,
82            None => DEFAULT_LAMBDA,
83        };
84
85        Model::init(self.layers.clone(), optimizer, lambda)
86    }
87
88    pub fn build_and_train(
89        self,
90        data: &Matrix,
91        labels: &Matrix,
92        batch_size: u32,
93        epochs: u32,
94        validation_dataset_size: usize,
95    ) {
96        let print_frequency: usize = match &self.user_defined_print_frequency {
97            Some(v) => *v,
98            None => DEFAULT_PRINT_FREQUENCY,
99        };
100
101        let silent_mode: bool = match self.user_defined_silent_mode {
102            Some(v) => v,
103            None => DEFAULT_SILENT_MODE,
104        };
105
106        let debug: bool = match self.user_defined_debug {
107            Some(v) => v,
108            None => DEFAULT_DEBUG,
109        };
110
111        let checkpoint = self.checkpoint.clone();
112
113        let mut model: Model = self.build();
114        model.train(
115            data,
116            labels,
117            batch_size,
118            epochs,
119            validation_dataset_size,
120            checkpoint,
121            print_frequency,
122            debug,
123            silent_mode,
124        );
125    }
126}