autodiff/
gradienttype.rs

1use num::complex::Complex;
2use num::rational::Ratio;
3use num::{Integer, Num};
4use std::num::Wrapping;
5
6/// Compile-time calculation for what the gradient type should be based
7/// on input and output types.
8///
9/// For most cases, the gradient type is the same as the output type.
10/// However, for multi-parameter functions, like functions of arrays, this
11/// is not the case
12///
13/// `<InputType as GradientType<OutputType>>::GradientType`
14pub trait GradientType<OutputType> {
15    /// The type of the gradient for a function with input type `Self` and output type `OutputType`
16    type GradientType;
17}
18
19// macro to implement GradientType for a simple type
20// where the gradient type is the same as the output type
21macro_rules! impl_simple_types {
22    ($($type:ty),*) => {
23        $(
24            // impl for values
25            impl<T> GradientType<T> for $type {
26                type GradientType = T;
27            }
28        )*
29    };
30}
31
32// implement GradientType for all primitive types as well as complex numbers
33impl_simple_types!(
34    f32,
35    f64,
36    i8,
37    i16,
38    i32,
39    i64,
40    u8,
41    u16,
42    u32,
43    u64,
44    u128,
45    i128,
46    usize,
47    isize,
48    num::BigInt,
49    num::BigUint,
50    Complex<f32>,
51    Complex<f64>
52);
53
54// generic impls
55impl<T, U> GradientType<U> for Ratio<T>
56where
57    T: Integer + Clone,
58    U: Num + Clone,
59{
60    type GradientType = Ratio<U>;
61}
62
63impl<T, U> GradientType<U> for Wrapping<T>
64where
65    T: Integer + Clone,
66    U: Num + Clone,
67{
68    type GradientType = Wrapping<U>;
69}