Trait burn::module::Parameter

source ·
pub trait Parameter: Clone + Debug + Send {
    type Device: Clone;

    // Required methods
    fn device(&self) -> Self::Device;
    fn is_require_grad(&self) -> bool;
    fn set_require_grad(self, require_grad: bool) -> Self;
}
Expand description

Trait that defines what is necessary for a type to be a parameter.

Required Associated Types§

source

type Device: Clone

The device type to be used.

Required Methods§

source

fn device(&self) -> Self::Device

Fetch the device.

source

fn is_require_grad(&self) -> bool

Fetch the gradient requirement.

source

fn set_require_grad(self, require_grad: bool) -> Self

Set the gradient requirement.

Object Safety§

This trait is not object safe.

Implementors§

source§

impl<B, const D: usize> Parameter for Tensor<B, D>
where B: Backend,

§

type Device = <B as Backend>::Device

source§

impl<B, const D: usize> Parameter for Tensor<B, D, Bool>
where B: Backend,

§

type Device = <B as Backend>::Device

source§

impl<B, const D: usize> Parameter for Tensor<B, D, Int>
where B: Backend,

§

type Device = <B as Backend>::Device