concision_neural/model/
mod.rs

1/*
2    Appellation: model <module>
3    Contrib: @FL03
4*/
5//! This module provides the scaffolding for creating models and layers in a neural network.
6
7#[doc(inline)]
8pub use self::{config::StandardModelConfig, store::ModelParams};
9
10pub mod config;
11pub mod store;
12
13pub(crate) mod prelude {
14    pub use super::store::*;
15}
16
17use crate::{ModelFeatures, NetworkConfig};
18use cnc::data::Dataset;
19
20/// This trait defines the base interface for all models, providing access to the models
21/// configuration, layout, and learned parameters.
22pub trait Model<T = f32> {
23    type Config: NetworkConfig<T>;
24    /// returns an immutable reference to the models configuration; this is typically used to
25    /// access the models hyperparameters (i.e. learning rate, momentum, etc.) and other
26    /// related control parameters.
27    fn config(&self) -> &Self::Config;
28    /// returns a mutable reference to the models configuration; useful for setting hyperparams
29    fn config_mut(&mut self) -> &mut Self::Config;
30    /// returns a copy of the models features (or layout); this is used to define the structure
31    /// of the model and its consituents.
32    fn features(&self) -> ModelFeatures;
33    /// returns an immutable reference to the model parameters
34    fn params(&self) -> &ModelParams<T>;
35    /// returns a mutable reference to the model's parameters
36    fn params_mut(&mut self) -> &mut ModelParams<T>;
37    /// propagates the input through the model; each layer is applied in sequence meaning that
38    /// the output of each previous layer is the input to the next layer. This pattern
39    /// repeats until the output layer returns the final result.
40    ///
41    /// By default, the trait simply passes each output from one layer to the next, however,
42    /// custom models will likely override this method to inject activation methods and other
43    /// related logic
44    fn predict<U, V>(&self, inputs: &U) -> cnc::Result<V>
45    where
46        Self: cnc::Forward<U, Output = V>,
47    {
48        <Self as cnc::Forward<U>>::forward(self, inputs)
49    }
50    #[deprecated(since = "0.1.17", note = "use predict instead")]
51    fn forward<U, V>(&self, inputs: &U) -> cnc::Result<V>
52    where
53        Self: cnc::Forward<U, Output = V>,
54    {
55        <Self as cnc::Forward<U>>::forward(self, inputs)
56    }
57    /// returns a model trainer prepared to train the model; this is a convenience method
58    /// that creates a new trainer instance and returns it. Trainers are lazily evaluated
59    /// meaning that the training process won't begin until the user calls the `begin` method.
60    fn train<U, V>(
61        &mut self,
62        dataset: Dataset<U, V>,
63    ) -> crate::train::trainer::Trainer<'_, Self, T, Dataset<U, V>>
64    where
65        Self: Sized,
66        T: Default,
67    {
68        crate::train::trainer::Trainer::new(self, dataset)
69    }
70}