Trait dfdx::nn::tensor_collection::TensorCollection
source · pub trait TensorCollection<E: Dtype, D: Device<E>>: Sized {
type To<E2: Dtype, D2: Device<E2>>;
// Required method
fn iter_tensors<V: ModuleVisitor<Self, E, D>>(
visitor: &mut V
) -> Result<Option<Self::To<V::E2, V::D2>>, V::Err>;
// Provided methods
fn module<F1, F2, Field>(
name: &str,
get_ref: F1,
get_mut: F2
) -> ModuleField<'_, F1, F2, Self, Field>
where F1: FnMut(&Self) -> &Field,
F2: FnMut(&mut Self) -> &mut Field,
Field: TensorCollection<E, D> { ... }
fn tensor<F1, F2, S>(
name: &str,
get_ref: F1,
get_mut: F2,
options: TensorOptions<S, E, D>
) -> TensorField<'_, F1, F2, Self, S, E, D>
where F1: FnMut(&Self) -> &Tensor<S, E, D>,
F2: FnMut(&mut Self) -> &mut Tensor<S, E, D>,
S: Shape { ... }
fn scalar<F1, F2, N>(
name: &str,
get_ref: F1,
get_mut: F2,
options: ScalarOptions<N>
) -> ScalarField<'_, F1, F2, Self, N>
where F1: FnMut(&Self) -> &N,
F2: FnMut(&mut Self) -> &mut N,
N: NumCast { ... }
}
Expand description
A collection of named tensors. Implementing this trait will enable anything that operates on tensors, including resetting, counting number of params, updating gradients, building model, and converting models between devices or dtypes.
Example implementation:
use dfdx::nn::modules::Linear;
struct Mlp<E: Dtype, D: Device<E>> {
l1: Linear<10, 10, E, D>,
l2: Linear<10, 2, E, D>,
relu: ReLU,
}
impl<E: Dtype + num_traits::Float + rand_distr::uniform::SampleUniform, D: Device<E>> TensorCollection<E, D> for Mlp<E, D> {
type To<E2: Dtype, D2: Device<E2>> = Mlp<E2, D2>;
fn iter_tensors<V: ModuleVisitor<Self, E, D>>(
visitor: &mut V,
) -> Result<Option<Self::To<V::E2, V::D2>>, V::Err> {
visitor.visit_fields(
(
// Define name of each field and how to access it, using ModuleField for Modules,
// and TensorField for Tensors.
Self::module("l1", |s| &s.l1, |s| &mut s.l1),
Self::module("l2", |s| &s.l2, |s| &mut s.l2),
),
// Define how to construct the collection given its fields in the order they are given
// above. This conversion is done using the ModuleFields trait.
|(l1, l2)| Mlp { l1, l2, relu: Default::default() }
)
}
}
let dev = Cpu::default();
let model = Mlp::<f32, Cpu>::build(&dev);
assert_eq!(132, model.num_trainable_params());
Required Associated Types§
Required Methods§
sourcefn iter_tensors<V: ModuleVisitor<Self, E, D>>(
visitor: &mut V
) -> Result<Option<Self::To<V::E2, V::D2>>, V::Err>
fn iter_tensors<V: ModuleVisitor<Self, E, D>>( visitor: &mut V ) -> Result<Option<Self::To<V::E2, V::D2>>, V::Err>
Specifies how to iterate through tensors or modules containted within this module, and how
to contruct this module given values for its fields. Returns Err(_)
to indicate an error,
Ok(None)
to indicate that there is no error and a module has not been built, and
Ok(Some(_))
contains Self::Output<E2, D2>
Provided Methods§
sourcefn module<F1, F2, Field>(
name: &str,
get_ref: F1,
get_mut: F2
) -> ModuleField<'_, F1, F2, Self, Field>where
F1: FnMut(&Self) -> &Field,
F2: FnMut(&mut Self) -> &mut Field,
Field: TensorCollection<E, D>,
fn module<F1, F2, Field>( name: &str, get_ref: F1, get_mut: F2 ) -> ModuleField<'_, F1, F2, Self, Field>where F1: FnMut(&Self) -> &Field, F2: FnMut(&mut Self) -> &mut Field, Field: TensorCollection<E, D>,
Creates a ModuleFields that represents a field that may contain one or more tensors.
See also: ModuleField, TensorCollection.
sourcefn tensor<F1, F2, S>(
name: &str,
get_ref: F1,
get_mut: F2,
options: TensorOptions<S, E, D>
) -> TensorField<'_, F1, F2, Self, S, E, D>where
F1: FnMut(&Self) -> &Tensor<S, E, D>,
F2: FnMut(&mut Self) -> &mut Tensor<S, E, D>,
S: Shape,
fn tensor<F1, F2, S>( name: &str, get_ref: F1, get_mut: F2, options: TensorOptions<S, E, D> ) -> TensorField<'_, F1, F2, Self, S, E, D>where F1: FnMut(&Self) -> &Tensor<S, E, D>, F2: FnMut(&mut Self) -> &mut Tensor<S, E, D>, S: Shape,
Creates a ModuleFields that represents a tensor field.
See also: TensorField, TensorCollection, TensorOptions.
sourcefn scalar<F1, F2, N>(
name: &str,
get_ref: F1,
get_mut: F2,
options: ScalarOptions<N>
) -> ScalarField<'_, F1, F2, Self, N>where
F1: FnMut(&Self) -> &N,
F2: FnMut(&mut Self) -> &mut N,
N: NumCast,
fn scalar<F1, F2, N>( name: &str, get_ref: F1, get_mut: F2, options: ScalarOptions<N> ) -> ScalarField<'_, F1, F2, Self, N>where F1: FnMut(&Self) -> &N, F2: FnMut(&mut Self) -> &mut N, N: NumCast,
Creates a ModuleFields that represents a scalar field.
See also: TensorField, TensorCollection, TensorOptions.