autodiff/
autotuple.rs

1use crate::traits::{InstOne, InstZero};
2use num::complex::Complex;
3use num::traits::{Num, NumOps, One, Pow, Signed, Zero};
4use paste::paste;
5use std::ops::{Add, Deref, Div, Mul, Neg, Rem, Sub};
6use crate::gradienttype::GradientType;
7use crate::forward::ForwardMul;
8
9#[derive(Debug, Clone, Copy, PartialEq)]
10pub struct AutoTuple<Tuple>(pub Tuple)
11where
12    Tuple: Clone + PartialEq;
13
14impl<Tuple> AutoTuple<Tuple>
15where
16    Tuple: Clone + PartialEq,
17{
18    pub fn new(tuple: Tuple) -> Self {
19        Self(tuple)
20    }
21}
22
23impl<Tuple> Deref for AutoTuple<Tuple>
24where
25    Tuple: Clone + PartialEq,
26{
27    type Target = Tuple;
28
29    fn deref(&self) -> &Self::Target {
30        &self.0
31    }
32}
33
34// impl Num to always return an error
35impl<Tuple> Num for AutoTuple<Tuple>
36where
37    Tuple: Clone + PartialEq,
38    AutoTuple<Tuple>: NumOps + One + Zero + PartialEq,
39{
40    type FromStrRadixErr = ();
41
42    fn from_str_radix(_str: &str, _radix: u32) -> Result<Self, Self::FromStrRadixErr> {
43        Err(())
44    }
45}
46
47// impl From for autotuple
48macro_rules! autotuple_from {
49    ($($idx:literal),+) =>
50    {
51        paste! {
52            impl<$([<T $idx>],)+> From<($([<T $idx>],)+)> for AutoTuple<($([<T $idx>],)+)>
53            where
54                ($([<T $idx>],)+): Clone + PartialEq,
55            {
56                fn from(tuple: ($([<T $idx>],)+)) -> Self {
57                    AutoTuple::new(tuple)
58                }
59            }
60        }
61    }
62}
63
64// implement From for autotuples up to length 16
65autotuple_from!(0);
66autotuple_from!(0, 1);
67autotuple_from!(0, 1, 2);
68autotuple_from!(0, 1, 2, 3);
69autotuple_from!(0, 1, 2, 3, 4);
70autotuple_from!(0, 1, 2, 3, 4, 5);
71autotuple_from!(0, 1, 2, 3, 4, 5, 6);
72autotuple_from!(0, 1, 2, 3, 4, 5, 6, 7);
73autotuple_from!(0, 1, 2, 3, 4, 5, 6, 7, 8);
74autotuple_from!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9);
75autotuple_from!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
76autotuple_from!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);
77autotuple_from!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12);
78autotuple_from!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13);
79autotuple_from!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14);
80autotuple_from!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
81
82// macro for implementing AutoTuple From<T> for T, used for constants
83// so AutoTuple::From(f32) -> AutoTuple<(f32,)>
84// i.e. AutoTuple::From(f32) = AutoTuple::From((f32,))
85
86macro_rules! autotuple_from_primitive {
87    ($($type:ty),+) =>
88    {
89        $(
90            impl From<$type> for AutoTuple<($type,)> {
91                fn from(t: $type) -> Self {
92                    AutoTuple::new((t,))
93                }
94            }
95        )+
96    }
97}
98
99autotuple_from_primitive!(
100    f32,
101    f64,
102    i8,
103    i16,
104    i32,
105    i64,
106    i128,
107    isize,
108    u8,
109    u16,
110    u32,
111    u64,
112    u128,
113    usize,
114    Complex<f32>,
115    Complex<f64>
116);
117
118// macro for implementing binary operations between autotuples with
119// the same number of elements
120// example: autotuple_binary_op!(Add, add, 0, 1, 2);
121// implements Add<AutoTuple<(U0, U1, U2)>> for AutoTuple<(T0, T1, T2)>
122// using paste! to make new types T0, T1, T2, etc
123
124macro_rules! autotuple_binary_op {
125    ($trt:ident, $mth:ident, $($idx:literal),+) =>
126    {
127        paste! {
128            // AutoTuple op AutoTuple
129            impl<$([<T $idx>],)+ $([<U $idx>],)+> $trt<AutoTuple<($([<U $idx>],)+)>> for AutoTuple<($([<T $idx>],)+)>
130            where
131                $([<T $idx>]: $trt<[<U $idx>], Output=[<T $idx>]>,)+
132                ($([<T $idx>],)+): Clone + PartialEq,
133                ($([<U $idx>],)+): Clone + PartialEq,
134            {
135                type Output = AutoTuple<($([<T $idx>],)+)>;
136
137                fn $mth(self, rhs: AutoTuple<($([<U $idx>],)+)>) -> Self::Output {
138                    AutoTuple::new(($( self.0.$idx.$mth(rhs.0.$idx), )+))
139                }
140            }
141            // AutoTuple op (U0, U1, U2)
142            impl<$([<T $idx>],)+ $([<U $idx>],)+> $trt<($([<U $idx>],)+)> for AutoTuple<($([<T $idx>],)+)>
143            where
144                $([<T $idx>]: $trt<[<U $idx>], Output=[<T $idx>]>,)+
145                ($([<T $idx>],)+): Clone + PartialEq,
146                $([<U $idx>]: Clone + PartialEq,)+
147            {
148                type Output = AutoTuple<($([<T $idx>],)+)>;
149
150                fn $mth(self, rhs: ($([<U $idx>],)+)) -> Self::Output {
151                    AutoTuple::new(($( self.0.$idx.$mth(rhs.$idx.clone()), )+))
152                }
153            }
154        }
155    }
156}
157
158// macro for implementing all binary ops between autotuples with
159// specified length
160macro_rules! autotuple_binary_ops {
161    ($($idx:literal),+) =>
162    {
163        autotuple_binary_op!(Add, add, $($idx),+);
164        autotuple_binary_op!(Sub, sub, $($idx),+);
165        autotuple_binary_op!(Mul, mul, $($idx),+);
166        autotuple_binary_op!(Div, div, $($idx),+);
167        autotuple_binary_op!(Rem, rem, $($idx),+);
168        autotuple_binary_op!(Pow, pow, $($idx),+);
169    }
170}
171
172// implement all binary ops for autotuples up to length 16
173autotuple_binary_ops!(0);
174autotuple_binary_ops!(0, 1);
175autotuple_binary_ops!(0, 1, 2);
176autotuple_binary_ops!(0, 1, 2, 3);
177autotuple_binary_ops!(0, 1, 2, 3, 4);
178autotuple_binary_ops!(0, 1, 2, 3, 4, 5);
179autotuple_binary_ops!(0, 1, 2, 3, 4, 5, 6);
180autotuple_binary_ops!(0, 1, 2, 3, 4, 5, 6, 7);
181autotuple_binary_ops!(0, 1, 2, 3, 4, 5, 6, 7, 8);
182autotuple_binary_ops!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9);
183autotuple_binary_ops!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
184autotuple_binary_ops!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);
185autotuple_binary_ops!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12);
186autotuple_binary_ops!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13);
187autotuple_binary_ops!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14);
188autotuple_binary_ops!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
189
190// macro for operations between autotuples of any length and
191// autotuples of length 1
192// so that it can do constant operations on AutoTuples of any length
193
194macro_rules! autotuple_const_op {
195    ($trt:ident, $mth:ident, $($idx:literal),+) =>
196    {
197        paste! {
198            // AutoTuple op AutoTuple
199            impl<$([<T $idx>],)+ U, $([<V $idx>],)+> $trt<AutoTuple<(U,)>> for AutoTuple<($([<T $idx>],)+)>
200            where
201                $([<T $idx>]: $trt<U, Output=[<V $idx>]>,)+
202                ($([<T $idx>],)+): Clone + PartialEq,
203                ($([<V $idx>],)+): Clone + PartialEq,
204                U: Clone + PartialEq,
205            {
206                type Output = AutoTuple<($([<V $idx>],)+)>;
207
208                fn $mth(self, rhs: AutoTuple<(U,)>) -> Self::Output {
209                    AutoTuple::new(($( self.0.$idx.$mth(rhs.0.0.clone()), )+))
210                }
211            }
212
213            impl<$([<U $idx>],)+ T, $([<V $idx>],)+> $trt<AutoTuple<($([<U $idx>],)+)>> for AutoTuple<(T,)>
214            where
215                $(T: $trt<[<U $idx>], Output=[<V $idx>]>,)+
216                ($([<U $idx>],)+): Clone + PartialEq,
217                ($([<V $idx>],)+): Clone + PartialEq,
218                T: Clone + PartialEq,
219            {
220                type Output = AutoTuple<($([<V $idx>],)+)>;
221
222                fn $mth(self, rhs: AutoTuple<($([<U $idx>],)+)>) -> Self::Output {
223                    AutoTuple::new(($( self.0.0.clone().$mth(rhs.0.$idx), )+))
224                }
225            }
226        }
227    }
228}
229
230// macro for implementing all binary ops between autotuples with
231// specified length
232macro_rules! autotuple_const_ops {
233    ($($idx:literal),+) =>
234    {
235        autotuple_const_op!(Add, add, $($idx),+);
236        autotuple_const_op!(Sub, sub, $($idx),+);
237        autotuple_const_op!(Mul, mul, $($idx),+);
238        autotuple_const_op!(Div, div, $($idx),+);
239        autotuple_const_op!(Rem, rem, $($idx),+);
240        autotuple_const_op!(Pow, pow, $($idx),+);
241    }
242}
243
244// implement all binary ops for autotuples up to length 16
245//autotuple_const_ops!(0);
246autotuple_const_ops!(0, 1);
247autotuple_const_ops!(0, 1, 2);
248autotuple_const_ops!(0, 1, 2, 3);
249autotuple_const_ops!(0, 1, 2, 3, 4);
250autotuple_const_ops!(0, 1, 2, 3, 4, 5);
251autotuple_const_ops!(0, 1, 2, 3, 4, 5, 6);
252autotuple_const_ops!(0, 1, 2, 3, 4, 5, 6, 7);
253autotuple_const_ops!(0, 1, 2, 3, 4, 5, 6, 7, 8);
254autotuple_const_ops!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9);
255autotuple_const_ops!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
256autotuple_const_ops!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);
257autotuple_const_ops!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12);
258autotuple_const_ops!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13);
259autotuple_const_ops!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14);
260autotuple_const_ops!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
261
262// macro for implementing unary operations on autotuples
263macro_rules! autotuple_unary_ops {
264    ($($idx:literal),+) =>
265    {
266        paste! {
267            impl<$([<T $idx>],)+> Neg for AutoTuple<($([<T $idx>],)+)>
268            where
269                $([<T $idx>]: Neg<Output=[<T $idx>]>,)+
270                ($([<T $idx>],)+): Clone + PartialEq,
271            {
272                type Output = AutoTuple<($([<T $idx>],)+)>;
273
274                fn neg(self) -> Self::Output {
275                    AutoTuple::new(($( self.0.$idx.neg(), )+))
276                }
277            }
278            impl<$([<T $idx>],)+> Zero for AutoTuple<($([<T $idx>],)+)>
279            where
280                $([<T $idx>]: Zero + Add<[<T $idx>], Output=[<T $idx>]>,)+
281                ($([<T $idx>],)+): Clone + PartialEq,
282                AutoTuple<($([<T $idx>],)+)>: Add<AutoTuple<($([<T $idx>],)+)>, Output=AutoTuple<($([<T $idx>],)+)>>,
283            {
284                fn zero() -> Self {
285                    AutoTuple::new(($([<T $idx>]::zero(), )+))
286                }
287                fn is_zero(&self) -> bool {
288                    $(self.0.$idx.is_zero() && )+ true
289                }
290            }
291            impl<$([<T $idx>],)+> InstZero for AutoTuple<($([<T $idx>],)+)>
292            where
293                $([<T $idx>]: InstZero,)+
294                ($([<T $idx>],)+): Clone + PartialEq,
295                AutoTuple<($([<T $idx>],)+)>: Add<AutoTuple<($([<T $idx>],)+)>, Output=AutoTuple<($([<T $idx>],)+)>>,
296            {
297                fn zero(&self) -> Self {
298                    AutoTuple::new(($( self.0.$idx.zero(), )+))
299                }
300                fn is_zero(&self) -> bool {
301                    $(self.0.$idx.is_zero() && )+ true
302                }
303            }
304            impl<$([<T $idx>],)+> One for AutoTuple<($([<T $idx>],)+)>
305            where
306                $([<T $idx>]: One + Mul<[<T $idx>], Output=[<T $idx>]>,)+
307                ($([<T $idx>],)+): Clone + PartialEq,
308                AutoTuple<($([<T $idx>],)+)>: Mul<AutoTuple<($([<T $idx>],)+)>, Output=AutoTuple<($([<T $idx>],)+)>>,
309            {
310                fn one() -> Self {
311                    AutoTuple::new(($( [<T $idx>]::one(), )+))
312                }
313            }
314            impl<$([<T $idx>],)+> InstOne for AutoTuple<($([<T $idx>],)+)>
315            where
316                $([<T $idx>]: InstOne + Mul<[<T $idx>], Output=[<T $idx>]>,)+
317                ($([<T $idx>],)+): Clone + PartialEq,
318                AutoTuple<($([<T $idx>],)+)>: Mul<AutoTuple<($([<T $idx>],)+)>, Output=AutoTuple<($([<T $idx>],)+)>>,
319            {
320                fn one(&self) -> Self {
321                    AutoTuple::new(($( self.0.$idx.one(), )+))
322                }
323            }
324            impl<$([<T $idx>],)+> Signed for AutoTuple<($([<T $idx>],)+)>
325            where
326                $([<T $idx>]: Signed + Num + Neg<Output = [<T $idx>]>,)+
327                ($([<T $idx>],)+): Clone + PartialEq,
328                AutoTuple<($([<T $idx>],)+)>: NumOps,
329            {
330                fn abs(&self) -> Self {
331                    AutoTuple::new(($( self.0.$idx.abs(), )+))
332                }
333                fn abs_sub(&self, other: &Self) -> Self {
334                    AutoTuple::new(($( self.0.$idx.abs_sub(&other.0.$idx), )+))
335                }
336                fn signum(&self) -> Self {
337                    AutoTuple::new(($( self.0.$idx.signum(), )+))
338                }
339                fn is_positive(&self) -> bool {
340                    $(self.0.$idx.is_positive() && )+ true
341                }
342                fn is_negative(&self) -> bool {
343                    $(self.0.$idx.is_negative() && )+ true
344                }
345            }
346        }
347    }
348}
349
350// implement for tuples of length 1-16
351autotuple_unary_ops!(0);
352autotuple_unary_ops!(0, 1);
353autotuple_unary_ops!(0, 1, 2);
354autotuple_unary_ops!(0, 1, 2, 3);
355autotuple_unary_ops!(0, 1, 2, 3, 4);
356autotuple_unary_ops!(0, 1, 2, 3, 4, 5);
357autotuple_unary_ops!(0, 1, 2, 3, 4, 5, 6);
358autotuple_unary_ops!(0, 1, 2, 3, 4, 5, 6, 7);
359autotuple_unary_ops!(0, 1, 2, 3, 4, 5, 6, 7, 8);
360autotuple_unary_ops!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9);
361autotuple_unary_ops!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
362autotuple_unary_ops!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);
363autotuple_unary_ops!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12);
364autotuple_unary_ops!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13);
365autotuple_unary_ops!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14);
366autotuple_unary_ops!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
367
368
369// macro for GradientType for AutoTuple
370macro_rules! autotuple_gradient_type {
371    ($($idx:literal),+) => {
372        paste! {
373
374            // size n input size n output, size n gradient
375            impl<$([<T $idx>],)+ $([<U $idx>],)+ $([<G $idx>],)+> GradientType<AutoTuple<($([<U $idx>],)+)>> for AutoTuple<($([<T $idx>],)+)>
376            where
377                $([<T $idx>]: GradientType<[<U $idx>], GradientType = [<G $idx>]>,)+
378                ($([<T $idx>],)+): Clone + PartialEq,
379                ($([<U $idx>],)+): Clone + PartialEq,
380                ($([<G $idx>],)+): Clone + PartialEq,
381            {
382                type GradientType = AutoTuple<($([<G $idx>],)+)>;
383            }
384        }
385    }
386}
387
388macro_rules! size_1_autotuple_gradient_type {
389    ($($idx:literal),+) => {
390        paste! {
391
392            // size 1 input size n output, size n gradient
393            impl<T, $([<U $idx>],)+ $([<G $idx>],)+> GradientType<AutoTuple<($([<U $idx>],)+)>> for AutoTuple<(T,)>
394            where
395                $(T: GradientType<[<U $idx>], GradientType = [<G $idx>]>,)+
396                (T,): Clone + PartialEq,
397                ($([<U $idx>],)+): Clone + PartialEq,
398                ($([<G $idx>],)+): Clone + PartialEq,
399            {
400                type GradientType = AutoTuple<($([<G $idx>],)+)>;
401            }
402
403            // size n input size 1 output, size n gradient
404            impl<$([<T $idx>],)+ U, $([<G $idx>],)+> GradientType<AutoTuple<(U,)>> for AutoTuple<($([<T $idx>],)+)>
405            where
406                $([<T $idx>]: GradientType<U, GradientType = [<G $idx>]>,)+
407                ($([<T $idx>],)+): Clone + PartialEq,
408                (U,): Clone + PartialEq,
409                ($([<G $idx>],)+): Clone + PartialEq,
410            {
411                type GradientType = AutoTuple<($([<G $idx>],)+)>;
412            }
413        }
414    }
415}
416
417// implement for tuples of length 1-16
418autotuple_gradient_type!(0);
419autotuple_gradient_type!(0, 1);
420autotuple_gradient_type!(0, 1, 2);
421autotuple_gradient_type!(0, 1, 2, 3);
422autotuple_gradient_type!(0, 1, 2, 3, 4);
423autotuple_gradient_type!(0, 1, 2, 3, 4, 5);
424autotuple_gradient_type!(0, 1, 2, 3, 4, 5, 6);
425autotuple_gradient_type!(0, 1, 2, 3, 4, 5, 6, 7);
426autotuple_gradient_type!(0, 1, 2, 3, 4, 5, 6, 7, 8);
427autotuple_gradient_type!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9);
428autotuple_gradient_type!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
429autotuple_gradient_type!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);
430autotuple_gradient_type!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12);
431autotuple_gradient_type!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13);
432autotuple_gradient_type!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14);
433autotuple_gradient_type!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
434size_1_autotuple_gradient_type!(0, 1);
435size_1_autotuple_gradient_type!(0, 1, 2);
436size_1_autotuple_gradient_type!(0, 1, 2, 3);
437size_1_autotuple_gradient_type!(0, 1, 2, 3, 4);
438size_1_autotuple_gradient_type!(0, 1, 2, 3, 4, 5);
439size_1_autotuple_gradient_type!(0, 1, 2, 3, 4, 5, 6);
440size_1_autotuple_gradient_type!(0, 1, 2, 3, 4, 5, 6, 7);
441size_1_autotuple_gradient_type!(0, 1, 2, 3, 4, 5, 6, 7, 8);
442size_1_autotuple_gradient_type!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9);
443size_1_autotuple_gradient_type!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
444size_1_autotuple_gradient_type!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);
445size_1_autotuple_gradient_type!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12);
446size_1_autotuple_gradient_type!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13);
447size_1_autotuple_gradient_type!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14);
448size_1_autotuple_gradient_type!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
449
450// macro to implement Default for tuples of length 1-16 whose elements implement Default
451macro_rules! autotuple_default {
452    ($($idx:literal),+) => {
453        paste! {
454            impl<$([<T $idx>],)+> Default for AutoTuple<($([<T $idx>],)+)>
455            where
456                ($([<T $idx>],)+): Clone + PartialEq,
457                $([<T $idx>]: Default,)+
458            {
459                fn default() -> Self {
460                    Self::new(($([<T $idx>]::default(),)+))
461                }
462            }
463        }
464    }
465}
466
467// implement Default for tuples of length 1-16
468autotuple_default!(0);
469autotuple_default!(0, 1);
470autotuple_default!(0, 1, 2);
471autotuple_default!(0, 1, 2, 3);
472autotuple_default!(0, 1, 2, 3, 4);
473autotuple_default!(0, 1, 2, 3, 4, 5);
474autotuple_default!(0, 1, 2, 3, 4, 5, 6);
475autotuple_default!(0, 1, 2, 3, 4, 5, 6, 7);
476autotuple_default!(0, 1, 2, 3, 4, 5, 6, 7, 8);
477autotuple_default!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9);
478autotuple_default!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
479autotuple_default!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);
480autotuple_default!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12);
481autotuple_default!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13);
482autotuple_default!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14);
483autotuple_default!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
484
485// macro to implement ForwardMul for tuples of length 1-16
486macro_rules! autotuple_forward_mul {
487    ($($idx:literal),+) => {
488        paste! {
489            impl<$([<S $idx>],)+ $([<I $idx>],)+ $([<OG $idx>],)+ $([<RG $idx>],)+> ForwardMul<
490                    AutoTuple<($([<I $idx>],)+)>,
491                    AutoTuple<($([<OG $idx>],)+)>,
492                    >
493                for AutoTuple<($([<S $idx>],)+)>
494            where
495                ($([<S $idx>],)+): Clone + PartialEq,
496                ($([<I $idx>],)+): Clone + PartialEq,
497                ($([<OG $idx>],)+): Clone + PartialEq,
498                ($([<RG $idx>],)+): Clone + PartialEq,
499                $(
500                    [<S $idx>] : ForwardMul<[<I $idx>], [<OG $idx>], ResultGrad = [<RG $idx>]>,
501                )+
502            {
503                type ResultGrad = AutoTuple<($([<RG $idx>],)+)>;
504                fn forward_mul(&self, other: &AutoTuple<($([<OG $idx>],)+)>) -> Self::ResultGrad {
505                    AutoTuple::new(($(
506                                (&self.0).$idx.forward_mul(&(*other).0.$idx),
507                    )+))
508                }
509            }
510        }
511    }
512}
513
514// implement ForwardMul for tuples of length 1-16
515autotuple_forward_mul!(0);
516autotuple_forward_mul!(0, 1);
517autotuple_forward_mul!(0, 1, 2);
518autotuple_forward_mul!(0, 1, 2, 3);
519autotuple_forward_mul!(0, 1, 2, 3, 4);
520autotuple_forward_mul!(0, 1, 2, 3, 4, 5);
521autotuple_forward_mul!(0, 1, 2, 3, 4, 5, 6);
522autotuple_forward_mul!(0, 1, 2, 3, 4, 5, 6, 7);
523autotuple_forward_mul!(0, 1, 2, 3, 4, 5, 6, 7, 8);
524autotuple_forward_mul!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9);
525autotuple_forward_mul!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
526autotuple_forward_mul!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);
527autotuple_forward_mul!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12);
528autotuple_forward_mul!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13);
529autotuple_forward_mul!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14);
530autotuple_forward_mul!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
531
532// forward mul for size 1 AutoTuples on size 1-16 AutoTuples
533macro_rules! size_1_autotuple_forward_mul {
534    ($($idx:literal),+) => {
535        paste! {
536
537            // size-1 I
538            impl<$([<S $idx>],)+ I, $([<OG $idx>],)+ $([<RG $idx>],)+> ForwardMul<
539                AutoTuple<(I,)>,
540                //AutoTuple<($([<O $idx>],)+)>,
541                AutoTuple<($([<OG $idx>],)+)>,
542                //AutoTuple<($([<RG $idx>],)+)>,
543                >
544                for AutoTuple<($([<S $idx>],)+)>
545            where
546                (I,): Clone + PartialEq,
547                ($([<S $idx>],)+): Clone + PartialEq,
548                //($([<O $idx>],)+): Clone + PartialEq,
549                ($([<OG $idx>],)+): Clone + PartialEq,
550                ($([<RG $idx>],)+): Clone + PartialEq,
551                $(
552                    //I : GradientType<[<O $idx>], GradientType = [<S $idx>]>,
553                    [<S $idx>] : ForwardMul<I, [<OG $idx>], ResultGrad = [<RG $idx>]>,
554                )+
555            {
556                type ResultGrad = AutoTuple<($([<RG $idx>],)+)>;
557                fn forward_mul(&self, other: &AutoTuple<($([<OG $idx>],)+)>) -> Self::ResultGrad {
558                    AutoTuple::new(($(
559                            self.0.$idx.forward_mul(&(*other).0.$idx),
560                    )+))
561                }
562            }
563
564            // size-1 OG
565            impl<$([<S $idx>],)+ $([<I $idx>],)+ OG, $([<RG $idx>],)+> ForwardMul<
566                AutoTuple<($([<I $idx>],)+)>,
567                AutoTuple<(OG,)>,
568                >
569                for AutoTuple<($([<S $idx>],)+)>
570            where
571                ($([<I $idx>],)+): Clone + PartialEq,
572                ($([<S $idx>],)+): Clone + PartialEq,
573                (OG,): Clone + PartialEq,
574                ($([<RG $idx>],)+): Clone + PartialEq,
575                $(
576                    [<S $idx>] : ForwardMul<[<I $idx>], OG, ResultGrad = [<RG $idx>]>,
577                )+
578            {
579                type ResultGrad = AutoTuple<($([<RG $idx>],)+)>;
580                fn forward_mul(&self, other: &AutoTuple<(OG,)>) -> Self::ResultGrad {
581                    AutoTuple::new(($(
582                            self.0.$idx.forward_mul(&(*other).0.0),
583                    )+))
584                }
585            }
586
587            /*
588            // size-1 O and RG
589            impl<$([<S $idx>],)+ $([<I $idx>],)+ O, $([<OG $idx>],)+ RG> ForwardMul<
590                AutoTuple<($([<I $idx>],)+)>,
591                AutoTuple<(O,)>,
592                AutoTuple<($([<OG $idx>],)+)>,
593                AutoTuple<(RG,)>,
594                >
595                for AutoTuple<($([<S $idx>],)+)>
596            where
597                S0: Clone,
598                OG0: Clone,
599                ($([<I $idx>],)+): Clone + PartialEq,
600                ($([<S $idx>],)+): Clone + PartialEq,
601                (O,): Clone + PartialEq,
602                ($([<OG $idx>],)+): Clone + PartialEq,
603                (RG,): Clone + PartialEq,
604                $(
605                    [<I $idx>] : GradientType<O, GradientType = [<S $idx>]>,
606                    [<S $idx>] : ForwardMul<[<I $idx>], O, [<OG $idx>], RG>,
607                )+
608                RG: Add<RG, Output = RG> + InstZero,
609            {
610                fn forward_mul(&self, other: &AutoTuple<($([<OG $idx>],)+)>) -> AutoTuple<(RG,)> {
611                    // forward mul between each pair then sum
612                    let sum = self.0.0.clone().forward_mul(&(*other).0.0.clone()).zero()
613                        $( +
614                            self.0.$idx.forward_mul(&(*other).0.$idx)
615                        )+;
616
617                    AutoTuple::new((sum,))
618                }
619            }*/
620        }
621    }
622}
623
624// implement ForwardMul for tuples of length 1-16
625//size_1_autotuple_forward_mul!(0);
626size_1_autotuple_forward_mul!(0, 1);
627size_1_autotuple_forward_mul!(0, 1, 2);
628size_1_autotuple_forward_mul!(0, 1, 2, 3);
629size_1_autotuple_forward_mul!(0, 1, 2, 3, 4);
630size_1_autotuple_forward_mul!(0, 1, 2, 3, 4, 5);
631size_1_autotuple_forward_mul!(0, 1, 2, 3, 4, 5, 6);
632size_1_autotuple_forward_mul!(0, 1, 2, 3, 4, 5, 6, 7);
633size_1_autotuple_forward_mul!(0, 1, 2, 3, 4, 5, 6, 7, 8);
634size_1_autotuple_forward_mul!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9);
635size_1_autotuple_forward_mul!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
636size_1_autotuple_forward_mul!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);
637size_1_autotuple_forward_mul!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12);
638size_1_autotuple_forward_mul!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13);
639size_1_autotuple_forward_mul!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14);
640size_1_autotuple_forward_mul!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
641
642#[test]
643fn test_autotuple() {
644    let a = AutoTuple::new((1u32, 1.0_f64));
645    let b_tup = (2u32, -1.0_f64);
646    let b = AutoTuple::new(b_tup);
647    let c1 = a + b;
648    let c2 = a + b_tup;
649    assert_eq!(c1, AutoTuple::new((3, 0.0)));
650    assert_eq!(c2, AutoTuple::new((3, 0.0)));
651    assert_eq!((*c1).0, 3);
652    assert_eq!((*c1).1, 0.0);
653
654    let d = AutoTuple::new((-1.0_f64, 1.0_f64));
655    let cnst: AutoTuple<(f64,)> = 2.0_f64.into();
656    let e1 = d * cnst;
657    let e2 = d + cnst;
658    assert_eq!(e1, AutoTuple::new((-2.0, 2.0)));
659    assert_eq!(e2, AutoTuple::new((1.0, 3.0)));
660}