1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
use num_traits::NumCast;

use crate::{
    shapes::{Dtype, Shape},
    tensor::Tensor,
    tensor_ops::Device,
};

use super::{ModuleVisitor, ScalarOptions, TensorCollection, TensorOptions};

/// A standard [ModuleVisitor] that executes `F` on every [Tensor] encountered.
/// `F` must implement [TensorVisitor]
#[derive(Debug)]
pub struct RecursiveWalker<'a, M, F> {
    pub m: M,
    pub f: &'a mut F,
}

/// Something that can visit [Tensor]s. Used in conjunction with [RecursiveWalker].
///
/// Example implementation to add two Modules together:
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// // A TensorVisitor that will add two Modules together, returning the resulting module.
/// struct Adder;
///
/// impl<E: Dtype, D: Device<E>> TensorVisitor<E, D> for Adder {
///     // Take a tuple of references to tensors
///     type Viewer = (ViewTensorRef, ViewTensorRef);
///     type Err = D::Err;
///
///     // Output with the device and dtype that are given
///     type E2 = E;
///     type D2 = D;
///
///     fn visit<S: Shape>(
///         &mut self,
///         opts: TensorOptions<S, E, D>,
///         (a, b): (&Tensor<S, E, D>, &Tensor<S, E, D>),
///     ) -> Result<Option<Tensor<S, Self::E2, Self::D2>>, Self::Err> {
///         // Returns Ok(Some(_)) to construct an output module. Return Ok(None) to not construct
///         // an output
///         Ok(Some(a.clone().try_add(b.clone())?))
///     }
/// }
///
/// type Model = Linear<2, 5>;
/// let model1 = dev.build_module::<Model, f32>();
/// let model2 = dev.build_module::<Model, f32>();
/// let model3 = TensorCollection::iter_tensors(&mut RecursiveWalker {
///     m: (&model1, &model2),
///     f: &mut Adder,
/// }).unwrap().unwrap();
///
/// assert_eq!(
///     (model1.weight.clone() + model2.weight.clone()).array(),
///     model3.weight.array()
/// );
/// assert_eq!(
///     (model1.bias.clone() + model2.bias.clone()).array(),
///     model3.bias.array()
/// );
/// ```
pub trait TensorVisitor<E: Dtype, D: Device<E>> {
    /// The type of tensor this struct uses. E.g. [ViewTensorMut], or [ViewTensorRef]
    type Viewer: TensorViewer;
    type Err;
    /// The dtype to output with
    type E2: Dtype;
    /// The device to output with
    type D2: Device<Self::E2>;

    /// What to do when visiting each Tensor. Return `Ok(None)` if this visitor should not
    /// construct a new module each time it is used, and `Ok(Some(_))` if it should.
    #[allow(clippy::type_complexity)]
    fn visit<S: Shape>(
        &mut self,
        opts: TensorOptions<S, E, D>,
        t: <Self::Viewer as TensorViewer>::View<'_, Tensor<S, E, D>>,
    ) -> Result<Option<Tensor<S, Self::E2, Self::D2>>, Self::Err>;

    fn visit_scalar<N: NumCast>(
        &mut self,
        opts: ScalarOptions<N>,
        _h: <Self::Viewer as TensorViewer>::View<'_, N>,
    ) -> Result<Option<N>, Self::Err> {
        Ok(Some(opts.default))
    }
}

/// Something that can view [Tensor]s in different ways. For example
/// [ViewTensorRef] can view `&Tensor`, and [ViewTensorMut] can view `&mut Tensor.
pub trait TensorViewer: 'static {
    type View<'a, Mod: 'a>
    where
        Self: 'a;

    /// Given a view of a module, returns a view of one of that module's fields
    fn view_field<'a, Mod, Field, GetRef, GetMut>(
        module: &'a mut Self::View<'_, Mod>,
        name: &str,
        get_ref: &mut GetRef,
        get_mut: &mut GetMut,
    ) -> Self::View<'a, Field>
    where
        GetRef: FnMut(&Mod) -> &Field,
        GetMut: FnMut(&mut Mod) -> &mut Field;
}

/// A list of a Module's fields. Used in [ModuleVisitor::visit_fields].
pub trait ModuleFields<M: TensorCollection<E, D>, E: Dtype, D: Device<E>> {
    /// A list of optional instances of each field,
    type Options<E2: Dtype, D2: Device<E2>>;

    /// A list of instances of each field,
    type Output<E2: Dtype, D2: Device<E2>>;

    /// Calls [ModuleVisitor::visit_module] or [ModuleVisitor::visit_tensor] for each field,
    /// and returns optionally constructed fields
    fn visit_fields<V: ModuleVisitor<M, E, D>>(
        self,
        visitor: &mut V,
    ) -> Result<Self::Options<V::E2, V::D2>, V::Err>;

    /// If any optional fields are None, returns None. Otherwise returns instances of all fields.
    fn handle_options<E2: Dtype, D2: Device<E2>>(
        options: Self::Options<E2, D2>,
    ) -> Option<Self::Output<E2, D2>>;
}

/// A [ModuleFields] that represents a field that contains one or more Tensors.
pub struct ModuleField<'a, F1, F2, Mod, Field>
where
    F1: FnMut(&Mod) -> &Field,
    F2: FnMut(&mut Mod) -> &mut Field,
{
    pub(super) name: &'a str,
    pub(super) get_ref: F1,
    pub(super) get_mut: F2,
    pub(super) m: std::marker::PhantomData<Mod>,
    pub(super) f: std::marker::PhantomData<Field>,
}

/// A [ModuleFields] that represents a field that contains a single Tensor.
pub struct TensorField<'a, F1, F2, Mod, S: Shape, E: Dtype, D: Device<E>>
where
    F1: FnMut(&Mod) -> &Tensor<S, E, D>,
    F2: FnMut(&mut Mod) -> &mut Tensor<S, E, D>,
{
    pub(super) name: &'a str,
    pub(super) get_ref: F1,
    pub(super) get_mut: F2,
    pub(super) options: TensorOptions<S, E, D>,
    pub(super) m: std::marker::PhantomData<Mod>,
}

/// A [ModuleFields] that represents a field that contains a scalar value that should be serialized.
pub struct ScalarField<'a, F1, F2, Mod, N>
where
    N: NumCast,
    F1: FnMut(&Mod) -> &N,
    F2: FnMut(&mut Mod) -> &mut N,
{
    pub(super) name: &'a str,
    pub(super) get_ref: F1,
    pub(super) get_mut: F2,
    pub(super) options: ScalarOptions<N>,
    pub(super) m: std::marker::PhantomData<Mod>,
}

/// A [TensorViewer] that represents a `&Tensor`
#[derive(Debug)]
pub enum ViewTensorRef {}

/// A [TensorViewer] that represents a `&mut Tensor`
#[derive(Debug)]
pub enum ViewTensorMut {}

/// A [TensorViewer] that represents a Tensor's name as a `String`
#[derive(Debug)]
pub enum ViewTensorName {}