autodiff/
traits.rs

1use num::complex::Complex;
2use num::rational::Ratio;
3use num::{Integer, Num, One, Zero};
4use std::num::Wrapping;
5use std::ops::{Add, Mul};
6use crate::gradienttype::GradientType;
7
8pub trait InstZero: Sized + Add<Self, Output = Self> {
9    // required methods
10    fn zero(&self) -> Self;
11
12    fn is_zero(&self) -> bool;
13
14    // provided methods
15    fn set_zero(&mut self) {
16        *self = self.zero();
17    }
18}
19
20pub trait InstOne: Sized + Mul<Self, Output = Self> {
21    // required methods
22    /// Returns the multiplicative identity of Self, 1.
23    /// i.e. self * self.one() == self
24    fn one(&self) -> Self;
25
26
27    // provided methods
28    fn set_one(&mut self) {
29        *self = self.one();
30    }
31
32    fn is_one(&self) -> bool
33    where
34        Self: PartialEq,
35    {
36        *self == self.one()
37    }
38}
39
40pub trait GradientIdentity: Sized
41where
42    Self: GradientType<Self>
43{
44    /// Returns the gradient identity for a function with input Self,
45    /// output Self, and gradient type `<Self as GradientType<Self>>::GradientType`
46    /// for primitive types, this is the same as one()
47    /// for Array types, this is more complicated
48    ///
49    /// This cannot be implemented for complex numbers
50    fn grad_identity(&self) -> <Self as GradientType<Self>>::GradientType;
51}
52
53// implementation for InstZero for all the types that implement Zero from num
54// u32, i128, i16, u128, f64, usize, i32, i8, f32, i64, u16, Wrapping<T: Zero>, isize, u8, u64,
55// BigInt, BigUint, Ratio<T: Integer>, Complex<T: Num>
56
57// a macro to do this for any type that can call ::zero()
58macro_rules! impl_zero {
59    ($($t:ty),*) => ($(
60        impl InstZero for $t {
61            fn zero(&self) -> $t {
62                <$t as Zero>::zero()
63            }
64
65            fn is_zero(&self) -> bool {
66                <$t as Zero>::is_zero(self)
67            }
68        }
69    )*)
70}
71
72impl_zero!(u32, i128, i16, u128, f64, usize, i32, i8, f32, i64, u16, isize, u8, u64);
73impl_zero!(num::BigInt, num::BigUint);
74
75// generic implementations done here
76impl<T> InstZero for Wrapping<T>
77where
78    T: Zero,
79    Wrapping<T>: Add<Wrapping<T>, Output = Wrapping<T>>,
80{
81    fn zero(&self) -> Wrapping<T> {
82        <Wrapping<T> as Zero>::zero()
83    }
84
85    fn is_zero(&self) -> bool {
86        <Wrapping<T> as Zero>::is_zero(self)
87    }
88}
89impl<T> InstZero for Ratio<T>
90where
91    T: Clone + Integer,
92{
93    fn zero(&self) -> Ratio<T> {
94        <Ratio<T> as Zero>::zero()
95    }
96
97    fn is_zero(&self) -> bool {
98        <Ratio<T> as Zero>::is_zero(self)
99    }
100}
101impl<T> InstZero for Complex<T>
102where
103    T: Clone + Num,
104{
105    fn zero(&self) -> Complex<T> {
106        <Complex<T> as Zero>::zero()
107    }
108
109    fn is_zero(&self) -> bool {
110        <Complex<T> as Zero>::is_zero(self)
111    }
112}
113
114// implementation for InstOne for all the types that implement One from num
115// Wrapping<T: One>, i64, u128, f32, u16, u32, i16, f64, isize, i32, u8, u64, usize, i128, i8
116// BigInt, BigUint, Ratio<T: Integer>, Complex<T: Num>
117
118// a macro to do this for any type that can call ::one()
119macro_rules! impl_one {
120    ($($t:ty),*) => ($(
121        impl InstOne for $t {
122            fn one(&self) -> $t {
123                <$t as One>::one()
124            }
125        }
126    )*)
127}
128
129impl_one!(i64, u128, f32, u16, u32, i16, f64, isize, i32, u8, u64, usize, i128, i8);
130impl_one!(num::BigInt, num::BigUint);
131
132// generic implementations done here
133impl<T> InstOne for Wrapping<T>
134where
135    T: One,
136    Wrapping<T>: Mul<Wrapping<T>, Output = Wrapping<T>>,
137{
138    fn one(&self) -> Wrapping<T> {
139        <Wrapping<T> as One>::one()
140    }
141}
142
143impl<T> InstOne for Ratio<T>
144where
145    T: Clone + Integer,
146{
147    fn one(&self) -> Ratio<T> {
148        <Ratio<T> as One>::one()
149    }
150}
151
152impl<T> InstOne for Complex<T>
153where
154    T: Clone + Num,
155{
156    fn one(&self) -> Complex<T> {
157        <Complex<T> as One>::one()
158    }
159}
160
161// macro for implementing gradient identity for primitive types that implement InstOne
162macro_rules! impl_grad_identity {
163    ($($t:ty),*) => ($(
164        impl<G> GradientIdentity for $t
165        where
166            Self: GradientType<Self, GradientType = G>,
167            G: One,
168        {
169            fn grad_identity(&self) -> G
170            {
171                G::one()
172            }
173        }
174    )*)
175}
176
177impl_grad_identity!(i64, u128, f32, u16, u32, i16, f64, isize, i32, u8, u64, usize, i128, i8);
178impl_grad_identity!(num::BigInt, num::BigUint);
179
180// generic implementations done here
181impl<T, G> GradientIdentity for Wrapping<T>
182where
183    Self: GradientType<Self, GradientType = G>,
184    G: One,
185{
186    fn grad_identity(&self) -> G {
187        G::one()
188    }
189}
190
191impl<T, G> GradientIdentity for Ratio<T>
192where
193    Self: GradientType<Self, GradientType = G>,
194    T: Clone + Integer,
195    G: One,
196{
197    fn grad_identity(&self) -> G {
198        G::one()
199    }
200}
201
202impl<T, G> GradientIdentity for Complex<T>
203where
204    Self: GradientType<Self, GradientType = G>,
205    T: Clone + Num,
206    G: One,
207{
208    fn grad_identity(&self) -> G {
209        G::one()
210    }
211}