Trait dfdx::nn::Module

source ·
pub trait Module<Input> {
    type Output;
    type Error: Debug;

    // Required method
    fn try_forward(&self, input: Input) -> Result<Self::Output, Self::Error>;

    // Provided method
    fn forward(&self, input: Input) -> Self::Output { ... }
}
Expand description

Immutable forward of Input that produces Module::Output. See ModuleMut for mutable forward.

Required Associated Types§

source

type Output

The type that this unit produces given Input.

source

type Error: Debug

Required Methods§

source

fn try_forward(&self, input: Input) -> Result<Self::Output, Self::Error>

Provided Methods§

source

fn forward(&self, input: Input) -> Self::Output

Forward Input through the module and produce Module::Output.

See ModuleMut::forward_mut() for version that can mutate self.

Implementations on Foreign Types§

source§

impl<Input, M4: Module<M3::Output, Error = M3::Error>, M3: Module<M2::Output, Error = M2::Error>, M2: Module<M1::Output, Error = M1::Error>, M1: Module<Input>> Module<Input> for (M1, M2, M3, M4)

source§

fn try_forward(&self, x: Input) -> Result<Self::Output, Self::Error>

Calls forward sequentially on each module in the tuple.

§

type Output = <M4 as Module<<M3 as Module<<M2 as Module<<M1 as Module<Input>>::Output>>::Output>>::Output>>::Output

§

type Error = <M4 as Module<<M3 as Module<<M2 as Module<<M1 as Module<Input>>::Output>>::Output>>::Output>>::Error

source§

impl<Input, M3: Module<M2::Output, Error = M2::Error>, M2: Module<M1::Output, Error = M1::Error>, M1: Module<Input>> Module<Input> for (M1, M2, M3)

source§

fn try_forward(&self, x: Input) -> Result<Self::Output, Self::Error>

Calls forward sequentially on each module in the tuple.

§

type Output = <M3 as Module<<M2 as Module<<M1 as Module<Input>>::Output>>::Output>>::Output

§

type Error = <M3 as Module<<M2 as Module<<M1 as Module<Input>>::Output>>::Output>>::Error

source§

impl<Input, M5: Module<M4::Output, Error = M4::Error>, M4: Module<M3::Output, Error = M3::Error>, M3: Module<M2::Output, Error = M2::Error>, M2: Module<M1::Output, Error = M1::Error>, M1: Module<Input>> Module<Input> for (M1, M2, M3, M4, M5)

source§

fn try_forward(&self, x: Input) -> Result<Self::Output, Self::Error>

Calls forward sequentially on each module in the tuple.

§

type Output = <M5 as Module<<M4 as Module<<M3 as Module<<M2 as Module<<M1 as Module<Input>>::Output>>::Output>>::Output>>::Output>>::Output

§

type Error = <M5 as Module<<M4 as Module<<M3 as Module<<M2 as Module<<M1 as Module<Input>>::Output>>::Output>>::Output>>::Output>>::Error

source§

impl<Input, M2: Module<M1::Output, Error = M1::Error>, M1: Module<Input>> Module<Input> for (M1, M2)

source§

fn try_forward(&self, x: Input) -> Result<Self::Output, Self::Error>

Calls forward sequentially on each module in the tuple.

§

type Output = <M2 as Module<<M1 as Module<Input>>::Output>>::Output

§

type Error = <M2 as Module<<M1 as Module<Input>>::Output>>::Error

source§

impl<Input, M6: Module<M5::Output, Error = M5::Error>, M5: Module<M4::Output, Error = M4::Error>, M4: Module<M3::Output, Error = M3::Error>, M3: Module<M2::Output, Error = M2::Error>, M2: Module<M1::Output, Error = M1::Error>, M1: Module<Input>> Module<Input> for (M1, M2, M3, M4, M5, M6)

source§

fn try_forward(&self, x: Input) -> Result<Self::Output, Self::Error>

Calls forward sequentially on each module in the tuple.

§

