easy_ml/numeric.rs
1/*!
2 * Numerical type definitions.
3 *
4 * `Numeric` together with `where for<'a> &'a T: NumericRef<T>`
5 * expresses the operations in [`NumericByValue`] for
6 * all 4 combinations of by value and by reference. [`Numeric`]
7 * additionally adds some additional constraints only needed by value on an implementing
8 * type such as `PartialOrd`, [`ZeroOne`] and
9 * [`FromUsize`].
10 *
11 * For additional operations for real valued numbers see [Real](crate::numeric::extra::Real)
12 */
13
14use std::cmp::PartialOrd;
15use std::fmt::Debug;
16use std::iter::Sum;
17use std::marker::Sized;
18use std::num::{Saturating, Wrapping};
19use std::ops::Add;
20use std::ops::Div;
21use std::ops::Mul;
22use std::ops::Neg;
23use std::ops::Sub;
24
25/**
26 * A trait defining what a numeric type is in terms of by value
27 * numerical operations matrices need their types to support for
28 * math operations.
29 *
30 * The requirements are Add, Sub, Mul, Div, Neg and Sized. Note that
31 * unsigned integers do not implement Neg unless they are wrapped by
32 * [Wrapping].
33 */
34pub trait NumericByValue<Rhs = Self, Output = Self>:
35 Add<Rhs, Output = Output>
36 + Sub<Rhs, Output = Output>
37 + Mul<Rhs, Output = Output>
38 + Div<Rhs, Output = Output>
39 + Neg<Output = Output>
40 + Sized
41{
42}
43
44/**
45 * Anything which implements all the super traits will automatically implement this trait too.
46 * This covers primitives such as f32, f64, signed integers and
47 * [Wrapped unsigned integers](std::num::Wrapping),
48 * [Saturating unsigned integers](std::num::Saturating),
49 * as well as [Traces](super::differentiation::Trace) and
50 * [Records](super::differentiation::Record) of those types.
51 *
52 * It will not include Matrix because Matrix does not implement Div.
53 * Similarly, unwrapped unsigned integers do not implement Neg so are not included.
54 */
55impl<T, Rhs, Output> NumericByValue<Rhs, Output> for T where
56 // Div is first here because Matrix does not implement it.
57 // if Add, Sub or Mul are first the rust compiler gets stuck
58 // in an infinite loop considering arbitarily nested matrix
59 // types, even though any level of nested Matrix types will
60 // never implement Div so shouldn't be considered for
61 // implementing NumericByValue
62 T: Div<Rhs, Output = Output>
63 + Add<Rhs, Output = Output>
64 + Sub<Rhs, Output = Output>
65 + Mul<Rhs, Output = Output>
66 + Neg<Output = Output>
67 + Sized
68{
69}
70
71/**
72 * The trait to define `&T op T` and `&T op &T` versions for NumericByValue
73 * based off the MIT/Apache 2.0 licensed code from num-traits 0.2.10:
74 *
75 * **This trait is not ever used directly for users of this library. You
76 * don't need to deal with it unless
77 * [implementing custom numeric types](super::using_custom_types)
78 * and even then it will be implemented automatically.**
79 *
80 * - [http://opensource.org/licenses/MIT](http://opensource.org/licenses/MIT)
81 * - [https://docs.rs/num-traits/0.2.10/src/num_traits/lib.rs.html#112](https://docs.rs/num-traits/0.2.10/src/num_traits/lib.rs.html#112)
82 *
83 * The trick is that all types implementing this trait will be references,
84 * so the first constraint expresses some &T which can be operated on with
85 * some right hand side type T to yield a value of type T.
86 *
87 * In a similar way the second constraint expresses `&T op &T -> T` operations
88 */
89pub trait NumericRef<T>:
90 // &T op T -> T
91 NumericByValue<T, T>
92 // &T op &T -> T
93 + for<'a> NumericByValue<&'a T, T> {}
94
95/**
96 * Anything which implements all the super traits will automatically implement this trait too.
97 * This covers primitives such as `&f32`, `&f64`, ie a type like `&u8` is `NumericRef<u8>`,
98 * as well as [Traces](super::differentiation::Trace) and
99 * [Records](super::differentiation::Record) of those types.
100 */
101impl<RefT, T> NumericRef<T> for RefT where
102 RefT: NumericByValue<T, T> + for<'a> NumericByValue<&'a T, T>
103{
104}
105
106/**
107 * A general purpose numeric trait that defines all the behaviour numerical
108 * matrices need their types to support for math operations.
109 *
110 * This trait extends the constraints in [NumericByValue]
111 * to types which also support the operations with a right hand side type
112 * by reference, and adds some additional constraints needed only
113 * by value on types.
114 *
115 * When used together with [NumericRef] this
116 * expresses all 4 by value and by reference combinations for the
117 * operations using the following syntax:
118 *
119 * ```ignore
120 * fn function_name<T: Numeric>()
121 * where for<'a> &'a T: NumericRef<T> {
122 *
123 * }
124 * ```
125 *
126 * This pair of constraints is used nearly everywhere some numeric
127 * type is needed, so although this trait does not require reference
128 * type methods by itself, in practise you won't be able to call many
129 * functions in this library with a numeric type that doesn't.
130 */
131pub trait Numeric:
132 // T op T -> T
133 NumericByValue
134 // T op &T -> T
135 + for<'a> NumericByValue<&'a Self>
136 + Clone
137 + ZeroOne
138 + FromUsize
139 + Sum
140 + PartialOrd
141 + Debug {}
142
143/**
144 * All types implemeting the operations in NumericByValue with a right hand
145 * side type by reference are Numeric.
146 *
147 * This covers primitives such as f32, f64, signed integers and
148 * [Wrapped unsigned integers](std::num::Wrapping),
149 * [Saturating unsigned integers](std::num::Saturating),
150 * as well as [Traces](super::differentiation::Trace) and
151 * [Records](super::differentiation::Record) of those types.
152 */
153impl<T> Numeric for T where
154 T: NumericByValue
155 + for<'a> NumericByValue<&'a T>
156 + Clone
157 + ZeroOne
158 + FromUsize
159 + Sum
160 + PartialOrd
161 + Debug
162{
163}
164
165/**
166 * A trait defining how to obtain 0 and 1 for every implementing type.
167 *
168 * The boilerplate implementations for primitives is performed with a macro.
169 * If a primitive type is missing from this list, please open an issue to add it in.
170 */
171pub trait ZeroOne: Sized {
172 fn zero() -> Self;
173 fn one() -> Self;
174}
175
176impl<T: ZeroOne> ZeroOne for Wrapping<T> {
177 #[inline]
178 fn zero() -> Wrapping<T> {
179 Wrapping(T::zero())
180 }
181 #[inline]
182 fn one() -> Wrapping<T> {
183 Wrapping(T::one())
184 }
185}
186
187impl<T: ZeroOne> ZeroOne for Saturating<T> {
188 #[inline]
189 fn zero() -> Saturating<T> {
190 Saturating(T::zero())
191 }
192 #[inline]
193 fn one() -> Saturating<T> {
194 Saturating(T::one())
195 }
196}
197
198macro_rules! zero_one_integral {
199 ($T:ty) => {
200 impl ZeroOne for $T {
201 #[inline]
202 fn zero() -> $T {
203 0
204 }
205 #[inline]
206 fn one() -> $T {
207 1
208 }
209 }
210 };
211}
212
213macro_rules! zero_one_float {
214 ($T:ty) => {
215 impl ZeroOne for $T {
216 #[inline]
217 fn zero() -> $T {
218 0.0
219 }
220 #[inline]
221 fn one() -> $T {
222 1.0
223 }
224 }
225 };
226}
227
228zero_one_integral!(u8);
229zero_one_integral!(i8);
230zero_one_integral!(u16);
231zero_one_integral!(i16);
232zero_one_integral!(u32);
233zero_one_integral!(i32);
234zero_one_integral!(u64);
235zero_one_integral!(i64);
236zero_one_integral!(u128);
237zero_one_integral!(i128);
238zero_one_float!(f32);
239zero_one_float!(f64);
240zero_one_integral!(usize);
241zero_one_integral!(isize);
242
243/**
244 * Specifies how to obtain an instance of this numeric type
245 * equal to the usize primitive. If the number is too large to
246 * represent in this type, `None` should be returned instead.
247 *
248 * The boilerplate implementations for primitives is performed with a macro.
249 * If a primitive type is missing from this list, please open an issue to add it in.
250 */
251pub trait FromUsize: Sized {
252 fn from_usize(n: usize) -> Option<Self>;
253}
254
255impl<T: FromUsize> FromUsize for Wrapping<T> {
256 fn from_usize(n: usize) -> Option<Wrapping<T>> {
257 Some(Wrapping(T::from_usize(n)?))
258 }
259}
260
261impl<T: FromUsize> FromUsize for Saturating<T> {
262 fn from_usize(n: usize) -> Option<Saturating<T>> {
263 Some(Saturating(T::from_usize(n)?))
264 }
265}
266
267macro_rules! from_usize_integral {
268 ($T:ty) => {
269 impl FromUsize for $T {
270 #[inline]
271 fn from_usize(n: usize) -> Option<$T> {
272 if n <= (<$T>::MAX as usize) {
273 Some(n as $T)
274 } else {
275 None
276 }
277 }
278 }
279 };
280}
281
282macro_rules! from_usize_float {
283 ($T:ty) => {
284 impl FromUsize for $T {
285 #[inline]
286 fn from_usize(n: usize) -> Option<$T> {
287 Some(n as $T)
288 }
289 }
290 };
291}
292
293from_usize_integral!(u8);
294from_usize_integral!(i8);
295from_usize_integral!(u16);
296from_usize_integral!(i16);
297from_usize_integral!(u32);
298from_usize_integral!(i32);
299from_usize_integral!(u64);
300from_usize_integral!(i64);
301from_usize_integral!(u128);
302from_usize_integral!(i128);
303from_usize_float!(f32);
304from_usize_float!(f64);
305from_usize_integral!(usize);
306from_usize_integral!(isize);
307
308/**
309 * Additional traits for more complex numerical operations on real numbers.
310 */
311pub mod extra {
312 use crate::numeric::{Numeric, NumericByValue};
313
314 /**
315 * A type which can be square rooted.
316 *
317 * This is implemented by `f32` and `f64` by value and by reference, as well as
318 * [Traces](super::super::differentiation::Trace)
319 * and [Records](super::super::differentiation::Record) of these.
320 */
321 pub trait Sqrt {
322 type Output;
323 fn sqrt(self) -> Self::Output;
324 }
325
326 macro_rules! sqrt_float {
327 ($T:ty) => {
328 impl Sqrt for $T {
329 type Output = $T;
330 #[inline]
331 fn sqrt(self) -> Self::Output {
332 self.sqrt()
333 }
334 }
335 impl Sqrt for &$T {
336 type Output = $T;
337 #[inline]
338 fn sqrt(self) -> Self::Output {
339 self.clone().sqrt()
340 }
341 }
342 };
343 }
344
345 sqrt_float!(f32);
346 sqrt_float!(f64);
347
348 /**
349 * A type which can compute e^self.
350 *
351 * This is implemented by `f32` and `f64` by value and by reference, as well as
352 * [Traces](super::super::differentiation::Trace)
353 * and [Records](super::super::differentiation::Record) of these.
354 */
355 pub trait Exp {
356 type Output;
357 fn exp(self) -> Self::Output;
358 }
359
360 macro_rules! exp_float {
361 ($T:ty) => {
362 impl Exp for $T {
363 type Output = $T;
364 #[inline]
365 fn exp(self) -> Self::Output {
366 self.exp()
367 }
368 }
369 impl Exp for &$T {
370 type Output = $T;
371 #[inline]
372 fn exp(self) -> Self::Output {
373 self.clone().exp()
374 }
375 }
376 };
377 }
378
379 exp_float!(f32);
380 exp_float!(f64);
381
382 /**
383 * A type which can compute self^rhs.
384 *
385 * This is implemented by `f32` and `f64` for all combinations of
386 * by value and by reference, as well as
387 * [Traces](super::super::differentiation::Trace)
388 * and [Records](super::super::differentiation::Record) of these.
389 *
390 * The Trace and Record implementations also implement versions with the other
391 * argument being a raw `f32` or `f64`, for convenience.
392 */
393 pub trait Pow<Rhs = Self> {
394 type Output;
395 fn pow(self, rhs: Rhs) -> Self::Output;
396 }
397
398 macro_rules! pow_float {
399 ($T:ty) => {
400 // T ^ T
401 impl Pow<$T> for $T {
402 type Output = $T;
403 #[inline]
404 fn pow(self, rhs: Self) -> Self::Output {
405 self.powf(rhs)
406 }
407 }
408 // T ^ &T
409 impl<'a> Pow<&'a $T> for $T {
410 type Output = $T;
411 #[inline]
412 fn pow(self, rhs: &Self) -> Self::Output {
413 self.powf(rhs.clone())
414 }
415 }
416 // &T ^ T
417 impl<'a> Pow<$T> for &'a $T {
418 type Output = $T;
419 #[inline]
420 fn pow(self, rhs: $T) -> Self::Output {
421 self.powf(rhs)
422 }
423 }
424 // &T ^ &T
425 impl<'a, 'b> Pow<&'b $T> for &'a $T {
426 type Output = $T;
427 #[inline]
428 fn pow(self, rhs: &$T) -> Self::Output {
429 self.powf(rhs.clone())
430 }
431 }
432 };
433 }
434
435 pow_float!(f32);
436 pow_float!(f64);
437
438 /**
439 * A type which can represent Pi.
440 */
441 pub trait Pi {
442 fn pi() -> Self;
443 }
444
445 impl Pi for f32 {
446 fn pi() -> f32 {
447 std::f32::consts::PI
448 }
449 }
450
451 impl Pi for f64 {
452 fn pi() -> f64 {
453 std::f64::consts::PI
454 }
455 }
456
457 /**
458 * A type which can compute the natural logarithm of itself: ln(self).
459 *
460 * This is implemented by `f32` and `f64` by value and by reference, as well as
461 * [Traces](super::super::differentiation::Trace)
462 * and [Records](super::super::differentiation::Record) of these.
463 */
464 pub trait Ln {
465 type Output;
466 fn ln(self) -> Self::Output;
467 }
468
469 macro_rules! ln_float {
470 ($T:ty) => {
471 impl Ln for $T {
472 type Output = $T;
473 #[inline]
474 fn ln(self) -> Self::Output {
475 self.ln()
476 }
477 }
478 impl Ln for &$T {
479 type Output = $T;
480 #[inline]
481 fn ln(self) -> Self::Output {
482 self.clone().ln()
483 }
484 }
485 };
486 }
487
488 ln_float!(f32);
489 ln_float!(f64);
490
491 /**
492 * A type which can compute the sine of itself: sin(self)
493 *
494 * This is implemented by `f32` and `f64` by value and by reference, as well as
495 * [Traces](super::super::differentiation::Trace)
496 * and [Records](super::super::differentiation::Record) of these.
497 */
498 pub trait Sin {
499 type Output;
500 fn sin(self) -> Self::Output;
501 }
502
503 macro_rules! sin_float {
504 ($T:ty) => {
505 impl Sin for $T {
506 type Output = $T;
507 #[inline]
508 fn sin(self) -> Self::Output {
509 self.sin()
510 }
511 }
512 impl Sin for &$T {
513 type Output = $T;
514 #[inline]
515 fn sin(self) -> Self::Output {
516 self.clone().sin()
517 }
518 }
519 };
520 }
521
522 sin_float!(f32);
523 sin_float!(f64);
524
525 /**
526 * A type which can compute the cosine of itself: cos(self)
527 *
528 * This is implemented by `f32` and `f64` by value and by reference, as well as
529 * [Traces](super::super::differentiation::Trace)
530 * and [Records](super::super::differentiation::Record) of these.
531 */
532 pub trait Cos {
533 type Output;
534 fn cos(self) -> Self::Output;
535 }
536
537 macro_rules! cos_float {
538 ($T:ty) => {
539 impl Cos for $T {
540 type Output = $T;
541 #[inline]
542 fn cos(self) -> Self::Output {
543 self.cos()
544 }
545 }
546 impl Cos for &$T {
547 type Output = $T;
548 #[inline]
549 fn cos(self) -> Self::Output {
550 self.clone().cos()
551 }
552 }
553 };
554 }
555
556 cos_float!(f32);
557 cos_float!(f64);
558
559 /**
560 * A trait defining what a real number type is in terms of by value
561 * numerical operations needed on top of operations defined by Numeric
562 * for some functions.
563 *
564 * The requirements on top of [Numeric] are Sqrt, Exp, Pow, Ln, Sin, Cos and Sized.
565 */
566 pub trait RealByValue<Rhs = Self, Output = Self>:
567 Sqrt<Output = Output>
568 + Exp<Output = Output>
569 + Pow<Rhs, Output = Output>
570 + Ln<Output = Output>
571 + Sin<Output = Output>
572 + Cos<Output = Output>
573 + Sized
574 + NumericByValue<Rhs, Output>
575 {
576 }
577
578 /**
579 * Anything which implements all the super traits will automatically implement this trait too.
580 * This covers primitives such as f32 & f64 as well as
581 * [Traces](super::super::differentiation::Trace) and
582 * [Records](super::super::differentiation::Record) of those types.
583 */
584 impl<T, Rhs, Output> RealByValue<Rhs, Output> for T where
585 T: Sqrt<Output = Output>
586 + Exp<Output = Output>
587 + Pow<Rhs, Output = Output>
588 + Ln<Output = Output>
589 + Sin<Output = Output>
590 + Cos<Output = Output>
591 + Sized
592 + NumericByValue<Rhs, Output>
593 {
594 }
595
596 /**
597 * The trait to define `&T op T` and `&T op &T` versions for RealByValue
598 * based off the MIT/Apache 2.0 licensed code from num-traits 0.2.10:
599 *
600 * **This trait is not ever used directly for users of this library. You
601 * don't need to deal with it unless
602 * [implementing custom numeric types](super::super::using_custom_types)
603 * and even then it will be implemented automatically.**
604 *
605 * - [http://opensource.org/licenses/MIT](http://opensource.org/licenses/MIT)
606 * - [https://docs.rs/num-traits/0.2.10/src/num_traits/lib.rs.html#112](https://docs.rs/num-traits/0.2.10/src/num_traits/lib.rs.html#112)
607 *
608 * The trick is that all types implementing this trait will be references,
609 * so the first constraint expresses some &T which can be operated on with
610 * some right hand side type T to yield a value of type T.
611 *
612 * In a similar way the second constraint expresses `&T op &T -> T` operations
613 */
614 pub trait RealRef<T>:
615 // &T op T -> T
616 RealByValue<T, T>
617 // &T op &T -> T
618 + for<'a> RealByValue<&'a T, T> {}
619
620 /**
621 * Anything which implements all the super traits will automatically implement this trait too.
622 * This covers primitives such as `&f32` & `&f64`, ie a type like `&f64` is `RealRef<&f64>`
623 * as well as [Traces](super::super::differentiation::Trace) and
624 * [Records](super::super::differentiation::Record) of those types.
625 */
626 impl<RefT, T> RealRef<T> for RefT where RefT: RealByValue<T, T> + for<'a> RealByValue<&'a T, T> {}
627
628 /**
629 * A general purpose extension to the numeric trait that adds many operations needed
630 * for more complex math operations.
631 *
632 * This trait extends the constraints in [RealByValue]
633 * to types which also support the operations with a right hand side type
634 * by reference, and adds some additional constraints needed only
635 * by value on types.
636 *
637 * When used together with [RealRef] this
638 * expresses all 4 by value and by reference combinations for the
639 * operations using the following syntax:
640 *
641 * ```ignore
642 * fn function_name<T: Real>()
643 * where for<'a> &'a T: RealRef<T> {
644 *
645 * }
646 * ```
647 *
648 * This pair of constraints is used where any real number type is needed, so although this
649 * trait does not require reference type methods by itself, in practise you won’t be able to
650 * call many functions in this library with a real type that doesn’t.
651 *
652 * In version 2.0 of Easy ML it now inherits from [Numeric] directly, old code depending on a
653 * previous version of Easy ML that also specified the Numeric traits such as:
654 *
655 * ```ignore
656 * fn function_name<T: Numeric + Real>()
657 * where for<'a> &'a T: NumericRef<T> + RealRef<T> {
658 *
659 * }
660 * ```
661 *
662 * can be updated when using Easy ML 2.0 or later to the following:
663 *
664 * ```ignore
665 * fn function_name<T: Real>()
666 * where for<'a> &'a T: RealRef<T> {
667 *
668 * }
669 * ```
670 */
671 pub trait Real:
672 // T op T -> T
673 RealByValue
674 // T op &T -> T
675 + for<'a> RealByValue<&'a Self>
676 + Pi
677 + Numeric {}
678
679 /**
680 * All types implemeting the operations in RealByValue with a right hand
681 * side type by reference are Real.
682 *
683 * This covers primitives such as f32 & f64 as well as
684 * [Traces](super::super::differentiation::Trace) and
685 * [Records](super::super::differentiation::Record) of those types.
686 */
687 impl<T> Real for T where T: RealByValue + for<'a> RealByValue<&'a T> + Pi + Numeric {}
688}