Trait burn::module::ADModule

source ·
pub trait ADModule<B>: Module<B> + Send + Sync + Debugwhere
    B: ADBackend,{
    type InnerModule: Module<<B as ADBackend>::InnerBackend>;

    // Required method
    fn valid(&self) -> Self::InnerModule;
}
Expand description

Module with auto-differentiation backend.

Required Associated Types§

source

type InnerModule: Module<<B as ADBackend>::InnerBackend>

Inner module without auto-differentiation.

Required Methods§

source

fn valid(&self) -> Self::InnerModule

Get the same module, but on the inner backend without auto-differentiation.

Implementations on Foreign Types§

source§

impl<B> ADModule<B> for usizewhere B: ADBackend,

source§

impl<B> ADModule<B> for i16where B: ADBackend,

source§

impl<T, B> ADModule<B> for Vec<T, Global>where T: ADModule<B> + Debug + Send + Sync + Clone, B: ADBackend,

source§

impl<B> ADModule<B> for f32where B: ADBackend,

source§

impl<B> ADModule<B> for u32where B: ADBackend,

source§

impl<B> ADModule<B> for i8where B: ADBackend,

§

type InnerModule = i8

source§

fn valid(&self) -> <i8 as ADModule<B>>::InnerModule

source§

impl<B> ADModule<B> for boolwhere B: ADBackend,

source§

impl<B> ADModule<B> for u8where B: ADBackend,

§

type InnerModule = u8

source§

fn valid(&self) -> <u8 as ADModule<B>>::InnerModule

source§

impl<B> ADModule<B> for i32where B: ADBackend,

source§

impl<const N: usize, T, B> ADModule<B> for [T; N]where T: ADModule<B> + Debug + Send + Sync + Clone + Copy, <T as ADModule<B>>::InnerModule: Copy + Debug, <<T as ADModule<B>>::InnerModule as Module<<B as ADBackend>::InnerBackend>>::Record: Debug, <T as Module<B>>::Record: Debug, B: ADBackend,

§

type InnerModule = [<T as ADModule<B>>::InnerModule; N]

source§

fn valid(&self) -> <[T; N] as ADModule<B>>::InnerModule

source§

impl<T, B> ADModule<B> for Option<T>where T: ADModule<B> + Debug + Send + Sync + Clone, B: ADBackend,

source§

impl<B> ADModule<B> for u64where B: ADBackend,

source§

impl<B> ADModule<B> for Stringwhere B: ADBackend,

source§

impl<B> ADModule<B> for i64where B: ADBackend,

source§

impl<B> ADModule<B> for f64where B: ADBackend,

source§

impl<B> ADModule<B> for u16where B: ADBackend,

source§

impl<B> ADModule<B> for PhantomData<B>where B: ADBackend,

Implementors§

source§

impl<B> ADModule<B> for PaddingConfig1dwhere B: ADBackend,

source§

impl<B> ADModule<B> for PaddingConfig2dwhere B: ADBackend,

source§

impl<B> ADModule<B> for MultiHeadAttention<B>where B: Backend + ADBackend, <B as ADBackend>::InnerBackend: Backend,

source§

impl<B> ADModule<B> for Conv1d<B>where B: Backend + ADBackend, <B as ADBackend>::InnerBackend: Backend,

source§

impl<B> ADModule<B> for Conv2d<B>where B: Backend + ADBackend, <B as ADBackend>::InnerBackend: Backend,

source§

impl<B> ADModule<B> for ConvTranspose1d<B>where B: Backend + ADBackend, <B as ADBackend>::InnerBackend: Backend,

source§

impl<B> ADModule<B> for ConvTranspose2d<B>where B: Backend + ADBackend, <B as ADBackend>::InnerBackend: Backend,

source§

impl<B> ADModule<B> for Gru<B>where B: Backend + ADBackend, <B as ADBackend>::InnerBackend: Backend,

source§

impl<B> ADModule<B> for BinaryCrossEntropyLoss<B>where B: Backend + ADBackend, <B as ADBackend>::InnerBackend: Backend,

source§

impl<B> ADModule<B> for CrossEntropyLoss<B>where B: Backend + ADBackend, <B as ADBackend>::InnerBackend: Backend,

source§

impl<B> ADModule<B> for AdaptiveAvgPool1dwhere B: ADBackend,

source§

impl<B> ADModule<B> for AdaptiveAvgPool2dwhere B: ADBackend,

source§

impl<B> ADModule<B> for AvgPool1dwhere B: ADBackend,

source§

impl<B> ADModule<B> for AvgPool2dwhere B: ADBackend,

source§

impl<B> ADModule<B> for MaxPool1dwhere B: ADBackend,

source§

impl<B> ADModule<B> for MaxPool2dwhere B: ADBackend,

source§

impl<B> ADModule<B> for Dropoutwhere B: ADBackend,

source§

impl<B> ADModule<B> for Embedding<B>where B: Backend + ADBackend, <B as ADBackend>::InnerBackend: Backend,

source§

impl<B> ADModule<B> for GELUwhere B: ADBackend,

source§

impl<B> ADModule<B> for GateController<B>where B: Backend + ADBackend, <B as ADBackend>::InnerBackend: Backend,

source§

impl<B> ADModule<B> for LayerNorm<B>where B: Backend + ADBackend, <B as ADBackend>::InnerBackend: Backend,

source§

impl<B> ADModule<B> for Linear<B>where B: Backend + ADBackend, <B as ADBackend>::InnerBackend: Backend,

source§

impl<B> ADModule<B> for Lstm<B>where B: Backend + ADBackend, <B as ADBackend>::InnerBackend: Backend,

source§

impl<B> ADModule<B> for PositionalEncoding<B>where B: Backend + ADBackend, <B as ADBackend>::InnerBackend: Backend,

source§

impl<B> ADModule<B> for ReLUwhere B: ADBackend,

source§

impl<B> ADModule<B> for PositionWiseFeedForward<B>where B: Backend + ADBackend, <B as ADBackend>::InnerBackend: Backend,

source§

impl<B> ADModule<B> for TransformerDecoder<B>where B: Backend + ADBackend, <B as ADBackend>::InnerBackend: Backend,

source§

impl<B> ADModule<B> for TransformerDecoderLayer<B>where B: Backend + ADBackend, <B as ADBackend>::InnerBackend: Backend,

source§

impl<B> ADModule<B> for TransformerEncoder<B>where B: Backend + ADBackend, <B as ADBackend>::InnerBackend: Backend,

source§

impl<B> ADModule<B> for TransformerEncoderLayer<B>where B: Backend + ADBackend, <B as ADBackend>::InnerBackend: Backend,

source§

impl<B> ADModule<B> for bf16where B: ADBackend,

source§

impl<B> ADModule<B> for f16where B: ADBackend,

source§

impl<B, const D: usize> ADModule<B> for BatchNorm<B, D>where B: Backend + ADBackend, <B as ADBackend>::InnerBackend: Backend,

source§

impl<const D: usize, B> ADModule<B> for Tensor<B, D, Float>where B: ADBackend,

source§

impl<const D: usize, B> ADModule<B> for Param<Tensor<B, D, Float>>where B: ADBackend,

source§

impl<const D: usize, B> ADModule<B> for RunningState<Tensor<B, D, Float>>where B: ADBackend,