type Output = <M6 as Module<<M5 as Module<<M4 as Module<<M3 as Module<<M2 as Module<<M1 as Module<Input>>::Output>>::Output>>::Output>>::Output>>::Output>>::Output

§

type Error = <M6 as Module<<M5 as Module<<M4 as Module<<M3 as Module<<M2 as Module<<M1 as Module<Input>>::Output>>::Output>>::Output>>::Output>>::Output>>::Error

Implementors§

source§

impl<Ax: Axes, S, E: Dtype, D: Device<E>, T: Tape<E, D>> Module<Tensor<S, E, D, T>> for Softmaxwhere S: Shape<LastAxis = Ax> + ReduceShape<Ax>,

§

type Output = Tensor<S, E, D, T>

§

type Error = <D as HasErr>::Err

source§

impl<B: Dim, C: Dim, H: Dim, W: Dim, E: Dtype, D: Device<E>, T: Tape<E, D>> Module<Tensor<(B, C, H, W), E, D, T>> for AvgPoolGlobal

§

type Output = Tensor<(B, C), E, D, T>

§

type Error = <D as HasErr>::Err

source§

impl<B: Dim, C: Dim, H: Dim, W: Dim, E: Dtype, D: Device<E>, T: Tape<E, D>> Module<Tensor<(B, C, H, W), E, D, T>> for MaxPoolGlobal

§

type Output = Tensor<(B, C), E, D, T>

§

type Error = <D as HasErr>::Err

source§

impl<B: Dim, C: Dim, H: Dim, W: Dim, E: Dtype, D: Device<E>, T: Tape<E, D>> Module<Tensor<(B, C, H, W), E, D, T>> for MinPoolGlobal

§

type Output = Tensor<(B, C), E, D, T>

§

type Error = <D as HasErr>::Err

source§

impl<B: Dim, S: Dim, const M: usize, E: Dtype, D: Device<E>, T: Tape<E, D>> Module<Tensor<(B, S, Const<M>), E, D, T>> for LayerNorm1D<M, E, D>

§

type Output = Tensor<(B, S, Const<M>), E, D, T>

§

type Error = <D as HasErr>::Err

source§

impl<B: Dim, const C: usize, E: Dtype, D: Device<E>> Module<Tensor<(B, Const<C>), E, D, NoneTape>> for BatchNorm1D<C, E, D>

§

type Output = Tensor<(B, Const<C>), E, D, NoneTape>

§

type Error = <D as HasErr>::Err

source§

impl<B: Dim, const C: usize, H: Dim, W: Dim, E: Dtype, D: Device<E>> Module<Tensor<(B, Const<C>, H, W), E, D, NoneTape>> for BatchNorm2D<C, E, D>

§

type Output = Tensor<(B, Const<C>, H, W), E, D, NoneTape>

§

type Error = <D as HasErr>::Err

source§

impl<B: Dim, const C: usize, H: Dim, W: Dim, E: Dtype, D: Device<E>, T: Tape<E, D>> Module<Tensor<(B, Const<C>, H, W), E, D, T>> for Bias2D<C, E, D>

§

type Output = Tensor<(B, Const<C>, H, W), E, D, T>

§

type Error = <D as HasErr>::Err

source§

impl<B: Dim, const C: usize, L: Dim, E: Dtype, D: Device<E>> Module<Tensor<(B, Const<C>, L), E, D, NoneTape>> for BatchNorm1D<C, E, D>

§

type Output = Tensor<(B, Const<C>, L), E, D, NoneTape>

§

type Error = <D as HasErr>::Err

source§

impl<B: Dim, const M: usize, E: Dtype, D: Device<E>, T: Tape<E, D>> Module<Tensor<(B, Const<M>), E, D, T>> for LayerNorm1D<M, E, D>

§

type Output = Tensor<(B, Const<M>), E, D, T>

§

type Error = <D as HasErr>::Err

source§

impl<C: Dim, H: Dim, W: Dim, E: Dtype, D: Device<E>, T: Tape<E, D>> Module<Tensor<(C, H, W), E, D, T>> for AvgPoolGlobal

