pub trait TensorCollection<E: Dtype, D: Device<E>>: Sized {
    type To<E2: Dtype, D2: Device<E2>>;

    // Required method
    fn iter_tensors<V: ModuleVisitor<Self, E, D>>(
        visitor: &mut V
    ) -> Result<Option<Self::To<V::E2, V::D2>>, V::Err>;

    // Provided methods
    fn module<F1, F2, Field>(
        name: &str,
        get_ref: F1,
        get_mut: F2
    ) -> ModuleField<'_, F1, F2, Self, Field>
       where F1: FnMut(&Self) -> &Field,
             F2: FnMut(&mut Self) -> &mut Field,
             Field: TensorCollection<E, D> { ... }
    fn tensor<F1, F2, S>(
        name: &str,
        get_ref: F1,
        get_mut: F2,
        options: TensorOptions<S, E, D>
    ) -> TensorField<'_, F1, F2, Self, S, E, D>
       where F1: FnMut(&Self) -> &Tensor<S, E, D>,
             F2: FnMut(&mut Self) -> &mut Tensor<S, E, D>,
             S: Shape { ... }
    fn scalar<F1, F2, N>(
        name: &str,
        get_ref: F1,
        get_mut: F2,
        options: ScalarOptions<N>
    ) -> ScalarField<'_, F1, F2, Self, N>
       where F1: FnMut(&Self) -> &N,
             F2: FnMut(&mut Self) -> &mut N,
             N: NumCast { ... }
}
Expand description

A collection of named tensors. Implementing this trait will enable anything that operates on tensors, including resetting, counting number of params, updating gradients, building model, and converting models between devices or dtypes.

Example implementation:

use dfdx::nn::modules::Linear;

struct Mlp<E: Dtype, D: Device<E>> {
    l1: Linear<10, 10, E, D>,
    l2: Linear<10, 2, E, D>,
    relu: ReLU,
}

impl<E: Dtype + num_traits::Float + rand_distr::uniform::SampleUniform, D: Device<E>> TensorCollection<E, D> for Mlp<E, D> {
    type To<E2: Dtype, D2: Device<E2>> = Mlp<E2, D2>;

    fn iter_tensors<V: ModuleVisitor<Self, E, D>>(
        visitor: &mut V,
    ) -> Result<Option<Self::To<V::E2, V::D2>>, V::Err> {
        visitor.visit_fields(
            (
                // Define name of each field and how to access it, using ModuleField for Modules,
                // and TensorField for Tensors.
                Self::module("l1", |s| &s.l1, |s| &mut s.l1),
                Self::module("l2", |s| &s.l2, |s| &mut s.l2),
            ),
            // Define how to construct the collection given its fields in the order they are given
            // above. This conversion is done using the ModuleFields trait.
            |(l1, l2)| Mlp { l1, l2, relu: Default::default() }
        )
    }
}

let dev = Cpu::default();
let model = Mlp::<f32, Cpu>::build(&dev);
assert_eq!(132, model.num_trainable_params());

Required Associated Types§

source

type To<E2: Dtype, D2: Device<E2>>

Type alias that specifies the how a module’s type changes when using a different dtype and/or device.

Required Methods§

source

fn iter_tensors<V: ModuleVisitor<Self, E, D>>( visitor: &mut V ) -> Result<Option<Self::To<V::E2, V::D2>>, V::Err>

Specifies how to iterate through tensors or modules containted within this module, and how to contruct this module given values for its fields. Returns Err(_) to indicate an error, Ok(None) to indicate that there is no error and a module has not been built, and Ok(Some(_)) contains Self::Output<E2, D2>

Provided Methods§

source

fn module<F1, F2, Field>( name: &str, get_ref: F1, get_mut: F2 ) -> ModuleField<'_, F1, F2, Self, Field>where F1: FnMut(&Self) -> &Field, F2: FnMut(&mut Self) -> &mut Field, Field: TensorCollection<E, D>,

Creates a ModuleFields that represents a field that may contain one or more tensors.

See also: ModuleField, TensorCollection.

source

