concision_core/nn/traits/model.rs
1/*
2 appellation: models <module>
3 authors: @FL03
4*/
5use crate::config::ModelConfiguration;
6use crate::{DeepModelParams, LayoutExt, RawModelLayout};
7use concision_params::Params;
8use concision_traits::Predict;
9
10/// The [`Model`] trait defines the core interface for all models; implementors will need to
11/// provide the type of configuration used by the model, the type of layout used by the model,
12/// and the type of parameters used by the model. The crate provides standard, or default,
13/// definitions of both the configuration and layout types, however, for
14pub trait Model<T = f32> {
15 /// The type of configuration used for the model
16 type Config: ModelConfiguration<T>;
17 /// The type of [`ModelLayout`] used by the model for this implementation.
18 type Layout: LayoutExt;
19 /// returns an immutable reference to the models configuration; this is typically used to
20 /// access the models hyperparameters (i.e. learning rate, momentum, etc.) and other
21 /// related control parameters.
22 fn config(&self) -> &Self::Config;
23 /// returns a mutable reference to the models configuration; useful for setting hyperparams
24 fn config_mut(&mut self) -> &mut Self::Config;
25 /// returns a copy of the model's current layout (features); a type providing the model
26 /// with a particular number of features for the various layers of a deep neural network.
27 ///
28 /// the layout is used in everything from creation and initialization routines to
29 /// validating the dimensionality of the model's inputs, outputs, training data, etc.
30 fn layout(&self) -> &Self::Layout;
31 /// returns an immutable reference to the model parameters
32 fn params(&self) -> &DeepModelParams<T>;
33 /// returns a mutable reference to the model's parameters
34 fn params_mut(&mut self) -> &mut DeepModelParams<T>;
35 /// propagates the input through the model; each layer is applied in sequence meaning that
36 /// the output of each previous layer is the input to the next layer. This pattern
37 /// repeats until the output layer returns the final result.
38 ///
39 /// By default, the trait simply passes each output from one layer to the next, however,
40 /// custom models will likely override this method to inject activation methods and other
41 /// related logic
42 fn predict<U, V>(&self, inputs: &U) -> V
43 where
44 Self: Predict<U, Output = V>,
45 {
46 Predict::predict(self, inputs)
47 }
48}
49
50pub trait ModelExt<T>: Model<T> {
51 /// [`replace`](core::mem::replace) the current configuration and returns the old one;
52 fn replace_config(&mut self, config: Self::Config) -> Self::Config {
53 core::mem::replace(self.config_mut(), config)
54 }
55 /// [`replace`](core::mem::replace) the current model parameters and returns the old one
56 fn replace_params(&mut self, params: DeepModelParams<T>) -> DeepModelParams<T> {
57 core::mem::replace(self.params_mut(), params)
58 }
59 /// overrides the current configuration and returns a mutable reference to the model
60 fn set_config(&mut self, config: Self::Config) -> &mut Self {
61 *self.config_mut() = config;
62 self
63 }
64 /// overrides the current model parameters and returns a mutable reference to the model
65 fn set_params(&mut self, params: DeepModelParams<T>) -> &mut Self {
66 *self.params_mut() = params;
67 self
68 }
69 /// returns an immutable reference to the input layer;
70 #[inline]
71 fn input_layer(&self) -> &Params<T> {
72 self.params().input()
73 }
74 /// returns a mutable reference to the input layer;
75 #[inline]
76 fn input_layer_mut(&mut self) -> &mut Params<T> {
77 self.params_mut().input_mut()
78 }
79 /// returns an immutable reference to the hidden layer(s);
80 #[inline]
81 fn hidden_layers(&self) -> &Vec<Params<T>> {
82 self.params().hidden()
83 }
84 /// returns a mutable reference to the hidden layer(s);
85 #[inline]
86 fn hidden_layers_mut(&mut self) -> &mut Vec<Params<T>> {
87 self.params_mut().hidden_mut()
88 }
89 /// returns an immutable reference to the output layer;
90 #[inline]
91 fn output_layer(&self) -> &Params<T> {
92 self.params().output()
93 }
94 /// returns a mutable reference to the output layer;
95 #[inline]
96 fn output_layer_mut(&mut self) -> &mut Params<T> {
97 self.params_mut().output_mut()
98 }
99 #[inline]
100 fn set_input_layer(&mut self, layer: Params<T>) -> &mut Self {
101 self.params_mut().set_input(layer);
102 self
103 }
104 #[inline]
105 fn set_hidden_layers(&mut self, layers: Vec<Params<T>>) -> &mut Self {
106 self.params_mut().set_hidden(layers);
107 self
108 }
109 #[inline]
110 fn set_output_layer(&mut self, layer: Params<T>) -> &mut Self {
111 self.params_mut().set_output(layer);
112 self
113 }
114 /// returns a 2-tuple representing the dimensions of the input layer; (input, hidden)
115 fn input_dim(&self) -> (usize, usize) {
116 self.layout().dim_input()
117 }
118 /// returns a 2-tuple representing the dimensions of the hidden layers; (hidden, hidden)
119 fn hidden_dim(&self) -> (usize, usize) {
120 self.layout().dim_hidden()
121 }
122 /// returns the total number of hidden layers in the model;
123 fn hidden_layers_count(&self) -> usize {
124 self.layout().depth()
125 }
126 /// returns a 2-tuple representing the dimensions of the output layer; (hidden, output)
127 fn output_dim(&self) -> (usize, usize) {
128 self.layout().dim_output()
129 }
130}
131
132impl<M, T> ModelExt<T> for M
133where
134 M: Model<T>,
135 M::Layout: LayoutExt,
136{
137}