autodiff/gradienttype.rs
1use num::complex::Complex;
2use num::rational::Ratio;
3use num::{Integer, Num};
4use std::num::Wrapping;
5
6/// Compile-time calculation for what the gradient type should be based
7/// on input and output types.
8///
9/// For most cases, the gradient type is the same as the output type.
10/// However, for multi-parameter functions, like functions of arrays, this
11/// is not the case
12///
13/// `<InputType as GradientType<OutputType>>::GradientType`
14pub trait GradientType<OutputType> {
15 /// The type of the gradient for a function with input type `Self` and output type `OutputType`
16 type GradientType;
17}
18
19// macro to implement GradientType for a simple type
20// where the gradient type is the same as the output type
21macro_rules! impl_simple_types {
22 ($($type:ty),*) => {
23 $(
24 // impl for values
25 impl<T> GradientType<T> for $type {
26 type GradientType = T;
27 }
28 )*
29 };
30}
31
32// implement GradientType for all primitive types as well as complex numbers
33impl_simple_types!(
34 f32,
35 f64,
36 i8,
37 i16,
38 i32,
39 i64,
40 u8,
41 u16,
42 u32,
43 u64,
44 u128,
45 i128,
46 usize,
47 isize,
48 num::BigInt,
49 num::BigUint,
50 Complex<f32>,
51 Complex<f64>
52);
53
54// generic impls
55impl<T, U> GradientType<U> for Ratio<T>
56where
57 T: Integer + Clone,
58 U: Num + Clone,
59{
60 type GradientType = Ratio<U>;
61}
62
63impl<T, U> GradientType<U> for Wrapping<T>
64where
65 T: Integer + Clone,
66 U: Num + Clone,
67{
68 type GradientType = Wrapping<U>;
69}