fn tensor<F1, F2, S>( name: &str, get_ref: F1, get_mut: F2, options: TensorOptions<S, E, D> ) -> TensorField<'_, F1, F2, Self, S, E, D>where F1: FnMut(&Self) -> &Tensor<S, E, D>, F2: FnMut(&mut Self) -> &mut Tensor<S, E, D>, S: Shape,

Creates a ModuleFields that represents a tensor field.

See also: TensorField, TensorCollection, TensorOptions.

source

fn scalar<F1, F2, N>( name: &str, get_ref: F1, get_mut: F2, options: ScalarOptions<N> ) -> ScalarField<'_, F1, F2, Self, N>where F1: FnMut(&Self) -> &N, F2: FnMut(&mut Self) -> &mut N, N: NumCast,

Creates a ModuleFields that represents a scalar field.

See also: TensorField, TensorCollection, TensorOptions.

Implementations on Foreign Types§

source§

impl<E: Dtype, D: Device<E>, M1: TensorCollection<E, D>, M2: TensorCollection<E, D>, M3: TensorCollection<E, D>, M4: TensorCollection<E, D>> TensorCollection<E, D> for (M1, M2, M3, M4)

§

type To<E2: Dtype, D2: Device<E2>> = (<M1 as TensorCollection<E, D>>::To<E2, D2>, <M2 as TensorCollection<E, D>>::To<E2, D2>, <M3 as TensorCollection<E, D>>::To<E2, D2>, <M4 as TensorCollection<E, D>>::To<E2, D2>)

source§

fn iter_tensors<V: ModuleVisitor<Self, E, D>>( visitor: &mut V ) -> Result<Option<Self::To<V::E2, V::D2>>, V::Err>

source§

impl<E: Dtype, D: Device<E>, M1: TensorCollection<E, D>> TensorCollection<E, D> for (M1,)

§

type To<E2: Dtype, D2: Device<E2>> = (<M1 as TensorCollection<E, D>>::To<E2, D2>,)

source§

fn iter_tensors<V: ModuleVisitor<Self, E, D>>( visitor: &mut V ) -> Result<Option<Self::To<V::E2, V::D2>>, V::Err>

source§

impl<E: Dtype, D: Device<E>, M1: TensorCollection<E, D>, M2: TensorCollection<E, D>, M3: TensorCollection<E, D>, M4: TensorCollection<E, D>, M5: TensorCollection<E, D>> TensorCollection<E, D> for (M1, M2, M3, M4, M5)

§

type To<E2: Dtype, D2: Device<E2>> = (<M1 as TensorCollection<E, D>>::To<E2, D2>, <M2 as TensorCollection<E, D>>::To<E2, D2>, <M3 as TensorCollection<E, D>>::To<E2, D2>, <M4 as TensorCollection<E, D>>::To<E2, D2>, <M5 as TensorCollection<E, D>>::To<E2, D2>)

source§

fn iter_tensors<V: ModuleVisitor<Self, E, D>>( visitor: &mut V ) -> Result<Option<Self::To<V::E2, V::D2>>, V::Err>

source§

impl<E: Dtype, D: Device<E>, M1: TensorCollection<E, D>, M2: TensorCollection<E, D>, M3: TensorCollection<E, D>> TensorCollection<E, D> for (M1, M2, M3)

§

type To<E2: Dtype, D2: Device<E2>> = (<M1 as TensorCollection<E, D>>::To<E2, D2>, <M2 as TensorCollection<E, D>>::To<E2, D2>, <M3 as TensorCollection<E, D>>::To<E2, D2>)

source§

fn iter_tensors<V: ModuleVisitor<Self, E, D>>( visitor: &mut V ) -> Result<Option<Self::To<V::E2, V::D2>>, V::Err>

source§

impl<E: Dtype, D: Device<E>, M1: TensorCollection<E, D>, M2: TensorCollection<E, D>> TensorCollection<E, D> for (M1, M2)

§

type To<E2: Dtype, D2: Device<E2>> = (<M1 as TensorCollection<E, D>>::To<E2, D2>, <M2 as TensorCollection<E, D>>::To<E2, D2>)

source§

fn iter_tensors<V: ModuleVisitor<Self, E, D>>( visitor: &mut V ) -> Result<Option<Self::To<V::E2, V::D2>>, V::Err>

source§