§

type Output = Tensor<(C,), E, D, T>

§

type Error = <D as HasErr>::Err

source§

impl<C: Dim, H: Dim, W: Dim, E: Dtype, D: Device<E>, T: Tape<E, D>> Module<Tensor<(C, H, W), E, D, T>> for MaxPoolGlobal

§

type Output = Tensor<(C,), E, D, T>

§

type Error = <D as HasErr>::Err

source§

impl<C: Dim, H: Dim, W: Dim, E: Dtype, D: Device<E>, T: Tape<E, D>> Module<Tensor<(C, H, W), E, D, T>> for MinPoolGlobal

§

type Output = Tensor<(C,), E, D, T>

§

type Error = <D as HasErr>::Err

source§

impl<Input, T: Module<Input, Output = Input>, const N: usize> Module<Input> for Repeated<T, N>

§

type Output = <T as Module<Input>>::Output

§

type Error = <T as Module<Input>>::Error

source§

impl<Input: SplitTape, A: Module<Input, Error = B::Error>, B: Module<Input>> Module<Input> for SplitInto<(A, B)>where A::Output: SplitTape<Tape = Input::Tape>,

§

type Output = (<<A as Module<Input>>::Output as SplitTape>::NoTape, <B as Module<Input>>::Output)

§

type Error = <B as Module<Input>>::Error

source§

impl<Input: SplitTape, A: Module<Input, Error = C::Error>, B: Module<Input, Error = C::Error>, C: Module<Input>> Module<Input> for SplitInto<(A, B, C)>where A::Output: SplitTape<Tape = Input::Tape>, B::Output: SplitTape<Tape = Input::Tape>,

§

type Output = (<<A as Module<Input>>::Output as SplitTape>::NoTape, <<B as Module<Input>>::Output as SplitTape>::NoTape, <C as Module<Input>>::Output)

§

type Error = <C as Module<Input>>::Error

source§

impl<Input: SplitTape, A: Module<Input, Error = D::Error>, B: Module<Input, Error = D::Error>, C: Module<Input, Error = D::Error>, D: Module<Input>> Module<Input> for SplitInto<(A, B, C, D)>where A::Output: SplitTape<Tape = Input::Tape>, B::Output: SplitTape<Tape = Input::Tape>, C::Output: SplitTape<Tape = Input::Tape>,

§

type Output = (<<A as Module<Input>>::Output as SplitTape>::NoTape, <<B as Module<Input>>::Output as SplitTape>::NoTape, <<C as Module<Input>>::Output as SplitTape>::NoTape, <D as Module<Input>>::Output)

§

type Error = <D as Module<Input>>::Error

source§

impl<Input: SplitTape, A: Module<Input, Error = E::Error>, B: Module<Input, Error = E::Error>, C: Module<Input, Error = E::Error>, D: Module<Input, Error = E::Error>, E: Module<Input>> Module<Input> for SplitInto<(A, B, C, D, E)>where A::Output: SplitTape<Tape = Input::Tape>, B::Output: SplitTape<Tape = Input::Tape>, C::Output: SplitTape<Tape = Input::Tape>, D::Output: SplitTape<Tape = Input::Tape>,

§

type Output = (<<A as Module<Input>>::Output as SplitTape>::NoTape, <<B as Module<Input>>::Output as SplitTape>::NoTape, <<C as Module<Input>>::Output as SplitTape>::NoTape, <<D as Module<Input>>::Output as SplitTape>::NoTape, <E as Module<Input>>::Output)

§

type Error = <E as Module<Input>>::Error

source§

impl<Input: SplitTape, A: Module<Input, Error = F::Error>, B: Module<Input, Error = F::Error>, C: Module<Input, Error = F::Error>, D: Module<Input, Error = F::Error>, E: Module<Input, Error = F::Error>, F: Module<Input>> Module<Input> for SplitInto<(A, B, C, D, E, F)>where A::Output: SplitTape<Tape = Input::Tape>, B::Output: SplitTape<Tape = Input::Tape>, C::Output: SplitTape<Tape = Input::Tape>, D::Output: SplitTape<Tape = Input::Tape>, E::Output: SplitTape<Tape = Input::Tape>,

