epsilon/
lib.rs

1//! # `epsilon` - Fast autograd using dual numbers
2//!
3//! Dual numbers are a straightforward awy of doing forward gradient
4//! propagation, i.e. keep track of derivatives for all expressions, to
5//! automatically differentiate a function without storing a computation graph.
6//!
7//! Using dual numbers, one can augment their numbers with a "dual" part,
8//! representing the derivative of the term with respect to some input variable.
9//! The input variable has a unit dual part of 1, and each resulting expression
10//! has a dual part with the derivaite up to that point.
11//!
12//! This can be trivially extended to multiple variables by storing one dual
13//! part per input variable.
14//!
15//! One can find a more in-depth at [Wikipedia](https://en.wikipedia.org/wiki/Dual_number)
16//!
17//! ---
18//!
19//! This crate statically generates code for dual numbers using macros, meaning
20//! one can provide names for each dependent variable, and corresponding
21//! methods will be generated with names reflecting the name of the variable.
22//!
23//! The interface exposed by the types generated by this crate is very similar
24//! to that of the standard numerical rust types, meaning most code using
25//! [f64](f64) should be very straightforward to convert to using dual numbers.
26//!
27//! Example usage:
28//!
29//! ```
30//! use epsilon::make_dual;
31//! // We want to compute dz/dx and dz/dy for z = x^2+y*sin(y) at x=5, y=7
32//! make_dual! { MyDual, x, y } // Create a dual number with terms `x` and `y`
33//!
34//! let (x, y) = (MyDual::x(5.), MyDual::y(7.)); // Perform the calculations, and compute the derivative at x=5, y=7
35//!
36//! let z = x.powf(2.) + y * y.sin();
37//!
38//! let dzdx = z.d_dx();
39//! let dzdy = z.d_dy();
40//!
41//! assert_eq!(dzdx, 10.); // 2 * x
42//! assert_eq!(dzdy, 5.934302379121921); // y*cos(y) + sin(y)
43//! ```
44
45// Reexport to access from macros
46#[doc(hidden)]
47pub use paste::paste;
48
49/// Simple trait specifying the minimum functionality of a real-like number
50/// All dual types will implement this trait, making it useful to make code
51/// generic between dual and real numbers
52pub trait Numerical:
53    Copy
54    + std::fmt::Debug
55    + std::fmt::Display
56    + std::ops::Add<Output = Self>
57    + std::ops::Sub<Output = Self>
58    + std::ops::Mul<Output = Self>
59    + std::ops::Div<Output = Self>
60    + std::ops::AddAssign
61    + std::ops::SubAssign
62    + std::ops::MulAssign
63    + std::ops::DivAssign
64{
65    fn powf(self, pow: f64) -> Self;
66    fn invert(self) -> Self;
67    fn sin(self) -> Self;
68    fn cos(self) -> Self;
69    fn tan(self) -> Self;
70}
71
72impl Numerical for f64 {
73    fn powf(self, pow: f64) -> Self {
74        f64::powf(self, pow)
75    }
76
77    fn invert(self) -> Self {
78        1. / self
79    }
80
81    fn sin(self) -> Self {
82        f64::sin(self)
83    }
84
85    fn cos(self) -> Self {
86        f64::cos(self)
87    }
88
89    fn tan(self) -> Self {
90        f64::tan(self)
91    }
92}
93
94
95#[macro_export]
96/// # Create a dual number
97/// `$name` specifies the name of the type, $inner specifies the backing type
98/// (either `f32 `or `f64`, defaults to `f64`), and each `$comp` is a dual
99/// compoment of the type.
100///
101/// For example, `make_dual! { SampleXYZ: f64, x, y, z, }` will generate a struct
102/// ```
103/// struct SampleXYZ {
104///     real: f64,
105///     eps_x: f64,
106///     eps_y: f64,
107///     eps_z: f64,
108/// }
109/// ```
110///
111/// An instance can be created using either the `$comp_eps` function, giving a
112/// specified real and dual part, or using the `$comp` function, giving a
113/// specified real part and a unit dual part.
114///
115/// All other functions such as `sin`, and trait implementations such as `Add`
116/// will all propagate the dual parts of the numbers.
117///
118/// Functions with discontinuous intervals or points (such as `abs` at 0, or
119/// `sqrt` at negative numbers) may panic if applied at a discontinuous point.
120/// These methods all have a `try_`-prefix variant returning an `Option<Self>`.
121
122macro_rules! make_dual {
123    ($name:ident, $($comp:ident),+) => { make_dual!{ $name: f64, $($comp,)+ } };
124    ($name:ident, $($comp:ident,)+) => { make_dual!{ $name: f64, $($comp,)+ } };
125    ($name:ident: $inner:ty, $($comp:ident),+) => { make_dual!{ $name: $inner, $($comp,)+ } };
126    ($name:ident: $inner:ty, $($comp:ident,)+) => { $crate::paste! {
127        macro_rules! impl_reverse {
128            ($t:ty, $op:ident, $fn:ident) => {
129                impl std::ops::$op<$t> for $inner {
130                    type Output = $t;
131
132                    fn $fn(self, other: $t) -> Self::Output {
133                        <$t as std::ops::$op>::$fn(<$t as From<$inner>>::from(self), other)
134                    }
135                }
136            };
137        }
138
139        macro_rules! impl_inplace {
140            ($t:ty, $op_inplace:ident, $fn_inplace:ident, $op_outofplace:ident, $fn_outofplace:ident) => {
141                impl std::ops::$op_inplace<$inner> for $t {
142                    fn $fn_inplace(&mut self, other: $inner) {
143                        *self = <Self as std::ops::$op_outofplace<$inner>>::$fn_outofplace(*self, other);
144                    }
145                }
146                impl std::ops::$op_inplace<$t> for $t {
147                    fn $fn_inplace(&mut self, other: $t) {
148                        *self = <Self as std::ops::$op_outofplace<$t>>::$fn_outofplace(*self, other);
149                    }
150                }
151            };
152        }
153
154        /// Dual type
155        #[derive(Copy, Clone, PartialEq, Debug)]
156        pub struct $name {
157            /// The real value of the dual type
158            pub real: $inner,
159            $(
160                /// Dual component
161                pub [< eps_ $comp >]: $inner,
162            )+
163        }
164
165        impl std::cmp::PartialOrd for $name {
166            fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
167                self.real.partial_cmp(&other.real)
168            }
169        }
170
171        impl std::fmt::Display for $name {
172            fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
173                write!(fmt, "{}", self.real)?;
174                $(
175                    let v = self.[< eps_ $comp >];
176                    if v < 0. {
177                        write!(fmt, " - {} eps_{}", -self.[< eps_ $comp >], stringify!($comp))?;
178                    } else if v > 0. {
179                        write!(fmt, " + {} eps_{}", self.[< eps_ $comp >], stringify!($comp))?;
180                    }
181                )+
182                Ok(())
183            }
184        }
185
186        impl From<$inner> for $name {
187            fn from(real: $inner) -> Self {
188                $name {
189                    real,
190                    $(
191                        [<eps_$comp>]: 0.,
192                    )+
193                }
194            }
195        }
196
197        impl $name {
198            $(
199                /// Create instance with specified real and dual part
200                pub fn [<eps_ $comp>](real: $inner, [<eps_ $comp>]: $inner) -> Self {
201                    Self {
202                        [<eps_ $comp>]: [<eps_ $comp>],
203                        .. Self::from(real)
204                    }
205                }
206            )+
207            $(
208                /// Create instance with specified real part and unit dual part
209                pub fn $comp(real: $inner) -> Self {
210                    Self::[<eps_ $comp>](real, 1.)
211                }
212            )+
213
214            $(
215                /// Derivative with respect to component
216                /// Shorthand for self.eps_$comp
217                pub fn [<d_d $comp>](self) -> $inner {
218                    self.[<eps_ $comp>]
219                }
220            )+
221
222            /// Raise `self` to `pow`
223            pub fn powf(self, pow: $inner) -> Self {
224                // power rule: d/dx [x^p] = p x^(p-1)
225                Self {
226                    real: self.real.powf(pow),
227                    $(
228                        [<eps_ $comp>]: self.[<eps_ $comp>] * pow * self.real.powf(pow - 1.),
229                    )+
230                }
231            }
232
233            /// Invert `self` (`1./self`)
234            pub fn invert(self) -> Self {
235                self.powf(-1.)
236            }
237
238            pub fn sin(self) -> Self {
239                let r = self.real.sin();
240                let dr = self.real.cos();
241
242                Self {
243                    real: r,
244                    $(
245                        [<eps_ $comp>]: self.[<eps_ $comp>] * dr,
246                    )+
247                }
248            }
249
250            pub fn cos(self) -> Self {
251                let r = self.real.cos();
252                let dr = -self.real.sin();
253
254                Self {
255                    real: r,
256                    $(
257                        [<eps_ $comp>]: self.[<eps_ $comp>] * dr,
258                    )+
259                }
260            }
261
262            pub fn tan(self) -> Self {
263                self.sin() / self.cos()
264            }
265        }
266
267        impl std::ops::Add<$inner> for $name {
268            type Output = Self;
269
270            fn add(mut self, other: $inner) -> Self::Output {
271                self.real += other;
272                self
273            }
274        }
275
276        impl std::ops::Add<$name> for $name {
277            type Output = Self;
278
279            fn add(mut self, other: Self) -> Self::Output {
280                self.real += other.real;
281                $(
282                    self.[<eps_ $comp>] += other.[<eps_ $comp>];
283                )+
284                self
285            }
286        }
287
288        impl std::ops::Mul<$inner> for $name {
289            type Output = Self;
290
291            fn mul(mut self, other: $inner) -> Self::Output {
292                self.real *= other;
293                $(
294                    self.[<eps_ $comp>] *= other;
295                )+
296                self
297            }
298        }
299
300        impl std::ops::Neg for $name {
301            type Output = Self;
302
303            fn neg(self) -> Self::Output {
304                self * -1.
305            }
306        }
307
308        impl std::ops::Mul<$name> for $name {
309            type Output = Self;
310
311            fn mul(self, other: Self) -> $name {
312                Self {
313                    real: self.real * other.real,
314                    $(
315                        [<eps_ $comp>]: self.[<eps_ $comp>] * other.real + other.[<eps_ $comp>] * self.real,
316                    )+
317                }
318            }
319        }
320
321        impl std::ops::Div<$name> for $name {
322            type Output = Self;
323
324            fn div(self, other: Self) -> Self::Output {
325                self * other.invert()
326            }
327        }
328
329        impl std::ops::Sub<$inner> for $name {
330            type Output = Self;
331
332            fn sub(self, other: $inner) -> Self::Output {
333                self + -other
334            }
335        }
336
337        impl std::ops::Sub<$name> for $name {
338            type Output = Self;
339
340            fn sub(self, other: $name) -> Self::Output {
341                self + -other
342            }
343        }
344
345        impl std::ops::Div<$inner> for $name {
346            type Output = Self;
347
348            fn div(self, other: $inner) -> Self::Output {
349                self * (1./other)
350            }
351        }
352
353        impl_reverse!{$name, Add, add}
354        impl_reverse!{$name, Sub, sub}
355        impl_reverse!{$name, Mul, mul}
356        impl_reverse!{$name, Div, div}
357        impl_inplace!{$name, AddAssign, add_assign, Add, add}
358        impl_inplace!{$name, SubAssign, sub_assign, Sub, sub}
359        impl_inplace!{$name, MulAssign, mul_assign, Mul, mul}
360        impl_inplace!{$name, DivAssign, div_assign, Div, div}
361
362        impl $crate::Numerical for $name {
363            fn powf(self, pow: f64) -> Self {
364                $name::powf(self, pow as $inner)
365            }
366
367            fn invert(self) -> Self {
368                $name::invert(self)
369            }
370
371            fn sin(self) -> Self {
372                $name::sin(self)
373            }
374
375            fn cos(self) -> Self {
376                $name::cos(self)
377            }
378
379            fn tan(self) -> Self {
380                $name::tan(self)
381            }
382        }
383
384    } }
385}
386
387#[cfg(any(test, doc))]
388/// # Sample type
389///
390pub mod sample {
391    //! As all types are generated at compile time using [`make_dual`](crate::make_dual), this module serves to show an example generated dual type.
392    //!
393    //! The type is called `SampleXYZ` and has the fields (components) `x`, `y` and `z`. Function names such as `eps_x` are generated based on the names of the components.
394    crate::make_dual! { SampleXYZ: f32, x, y, z }
395}
396
397#[cfg(test)]
398mod tests {
399    use super::sample::SampleXYZ;
400
401    #[test]
402    fn test_const_ops() {
403        assert_eq!(SampleXYZ::x(5.) + 3., SampleXYZ::x(8.));
404        assert_eq!(SampleXYZ::y(5.) - 3., SampleXYZ::y(2.));
405        assert_eq!(3. - SampleXYZ::y(5.), SampleXYZ::eps_y(-2., -1.));
406        assert_eq!(
407            SampleXYZ::z(2.) * 3.,
408            SampleXYZ {
409                real: 6.,
410                eps_x: 0.,
411                eps_y: 0.,
412                eps_z: 3.
413            }
414        );
415        assert_eq!(
416            SampleXYZ::x(10.).powf(2.),
417            SampleXYZ {
418                real: 100.,
419                eps_x: 20.,
420                eps_y: 0.,
421                eps_z: 0.,
422            }
423        );
424        assert_eq!(
425            SampleXYZ::y(10.).invert(),
426            SampleXYZ {
427                real: 0.1,
428                eps_x: 0.,
429                eps_y: -0.01,
430                eps_z: 0.,
431            }
432        );
433
434        assert_eq!(
435            format!("{}", SampleXYZ::y(10.).invert()),
436            "0.1 - 0.01 eps_y"
437        );
438
439        let mut v = SampleXYZ::x(3.);
440        v += 7.;
441        assert_eq!(v, SampleXYZ::x(10.));
442    }
443
444    #[test]
445    fn test_dist() {
446        let x = SampleXYZ::x(1.);
447        assert_eq!((x + 1.) * (x + 1.), x * x + 2. * x + 1.);
448        assert_eq!((x + 1.) * (x + 1.), (x + 1.).powf(2.),);
449        assert_eq!((x + 1.) * (x - 1.), x * x - 1.);
450    }
451
452    #[test]
453    fn test_diff() {
454        let x = SampleXYZ::x(1.);
455        let z = x * x;
456        assert_eq!(z.d_dx(), 2.);
457    }
458
459    #[test]
460    fn test_trig() {
461        let x = SampleXYZ::x(0.);
462        assert_eq!(x.sin(), SampleXYZ::eps_x(0., 1.),);
463        assert_eq!(x.cos(), SampleXYZ::eps_x(1., 0.),);
464        assert_eq!(x.tan(), SampleXYZ::eps_x(0., 1.),);
465    }
466}