pub trait NumParams<E: Dtype, D: Device<E>>: TensorCollection<E, D> {
// Provided method
fn num_trainable_params(&self) -> usize { ... }
}
Expand description
Get the number of trainable parameters in a model.
type Model = Linear<2, 5>;
let model = dev.build_module::<Model, f32>();
assert_eq!(model.num_trainable_params(), 2 * 5 + 5);
Provided Methods§
sourcefn num_trainable_params(&self) -> usize
fn num_trainable_params(&self) -> usize
Returns the number of trainable params in any model.