Trait dfdx::nn::tensor_collection::ModuleVisitor
source · pub trait ModuleVisitor<T: TensorCollection<E, D>, E: Dtype, D: Device<E>>: Sized {
type Err;
type E2: Dtype;
type D2: Device<Self::E2>;
// Required methods
fn visit_module<Field, GetRef, GetMut>(
&mut self,
name: &str,
get_refs: GetRef,
get_muts: GetMut
) -> Result<Option<Field::To<Self::E2, Self::D2>>, Self::Err>
where GetRef: FnMut(&T) -> &Field,
GetMut: FnMut(&mut T) -> &mut Field,
Field: TensorCollection<E, D>;
fn visit_tensor<S: Shape, GetRef, GetMut>(
&mut self,
name: &str,
get_refs: GetRef,
get_muts: GetMut,
opts: TensorOptions<S, E, D>
) -> Result<Option<Tensor<S, Self::E2, Self::D2>>, Self::Err>
where GetRef: FnMut(&T) -> &Tensor<S, E, D>,
GetMut: FnMut(&mut T) -> &mut Tensor<S, E, D>;
fn visit_fields<M: ModuleFields<T, E, D>>(
&mut self,
fields: M,
builder: impl FnOnce(M::Output<Self::E2, Self::D2>) -> T::To<Self::E2, Self::D2>
) -> Result<Option<T::To<Self::E2, Self::D2>>, Self::Err>;
}
Expand description
An object that can visit TensorCollections and Tensors recursively.
Required Associated Types§
Required Methods§
sourcefn visit_module<Field, GetRef, GetMut>(
&mut self,
name: &str,
get_refs: GetRef,
get_muts: GetMut
) -> Result<Option<Field::To<Self::E2, Self::D2>>, Self::Err>where
GetRef: FnMut(&T) -> &Field,
GetMut: FnMut(&mut T) -> &mut Field,
Field: TensorCollection<E, D>,
fn visit_module<Field, GetRef, GetMut>( &mut self, name: &str, get_refs: GetRef, get_muts: GetMut ) -> Result<Option<Field::To<Self::E2, Self::D2>>, Self::Err>where GetRef: FnMut(&T) -> &Field, GetMut: FnMut(&mut T) -> &mut Field, Field: TensorCollection<E, D>,
Visit a TensorCollection. Do not use this; use visit_fields instead.
sourcefn visit_tensor<S: Shape, GetRef, GetMut>(
&mut self,
name: &str,
get_refs: GetRef,
get_muts: GetMut,
opts: TensorOptions<S, E, D>
) -> Result<Option<Tensor<S, Self::E2, Self::D2>>, Self::Err>where
GetRef: FnMut(&T) -> &Tensor<S, E, D>,
GetMut: FnMut(&mut T) -> &mut Tensor<S, E, D>,
fn visit_tensor<S: Shape, GetRef, GetMut>( &mut self, name: &str, get_refs: GetRef, get_muts: GetMut, opts: TensorOptions<S, E, D> ) -> Result<Option<Tensor<S, Self::E2, Self::D2>>, Self::Err>where GetRef: FnMut(&T) -> &Tensor<S, E, D>, GetMut: FnMut(&mut T) -> &mut Tensor<S, E, D>,
Visits an actual named Tensor. Do not use this; use visit_fields instead.
sourcefn visit_fields<M: ModuleFields<T, E, D>>(
&mut self,
fields: M,
builder: impl FnOnce(M::Output<Self::E2, Self::D2>) -> T::To<Self::E2, Self::D2>
) -> Result<Option<T::To<Self::E2, Self::D2>>, Self::Err>
fn visit_fields<M: ModuleFields<T, E, D>>( &mut self, fields: M, builder: impl FnOnce(M::Output<Self::E2, Self::D2>) -> T::To<Self::E2, Self::D2> ) -> Result<Option<T::To<Self::E2, Self::D2>>, Self::Err>
Takes something that implements ModuleFields and function that takes ModuleFields::Output and returns an instance of T.