§

type Output = (<<A as Module<Input>>::Output as SplitTape>::NoTape, <<B as Module<Input>>::Output as SplitTape>::NoTape, <<C as Module<Input>>::Output as SplitTape>::NoTape, <<D as Module<Input>>::Output as SplitTape>::NoTape, <<E as Module<Input>>::Output as SplitTape>::NoTape, <F as Module<Input>>::Output)

§

type Error = <F as Module<Input>>::Error

source§

impl<Out: Add<Out, Output = Out>, Ai, Bi, A: Module<Ai, Output = Out>, B: Module<Bi, Output = Out, Error = A::Error>> Module<(Ai, Bi)> for AddInto<(A, B)>

§

type Output = Out

§

type Error = <A as Module<Ai>>::Error

source§

impl<Out: Add<Out, Output = Out>, Ai, Bi, Ci, A: Module<Ai, Output = Out>, B: Module<Bi, Output = Out, Error = A::Error>, C: Module<Ci, Output = Out, Error = A::Error>> Module<(Ai, Bi, Ci)> for AddInto<(A, B, C)>

§

type Output = Out

§

type Error = <A as Module<Ai>>::Error

source§

impl<Out: Add<Out, Output = Out>, Ai, Bi, Ci, Di, A: Module<Ai, Output = Out>, B: Module<Bi, Output = Out, Error = A::Error>, C: Module<Ci, Output = Out, Error = A::Error>, D: Module<Di, Output = Out, Error = A::Error>> Module<(Ai, Bi, Ci, Di)> for AddInto<(A, B, C, D)>

§

type Output = Out

§

type Error = <A as Module<Ai>>::Error

source§

impl<Out: Add<Out, Output = Out>, Ai, Bi, Ci, Di, Ei, A: Module<Ai, Output = Out>, B: Module<Bi, Output = Out, Error = A::Error>, C: Module<Ci, Output = Out, Error = A::Error>, D: Module<Di, Output = Out, Error = A::Error>, E: Module<Ei, Output = Out, Error = A::Error>> Module<(Ai, Bi, Ci, Di, Ei)> for AddInto<(A, B, C, D, E)>

§

type Output = Out

§

type Error = <A as Module<Ai>>::Error

source§

impl<Out: Add<Out, Output = Out>, Ai, Bi, Ci, Di, Ei, Fi, A: Module<Ai, Output = Out>, B: Module<Bi, Output = Out, Error = A::Error>, C: Module<Ci, Output = Out, Error = A::Error>, D: Module<Di, Output = Out, Error = A::Error>, E: Module<Ei, Output = Out, Error = A::Error>, F: Module<Fi, Output = Out, Error = A::Error>> Module<(Ai, Bi, Ci, Di, Ei, Fi)> for AddInto<(A, B, C, D, E, F)>

§

type Output = Out

§

type Error = <A as Module<Ai>>::Error

source§

impl<S: Shape, E: Dtype, D: Device<E>> Module<Tensor<S, E, D, NoneTape>> for Dropout

§

type Output = Tensor<S, E, D, NoneTape>

§

type Error = <D as HasErr>::Err

source§

impl<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D>> Module<Tensor<S, E, D, T>> for Abs

§

type Output = Tensor<S, E, D, T>

§

type Error = <D as HasErr>::Err

source§

impl<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D>> Module<Tensor<S, E, D, T>> for Cos

§

type Output = Tensor<S, E, D, T>

§

type Error = <D as HasErr>::Err

source§

impl<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D>> Module<Tensor<S, E, D, T>> for Exp

§

type Output = Tensor<S, E, D, T>

§

type Error = <D as HasErr>::Err

source§