impl<E: Dtype, D: Device<E>> TensorCollection<E, D> for ()

§

type To<E2: Dtype, D2: Device<E2>> = ()

source§

fn iter_tensors<V: ModuleVisitor<Self, E, D>>( _: &mut V ) -> Result<Option<Self::To<V::E2, V::D2>>, V::Err>

source§

impl<E: Dtype, D: Device<E>, M1: TensorCollection<E, D>, M2: TensorCollection<E, D>, M3: TensorCollection<E, D>, M4: TensorCollection<E, D>, M5: TensorCollection<E, D>, M6: TensorCollection<E, D>> TensorCollection<E, D> for (M1, M2, M3, M4, M5, M6)

§

type To<E2: Dtype, D2: Device<E2>> = (<M1 as TensorCollection<E, D>>::To<E2, D2>, <M2 as TensorCollection<E, D>>::To<E2, D2>, <M3 as TensorCollection<E, D>>::To<E2, D2>, <M4 as TensorCollection<E, D>>::To<E2, D2>, <M5 as TensorCollection<E, D>>::To<E2, D2>, <M6 as TensorCollection<E, D>>::To<E2, D2>)

source§

fn iter_tensors<V: ModuleVisitor<Self, E, D>>( visitor: &mut V ) -> Result<Option<Self::To<V::E2, V::D2>>, V::Err>

Implementors§

source§

impl<C: ConstDim, E: Dtype, D: Device<E>> TensorCollection<E, D> for PReLU1D<C, E, D>

§

type To<E2: Dtype, D2: Device<E2>> = PReLU1D<C, E2, D2>

source§

impl<E: Dtype, D: Device<E>> TensorCollection<E, D> for Dropout

§

type To<E2: Dtype, D2: Device<E2>> = Dropout

source§

impl<E: Dtype, D: Device<E>> TensorCollection<E, D> for PReLU<E, D>

§

type To<E2: Dtype, D2: Device<E2>> = PReLU<E2, D2>

source§

impl<E: Dtype, D: Device<E>, F: TensorCollection<E, D>> TensorCollection<E, D> for Residual<F>

§

type To<E2: Dtype, D2: Device<E2>> = Residual<<F as TensorCollection<E, D>>::To<E2, D2>>

source§

impl<E: Dtype, D: Device<E>, F: TensorCollection<E, D>, R: TensorCollection<E, D>> TensorCollection<E, D> for GeneralizedResidual<F, R>

§

type To<E2: Dtype, D2: Device<E2>> = GeneralizedResidual<<F as TensorCollection<E, D>>::To<E2, D2>, <R as TensorCollection<E, D>>::To<E2, D2>>

source§

impl<E: Dtype, D: Device<E>, T: ZeroSizedModule> TensorCollection<E, D> for T

§

type To<E2: Dtype, D2: Device<E2>> = T

source§

impl<E: Dtype, D: Device<E>, T: TensorCollection<E, D>> TensorCollection<E, D> for AddInto<T>

§

type To<E2: Dtype, D2: Device<E2>> = AddInto<<T as TensorCollection<E, D>>::To<E2, D2>>

source§

impl<E: Dtype, D: Device<E>, T: TensorCollection<E, D>> TensorCollection<E, D> for SplitInto<T>

§

type To<E2: Dtype, D2: Device<E2>> = SplitInto<<T as TensorCollection<E, D>>::To<E2, D2>>

source§

impl<E: Dtype, D: Device<E>, T: TensorCollection<E, D>, const N: usize> TensorCollection<E, D> for Repeated<T, N>

§

type To<E2: Dtype, D2: Device<E2>> = Repeated<<T as TensorCollection<E, D>>::To<E2, D2>, N>

source§

impl<S: ConstShape, E: Dtype, D: Device<E>> TensorCollection<E, D> for Tensor<S, E, D>

§

type To<E2: Dtype, D2: Device<E2>> = Tensor<S, E2, D2, NoneTape>

source§

impl<const C: usize, E: Dtype, D: Device<E>> TensorCollection<E, D> for BatchNorm1D<C, E, D>

§

type To<E2: Dtype, D2: Device<E2>> = BatchNorm1D<C, E2, D2>

source§

