pub trait TensorVisitor<E: Dtype, D: Device<E>> {
    type Viewer: TensorViewer;
    type Err;
    type E2: Dtype;
    type D2: Device<Self::E2>;

    // Required method
    fn visit<S: Shape>(
        &mut self,
        opts: TensorOptions<S, E, D>,
        t: <Self::Viewer as TensorViewer>::View<'_, Tensor<S, E, D>>
    ) -> Result<Option<Tensor<S, Self::E2, Self::D2>>, Self::Err>;

    // Provided method
    fn visit_scalar<N: NumCast>(
        &mut self,
        opts: ScalarOptions<N>,
        _h: <Self::Viewer as TensorViewer>::View<'_, N>
    ) -> Result<Option<N>, Self::Err> { ... }
}
Expand description

Something that can visit Tensors. Used in conjunction with RecursiveWalker.

Example implementation to add two Modules together:

// A TensorVisitor that will add two Modules together, returning the resulting module.
struct Adder;

impl<E: Dtype, D: Device<E>> TensorVisitor<E, D> for Adder {
    // Take a tuple of references to tensors
    type Viewer = (ViewTensorRef, ViewTensorRef);
    type Err = D::Err;

    // Output with the device and dtype that are given
    type E2 = E;
    type D2 = D;

    fn visit<S: Shape>(
        &mut self,
        opts: TensorOptions<S, E, D>,
        (a, b): (&Tensor<S, E, D>, &Tensor<S, E, D>),
    ) -> Result<Option<Tensor<S, Self::E2, Self::D2>>, Self::Err> {
        // Returns Ok(Some(_)) to construct an output module. Return Ok(None) to not construct
        // an output
        Ok(Some(a.clone().try_add(b.clone())?))
    }
}

type Model = Linear<2, 5>;
let model1 = dev.build_module::<Model, f32>();
let model2 = dev.build_module::<Model, f32>();
let model3 = TensorCollection::iter_tensors(&mut RecursiveWalker {
    m: (&model1, &model2),
    f: &mut Adder,
}).unwrap().unwrap();

assert_eq!(
    (model1.weight.clone() + model2.weight.clone()).array(),
    model3.weight.array()
);
assert_eq!(
    (model1.bias.clone() + model2.bias.clone()).array(),
    model3.bias.array()
);

Required Associated Types§

source

type Viewer: TensorViewer

The type of tensor this struct uses. E.g. ViewTensorMut, or ViewTensorRef

source

type Err

source

type E2: Dtype

The dtype to output with

source

type D2: Device<Self::E2>

The device to output with

Required Methods§

source

fn visit<S: Shape>( &mut self, opts: TensorOptions<S, E, D>, t: <Self::Viewer as TensorViewer>::View<'_, Tensor<S, E, D>> ) -> Result<Option<Tensor<S, Self::E2, Self::D2>>, Self::Err>

What to do when visiting each Tensor. Return Ok(None) if this visitor should not construct a new module each time it is used, and Ok(Some(_)) if it should.

Provided Methods§

source

fn visit_scalar<N: NumCast>( &mut self, opts: ScalarOptions<N>, _h: <Self::Viewer as TensorViewer>::View<'_, N> ) -> Result<Option<N>, Self::Err>

Implementations on Foreign Types§

source§

impl<E: Dtype, D: Device<E>, M> TensorVisitor<E, D> for (&mut Sgd<M, E, D>, &Gradients<E, D>, UnusedTensors)

§

type Viewer = ViewTensorMut

§

type Err = <D as HasErr>::Err

§

type E2 = E

§

type D2 = D

source§

fn visit<S: Shape>( &mut self, opts: TensorOptions<S, E, D>, p: &mut Tensor<S, E, D> ) -> Result<Option<Tensor<S, E, D>>, Self::Err>

source§

impl<'data, E: Dtype + SafeDtype, D: Device<E>> TensorVisitor<E, D> for SafeTensors<'data>

§

type Viewer = (ViewTensorMut, ViewTensorName)

§

type Err = Error

§

type E2 = E

§

type D2 = D

source§

fn visit<S: Shape>( &mut self, _: TensorOptions<S, E, D>, (t, full_path): (&mut Tensor<S, E, D>, String) ) -> Result<Option<Tensor<S, E, D>>, Self::Err>

source§

fn visit_scalar<N: NumCast>( &mut self, opts: ScalarOptions<N>, (n, full_path): (&mut N, String) ) -> Result<Option<N>, Self::Err>

source§

impl<R: Read + Seek, E: Dtype + NumpyDtype, D: Device<E>> TensorVisitor<E, D> for ZipArchive<R>

§

type Viewer = (ViewTensorMut, ViewTensorName)

§

type Err = NpzError

§

type E2 = E

§

type D2 = D

source§

fn visit<S: Shape>( &mut self, _: TensorOptions<S, E, D>, (t, full_path): (&mut Tensor<S, E, D>, String) ) -> Result<Option<Tensor<S, E, D>>, Self::Err>

source§

fn visit_scalar<N: NumCast>( &mut self, opts: ScalarOptions<N>, (n, full_path): (&mut N, String) ) -> Result<Option<N>, Self::Err>

source§

impl<W: Write + Seek, E: Dtype + NumpyDtype, D: Device<E>> TensorVisitor<E, D> for ZipWriter<W>

§

type Viewer = (ViewTensorRef, ViewTensorName)

§

type Err = ZipError

§

type E2 = E

§

type D2 = D

source§

fn visit<S: Shape>( &mut self, _: TensorOptions<S, E, D>, (t, full_path): (&Tensor<S, E, D>, String) ) -> Result<Option<Tensor<S, E, D>>, Self::Err>

source§

fn visit_scalar<N: NumCast>( &mut self, _opts: ScalarOptions<N>, (n, full_path): (&N, String) ) -> Result<Option<N>, Self::Err>

source§

impl<M, E: Dtype, D: Device<E>> TensorVisitor<E, D> for (&mut RMSprop<M, E, D>, &Gradients<E, D>, UnusedTensors)

§

type Viewer = ViewTensorMut

§

type Err = <D as HasErr>::Err

§

type E2 = E

§

type D2 = D

source§

fn visit<S: Shape>( &mut self, opts: TensorOptions<S, E, D>, p: &mut Tensor<S, E, D> ) -> Result<Option<Tensor<S, E, D>>, Self::Err>

source§

impl<M, D: Device<E>, E: Dtype> TensorVisitor<E, D> for (&mut Adam<M, E, D>, &Gradients<E, D>, UnusedTensors)

§

type Viewer = ViewTensorMut

§

type Err = <D as HasErr>::Err

§

type E2 = E

§

type D2 = D

source§

fn visit<S: Shape>( &mut self, opts: TensorOptions<S, E, D>, p: &mut Tensor<S, E, D> ) -> Result<Option<Tensor<S, E, D>>, Self::Err>

Implementors§