easy_ml/
numeric.rs

1/*!
2 * Numerical type definitions.
3 *
4 * `Numeric` together with `where for<'a> &'a T: NumericRef<T>`
5 * expresses the operations in [`NumericByValue`] for
6 * all 4 combinations of by value and by reference. [`Numeric`]
7 * additionally adds some additional constraints only needed by value on an implementing
8 * type such as `PartialOrd`, [`ZeroOne`] and
9 * [`FromUsize`].
10 *
11 * For additional operations for real valued numbers see [Real](crate::numeric::extra::Real)
12 */
13
14use std::cmp::PartialOrd;
15use std::fmt::Debug;
16use std::iter::Sum;
17use std::marker::Sized;
18use std::num::{Saturating, Wrapping};
19use std::ops::Add;
20use std::ops::Div;
21use std::ops::Mul;
22use std::ops::Neg;
23use std::ops::Sub;
24
25/**
26 * A trait defining what a numeric type is in terms of by value
27 * numerical operations matrices need their types to support for
28 * math operations.
29 *
30 * The requirements are Add, Sub, Mul, Div, Neg and Sized. Note that
31 * unsigned integers do not implement Neg unless they are wrapped by
32 * [Wrapping].
33 */
34pub trait NumericByValue<Rhs = Self, Output = Self>:
35    Add<Rhs, Output = Output>
36    + Sub<Rhs, Output = Output>
37    + Mul<Rhs, Output = Output>
38    + Div<Rhs, Output = Output>
39    + Neg<Output = Output>
40    + Sized
41{
42}
43
44/**
45 * Anything which implements all the super traits will automatically implement this trait too.
46 * This covers primitives such as f32, f64, signed integers and
47 * [Wrapped unsigned integers](std::num::Wrapping),
48 * [Saturating unsigned integers](std::num::Saturating),
49 * as well as [Traces](super::differentiation::Trace) and
50 * [Records](super::differentiation::Record) of those types.
51 *
52 * It will not include Matrix because Matrix does not implement Div.
53 * Similarly, unwrapped unsigned integers do not implement Neg so are not included.
54 */
55impl<T, Rhs, Output> NumericByValue<Rhs, Output> for T where
56    // Div is first here because Matrix does not implement it.
57    // if Add, Sub or Mul are first the rust compiler gets stuck
58    // in an infinite loop considering arbitarily nested matrix
59    // types, even though any level of nested Matrix types will
60    // never implement Div so shouldn't be considered for
61    // implementing NumericByValue
62    T: Div<Rhs, Output = Output>
63        + Add<Rhs, Output = Output>
64        + Sub<Rhs, Output = Output>
65        + Mul<Rhs, Output = Output>
66        + Neg<Output = Output>
67        + Sized
68{
69}
70
71/**
72 * The trait to define `&T op T` and `&T op &T` versions for NumericByValue
73 * based off the MIT/Apache 2.0 licensed code from num-traits 0.2.10:
74 *
75 * **This trait is not ever used directly for users of this library. You
76 * don't need to deal with it unless
77 * [implementing custom numeric types](super::using_custom_types)
78 * and even then it will be implemented automatically.**
79 *
80 * - [http://opensource.org/licenses/MIT](http://opensource.org/licenses/MIT)
81 * - [https://docs.rs/num-traits/0.2.10/src/num_traits/lib.rs.html#112](https://docs.rs/num-traits/0.2.10/src/num_traits/lib.rs.html#112)
82 *
83 * The trick is that all types implementing this trait will be references,
84 * so the first constraint expresses some &T which can be operated on with
85 * some right hand side type T to yield a value of type T.
86 *
87 * In a similar way the second constraint expresses `&T op &T -> T` operations
88 */
89pub trait NumericRef<T>:
90    // &T op T -> T
91    NumericByValue<T, T>
92    // &T op &T -> T
93    + for<'a> NumericByValue<&'a T, T> {}
94
95/**
96 * Anything which implements all the super traits will automatically implement this trait too.
97 * This covers primitives such as `&f32`, `&f64`, ie a type like `&u8` is `NumericRef<u8>`,
98 * as well as [Traces](super::differentiation::Trace) and
99 * [Records](super::differentiation::Record) of those types.
100 */
101impl<RefT, T> NumericRef<T> for RefT where
102    RefT: NumericByValue<T, T> + for<'a> NumericByValue<&'a T, T>
103{
104}
105
106/**
107 * A general purpose numeric trait that defines all the behaviour numerical
108 * matrices need their types to support for math operations.
109 *
110 * This trait extends the constraints in [NumericByValue]
111 * to types which also support the operations with a right hand side type
112 * by reference, and adds some additional constraints needed only
113 * by value on types.
114 *
115 * When used together with [NumericRef] this
116 * expresses all 4 by value and by reference combinations for the
117 * operations using the following syntax:
118 *
119 * ```ignore
120 * fn function_name<T: Numeric>()
121 * where for<'a> &'a T: NumericRef<T> {
122 *
123 * }
124 * ```
125 *
126 * This pair of constraints is used nearly everywhere some numeric
127 * type is needed, so although this trait does not require reference
128 * type methods by itself, in practise you won't be able to call many
129 * functions in this library with a numeric type that doesn't.
130 */
131pub trait Numeric:
132    // T op T -> T
133    NumericByValue
134    // T op &T -> T
135    + for<'a> NumericByValue<&'a Self>
136    + Clone
137    + ZeroOne
138    + FromUsize
139    + Sum
140    + PartialOrd
141    + Debug {}
142
143/**
144 * All types implemeting the operations in NumericByValue with a right hand
145 * side type by reference are Numeric.
146 *
147 * This covers primitives such as f32, f64, signed integers and
148 * [Wrapped unsigned integers](std::num::Wrapping),
149 * [Saturating unsigned integers](std::num::Saturating),
150 * as well as [Traces](super::differentiation::Trace) and
151 * [Records](super::differentiation::Record) of those types.
152 */
153impl<T> Numeric for T where
154    T: NumericByValue
155        + for<'a> NumericByValue<&'a T>
156        + Clone
157        + ZeroOne
158        + FromUsize
159        + Sum
160        + PartialOrd
161        + Debug
162{
163}
164
165/**
166 * A trait defining how to obtain 0 and 1 for every implementing type.
167 *
168 * The boilerplate implementations for primitives is performed with a macro.
169 * If a primitive type is missing from this list, please open an issue to add it in.
170 */
171pub trait ZeroOne: Sized {
172    fn zero() -> Self;
173    fn one() -> Self;
174}
175
176impl<T: ZeroOne> ZeroOne for Wrapping<T> {
177    #[inline]
178    fn zero() -> Wrapping<T> {
179        Wrapping(T::zero())
180    }
181    #[inline]
182    fn one() -> Wrapping<T> {
183        Wrapping(T::one())
184    }
185}
186
187impl<T: ZeroOne> ZeroOne for Saturating<T> {
188    #[inline]
189    fn zero() -> Saturating<T> {
190        Saturating(T::zero())
191    }
192    #[inline]
193    fn one() -> Saturating<T> {
194        Saturating(T::one())
195    }
196}
197
198macro_rules! zero_one_integral {
199    ($T:ty) => {
200        impl ZeroOne for $T {
201            #[inline]
202            fn zero() -> $T {
203                0
204            }
205            #[inline]
206            fn one() -> $T {
207                1
208            }
209        }
210    };
211}
212
213macro_rules! zero_one_float {
214    ($T:ty) => {
215        impl ZeroOne for $T {
216            #[inline]
217            fn zero() -> $T {
218                0.0
219            }
220            #[inline]
221            fn one() -> $T {
222                1.0
223            }
224        }
225    };
226}
227
228zero_one_integral!(u8);
229zero_one_integral!(i8);
230zero_one_integral!(u16);
231zero_one_integral!(i16);
232zero_one_integral!(u32);
233zero_one_integral!(i32);
234zero_one_integral!(u64);
235zero_one_integral!(i64);
236zero_one_integral!(u128);
237zero_one_integral!(i128);
238zero_one_float!(f32);
239zero_one_float!(f64);
240zero_one_integral!(usize);
241zero_one_integral!(isize);
242
243/**
244 * Specifies how to obtain an instance of this numeric type
245 * equal to the usize primitive. If the number is too large to
246 * represent in this type, `None` should be returned instead.
247 *
248 * The boilerplate implementations for primitives is performed with a macro.
249 * If a primitive type is missing from this list, please open an issue to add it in.
250 */
251pub trait FromUsize: Sized {
252    fn from_usize(n: usize) -> Option<Self>;
253}
254
255impl<T: FromUsize> FromUsize for Wrapping<T> {
256    fn from_usize(n: usize) -> Option<Wrapping<T>> {
257        Some(Wrapping(T::from_usize(n)?))
258    }
259}
260
261impl<T: FromUsize> FromUsize for Saturating<T> {
262    fn from_usize(n: usize) -> Option<Saturating<T>> {
263        Some(Saturating(T::from_usize(n)?))
264    }
265}
266
267macro_rules! from_usize_integral {
268    ($T:ty) => {
269        impl FromUsize for $T {
270            #[inline]
271            fn from_usize(n: usize) -> Option<$T> {
272                if n <= (<$T>::MAX as usize) {
273                    Some(n as $T)
274                } else {
275                    None
276                }
277            }
278        }
279    };
280}
281
282macro_rules! from_usize_float {
283    ($T:ty) => {
284        impl FromUsize for $T {
285            #[inline]
286            fn from_usize(n: usize) -> Option<$T> {
287                Some(n as $T)
288            }
289        }
290    };
291}
292
293from_usize_integral!(u8);
294from_usize_integral!(i8);
295from_usize_integral!(u16);
296from_usize_integral!(i16);
297from_usize_integral!(u32);
298from_usize_integral!(i32);
299from_usize_integral!(u64);
300from_usize_integral!(i64);
301from_usize_integral!(u128);
302from_usize_integral!(i128);
303from_usize_float!(f32);
304from_usize_float!(f64);
305from_usize_integral!(usize);
306from_usize_integral!(isize);
307
308/**
309 * Additional traits for more complex numerical operations on real numbers.
310 */
311pub mod extra {
312    use crate::numeric::{Numeric, NumericByValue};
313
314    /**
315     * A type which can be square rooted.
316     *
317     * This is implemented by `f32` and `f64` by value and by reference, as well as
318     * [Traces](super::super::differentiation::Trace)
319     * and [Records](super::super::differentiation::Record) of these.
320     */
321    pub trait Sqrt {
322        type Output;
323        fn sqrt(self) -> Self::Output;
324    }
325
326    macro_rules! sqrt_float {
327        ($T:ty) => {
328            impl Sqrt for $T {
329                type Output = $T;
330                #[inline]
331                fn sqrt(self) -> Self::Output {
332                    self.sqrt()
333                }
334            }
335            impl Sqrt for &$T {
336                type Output = $T;
337                #[inline]
338                fn sqrt(self) -> Self::Output {
339                    self.clone().sqrt()
340                }
341            }
342        };
343    }
344
345    sqrt_float!(f32);
346    sqrt_float!(f64);
347
348    /**
349     * A type which can compute e^self.
350     *
351     * This is implemented by `f32` and `f64` by value and by reference, as well as
352     * [Traces](super::super::differentiation::Trace)
353     * and [Records](super::super::differentiation::Record) of these.
354     */
355    pub trait Exp {
356        type Output;
357        fn exp(self) -> Self::Output;
358    }
359
360    macro_rules! exp_float {
361        ($T:ty) => {
362            impl Exp for $T {
363                type Output = $T;
364                #[inline]
365                fn exp(self) -> Self::Output {
366                    self.exp()
367                }
368            }
369            impl Exp for &$T {
370                type Output = $T;
371                #[inline]
372                fn exp(self) -> Self::Output {
373                    self.clone().exp()
374                }
375            }
376        };
377    }
378
379    exp_float!(f32);
380    exp_float!(f64);
381
382    /**
383     * A type which can compute self^rhs.
384     *
385     * This is implemented by `f32` and `f64` for all combinations of
386     * by value and by reference, as well as
387     * [Traces](super::super::differentiation::Trace)
388     * and [Records](super::super::differentiation::Record) of these.
389     *
390     * The Trace and Record implementations also implement versions with the other
391     * argument being a raw `f32` or `f64`, for convenience.
392     */
393    pub trait Pow<Rhs = Self> {
394        type Output;
395        fn pow(self, rhs: Rhs) -> Self::Output;
396    }
397
398    macro_rules! pow_float {
399        ($T:ty) => {
400            // T ^ T
401            impl Pow<$T> for $T {
402                type Output = $T;
403                #[inline]
404                fn pow(self, rhs: Self) -> Self::Output {
405                    self.powf(rhs)
406                }
407            }
408            // T ^ &T
409            impl<'a> Pow<&'a $T> for $T {
410                type Output = $T;
411                #[inline]
412                fn pow(self, rhs: &Self) -> Self::Output {
413                    self.powf(rhs.clone())
414                }
415            }
416            // &T ^ T
417            impl<'a> Pow<$T> for &'a $T {
418                type Output = $T;
419                #[inline]
420                fn pow(self, rhs: $T) -> Self::Output {
421                    self.powf(rhs)
422                }
423            }
424            // &T ^ &T
425            impl<'a, 'b> Pow<&'b $T> for &'a $T {
426                type Output = $T;
427                #[inline]
428                fn pow(self, rhs: &$T) -> Self::Output {
429                    self.powf(rhs.clone())
430                }
431            }
432        };
433    }
434
435    pow_float!(f32);
436    pow_float!(f64);
437
438    /**
439     * A type which can represent Pi.
440     */
441    pub trait Pi {
442        fn pi() -> Self;
443    }
444
445    impl Pi for f32 {
446        fn pi() -> f32 {
447            std::f32::consts::PI
448        }
449    }
450
451    impl Pi for f64 {
452        fn pi() -> f64 {
453            std::f64::consts::PI
454        }
455    }
456
457    /**
458     * A type which can compute the natural logarithm of itself: ln(self).
459     *
460     * This is implemented by `f32` and `f64` by value and by reference, as well as
461     * [Traces](super::super::differentiation::Trace)
462     * and [Records](super::super::differentiation::Record) of these.
463     */
464    pub trait Ln {
465        type Output;
466        fn ln(self) -> Self::Output;
467    }
468
469    macro_rules! ln_float {
470        ($T:ty) => {
471            impl Ln for $T {
472                type Output = $T;
473                #[inline]
474                fn ln(self) -> Self::Output {
475                    self.ln()
476                }
477            }
478            impl Ln for &$T {
479                type Output = $T;
480                #[inline]
481                fn ln(self) -> Self::Output {
482                    self.clone().ln()
483                }
484            }
485        };
486    }
487
488    ln_float!(f32);
489    ln_float!(f64);
490
491    /**
492     * A type which can compute the sine of itself: sin(self)
493     *
494     * This is implemented by `f32` and `f64` by value and by reference, as well as
495     * [Traces](super::super::differentiation::Trace)
496     * and [Records](super::super::differentiation::Record) of these.
497     */
498    pub trait Sin {
499        type Output;
500        fn sin(self) -> Self::Output;
501    }
502
503    macro_rules! sin_float {
504        ($T:ty) => {
505            impl Sin for $T {
506                type Output = $T;
507                #[inline]
508                fn sin(self) -> Self::Output {
509                    self.sin()
510                }
511            }
512            impl Sin for &$T {
513                type Output = $T;
514                #[inline]
515                fn sin(self) -> Self::Output {
516                    self.clone().sin()
517                }
518            }
519        };
520    }
521
522    sin_float!(f32);
523    sin_float!(f64);
524
525    /**
526     * A type which can compute the cosine of itself: cos(self)
527     *
528     * This is implemented by `f32` and `f64` by value and by reference, as well as
529     * [Traces](super::super::differentiation::Trace)
530     * and [Records](super::super::differentiation::Record) of these.
531     */
532    pub trait Cos {
533        type Output;
534        fn cos(self) -> Self::Output;
535    }
536
537    macro_rules! cos_float {
538        ($T:ty) => {
539            impl Cos for $T {
540                type Output = $T;
541                #[inline]
542                fn cos(self) -> Self::Output {
543                    self.cos()
544                }
545            }
546            impl Cos for &$T {
547                type Output = $T;
548                #[inline]
549                fn cos(self) -> Self::Output {
550                    self.clone().cos()
551                }
552            }
553        };
554    }
555
556    cos_float!(f32);
557    cos_float!(f64);
558
559    /**
560     * A trait defining what a real number type is in terms of by value
561     * numerical operations needed on top of operations defined by Numeric
562     * for some functions.
563     *
564     * The requirements on top of [Numeric] are Sqrt, Exp, Pow, Ln, Sin, Cos and Sized.
565     */
566    pub trait RealByValue<Rhs = Self, Output = Self>:
567        Sqrt<Output = Output>
568        + Exp<Output = Output>
569        + Pow<Rhs, Output = Output>
570        + Ln<Output = Output>
571        + Sin<Output = Output>
572        + Cos<Output = Output>
573        + Sized
574        + NumericByValue<Rhs, Output>
575    {
576    }
577
578    /**
579     * Anything which implements all the super traits will automatically implement this trait too.
580     * This covers primitives such as f32 & f64 as well as
581     * [Traces](super::super::differentiation::Trace) and
582     * [Records](super::super::differentiation::Record) of those types.
583     */
584    impl<T, Rhs, Output> RealByValue<Rhs, Output> for T where
585        T: Sqrt<Output = Output>
586            + Exp<Output = Output>
587            + Pow<Rhs, Output = Output>
588            + Ln<Output = Output>
589            + Sin<Output = Output>
590            + Cos<Output = Output>
591            + Sized
592            + NumericByValue<Rhs, Output>
593    {
594    }
595
596    /**
597     * The trait to define `&T op T` and `&T op &T` versions for RealByValue
598     * based off the MIT/Apache 2.0 licensed code from num-traits 0.2.10:
599     *
600     * **This trait is not ever used directly for users of this library. You
601     * don't need to deal with it unless
602     * [implementing custom numeric types](super::super::using_custom_types)
603     * and even then it will be implemented automatically.**
604     *
605     * - [http://opensource.org/licenses/MIT](http://opensource.org/licenses/MIT)
606     * - [https://docs.rs/num-traits/0.2.10/src/num_traits/lib.rs.html#112](https://docs.rs/num-traits/0.2.10/src/num_traits/lib.rs.html#112)
607     *
608     * The trick is that all types implementing this trait will be references,
609     * so the first constraint expresses some &T which can be operated on with
610     * some right hand side type T to yield a value of type T.
611     *
612     * In a similar way the second constraint expresses `&T op &T -> T` operations
613     */
614    pub trait RealRef<T>:
615    // &T op T -> T
616    RealByValue<T, T>
617    // &T op &T -> T
618    + for<'a> RealByValue<&'a T, T> {}
619
620    /**
621     * Anything which implements all the super traits will automatically implement this trait too.
622     * This covers primitives such as `&f32` & `&f64`, ie a type like `&f64` is `RealRef<&f64>`
623     * as well as [Traces](super::super::differentiation::Trace) and
624     * [Records](super::super::differentiation::Record) of those types.
625     */
626    impl<RefT, T> RealRef<T> for RefT where RefT: RealByValue<T, T> + for<'a> RealByValue<&'a T, T> {}
627
628    /**
629     * A general purpose extension to the numeric trait that adds many operations needed
630     * for more complex math operations.
631     *
632     * This trait extends the constraints in [RealByValue]
633     * to types which also support the operations with a right hand side type
634     * by reference, and adds some additional constraints needed only
635     * by value on types.
636     *
637     * When used together with [RealRef] this
638     * expresses all 4 by value and by reference combinations for the
639     * operations using the following syntax:
640     *
641     * ```ignore
642     * fn function_name<T: Real>()
643     * where for<'a> &'a T: RealRef<T> {
644     *
645     * }
646     * ```
647     *
648     * This pair of constraints is used where any real number type is needed, so although this
649     * trait does not require reference type methods by itself, in practise you won’t be able to
650     * call many functions in this library with a real type that doesn’t.
651     *
652     * In version 2.0 of Easy ML it now inherits from [Numeric] directly, old code depending on a
653     * previous version of Easy ML that also specified the Numeric traits such as:
654     *
655     * ```ignore
656     * fn function_name<T: Numeric + Real>()
657     * where for<'a> &'a T: NumericRef<T> + RealRef<T> {
658     *
659     * }
660     * ```
661     *
662     * can be updated when using Easy ML 2.0 or later to the following:
663     *
664     * ```ignore
665     * fn function_name<T: Real>()
666     * where for<'a> &'a T: RealRef<T> {
667     *
668     * }
669     * ```
670     */
671    pub trait Real:
672    // T op T -> T
673    RealByValue
674    // T op &T -> T
675    + for<'a> RealByValue<&'a Self>
676    + Pi
677    + Numeric {}
678
679    /**
680     * All types implemeting the operations in RealByValue with a right hand
681     * side type by reference are Real.
682     *
683     * This covers primitives such as f32 & f64 as well as
684     * [Traces](super::super::differentiation::Trace) and
685     * [Records](super::super::differentiation::Record) of those types.
686     */
687    impl<T> Real for T where T: RealByValue + for<'a> RealByValue<&'a T> + Pi + Numeric {}
688}