Trait dfdx::nn::tensor_collection::TensorVisitor
source · pub trait TensorVisitor<E: Dtype, D: Device<E>> {
type Viewer: TensorViewer;
type Err;
type E2: Dtype;
type D2: Device<Self::E2>;
// Required method
fn visit<S: Shape>(
&mut self,
opts: TensorOptions<S, E, D>,
t: <Self::Viewer as TensorViewer>::View<'_, Tensor<S, E, D>>
) -> Result<Option<Tensor<S, Self::E2, Self::D2>>, Self::Err>;
// Provided method
fn visit_scalar<N: NumCast>(
&mut self,
opts: ScalarOptions<N>,
_h: <Self::Viewer as TensorViewer>::View<'_, N>
) -> Result<Option<N>, Self::Err> { ... }
}
Expand description
Something that can visit Tensors. Used in conjunction with RecursiveWalker.
Example implementation to add two Modules together:
// A TensorVisitor that will add two Modules together, returning the resulting module.
struct Adder;
impl<E: Dtype, D: Device<E>> TensorVisitor<E, D> for Adder {
// Take a tuple of references to tensors
type Viewer = (ViewTensorRef, ViewTensorRef);
type Err = D::Err;
// Output with the device and dtype that are given
type E2 = E;
type D2 = D;
fn visit<S: Shape>(
&mut self,
opts: TensorOptions<S, E, D>,
(a, b): (&Tensor<S, E, D>, &Tensor<S, E, D>),
) -> Result<Option<Tensor<S, Self::E2, Self::D2>>, Self::Err> {
// Returns Ok(Some(_)) to construct an output module. Return Ok(None) to not construct
// an output
Ok(Some(a.clone().try_add(b.clone())?))
}
}
type Model = Linear<2, 5>;
let model1 = dev.build_module::<Model, f32>();
let model2 = dev.build_module::<Model, f32>();
let model3 = TensorCollection::iter_tensors(&mut RecursiveWalker {
m: (&model1, &model2),
f: &mut Adder,
}).unwrap().unwrap();
assert_eq!(
(model1.weight.clone() + model2.weight.clone()).array(),
model3.weight.array()
);
assert_eq!(
(model1.bias.clone() + model2.bias.clone()).array(),
model3.bias.array()
);
Required Associated Types§
sourcetype Viewer: TensorViewer
type Viewer: TensorViewer
The type of tensor this struct uses. E.g. ViewTensorMut, or ViewTensorRef
type Err
Required Methods§
sourcefn visit<S: Shape>(
&mut self,
opts: TensorOptions<S, E, D>,
t: <Self::Viewer as TensorViewer>::View<'_, Tensor<S, E, D>>
) -> Result<Option<Tensor<S, Self::E2, Self::D2>>, Self::Err>
fn visit<S: Shape>( &mut self, opts: TensorOptions<S, E, D>, t: <Self::Viewer as TensorViewer>::View<'_, Tensor<S, E, D>> ) -> Result<Option<Tensor<S, Self::E2, Self::D2>>, Self::Err>
What to do when visiting each Tensor. Return Ok(None)
if this visitor should not
construct a new module each time it is used, and Ok(Some(_))
if it should.