Trait dfdx::nn::NumParams

source ·
pub trait NumParams<E: Dtype, D: Device<E>>: TensorCollection<E, D> {
    // Provided method
    fn num_trainable_params(&self) -> usize { ... }
}
Expand description

Get the number of trainable parameters in a model.

type Model = Linear<2, 5>;
let model = dev.build_module::<Model, f32>();
assert_eq!(model.num_trainable_params(), 2 * 5 + 5);

Provided Methods§

source

fn num_trainable_params(&self) -> usize

Returns the number of trainable params in any model.

Implementors§

source§

impl<E: Dtype, D: Device<E>, M: TensorCollection<E, D>> NumParams<E, D> for M