impl<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D>> Module<Tensor<S, E, D, T>> for GeLU

§

type Output = Tensor<S, E, D, T>

§

type Error = <D as HasErr>::Err

source§

impl<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D>> Module<Tensor<S, E, D, T>> for Ln

§

type Output = Tensor<S, E, D, T>

§

type Error = <D as HasErr>::Err

source§

impl<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D>> Module<Tensor<S, E, D, T>> for ReLU

§

type Output = Tensor<S, E, D, T>

§

type Error = <D as HasErr>::Err

source§

impl<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D>> Module<Tensor<S, E, D, T>> for Sigmoid

§

type Output = Tensor<S, E, D, T>

§

type Error = <D as HasErr>::Err

source§

impl<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D>> Module<Tensor<S, E, D, T>> for Sin

§

type Output = Tensor<S, E, D, T>

§

type Error = <D as HasErr>::Err

source§

impl<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D>> Module<Tensor<S, E, D, T>> for Sqrt

§

type Output = Tensor<S, E, D, T>

§

type Error = <D as HasErr>::Err

source§

impl<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D>> Module<Tensor<S, E, D, T>> for Square

§

type Output = Tensor<S, E, D, T>

§

type Error = <D as HasErr>::Err

source§

impl<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D>> Module<Tensor<S, E, D, T>> for Tanh

§

type Output = Tensor<S, E, D, T>

§

type Error = <D as HasErr>::Err

source§

impl<T: WithEmptyTape + TryAdd<T>, F: Module<T, Output = T, Error = T::Err>> Module<T> for Residual<F>

§

type Output = T

§

type Error = <F as Module<T>>::Error

source§

impl<T: WithEmptyTape, F: Module<T>, R: Module<T, Output = F::Output, Error = F::Error>> Module<T> for GeneralizedResidual<F, R>where F::Output: TryAdd<F::Output> + HasErr<Err = F::Error>,

§

type Output = <F as Module<T>>::Output

§

type Error = <F as Module<T>>::Error

source§

impl<const B: usize, const C: usize, const H: usize, const W: usize, D, E: Dtype, T> Module<Tensor<(Const<B>, Const<C>, Const<H>, Const<W>), E, D, T>> for Flatten2Dwhere D: Device<E>, T: Tape<E, D>, Rank2<B, { _ }>: Sized,

§

type Output = Tensor<(Const<B>, Const<{ C * H * W }>), E, D, T>

§

type Error = <D as HasErr>::Err

source§

impl<const C: usize, H: Dim, W: Dim, E: Dtype, D: Device<E>> Module<Tensor<(Const<C>, H, W), E, D, NoneTape>> for BatchNorm2D<C, E, D>

§

type Output = Tensor<(Const<C>, H, W), E, D, NoneTape>

§

type Error = <D as HasErr>::Err

source§

impl<const C: usize, H: Dim, W: Dim, E: Dtype, D: Device<E>, T: Tape<E, D>> Module<Tensor<(Const<C>, H, W), E, D, T>> for Bias2D<C, E, D>

§

type Output = Tensor<(Const<C>, H, W), E, D, T>

§

type Error = <D as HasErr>::Err

source§

impl<const C: usize, const H: usize, const W: usize, D: Device<E>, E: Dtype, T: Tape<E, D>> Module<Tensor<(Const<C>, Const<H>, Const<W>), E, D, T>> for Flatten2Dwhere Rank1<{ _ }>: Sized,

§

type Output = Tensor<(Const<{ C * H * W }>,), E, D, T>

§

type Error = <D as HasErr>::Err

source§

impl<const C: usize, const O: usize, const K: usize, const S: usize, const P: usize, E, D, Img> Module<Img> for Conv2D<C, O, K, S, P, E, D>where E: Dtype, D: Device<E>, Img: TryConv2DTo<Tensor<Rank4<O, C, K, K>, E, D>, S, P> + HasErr<Err = D::Err>,

§

