pub trait ModuleVisitor<T: TensorCollection<E, D>, E: Dtype, D: Device<E>>: Sized {
    type Err;
    type E2: Dtype;
    type D2: Device<Self::E2>;

    // Required methods
    fn visit_module<Field, GetRef, GetMut>(
        &mut self,
        name: &str,
        get_refs: GetRef,
        get_muts: GetMut
    ) -> Result<Option<Field::To<Self::E2, Self::D2>>, Self::Err>
       where GetRef: FnMut(&T) -> &Field,
             GetMut: FnMut(&mut T) -> &mut Field,
             Field: TensorCollection<E, D>;
    fn visit_tensor<S: Shape, GetRef, GetMut>(
        &mut self,
        name: &str,
        get_refs: GetRef,
        get_muts: GetMut,
        opts: TensorOptions<S, E, D>
    ) -> Result<Option<Tensor<S, Self::E2, Self::D2>>, Self::Err>
       where GetRef: FnMut(&T) -> &Tensor<S, E, D>,
             GetMut: FnMut(&mut T) -> &mut Tensor<S, E, D>;
    fn visit_scalar<N, GetRef, GetMut>(
        &mut self,
        name: &str,
        get_refs: GetRef,
        get_muts: GetMut,
        opts: ScalarOptions<N>
    ) -> Result<Option<N>, Self::Err>
       where N: NumCast,
             GetRef: FnMut(&T) -> &N,
             GetMut: FnMut(&mut T) -> &mut N;
    fn visit_fields<M: ModuleFields<T, E, D>>(
        &mut self,
        fields: M,
        builder: impl FnOnce(M::Output<Self::E2, Self::D2>) -> T::To<Self::E2, Self::D2>
    ) -> Result<Option<T::To<Self::E2, Self::D2>>, Self::Err>;
}
Expand description

An object that can visit TensorCollections and Tensors recursively.

Required Associated Types§

source

type Err

source

type E2: Dtype

source

type D2: Device<Self::E2>

Required Methods§

source

fn visit_module<Field, GetRef, GetMut>( &mut self, name: &str, get_refs: GetRef, get_muts: GetMut ) -> Result<Option<Field::To<Self::E2, Self::D2>>, Self::Err>where GetRef: FnMut(&T) -> &Field, GetMut: FnMut(&mut T) -> &mut Field, Field: TensorCollection<E, D>,

Visit a TensorCollection. Do not use this; use visit_fields instead.

source

fn visit_tensor<S: Shape, GetRef, GetMut>( &mut self, name: &str, get_refs: GetRef, get_muts: GetMut, opts: TensorOptions<S, E, D> ) -> Result<Option<Tensor<S, Self::E2, Self::D2>>, Self::Err>where GetRef: FnMut(&T) -> &Tensor<S, E, D>, GetMut: FnMut(&mut T) -> &mut Tensor<S, E, D>,

Visits an actual named Tensor. Do not use this; use visit_fields instead.

source

fn visit_scalar<N, GetRef, GetMut>( &mut self, name: &str, get_refs: GetRef, get_muts: GetMut, opts: ScalarOptions<N> ) -> Result<Option<N>, Self::Err>where N: NumCast, GetRef: FnMut(&T) -> &N, GetMut: FnMut(&mut T) -> &mut N,

source

fn visit_fields<M: ModuleFields<T, E, D>>( &mut self, fields: M, builder: impl FnOnce(M::Output<Self::E2, Self::D2>) -> T::To<Self::E2, Self::D2> ) -> Result<Option<T::To<Self::E2, Self::D2>>, Self::Err>

Takes something that implements ModuleFields and function that takes ModuleFields::Output and returns an instance of T.

Implementors§

source§

impl<'a, T: TensorCollection<E, D>, E: Dtype, D: Device<E>, F: TensorVisitor<E, D>> ModuleVisitor<T, E, D> for RecursiveWalker<'a, <F::Viewer as TensorViewer>::View<'a, T>, F>

§

type Err = <F as TensorVisitor<E, D>>::Err

§

type E2 = <F as TensorVisitor<E, D>>::E2

§

type D2 = <F as TensorVisitor<E, D>>::D2