1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
use num_traits::NumCast;
use crate::{
shapes::{Dtype, Shape},
tensor::Tensor,
tensor_ops::Device,
};
use super::{ModuleVisitor, ScalarOptions, TensorCollection, TensorOptions};
/// A standard [ModuleVisitor] that executes `F` on every [Tensor] encountered.
/// `F` must implement [TensorVisitor]
#[derive(Debug)]
pub struct RecursiveWalker<'a, M, F> {
pub m: M,
pub f: &'a mut F,
}
/// Something that can visit [Tensor]s. Used in conjunction with [RecursiveWalker].
///
/// Example implementation to add two Modules together:
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// // A TensorVisitor that will add two Modules together, returning the resulting module.
/// struct Adder;
///
/// impl<E: Dtype, D: Device<E>> TensorVisitor<E, D> for Adder {
/// // Take a tuple of references to tensors
/// type Viewer = (ViewTensorRef, ViewTensorRef);
/// type Err = D::Err;
///
/// // Output with the device and dtype that are given
/// type E2 = E;
/// type D2 = D;
///
/// fn visit<S: Shape>(
/// &mut self,
/// opts: TensorOptions<S, E, D>,
/// (a, b): (&Tensor<S, E, D>, &Tensor<S, E, D>),
/// ) -> Result<Option<Tensor<S, Self::E2, Self::D2>>, Self::Err> {
/// // Returns Ok(Some(_)) to construct an output module. Return Ok(None) to not construct
/// // an output
/// Ok(Some(a.clone().try_add(b.clone())?))
/// }
/// }
///
/// type Model = Linear<2, 5>;
/// let model1 = dev.build_module::<Model, f32>();
/// let model2 = dev.build_module::<Model, f32>();
/// let model3 = TensorCollection::iter_tensors(&mut RecursiveWalker {
/// m: (&model1, &model2),
/// f: &mut Adder,
/// }).unwrap().unwrap();
///
/// assert_eq!(
/// (model1.weight.clone() + model2.weight.clone()).array(),
/// model3.weight.array()
/// );
/// assert_eq!(
/// (model1.bias.clone() + model2.bias.clone()).array(),
/// model3.bias.array()
/// );
/// ```
pub trait TensorVisitor<E: Dtype, D: Device<E>> {
/// The type of tensor this struct uses. E.g. [ViewTensorMut], or [ViewTensorRef]
type Viewer: TensorViewer;
type Err;
/// The dtype to output with
type E2: Dtype;
/// The device to output with
type D2: Device<Self::E2>;
/// What to do when visiting each Tensor. Return `Ok(None)` if this visitor should not
/// construct a new module each time it is used, and `Ok(Some(_))` if it should.
#[allow(clippy::type_complexity)]
fn visit<S: Shape>(
&mut self,
opts: TensorOptions<S, E, D>,
t: <Self::Viewer as TensorViewer>::View<'_, Tensor<S, E, D>>,
) -> Result<Option<Tensor<S, Self::E2, Self::D2>>, Self::Err>;
fn visit_scalar<N: NumCast>(
&mut self,
opts: ScalarOptions<N>,
_h: <Self::Viewer as TensorViewer>::View<'_, N>,
) -> Result<Option<N>, Self::Err> {
Ok(Some(opts.default))
}
}
/// Something that can view [Tensor]s in different ways. For example
/// [ViewTensorRef] can view `&Tensor`, and [ViewTensorMut] can view `&mut Tensor.
pub trait TensorViewer: 'static {
type View<'a, Mod: 'a>
where
Self: 'a;
/// Given a view of a module, returns a view of one of that module's fields
fn view_field<'a, Mod, Field, GetRef, GetMut>(
module: &'a mut Self::View<'_, Mod>,
name: &str,
get_ref: &mut GetRef,
get_mut: &mut GetMut,
) -> Self::View<'a, Field>
where
GetRef: FnMut(&Mod) -> &Field,
GetMut: FnMut(&mut Mod) -> &mut Field;
}
/// A list of a Module's fields. Used in [ModuleVisitor::visit_fields].
pub trait ModuleFields<M: TensorCollection<E, D>, E: Dtype, D: Device<E>> {
/// A list of optional instances of each field,
type Options<E2: Dtype, D2: Device<E2>>;
/// A list of instances of each field,
type Output<E2: Dtype, D2: Device<E2>>;
/// Calls [ModuleVisitor::visit_module] or [ModuleVisitor::visit_tensor] for each field,
/// and returns optionally constructed fields
fn visit_fields<V: ModuleVisitor<M, E, D>>(
self,
visitor: &mut V,
) -> Result<Self::Options<V::E2, V::D2>, V::Err>;
/// If any optional fields are None, returns None. Otherwise returns instances of all fields.
fn handle_options<E2: Dtype, D2: Device<E2>>(
options: Self::Options<E2, D2>,
) -> Option<Self::Output<E2, D2>>;
}
/// A [ModuleFields] that represents a field that contains one or more Tensors.
pub struct ModuleField<'a, F1, F2, Mod, Field>
where
F1: FnMut(&Mod) -> &Field,
F2: FnMut(&mut Mod) -> &mut Field,
{
pub(super) name: &'a str,
pub(super) get_ref: F1,
pub(super) get_mut: F2,
pub(super) m: std::marker::PhantomData<Mod>,
pub(super) f: std::marker::PhantomData<Field>,
}
/// A [ModuleFields] that represents a field that contains a single Tensor.
pub struct TensorField<'a, F1, F2, Mod, S: Shape, E: Dtype, D: Device<E>>
where
F1: FnMut(&Mod) -> &Tensor<S, E, D>,
F2: FnMut(&mut Mod) -> &mut Tensor<S, E, D>,
{
pub(super) name: &'a str,
pub(super) get_ref: F1,
pub(super) get_mut: F2,
pub(super) options: TensorOptions<S, E, D>,
pub(super) m: std::marker::PhantomData<Mod>,
}
/// A [ModuleFields] that represents a field that contains a scalar value that should be serialized.
pub struct ScalarField<'a, F1, F2, Mod, N>
where
N: NumCast,
F1: FnMut(&Mod) -> &N,
F2: FnMut(&mut Mod) -> &mut N,
{
pub(super) name: &'a str,
pub(super) get_ref: F1,
pub(super) get_mut: F2,
pub(super) options: ScalarOptions<N>,
pub(super) m: std::marker::PhantomData<Mod>,
}
/// A [TensorViewer] that represents a `&Tensor`
#[derive(Debug)]
pub enum ViewTensorRef {}
/// A [TensorViewer] that represents a `&mut Tensor`
#[derive(Debug)]
pub enum ViewTensorMut {}
/// A [TensorViewer] that represents a Tensor's name as a `String`
#[derive(Debug)]
pub enum ViewTensorName {}