type Output = <Img as TryConv2DTo<Tensor<(Const<O>, Const<C>, Const<K>, Const<K>), E, D, NoneTape>, S, P>>::Output

§

type Error = <D as HasErr>::Err

source§

impl<const I: usize, const O: usize, E: Dtype, D: Device<E>, T> Module<T> for Linear<I, O, E, D>where T: SplitTape + TryMatMul<Tensor<Rank2<I, O>, E, D, T::Tape>> + HasErr<Err = D::Err>, T::Tape: Tape<E, D>, for<'a> Bias1D<'a, O, E, D>: Module<T::Output, Output = T::Output, Error = D::Err>,

§

type Output = <T as TryMatMul<Tensor<(Const<I>, Const<O>), E, D, <T as SplitTape>::Tape>>>::Output

§

type Error = <D as HasErr>::Err

source§

impl<const I: usize, const O: usize, E: Dtype, D: Device<E>, T> Module<T> for UnbiasedLinear<I, O, E, D>where T: SplitTape + TryMatMul<Tensor<Rank2<I, O>, E, D, T::Tape>> + HasErr<Err = D::Err>, T::Tape: Tape<E, D>,

§

type Output = <T as TryMatMul<Tensor<(Const<I>, Const<O>), E, D, <T as SplitTape>::Tape>>>::Output

§

type Error = <D as HasErr>::Err

source§

impl<const K: usize, const S: usize, const P: usize, Img: ConstAvgPool2D<K, S, P>> Module<Img> for AvgPool2D<K, S, P>

§

type Output = <Img as ConstAvgPool2D<K, S, P>>::Output

§

type Error = <Img as HasErr>::Err

source§

impl<const K: usize, const S: usize, const P: usize, Img: ConstMaxPool2D<K, S, P>> Module<Img> for MaxPool2D<K, S, P>

§

type Output = <Img as ConstMaxPool2D<K, S, P>>::Output

§

type Error = <Img as HasErr>::Err

source§

impl<const K: usize, const S: usize, const P: usize, Img: ConstMinPool2D<K, S, P>> Module<Img> for MinPool2D<K, S, P>

§

type Output = <Img as ConstMinPool2D<K, S, P>>::Output

§

type Error = <Img as HasErr>::Err

source§

impl<const M: usize, E: Dtype, D: Device<E>, T: Tape<E, D>> Module<Tensor<(Const<M>,), E, D, T>> for LayerNorm1D<M, E, D>

§

type Output = Tensor<(Const<M>,), E, D, T>

§

type Error = <D as HasErr>::Err

source§

impl<const M: usize, const H: usize, const EL: usize, const DL: usize, const F: usize, E: Dtype, D: Device<E>, Src: SplitTape, Tgt: PutTape<Src::Tape>> Module<(Src, Tgt)> for Transformer<M, H, EL, DL, F, E, D>where TransformerEncoder<M, H, F, EL, E, D>: Module<Src, Output = Src, Error = D::Err>, TransformerDecoder<M, H, F, DL, E, D>: Module<(<Tgt as PutTape<Src::Tape>>::Output, Src::NoTape), Output = <Tgt as PutTape<Src::Tape>>::Output, Error = D::Err>,

§

type Output = <Tgt as PutTape<<Src as SplitTape>::Tape>>::Output

§

type Error = <D as HasErr>::Err

source§

impl<const M: usize, const H: usize, const F: usize, E: Dtype, D: Device<E>, Src> Module<Src> for TransformerEncoderBlock<M, H, F, E, D>where Src: SplitTape + Add<Src::NoTape, Output = Src>, MultiHeadAttention<M, H, M, M, E, D>: Module<Src, Output = Src, Error = D::Err>, LayerNorm1D<M, E, D>: Module<Src, Output = Src, Error = D::Err>, Residual<(Linear<M, F, E, D>, ReLU, Linear<F, M, E, D>)>: Module<Src, Output = Src, Error = D::Err>,

§

type Output = Src

§

type Error = <D as HasErr>::Err

source§

