Trait Model

Source
pub trait Model<T = f32> {
    type Config: NetworkConfig<T>;

    // Required methods
    fn config(&self) -> &Self::Config;
    fn config_mut(&mut self) -> &mut Self::Config;
    fn features(&self) -> ModelFeatures;
    fn params(&self) -> &ModelParams<T>;
    fn params_mut(&mut self) -> &mut ModelParams<T>;

    // Provided methods
    fn predict<U, V>(&self, inputs: &U) -> Result<V>
       where Self: Forward<U, Output = V> { ... }
    fn forward<U, V>(&self, inputs: &U) -> Result<V>
       where Self: Forward<U, Output = V> { ... }
    fn train<U, V>(
        &mut self,
        dataset: Dataset<U, V>,
    ) -> Trainer<'_, Self, T, Dataset<U, V>>
       where Self: Sized,
             T: Default { ... }
}
Expand description

This trait defines the base interface for all models, providing access to the models configuration, layout, and learned parameters.

Required Associated Types§

Required Methods§

Source

fn config(&self) -> &Self::Config

returns an immutable reference to the models configuration; this is typically used to access the models hyperparameters (i.e. learning rate, momentum, etc.) and other related control parameters.

Source

fn config_mut(&mut self) -> &mut Self::Config

returns a mutable reference to the models configuration; useful for setting hyperparams

Source

fn features(&self) -> ModelFeatures

returns a copy of the models features (or layout); this is used to define the structure of the model and its consituents.

Source

fn params(&self) -> &ModelParams<T>

returns an immutable reference to the model parameters

Source

fn params_mut(&mut self) -> &mut ModelParams<T>

returns a mutable reference to the model’s parameters

Provided Methods§

Source

fn predict<U, V>(&self, inputs: &U) -> Result<V>
where Self: Forward<U, Output = V>,

propagates the input through the model; each layer is applied in sequence meaning that the output of each previous layer is the input to the next layer. This pattern repeats until the output layer returns the final result.

By default, the trait simply passes each output from one layer to the next, however, custom models will likely override this method to inject activation methods and other related logic

Source

fn forward<U, V>(&self, inputs: &U) -> Result<V>
where Self: Forward<U, Output = V>,

👎Deprecated since 0.1.17: use predict instead
Source

fn train<U, V>( &mut self, dataset: Dataset<U, V>, ) -> Trainer<'_, Self, T, Dataset<U, V>>
where Self: Sized, T: Default,

returns a model trainer prepared to train the model; this is a convenience method that creates a new trainer instance and returns it. Trainers are lazily evaluated meaning that the training process won’t begin until the user calls the begin method.

Dyn Compatibility§

This trait is not dyn compatible.

In older versions of Rust, dyn compatibility was called "object safety", so this trait is not object safe.

Implementors§