1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231
#![allow(clippy::type_complexity)]
use crate::{
shapes::{ConstShape, Dtype, Shape},
tensor::{OneFillStorage, Tensor, ZeroFillStorage},
tensor_ops::Device,
};
use super::{ModuleField, ModuleFields, TensorField};
/// A collection of named tensors. Implementing this trait will enable anything
/// that operates on tensors, including resetting, counting number of params, updating gradients,
/// building model, and converting models between devices or dtypes.
///
/// Example implementation:
/// ```rust
/// # use dfdx::prelude::*;
/// use dfdx::nn::modules::Linear;
///
/// struct Mlp<E: Dtype, D: Device<E>> {
/// l1: Linear<10, 10, E, D>,
/// l2: Linear<10, 2, E, D>,
/// relu: ReLU,
/// }
///
/// impl<E: Dtype, D: Device<E>> TensorCollection<E, D> for Mlp<E, D> {
/// type To<E2: Dtype, D2: Device<E2>> = Mlp<E2, D2>;
///
/// fn iter_tensors<V: ModuleVisitor<Self, E, D>>(
/// visitor: &mut V,
/// ) -> Result<Option<Self::To<V::E2, V::D2>>, V::Err> {
/// visitor.visit_fields(
/// (
/// // Define name of each field and how to access it, using ModuleField for Modules,
/// // and TensorField for Tensors.
/// Self::module("l1", |s| &s.l1, |s| &mut s.l1),
/// Self::module("l2", |s| &s.l2, |s| &mut s.l2),
/// ),
/// // Define how to construct the collection given its fields in the order they are given
/// // above. This conversion is done using the ModuleFields trait.
/// |(l1, l2)| Mlp { l1, l2, relu: Default::default() }
/// )
/// }
/// }
///
/// let dev = Cpu::default();
/// let model = Mlp::<f32, Cpu>::build(&dev);
/// assert_eq!(132, model.num_trainable_params());
///
/// ```
pub trait TensorCollection<E: Dtype, D: Device<E>>: Sized {
/// Type alias that specifies the how a module's type changes when using a different dtype and/or
/// device.
type To<E2: Dtype, D2: Device<E2>>;
/// Specifies how to iterate through tensors or modules containted within this module, and how
/// to contruct this module given values for its fields. Returns `Err(_)` to indicate an error,
/// `Ok(None)` to indicate that there is no error and a module has not been built, and
/// `Ok(Some(_))` contains `Self::Output<E2, D2>`
fn iter_tensors<V: ModuleVisitor<Self, E, D>>(
visitor: &mut V,
) -> Result<Option<Self::To<V::E2, V::D2>>, V::Err>;
/// Creates a [ModuleFields] that represents a field that may contain one or more tensors.
///
/// See also: [ModuleField], [TensorCollection].
fn module<F1, F2, Field>(
name: &str,
get_ref: F1,
get_mut: F2,
) -> ModuleField<F1, F2, Self, Field>
where
F1: FnMut(&Self) -> &Field,
F2: FnMut(&mut Self) -> &mut Field,
Field: TensorCollection<E, D>,
{
ModuleField {
name,
get_ref,
get_mut,
m: Default::default(),
f: Default::default(),
}
}
/// Creates a [ModuleFields] that represents a tensor field.
///
/// See also: [TensorField], [TensorCollection], [TensorOptions].
fn tensor<F1, F2, S>(
name: &str,
get_ref: F1,
get_mut: F2,
options: TensorOptions<S, E, D>,
) -> TensorField<F1, F2, Self, S, E, D>
where
F1: FnMut(&Self) -> &Tensor<S, E, D>,
F2: FnMut(&mut Self) -> &mut Tensor<S, E, D>,
S: Shape,
{
TensorField {
name,
get_ref,
get_mut,
options,
m: Default::default(),
}
}
}
/// An object that can visit [TensorCollection]s and [Tensor]s recursively.
pub trait ModuleVisitor<T: TensorCollection<E, D>, E: Dtype, D: Device<E>>: Sized {
type Err;
type E2: Dtype;
type D2: Device<Self::E2>;
/// Visit a [TensorCollection]. Do not use this; use visit_fields instead.
fn visit_module<Field, GetRef, GetMut>(
&mut self,
name: &str,
get_refs: GetRef,
get_muts: GetMut,
) -> Result<Option<Field::To<Self::E2, Self::D2>>, Self::Err>
where
GetRef: FnMut(&T) -> &Field,
GetMut: FnMut(&mut T) -> &mut Field,
Field: TensorCollection<E, D>;
/// Visits an actual named [Tensor]. Do not use this; use visit_fields instead.
fn visit_tensor<S: Shape, GetRef, GetMut>(
&mut self,
name: &str,
get_refs: GetRef,
get_muts: GetMut,
opts: TensorOptions<S, E, D>,
) -> Result<Option<Tensor<S, Self::E2, Self::D2>>, Self::Err>
where
GetRef: FnMut(&T) -> &Tensor<S, E, D>,
GetMut: FnMut(&mut T) -> &mut Tensor<S, E, D>;
/// Takes something that implements [ModuleFields] and function that takes
/// [ModuleFields::Output] and returns an instance of T.
fn visit_fields<M: ModuleFields<T, E, D>>(
&mut self,
fields: M,
builder: impl FnOnce(M::Output<Self::E2, Self::D2>) -> T::To<Self::E2, Self::D2>,
) -> Result<Option<T::To<Self::E2, Self::D2>>, Self::Err>;
}
impl<S: ConstShape, E: Dtype, D: Device<E>> TensorCollection<E, D> for Tensor<S, E, D> {
type To<E2: Dtype, D2: Device<E2>> = Tensor<S, E2, D2>;
fn iter_tensors<V: ModuleVisitor<Self, E, D>>(
visitor: &mut V,
) -> Result<Option<Self::To<V::E2, V::D2>>, V::Err> {
visitor.visit_tensor(
"",
|s| s,
|s| s,
TensorOptions {
do_gradient_update: true,
reset: |_| Ok(()),
shape: Default::default(),
},
)
}
}
/// Options to change behavior of [ModuleVisitor]
#[non_exhaustive]
pub struct TensorOptions<S: Shape, E: Dtype, D: Device<E>> {
/// Whether the tensor should be updated with gradients
pub do_gradient_update: bool,
/// How to reset the tensor in the future.
pub reset: fn(&'_ mut Tensor<S, E, D>) -> Result<(), D::Err>,
/// The [Shape] that BuildModule uses to construct the tensor
pub shape: S,
}
impl<S: Shape, E: Dtype, D: Device<E>> TensorOptions<S, E, D> {
/// A tensor that should be updated with gradients & reset to 0
pub fn reset_to_zeros() -> Self
where
D: ZeroFillStorage<E>,
S: ConstShape,
{
TensorOptions {
do_gradient_update: true,
reset: |t| t.try_fill_with_zeros(),
shape: S::default(),
}
}
/// A tensor that should be updated with gradients & reset to 1
pub fn reset_to_ones() -> Self
where
D: OneFillStorage<E>,
S: ConstShape,
{
TensorOptions {
do_gradient_update: true,
reset: |t| t.try_fill_with_ones(),
shape: S::default(),
}
}
/// A tensor that should be updated with gradients & reset with the fn passed in
pub fn reset_with(reset: fn(&mut Tensor<S, E, D>) -> Result<(), D::Err>) -> Self
where
S: ConstShape,
{
TensorOptions {
do_gradient_update: true,
reset,
shape: S::default(),
}
}
/// A tensor that should **NOT** be updated with gradients & reset with the fn passed in
pub fn detached(reset: fn(&mut Tensor<S, E, D>) -> Result<(), D::Err>) -> Self
where
S: ConstShape,
{
TensorOptions {
do_gradient_update: false,
reset,
shape: S::default(),
}
}
}