impl<const M: usize, const H: usize, const F: usize, E: Dtype, D: Device<E>, Tgt, Mem> Module<(Tgt, Mem)> for TransformerDecoderBlock<M, H, F, E, D>where Tgt: SplitTape + TryAdd<Tgt::NoTape> + HasErr<Err = D::Err>, Mem: Clone, MultiHeadAttention<M, H, M, M, E, D>: Module<Tgt, Output = Tgt, Error = D::Err> + Module<(Tgt, Mem, Mem), Output = Tgt, Error = D::Err>, LayerNorm1D<M, E, D>: Module<Tgt, Output = Tgt, Error = D::Err>, Residual<(Linear<M, F, E, D>, ReLU, Linear<F, M, E, D>)>: Module<Tgt, Output = Tgt, Error = D::Err>,

§

type Output = Tgt

§

type Error = <D as HasErr>::Err

source§

impl<const M: usize, const H: usize, const F: usize, const L: usize, E, D, Tgt, Mem: Clone> Module<(Tgt, Mem)> for TransformerDecoder<M, H, F, L, E, D>where E: Dtype, D: Device<E>, TransformerDecoderBlock<M, H, F, E, D>: Module<(Tgt, Mem), Output = Tgt, Error = D::Err>,

§

type Output = Tgt

§

type Error = <D as HasErr>::Err

source§

impl<const M: usize, const H: usize, const K: usize, const V: usize, E, D, B, S1, S2, T> Module<(Tensor<(B, S1, Const<M>), E, D, T>, Tensor<(B, S2, Const<M>), E, D, NoneTape>, Tensor<(B, S2, Const<M>), E, D, NoneTape>)> for MultiHeadAttention<M, H, K, V, E, D>where E: Dtype + Float, D: Device<E>, B: Dim, S1: Dim, S2: Dim, T: Tape<E, D>,

§

type Output = Tensor<(B, S1, Const<M>), E, D, T>

§

type Error = <D as HasErr>::Err

source§

impl<const M: usize, const H: usize, const K: usize, const V: usize, E, D, S1, S2, T> Module<(Tensor<(S1, Const<M>), E, D, T>, Tensor<(S2, Const<M>), E, D, NoneTape>, Tensor<(S2, Const<M>), E, D, NoneTape>)> for MultiHeadAttention<M, H, K, V, E, D>where E: Dtype + Float, D: Device<E>, S1: Dim, S2: Dim, T: Tape<E, D>,

§

type Output = Tensor<(S1, Const<M>), E, D, T>

§

type Error = <D as HasErr>::Err

source§

impl<const M: usize, const H: usize, const K: usize, const V: usize, E, D, Src> Module<Src> for MultiHeadAttention<M, H, K, V, E, D>where E: Dtype, D: Device<E>, Src: SplitTape, Self: Module<(Src, Src::NoTape, Src::NoTape), Output = Src, Error = D::Err>,

§

type Output = Src

§

type Error = <D as HasErr>::Err

source§

impl<const N: usize, S: Shape, E: Dtype, D: Device<E>> Module<Tensor<S, E, D, NoneTape>> for DropoutOneIn<N>

§

type Output = Tensor<S, E, D, NoneTape>

§

type Error = <D as HasErr>::Err

source§

impl<const V: usize, const M: usize, SEQ: Dim, E: Dtype, D: Device<E>, T: Tape<E, D>> Module<Tensor<(SEQ,), usize, D, T>> for Embedding<V, M, E, D>

§

type Output = Tensor<(SEQ, Const<M>), E, D, T>

§

type Error = <D as HasErr>::Err

source§

impl<const VOCAB: usize, const DIM: usize, BATCH: Dim, SEQ: Dim, E: Dtype, D: Device<E>, T: Tape<E, D>> Module<Tensor<(BATCH, SEQ), usize, D, T>> for Embedding<VOCAB, DIM, E, D>

§

type Output = Tensor<(BATCH, SEQ, Const<DIM>), E, D, T>

§

type Error = <D as HasErr>::Err