concision_neural/traits/
models.rs

1/*
2    appellation: models <module>
3    authors: @FL03
4*/
5use crate::config::NetworkConfig;
6use crate::{DeepModelParams, ModelLayout};
7use crate::{Predict, Train};
8use concision_core::params::Params;
9use concision_data::DatasetBase;
10
11/// The [`Model`] trait defines the core interface for all models; implementors will need to
12/// provide the type of configuration used by the model, the type of layout used by the model,
13/// and the type of parameters used by the model. The crate provides standard, or default,
14/// definitions of both the configuration and layout types, however, for
15pub trait Model<T = f32> {
16    /// The type of configuration used for the model
17    type Config: NetworkConfig<T>;
18    /// The type of [`ModelLayout`] used by the model for this implementation.
19    type Layout: ModelLayout;
20    /// returns an immutable reference to the models configuration; this is typically used to
21    /// access the models hyperparameters (i.e. learning rate, momentum, etc.) and other
22    /// related control parameters.
23    fn config(&self) -> &Self::Config;
24    /// returns a mutable reference to the models configuration; useful for setting hyperparams
25    fn config_mut(&mut self) -> &mut Self::Config;
26    /// returns a copy of the model's current layout (features); a type providing the model
27    /// with a particular number of features for the various layers of a deep neural network.
28    ///
29    /// the layout is used in everything from creation and initialization routines to
30    /// validating the dimensionality of the model's inputs, outputs, training data, etc.
31    fn layout(&self) -> Self::Layout;
32    /// returns an immutable reference to the model parameters
33    fn params(&self) -> &DeepModelParams<T>;
34    /// returns a mutable reference to the model's parameters
35    fn params_mut(&mut self) -> &mut DeepModelParams<T>;
36    /// propagates the input through the model; each layer is applied in sequence meaning that
37    /// the output of each previous layer is the input to the next layer. This pattern
38    /// repeats until the output layer returns the final result.
39    ///
40    /// By default, the trait simply passes each output from one layer to the next, however,
41    /// custom models will likely override this method to inject activation methods and other
42    /// related logic
43    fn predict<U, V>(&self, inputs: &U) -> crate::ModelResult<V>
44    where
45        Self: Predict<U, Output = V>,
46    {
47        Predict::predict(self, inputs)
48    }
49    /// a convience method that trains the model using the provided dataset; this method
50    /// requires that the model implements the [`Train`] trait and that the dataset
51    fn train<U, V, W>(&mut self, dataset: &DatasetBase<U, V>) -> crate::ModelResult<W>
52    where
53        Self: Train<U, V, Output = W>,
54    {
55        Train::train(self, dataset.records(), dataset.targets())
56    }
57}
58
59pub trait ModelExt<T>: Model<T> {
60    /// [`replace`](core::mem::replace) the current configuration and returns the old one;
61    fn replace_config(&mut self, config: Self::Config) -> Self::Config {
62        core::mem::replace(self.config_mut(), config)
63    }
64    /// [`replace`](core::mem::replace) the current model parameters and returns the old one
65    fn replace_params(&mut self, params: DeepModelParams<T>) -> DeepModelParams<T> {
66        core::mem::replace(self.params_mut(), params)
67    }
68    /// overrides the current configuration and returns a mutable reference to the model
69    fn set_config(&mut self, config: Self::Config) -> &mut Self {
70        *self.config_mut() = config;
71        self
72    }
73    /// overrides the current model parameters and returns a mutable reference to the model
74    fn set_params(&mut self, params: DeepModelParams<T>) -> &mut Self {
75        *self.params_mut() = params;
76        self
77    }
78    /// returns an immutable reference to the input layer;
79    #[inline]
80    fn input_layer(&self) -> &Params<T> {
81        self.params().input()
82    }
83    /// returns a mutable reference to the input layer;
84    #[inline]
85    fn input_layer_mut(&mut self) -> &mut Params<T> {
86        self.params_mut().input_mut()
87    }
88    /// returns an immutable reference to the hidden layer(s);
89    #[inline]
90    fn hidden_layers(&self) -> &Vec<Params<T>> {
91        self.params().hidden()
92    }
93    /// returns a mutable reference to the hidden layer(s);
94    #[inline]
95    fn hidden_layers_mut(&mut self) -> &mut Vec<Params<T>> {
96        self.params_mut().hidden_mut()
97    }
98    /// returns an immutable reference to the output layer;
99    #[inline]
100    fn output_layer(&self) -> &Params<T> {
101        self.params().output()
102    }
103    /// returns a mutable reference to the output layer;
104    #[inline]
105    fn output_layer_mut(&mut self) -> &mut Params<T> {
106        self.params_mut().output_mut()
107    }
108    #[inline]
109    fn set_input_layer(&mut self, layer: Params<T>) -> &mut Self {
110        self.params_mut().set_input(layer);
111        self
112    }
113    #[inline]
114    fn set_hidden_layers(&mut self, layers: Vec<Params<T>>) -> &mut Self {
115        self.params_mut().set_hidden(layers);
116        self
117    }
118    #[inline]
119    fn set_output_layer(&mut self, layer: Params<T>) -> &mut Self {
120        self.params_mut().set_output(layer);
121        self
122    }
123    /// returns a 2-tuple representing the dimensions of the input layer; (input, hidden)
124    fn input_dim(&self) -> (usize, usize) {
125        self.layout().dim_input()
126    }
127    /// returns a 2-tuple representing the dimensions of the hidden layers; (hidden, hidden)
128    fn hidden_dim(&self) -> (usize, usize) {
129        self.layout().dim_hidden()
130    }
131    /// returns the total number of hidden layers in the model;
132    fn hidden_layers_count(&self) -> usize {
133        self.layout().layers()
134    }
135    /// returns a 2-tuple representing the dimensions of the output layer; (hidden, output)
136    fn output_dim(&self) -> (usize, usize) {
137        self.layout().dim_output()
138    }
139}
140
141/// The [`DeepNeuralNetwork`] trait is a specialization of the [`Model`] trait that
142/// provides additional functionality for deep neural networks. This trait is
143pub trait DeepNeuralNetwork<T = f32>: Model<T> {}
144
145impl<M, T> ModelExt<T> for M
146where
147    M: Model<T>,
148    M::Layout: ModelLayout,
149{
150}