impl<const C: usize, E: Dtype, D: Device<E>> TensorCollection<E, D> for BatchNorm2D<C, E, D>

§

type To<E2: Dtype, D2: Device<E2>> = BatchNorm2D<C, E2, D2>

source§

impl<const C: usize, E: Dtype, D: Device<E>> TensorCollection<E, D> for Bias2D<C, E, D>

§

type To<E2: Dtype, D2: Device<E2>> = Bias2D<C, E2, D2>

source§

impl<const C: usize, const M: usize, E: Dtype + Float + SampleUniform, D: Device<E>> TensorCollection<E, D> for Embedding<C, M, E, D>

§

type To<E2: Dtype, D2: Device<E2>> = Embedding<C, M, E2, D2>

source§

impl<const I: usize, const O: usize, E: Dtype + Float + SampleUniform, D: Device<E>> TensorCollection<E, D> for Linear<I, O, E, D>

§

type To<E2: Dtype, D2: Device<E2>> = Linear<I, O, E2, D2>

source§

impl<const I: usize, const O: usize, E: Dtype + Float + SampleUniform, D: Device<E>> TensorCollection<E, D> for UnbiasedLinear<I, O, E, D>

§

type To<E2: Dtype, D2: Device<E2>> = UnbiasedLinear<I, O, E2, D2>

source§

impl<const I: usize, const O: usize, const K: usize, const S: usize, const P: usize, const L: usize, const G: usize, E, D> TensorCollection<E, D> for Conv2D<I, O, K, S, P, L, G, E, D>where Const<{ _ }>: Sized, E: Dtype + Float + SampleUniform, D: Device<E>,

§

type To<E2: Dtype, D2: Device<E2>> = Conv2D<I, O, K, S, P, L, G, E2, D2>

source§

impl<const I: usize, const O: usize, const K: usize, const S: usize, const P: usize, const L: usize, const G: usize, E, D> TensorCollection<E, D> for ConvTrans2D<I, O, K, S, P, L, G, E, D>where E: Dtype + Float + SampleUniform, D: Device<E>, Const<{ _ }>: Sized,

§

type To<E2: Dtype, D2: Device<E2>> = ConvTrans2D<I, O, K, S, P, L, G, E2, D2>

source§

impl<const M: usize, E: Dtype, D: Device<E>> TensorCollection<E, D> for LayerNorm1D<M, E, D>

§

type To<E2: Dtype, D2: Device<E2>> = LayerNorm1D<M, E2, D2>

source§

impl<const M: usize, const H: usize, const A: usize, const B: usize, const F: usize, E, D> TensorCollection<E, D> for Transformer<M, H, A, B, F, E, D>where E: Dtype + Float + SampleUniform, D: Device<E>,

§

type To<E2: Dtype, D2: Device<E2>> = Transformer<M, H, A, B, F, E2, D2>

source§

impl<const M: usize, const H: usize, const F: usize, E, D: Device<E>> TensorCollection<E, D> for TransformerEncoderBlock<M, H, F, E, D>where E: Dtype + Float + SampleUniform,

§

type To<E2: Dtype, D2: Device<E2>> = TransformerEncoderBlock<M, H, F, E2, D2>

source§

impl<const M: usize, const H: usize, const F: usize, const L: usize, E, D: Device<E>> TensorCollection<E, D> for TransformerDecoder<M, H, F, L, E, D>where E: Dtype + Float + SampleUniform,

§

type To<E2: Dtype, D2: Device<E2>> = TransformerDecoder<M, H, F, L, E2, D2>

source§

impl<const M: usize, const H: usize, const K: usize, const V: usize, E, D: Device<E>> TensorCollection<E, D> for MultiHeadAttention<M, H, K, V, E, D>where E: Dtype + Float + SampleUniform,

§

type To<E2: Dtype, D2: Device<E2>> = MultiHeadAttention<M, H, K, V, E2, D2>

source§

impl<const M: usize, const N: usize, const F: usize, E, D: Device<E>> TensorCollection<E, D> for TransformerDecoderBlock<M, N, F, E, D>where E: Dtype + Float + SampleUniform,

§

type To<E2: Dtype, D2: Device<E2>> = TransformerDecoderBlock<M, N, F, E2, D2>