easy_ml/differentiation/
record_operations.rs

1/*!
2 * Operator implementations for Records.
3 *
4 * These implementations are written here but Rust docs will display them on the
5 * [Record] struct page.
6 *
7 * Records of any Numeric type (provided the type also implements the operations by reference
8 * as described in the [numeric](super::super::numeric) module) implement all the standard
9 * library traits for addition, subtraction, multiplication and division, so you can
10 * use the normal `+ - * /` operators as you can with normal number types. As a convenience,
11 * these operations can also be used with a Record on the left hand side and a the same type
12 * that the Record is generic over on the right hand side, so you can do
13 *
14 * ```
15 * use easy_ml::differentiation::{Record, WengertList};
16 * let list = WengertList::new();
17 * let x: Record<f32> = Record::variable(2.0, &list);
18 * let y: f32 = 2.0;
19 * let z: Record<f32> = x * y;
20 * assert_eq!(z.number, 4.0);
21 * ```
22 *
23 * or more succinctly
24 *
25 * ```
26 * use easy_ml::differentiation::{Record, WengertList};
27 * assert_eq!((Record::variable(2.0, &WengertList::new()) * 2.0).number, 4.0);
28 * ```
29 *
30 * Records of a [Real] type (provided the type also implements the operations by reference as
31 * described in the [numeric](super::super::numeric::extra) module) also implement
32 * all of those extra traits and operations. Note that to use a method defined in a trait
33 * you have to import the trait as well as have a type that implements it!
34 */
35
36use crate::differentiation::functions::{
37    Addition, Cosine, Division, Exponential, FunctionDerivative, Multiplication, NaturalLogarithm,
38    Power, Sine, SquareRoot, Subtraction, UnaryFunctionDerivative,
39};
40use crate::differentiation::{Primitive, Record, WengertList};
41use crate::numeric::extra::{Cos, Exp, Ln, Pi, Pow, Real, RealRef, Sin, Sqrt};
42use crate::numeric::{FromUsize, Numeric, NumericRef, ZeroOne};
43use std::cmp::Ordering;
44use std::iter::Sum;
45use std::ops::{Add, Div, Mul, Neg, Sub};
46
47/**
48 * A record is displayed by showing its number component.
49 */
50impl<'a, T: std::fmt::Display + Primitive> std::fmt::Display for Record<'a, T> {
51    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
52        write!(f, "{}", self.number)
53    }
54}
55
56/**
57 * Record implements ZeroOne by returning constants.
58 */
59impl<'a, T: Numeric + Primitive> ZeroOne for Record<'a, T> {
60    #[inline]
61    fn zero() -> Record<'a, T> {
62        Record::constant(T::zero())
63    }
64    #[inline]
65    fn one() -> Record<'a, T> {
66        Record::constant(T::one())
67    }
68}
69
70impl<'a, T: Numeric + Primitive> FromUsize for Record<'a, T> {
71    #[inline]
72    fn from_usize(n: usize) -> Option<Record<'a, T>> {
73        Some(Record::constant(T::from_usize(n)?))
74    }
75}
76
77/**
78 * Any record of a Cloneable type implements clone
79 */
80impl<'a, T: Clone + Primitive> Clone for Record<'a, T> {
81    #[inline]
82    fn clone(&self) -> Self {
83        Record {
84            number: self.number.clone(),
85            history: self.history,
86            index: self.index,
87        }
88    }
89}
90
91/**
92 * Any record of a Copy type implements Copy
93 */
94impl<'a, T: Copy + Primitive> Copy for Record<'a, T> {}
95
96/**
97 * Compares two record's referenced WengertLists.
98 *
99 * If either Record is missing a reference to a WengertList then
100 * this is trivially 'true', in so far as we will use the WengertList of
101 * the other one.
102 *
103 * If both records have a WengertList, then checks that the lists are
104 * the same.
105 */
106pub(crate) fn same_list<T: Primitive>(a: &Record<T>, b: &Record<T>) -> bool {
107    match (a.history, b.history) {
108        (None, None) => true,
109        (Some(_), None) => true,
110        (None, Some(_)) => true,
111        (Some(list_a), Some(list_b)) => same_lists(list_a, list_b),
112    }
113}
114
115/// Compares two WengertList references directly.
116pub(crate) fn same_lists<T: Primitive>(list_a: &WengertList<T>, list_b: &WengertList<T>) -> bool {
117    std::ptr::eq(list_a, list_b)
118}
119
120/// Compares two Options of WengertList references directly.
121pub(crate) fn are_same_list<T: Primitive>(
122    list_a: Option<&WengertList<T>>,
123    list_b: Option<&WengertList<T>>,
124) -> bool {
125    match (list_a, list_b) {
126        (None, None) => true,
127        (Some(_), None) => true,
128        (None, Some(_)) => true,
129        (Some(list_a), Some(list_b)) => same_lists(list_a, list_b),
130    }
131}
132
133/// Compares two Options of WengertList references directly, returning false when exactly one is
134/// a constant.
135pub(crate) fn are_exact_same_list<T: Primitive>(
136    list_a: Option<&WengertList<T>>,
137    list_b: Option<&WengertList<T>>,
138) -> bool {
139    match (list_a, list_b) {
140        (None, None) => true,
141        (Some(_), None) => false,
142        (None, Some(_)) => false,
143        (Some(list_a), Some(list_b)) => same_lists(list_a, list_b),
144    }
145}
146
147/**
148 * Addition for two records of the same type with both referenced and
149 * both using the same WengertList.
150 */
151impl<'a, 'l, 'r, T: Numeric + Primitive> Add<&'r Record<'a, T>> for &'l Record<'a, T>
152where
153    for<'t> &'t T: NumericRef<T>,
154{
155    type Output = Record<'a, T>;
156    #[track_caller]
157    #[inline]
158    fn add(self, rhs: &Record<'a, T>) -> Self::Output {
159        assert!(
160            same_list(self, rhs),
161            "Records must be using the same WengertList"
162        );
163        match (self.history, rhs.history) {
164            // If neither inputs have a WengertList then we don't need to record
165            // the computation graph at this point because neither are inputs to
166            // the overall function.
167            // eg f(x, y) = ((1 + 1) * x) + (2 * (1 + y)) needs the records
168            // for 2x + (2 * (1 + y)) to be stored, but we don't care about the derivatives
169            // for 1 + 1, because neither were inputs to f.
170            (None, None) => Record {
171                number: Addition::<T>::function(self.number.clone(), rhs.number.clone()),
172                history: None,
173                index: 0,
174            },
175            // If only one input has a WengertList treat the other as a constant
176            (Some(_), None) => self + &rhs.number,
177            (None, Some(_)) => rhs + &self.number,
178            (Some(history), Some(_)) => Record {
179                number: Addition::<T>::function(self.number.clone(), rhs.number.clone()),
180                history: Some(history),
181                index: history.append_binary(
182                    self.index,
183                    Addition::<T>::d_function_dx(self.number.clone(), rhs.number.clone()),
184                    rhs.index,
185                    Addition::<T>::d_function_dy(self.number.clone(), rhs.number.clone()),
186                ),
187            },
188        }
189    }
190}
191
192/**
193 * Addition for a record and a constant of the same type with both referenced.
194 */
195impl<'a, T: Numeric + Primitive> Add<&T> for &Record<'a, T>
196where
197    for<'t> &'t T: NumericRef<T>,
198{
199    type Output = Record<'a, T>;
200    #[track_caller]
201    #[inline]
202    fn add(self, rhs: &T) -> Self::Output {
203        match self.history {
204            None => Record {
205                number: Addition::<T>::function(self.number.clone(), rhs.clone()),
206                history: None,
207                index: 0,
208            },
209            Some(history) => Record {
210                number: Addition::<T>::function(self.number.clone(), rhs.clone()),
211                history: Some(history),
212                index: history.append_unary(
213                    self.index,
214                    Addition::<T>::d_function_dx(self.number.clone(), rhs.clone()),
215                ),
216            },
217        }
218    }
219}
220
221macro_rules! record_operator_impl_value_value {
222    (impl $op:tt for Record { fn $method:ident }) => {
223        /**
224         * Operation for two records of the same type.
225         */
226        impl<'a, T: Numeric + Primitive> $op for Record<'a, T>
227        where
228            for<'t> &'t T: NumericRef<T>,
229        {
230            type Output = Record<'a, T>;
231            #[track_caller]
232            #[inline]
233            fn $method(self, rhs: Record<'a, T>) -> Self::Output {
234                (&self).$method(&rhs)
235            }
236        }
237    };
238}
239
240macro_rules! record_operator_impl_value_reference {
241    (impl $op:tt for Record { fn $method:ident }) => {
242        /**
243         * Operation for two records of the same type with the right referenced.
244         */
245        impl<'a, T: Numeric + Primitive> $op<&Record<'a, T>> for Record<'a, T>
246        where
247            for<'t> &'t T: NumericRef<T>,
248        {
249            type Output = Record<'a, T>;
250            #[track_caller]
251            #[inline]
252            fn $method(self, rhs: &Record<'a, T>) -> Self::Output {
253                (&self).$method(rhs)
254            }
255        }
256    };
257}
258
259macro_rules! record_operator_impl_reference_value {
260    (impl $op:tt for Record { fn $method:ident }) => {
261        /**
262         * Operation for two records of the same type with the left referenced.
263         */
264        impl<'a, T: Numeric + Primitive> $op<Record<'a, T>> for &Record<'a, T>
265        where
266            for<'t> &'t T: NumericRef<T>,
267        {
268            type Output = Record<'a, T>;
269            #[track_caller]
270            #[inline]
271            fn $method(self, rhs: Record<'a, T>) -> Self::Output {
272                self.$method(&rhs)
273            }
274        }
275    };
276}
277
278record_operator_impl_value_value!(impl Add for Record { fn add });
279record_operator_impl_reference_value!(impl Add for Record { fn add });
280record_operator_impl_value_reference!(impl Add for Record { fn add });
281
282macro_rules! record_number_operator_impl_value_value {
283    (impl $op:tt for Record { fn $method:ident }) => {
284        /**
285         * Operation for a record and a constant of the same type.
286         */
287        impl<'a, T: Numeric + Primitive> $op<T> for Record<'a, T>
288        where
289            for<'t> &'t T: NumericRef<T>,
290        {
291            type Output = Record<'a, T>;
292            #[inline]
293            fn $method(self, rhs: T) -> Self::Output {
294                (&self).$method(&rhs)
295            }
296        }
297    };
298}
299
300macro_rules! record_number_operator_impl_value_reference {
301    (impl $op:tt for Record { fn $method:ident }) => {
302        /**
303         * Operation for a record and a constant of the same type with the right referenced.
304         */
305        impl<'a, T: Numeric + Primitive> $op<&T> for Record<'a, T>
306        where
307            for<'t> &'t T: NumericRef<T>,
308        {
309            type Output = Record<'a, T>;
310            #[inline]
311            fn $method(self, rhs: &T) -> Self::Output {
312                (&self).$method(rhs)
313            }
314        }
315    };
316}
317
318macro_rules! record_number_operator_impl_reference_value {
319    (impl $op:tt for Record { fn $method:ident }) => {
320        /**
321         * Operation for a record and a constant of the same type with the left referenced.
322         */
323        impl<'a, T: Numeric + Primitive> $op<T> for &Record<'a, T>
324        where
325            for<'t> &'t T: NumericRef<T>,
326        {
327            type Output = Record<'a, T>;
328            #[inline]
329            fn $method(self, rhs: T) -> Self::Output {
330                self.$method(&rhs)
331            }
332        }
333    };
334}
335
336record_number_operator_impl_value_value!(impl Add for Record { fn add });
337record_number_operator_impl_reference_value!(impl Add for Record { fn add });
338record_number_operator_impl_value_reference!(impl Add for Record { fn add });
339
340/**
341 * Multiplication for two records of the same type with both referenced and
342 * both using the same WengertList.
343 */
344impl<'a, 'l, 'r, T: Numeric + Primitive> Mul<&'r Record<'a, T>> for &'l Record<'a, T>
345where
346    for<'t> &'t T: NumericRef<T>,
347{
348    type Output = Record<'a, T>;
349    #[track_caller]
350    #[inline]
351    fn mul(self, rhs: &Record<'a, T>) -> Self::Output {
352        assert!(
353            same_list(self, rhs),
354            "Records must be using the same WengertList"
355        );
356        match (self.history, rhs.history) {
357            (None, None) => Record {
358                number: Multiplication::<T>::function(self.number.clone(), rhs.number.clone()),
359                history: None,
360                index: 0,
361            },
362            // If only one input has a WengertList treat the other as a constant
363            (Some(_), None) => self * &rhs.number,
364            (None, Some(_)) => rhs * &self.number,
365            (Some(history), Some(_)) => Record {
366                number: Multiplication::<T>::function(self.number.clone(), rhs.number.clone()),
367                history: Some(history),
368                index: history.append_binary(
369                    self.index,
370                    Multiplication::<T>::d_function_dx(self.number.clone(), rhs.number.clone()),
371                    rhs.index,
372                    Multiplication::<T>::d_function_dy(self.number.clone(), rhs.number.clone()),
373                ),
374            },
375        }
376    }
377}
378
379record_operator_impl_value_value!(impl Mul for Record { fn mul });
380record_operator_impl_reference_value!(impl Mul for Record { fn mul });
381record_operator_impl_value_reference!(impl Mul for Record { fn mul });
382
383/**
384 * Multiplication for a record and a constant of the same type with both referenced.
385 */
386impl<'a, T: Numeric + Primitive> Mul<&T> for &Record<'a, T>
387where
388    for<'t> &'t T: NumericRef<T>,
389{
390    type Output = Record<'a, T>;
391    #[track_caller]
392    #[inline]
393    fn mul(self, rhs: &T) -> Self::Output {
394        match self.history {
395            None => Record {
396                number: Multiplication::<T>::function(self.number.clone(), rhs.clone()),
397                history: None,
398                index: 0,
399            },
400            Some(history) => Record {
401                number: Multiplication::<T>::function(self.number.clone(), rhs.clone()),
402                history: Some(history),
403                index: history.append_unary(
404                    self.index,
405                    Multiplication::<T>::d_function_dx(self.number.clone(), rhs.clone()),
406                ),
407            },
408        }
409    }
410}
411
412record_number_operator_impl_value_value!(impl Mul for Record { fn mul });
413record_number_operator_impl_reference_value!(impl Mul for Record { fn mul });
414record_number_operator_impl_value_reference!(impl Mul for Record { fn mul });
415
416/**
417 * Subtraction for two records of the same type with both referenced and
418 * both using the same WengertList.
419 */
420impl<'a, 'l, 'r, T: Numeric + Primitive> Sub<&'r Record<'a, T>> for &'l Record<'a, T>
421where
422    for<'t> &'t T: NumericRef<T>,
423{
424    type Output = Record<'a, T>;
425    #[track_caller]
426    #[inline]
427    fn sub(self, rhs: &Record<'a, T>) -> Self::Output {
428        assert!(
429            same_list(self, rhs),
430            "Records must be using the same WengertList"
431        );
432        match (self.history, rhs.history) {
433            // If neither inputs have a WengertList then we don't need to record
434            // the computation graph at this point because neither are inputs to
435            // the overall function.
436            // eg f(x, y) = ((1 + 1) * x) + (2 * (1 + y)) needs the records
437            // for 2x + (2 * (1 + y)) to be stored, but we don't care about the derivatives
438            // for 1 + 1, because neither were inputs to f.
439            (None, None) => Record {
440                number: Subtraction::<T>::function(self.number.clone(), rhs.number.clone()),
441                history: None,
442                index: 0,
443            },
444            // If only one input has a WengertList treat the other as a constant
445            (Some(_), None) => self - &rhs.number,
446            // Record::constant can't be used here as that would cause an infinite loop,
447            // so use the swapped version of Sub
448            (None, Some(_)) => rhs.sub_swapped(self.number.clone()),
449            (Some(history), Some(_)) => Record {
450                number: Subtraction::<T>::function(self.number.clone(), rhs.number.clone()),
451                history: Some(history),
452                index: history.append_binary(
453                    self.index,
454                    Subtraction::<T>::d_function_dx(self.number.clone(), rhs.number.clone()),
455                    rhs.index,
456                    Subtraction::<T>::d_function_dy(self.number.clone(), rhs.number.clone()),
457                ),
458            },
459        }
460    }
461}
462
463record_operator_impl_value_value!(impl Sub for Record { fn sub });
464record_operator_impl_reference_value!(impl Sub for Record { fn sub });
465record_operator_impl_value_reference!(impl Sub for Record { fn sub });
466
467/**
468 * Subtraction for a record and a constant of the same type with both referenced.
469 */
470impl<'a, T: Numeric + Primitive> Sub<&T> for &Record<'a, T>
471where
472    for<'t> &'t T: NumericRef<T>,
473{
474    type Output = Record<'a, T>;
475    #[inline]
476    fn sub(self, rhs: &T) -> Self::Output {
477        match self.history {
478            None => Record {
479                number: Subtraction::<T>::function(self.number.clone(), rhs.clone()),
480                history: None,
481                index: 0,
482            },
483            Some(history) => Record {
484                number: Subtraction::<T>::function(self.number.clone(), rhs.clone()),
485                history: Some(history),
486                index: history.append_unary(
487                    self.index,
488                    Subtraction::<T>::d_function_dx(self.number.clone(), rhs.clone()),
489                ),
490            },
491        }
492    }
493}
494
495record_number_operator_impl_value_value!(impl Sub for Record { fn sub });
496record_number_operator_impl_reference_value!(impl Sub for Record { fn sub });
497record_number_operator_impl_value_reference!(impl Sub for Record { fn sub });
498
499/**
500 * A trait which defines subtraction and division with the arguments
501 * swapped around, ie 5.sub_swapped(7) would equal 2. This trait is
502 * only implemented for Records and constant operations.
503 *
504 * Addition and Multiplication are not included because argument order
505 * doesn't matter for those operations, so you can just swap the left and
506 * right and get the same result.
507 *
508 * Implementations for Trace are not included because you can just lift
509 * a constant to a Trace with ease. While you can lift constants to Records
510 * with ease too, these operations allow for the avoidance of storing the
511 * constant on the WengertList which saves memory.
512 *
513 * ```
514 * use easy_ml::differentiation::{Record, RecordTensor, WengertList};
515 * use easy_ml::differentiation::record_operations::SwappedOperations;
516 * use easy_ml::tensors::Tensor;
517 *
518 * let history = WengertList::new();
519 *
520 * let x = Record::variable(-1.0, &history);
521 * let z = x.sub_swapped(10.0);
522 * assert_eq!(z.number, 11.0);
523 *
524 * let X = RecordTensor::variables(
525 *     &history,
526 *     Tensor::from_fn([("x", 2), ("y", 2)], |[r, c]| ((r + 4) * (c + 1)) as f64)
527 * );
528 * let Z = X.div_swapped(100.0);
529 * assert_eq!(
530 *     Z.view().map(|(x, _)| x),
531 *     Tensor::from([("x", 2), ("y", 2)], vec![ 25.0, 12.5, 20.0, 10.0 ])
532 * );
533 * ```
534 */
535pub trait SwappedOperations<Lhs = Self> {
536    type Output;
537    fn sub_swapped(self, lhs: Lhs) -> Self::Output;
538    fn div_swapped(self, lhs: Lhs) -> Self::Output;
539}
540
541impl<'a, T: Numeric + Primitive> SwappedOperations<&T> for &Record<'a, T>
542where
543    for<'t> &'t T: NumericRef<T>,
544{
545    type Output = Record<'a, T>;
546    /**
547     * Subtraction for a record and a constant, where the constant
548     * is the left hand side, ie C - record.
549     */
550    #[inline]
551    fn sub_swapped(self, lhs: &T) -> Self::Output {
552        match self.history {
553            None => Record {
554                number: Subtraction::<T>::function(lhs.clone(), self.number.clone()),
555                history: None,
556                index: 0,
557            },
558            Some(history) => {
559                Record {
560                    number: Subtraction::<T>::function(lhs.clone(), self.number.clone()),
561                    history: Some(history),
562                    index: history.append_unary(
563                        self.index,
564                        // We want with respect to y because it is the right hand side here that we
565                        // need the derivative for (since left is a constant).
566                        Subtraction::<T>::d_function_dy(lhs.clone(), self.number.clone()),
567                    ),
568                }
569            }
570        }
571    }
572
573    /**
574     * Division for a record and a constant, where the constant
575     * is the left hand side, ie C / record.
576     */
577    #[inline]
578    fn div_swapped(self, lhs: &T) -> Self::Output {
579        match self.history {
580            None => Record {
581                number: Division::<T>::function(lhs.clone(), self.number.clone()),
582                history: None,
583                index: 0,
584            },
585            Some(history) => {
586                Record {
587                    number: Division::<T>::function(lhs.clone(), self.number.clone()),
588                    history: Some(history),
589                    index: history.append_unary(
590                        self.index,
591                        // We want with respect to y because it is the right hand side here that we
592                        // need the derivative for (since left is a constant).
593                        Division::<T>::d_function_dy(lhs.clone(), self.number.clone()),
594                    ),
595                }
596            }
597        }
598    }
599}
600
601impl<'a, T: Numeric + Primitive> SwappedOperations<T> for &Record<'a, T>
602where
603    for<'t> &'t T: NumericRef<T>,
604{
605    type Output = Record<'a, T>;
606    /**
607     * Subtraction for a record and a constant, where the constant
608     * is the left hand side, ie C - record.
609     */
610    #[inline]
611    fn sub_swapped(self, lhs: T) -> Self::Output {
612        self.sub_swapped(&lhs)
613    }
614
615    /**
616     * Division for a record and a constant, where the constant
617     * is the left hand side, ie C / record.
618     */
619    #[inline]
620    fn div_swapped(self, lhs: T) -> Self::Output {
621        self.div_swapped(&lhs)
622    }
623}
624
625impl<'a, T: Numeric + Primitive> SwappedOperations<T> for Record<'a, T>
626where
627    for<'t> &'t T: NumericRef<T>,
628{
629    type Output = Record<'a, T>;
630    /**
631     * Subtraction for a record and a constant, where the constant
632     * is the left hand side, ie C - record.
633     */
634    #[inline]
635    fn sub_swapped(self, lhs: T) -> Self::Output {
636        (&self).sub_swapped(&lhs)
637    }
638
639    /**
640     * Division for a record and a constant, where the constant
641     * is the left hand side, ie C / record.
642     */
643    #[inline]
644    fn div_swapped(self, lhs: T) -> Self::Output {
645        (&self).div_swapped(&lhs)
646    }
647}
648
649impl<'a, T: Numeric + Primitive> SwappedOperations<&T> for Record<'a, T>
650where
651    for<'t> &'t T: NumericRef<T>,
652{
653    type Output = Record<'a, T>;
654    /**
655     * Subtraction for a record and a constant, where the constant
656     * is the left hand side, ie C - record.
657     */
658    #[inline]
659    fn sub_swapped(self, lhs: &T) -> Self::Output {
660        (&self).sub_swapped(lhs)
661    }
662
663    /**
664     * Division for a record and a constant, where the constant
665     * is the left hand side, ie C / record.
666     */
667    #[inline]
668    fn div_swapped(self, lhs: &T) -> Self::Output {
669        (&self).div_swapped(lhs)
670    }
671}
672
673/**
674 * Dvision for two records of the same type with both referenced and
675 * both using the same WengertList.
676 */
677impl<'a, 'l, 'r, T: Numeric + Primitive> Div<&'r Record<'a, T>> for &'l Record<'a, T>
678where
679    for<'t> &'t T: NumericRef<T>,
680{
681    type Output = Record<'a, T>;
682    #[track_caller]
683    #[inline]
684    fn div(self, rhs: &Record<'a, T>) -> Self::Output {
685        assert!(
686            same_list(self, rhs),
687            "Records must be using the same WengertList"
688        );
689        match (self.history, rhs.history) {
690            (None, None) => Record {
691                number: Division::<T>::function(self.number.clone(), rhs.number.clone()),
692                history: None,
693                index: 0,
694            },
695            // If only one input has a WengertList treat the other as a constant
696            (Some(_), None) => self / &rhs.number,
697            // Record::constant can't be used here as that would cause an infinite loop,
698            // so use the swapped version of Div
699            (None, Some(_)) => rhs.div_swapped(self.number.clone()),
700            (Some(history), Some(_)) => Record {
701                number: Division::<T>::function(self.number.clone(), rhs.number.clone()),
702                history: Some(history),
703                index: history.append_binary(
704                    self.index,
705                    Division::<T>::d_function_dx(self.number.clone(), rhs.number.clone()),
706                    rhs.index,
707                    Division::<T>::d_function_dy(self.number.clone(), rhs.number.clone()),
708                ),
709            },
710        }
711    }
712}
713
714record_operator_impl_value_value!(impl Div for Record { fn div });
715record_operator_impl_reference_value!(impl Div for Record { fn div });
716record_operator_impl_value_reference!(impl Div for Record { fn div });
717
718/**
719 * Division for a record and a constant of the same type with both referenced.
720 */
721impl<'a, T: Numeric + Primitive> Div<&T> for &Record<'a, T>
722where
723    for<'t> &'t T: NumericRef<T>,
724{
725    type Output = Record<'a, T>;
726    #[track_caller]
727    #[inline]
728    fn div(self, rhs: &T) -> Self::Output {
729        match self.history {
730            None => Record {
731                number: Division::<T>::function(self.number.clone(), rhs.clone()),
732                history: None,
733                index: 0,
734            },
735            Some(history) => Record {
736                number: Division::<T>::function(self.number.clone(), rhs.clone()),
737                history: Some(history),
738                index: history.append_unary(
739                    self.index,
740                    Division::<T>::d_function_dx(self.number.clone(), rhs.clone()),
741                ),
742            },
743        }
744    }
745}
746
747record_number_operator_impl_value_value!(impl Div for Record { fn div });
748record_number_operator_impl_reference_value!(impl Div for Record { fn div });
749record_number_operator_impl_value_reference!(impl Div for Record { fn div });
750
751/**
752 * Negation of a record by reference.
753 */
754impl<'a, T: Numeric + Primitive> Neg for &Record<'a, T>
755where
756    for<'t> &'t T: NumericRef<T>,
757{
758    type Output = Record<'a, T>;
759    #[inline]
760    fn neg(self) -> Self::Output {
761        match self.history {
762            None => Record {
763                number: -self.number.clone(),
764                history: None,
765                index: 0,
766            },
767            Some(_) => Record::constant(T::zero()) - self,
768        }
769    }
770}
771
772/**
773 * Negation of a record by value.
774 */
775impl<'a, T: Numeric + Primitive> Neg for Record<'a, T>
776where
777    for<'t> &'t T: NumericRef<T>,
778{
779    type Output = Record<'a, T>;
780    #[inline]
781    fn neg(self) -> Self::Output {
782        match self.history {
783            None => Record {
784                number: -self.number,
785                history: None,
786                index: 0,
787            },
788            Some(_) => Record::constant(T::zero()) - self,
789        }
790    }
791}
792
793/**
794 * Any record of a PartialEq type implements PartialEq
795 *
796 * Note that as a Record is intended to be substitutable with its
797 * type T only the number parts of the record are compared.
798 */
799impl<'a, T: PartialEq + Primitive> PartialEq for Record<'a, T> {
800    #[inline]
801    fn eq(&self, other: &Self) -> bool {
802        self.number == other.number
803    }
804}
805
806/**
807 * Any record of a PartialOrd type implements PartialOrd
808 *
809 * Note that as a Record is intended to be substitutable with its
810 * type T only the number parts of the record are compared.
811 */
812impl<'a, T: PartialOrd + Primitive> PartialOrd for Record<'a, T> {
813    #[inline]
814    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
815        self.number.partial_cmp(&other.number)
816    }
817}
818
819/**
820 * Any record of a Numeric type implements Sum, which is
821 * the same as adding a bunch of Record types together.
822 */
823impl<'a, T: Numeric + Primitive> Sum for Record<'a, T> {
824    #[track_caller]
825    fn sum<I>(mut iter: I) -> Record<'a, T>
826    where
827        I: Iterator<Item = Record<'a, T>>,
828    {
829        let mut total = Record::<'a, T>::zero();
830        loop {
831            match iter.next() {
832                None => return total,
833                Some(next) => {
834                    total = match (total.history, next.history) {
835                        (None, None) => Record {
836                            number: total.number.clone() + next.number.clone(),
837                            history: None,
838                            index: 0,
839                        },
840                        // If only one input has a WengertList treat the other as a constant
841                        (Some(history), None) => {
842                            Record {
843                                number: total.number.clone() + next.number.clone(),
844                                history: Some(history),
845                                index: history.append_unary(
846                                    total.index,
847                                    // δ(total + next) / δtotal = 1
848                                    T::one(),
849                                ),
850                            }
851                        }
852                        (None, Some(history)) => {
853                            Record {
854                                number: total.number.clone() + next.number.clone(),
855                                history: Some(history),
856                                index: history.append_unary(
857                                    next.index,
858                                    // δ(next + total) / δnext = 1
859                                    T::one(),
860                                ),
861                            }
862                        }
863                        (Some(history), Some(_)) => {
864                            assert!(
865                                same_list(&total, &next),
866                                "Records must be using the same WengertList"
867                            );
868                            Record {
869                                number: total.number.clone() + next.number.clone(),
870                                history: Some(history),
871                                index: history.append_binary(
872                                    total.index,
873                                    // δ(total + next) / δtotal = 1
874                                    T::one(),
875                                    next.index,
876                                    // δ(total + next) / δnext = 1
877                                    T::one(),
878                                ),
879                            }
880                        }
881                    }
882                }
883            }
884        }
885    }
886}
887
888/**
889 * Sine of a Record by reference.
890 */
891impl<'a, T: Real + Primitive> Sin for &Record<'a, T>
892where
893    for<'t> &'t T: RealRef<T>,
894{
895    type Output = Record<'a, T>;
896    #[inline]
897    fn sin(self) -> Self::Output {
898        match self.history {
899            None => Record {
900                number: Sine::<T>::function(self.number.clone()),
901                history: None,
902                index: 0,
903            },
904            Some(history) => Record {
905                number: Sine::<T>::function(self.number.clone()),
906                history: Some(history),
907                index: history
908                    .append_unary(self.index, Sine::<T>::d_function_dx(self.number.clone())),
909            },
910        }
911    }
912}
913
914macro_rules! record_real_operator_impl_value {
915    (impl $op:tt for Record { fn $method:ident }) => {
916        /**
917         * Operation for a record by value.
918         */
919        impl<'a, T: Real + Primitive> $op for Record<'a, T>
920        where
921            for<'t> &'t T: RealRef<T>,
922        {
923            type Output = Record<'a, T>;
924            #[inline]
925            fn $method(self) -> Self::Output {
926                (&self).$method()
927            }
928        }
929    };
930}
931
932record_real_operator_impl_value!(impl Sin for Record { fn sin });
933
934/**
935 * Cosine of a Record by reference.
936 */
937impl<'a, T: Real + Primitive> Cos for &Record<'a, T>
938where
939    for<'t> &'t T: RealRef<T>,
940{
941    type Output = Record<'a, T>;
942    #[inline]
943    fn cos(self) -> Self::Output {
944        match self.history {
945            None => Record {
946                number: Cosine::<T>::function(self.number.clone()),
947                history: None,
948                index: 0,
949            },
950            Some(history) => Record {
951                number: Cosine::<T>::function(self.number.clone()),
952                history: Some(history),
953                index: history
954                    .append_unary(self.index, Cosine::<T>::d_function_dx(self.number.clone())),
955            },
956        }
957    }
958}
959
960record_real_operator_impl_value!(impl Cos for Record { fn cos });
961
962/**
963 * Exponential, ie e<sup>x</sup> of a Record by reference.
964 */
965impl<'a, T: Real + Primitive> Exp for &Record<'a, T>
966where
967    for<'t> &'t T: RealRef<T>,
968{
969    type Output = Record<'a, T>;
970    #[inline]
971    fn exp(self) -> Self::Output {
972        match self.history {
973            None => Record {
974                number: Exponential::<T>::function(self.number.clone()),
975                history: None,
976                index: 0,
977            },
978            Some(history) => Record {
979                number: Exponential::<T>::function(self.number.clone()),
980                history: Some(history),
981                index: history.append_unary(
982                    self.index,
983                    Exponential::<T>::d_function_dx(self.number.clone()),
984                ),
985            },
986        }
987    }
988}
989
990record_real_operator_impl_value!(impl Exp for Record { fn exp });
991
992/**
993 * Natural logarithm, ie ln(x) of a Record by reference.
994 */
995impl<'a, T: Real + Primitive> Ln for &Record<'a, T>
996where
997    for<'t> &'t T: RealRef<T>,
998{
999    type Output = Record<'a, T>;
1000    #[inline]
1001    fn ln(self) -> Self::Output {
1002        match self.history {
1003            None => Record {
1004                number: NaturalLogarithm::<T>::function(self.number.clone()),
1005                history: None,
1006                index: 0,
1007            },
1008            Some(history) => Record {
1009                number: NaturalLogarithm::<T>::function(self.number.clone()),
1010                history: Some(history),
1011                index: history.append_unary(
1012                    self.index,
1013                    NaturalLogarithm::<T>::d_function_dx(self.number.clone()),
1014                ),
1015            },
1016        }
1017    }
1018}
1019
1020record_real_operator_impl_value!(impl Ln for Record { fn ln });
1021
1022/**
1023 * Square root of a Record by reference.
1024 */
1025impl<'a, T: Real + Primitive> Sqrt for &Record<'a, T>
1026where
1027    for<'t> &'t T: RealRef<T>,
1028{
1029    type Output = Record<'a, T>;
1030    #[inline]
1031    fn sqrt(self) -> Self::Output {
1032        match self.history {
1033            None => Record {
1034                number: SquareRoot::<T>::function(self.number.clone()),
1035                history: None,
1036                index: 0,
1037            },
1038            Some(history) => Record {
1039                number: SquareRoot::<T>::function(self.number.clone()),
1040                history: Some(history),
1041                index: history.append_unary(
1042                    self.index,
1043                    SquareRoot::<T>::d_function_dx(self.number.clone()),
1044                ),
1045            },
1046        }
1047    }
1048}
1049
1050record_real_operator_impl_value!(impl Sqrt for Record { fn sqrt });
1051
1052/**
1053 * Power of one Record to another, ie self^rhs for two records of
1054 * the same type with both referenced and both using the same WengertList.
1055 */
1056impl<'a, 'l, 'r, T: Real + Primitive> Pow<&'r Record<'a, T>> for &'l Record<'a, T>
1057where
1058    for<'t> &'t T: RealRef<T>,
1059{
1060    type Output = Record<'a, T>;
1061    #[inline]
1062    #[track_caller]
1063    fn pow(self, rhs: &Record<'a, T>) -> Self::Output {
1064        assert!(
1065            same_list(self, rhs),
1066            "Records must be using the same WengertList"
1067        );
1068        match (self.history, rhs.history) {
1069            (None, None) => Record {
1070                number: Power::<T>::function(self.number.clone(), rhs.number.clone()),
1071                history: None,
1072                index: 0,
1073            },
1074            // If only one input has a WengertList treat the other as a constant
1075            (Some(_), None) => self.pow(&rhs.number),
1076            (None, Some(_)) => (&self.number).pow(rhs),
1077            (Some(history), Some(_)) => Record {
1078                number: Power::<T>::function(self.number.clone(), rhs.number.clone()),
1079                history: Some(history),
1080                index: history.append_binary(
1081                    self.index,
1082                    Power::<T>::d_function_dx(self.number.clone(), rhs.number.clone()),
1083                    rhs.index,
1084                    Power::<T>::d_function_dy(self.number.clone(), rhs.number.clone()),
1085                ),
1086            },
1087        }
1088    }
1089}
1090
1091macro_rules! record_real_operator_impl_value_value {
1092    (impl $op:tt for Record { fn $method:ident }) => {
1093        /**
1094         * Operation for two records of the same type.
1095         */
1096        impl<'a, T: Real + Primitive> $op for Record<'a, T>
1097        where
1098            for<'t> &'t T: RealRef<T>,
1099        {
1100            type Output = Record<'a, T>;
1101            #[track_caller]
1102            #[inline]
1103            fn $method(self, rhs: Record<'a, T>) -> Self::Output {
1104                (&self).$method(&rhs)
1105            }
1106        }
1107    };
1108}
1109
1110macro_rules! record_real_operator_impl_value_reference {
1111    (impl $op:tt for Record { fn $method:ident }) => {
1112        /**
1113         * Operation for two records of the same type with the right referenced.
1114         */
1115        impl<'a, T: Real + Primitive> $op<&Record<'a, T>> for Record<'a, T>
1116        where
1117            for<'t> &'t T: RealRef<T>,
1118        {
1119            type Output = Record<'a, T>;
1120            #[track_caller]
1121            #[inline]
1122            fn $method(self, rhs: &Record<'a, T>) -> Self::Output {
1123                (&self).$method(rhs)
1124            }
1125        }
1126    };
1127}
1128
1129macro_rules! record_real_operator_impl_reference_value {
1130    (impl $op:tt for Record { fn $method:ident }) => {
1131        /**
1132         * Operation for two records of the same type with the left referenced.
1133         */
1134        impl<'a, T: Real + Primitive> $op<Record<'a, T>> for &Record<'a, T>
1135        where
1136            for<'t> &'t T: RealRef<T>,
1137        {
1138            type Output = Record<'a, T>;
1139            #[track_caller]
1140            #[inline]
1141            fn $method(self, rhs: Record<'a, T>) -> Self::Output {
1142                self.$method(&rhs)
1143            }
1144        }
1145    };
1146}
1147
1148record_real_operator_impl_value_value!(impl Pow for Record { fn pow });
1149record_real_operator_impl_reference_value!(impl Pow for Record { fn pow });
1150record_real_operator_impl_value_reference!(impl Pow for Record { fn pow });
1151
1152/**
1153 * Power of one Record to a constant of the same type with both referenced.
1154 */
1155impl<'a, T: Real + Primitive> Pow<&T> for &Record<'a, T>
1156where
1157    for<'t> &'t T: RealRef<T>,
1158{
1159    type Output = Record<'a, T>;
1160    #[inline]
1161    fn pow(self, rhs: &T) -> Self::Output {
1162        match self.history {
1163            None => Record {
1164                number: Power::<T>::function(self.number.clone(), rhs.clone()),
1165                history: None,
1166                index: 0,
1167            },
1168            Some(history) => Record {
1169                number: Power::<T>::function(self.number.clone(), rhs.clone()),
1170                history: Some(history),
1171                index: history.append_unary(
1172                    self.index,
1173                    Power::<T>::d_function_dx(self.number.clone(), rhs.clone()),
1174                ),
1175            },
1176        }
1177    }
1178}
1179
1180/**
1181 * Power of a constant to a Record of the same type with both referenced.
1182 */
1183impl<'a, T: Real + Primitive> Pow<&Record<'a, T>> for &T
1184where
1185    for<'t> &'t T: RealRef<T>,
1186{
1187    type Output = Record<'a, T>;
1188    #[inline]
1189    fn pow(self, rhs: &Record<'a, T>) -> Self::Output {
1190        match rhs.history {
1191            None => Record {
1192                number: Power::<T>::function(self.clone(), rhs.number.clone()),
1193                history: None,
1194                index: 0,
1195            },
1196            Some(history) => {
1197                Record {
1198                    number: Power::<T>::function(self.clone(), rhs.number.clone()),
1199                    history: Some(history),
1200                    index: history.append_unary(
1201                        rhs.index,
1202                        // We want with respect to y because it is the right hand side here that we
1203                        // need the derivative for (since left is a constant).
1204                        Power::<T>::d_function_dy(self.clone(), rhs.number.clone()),
1205                    ),
1206                }
1207            }
1208        }
1209    }
1210}
1211
1212macro_rules! record_real_number_operator_impl_value_value {
1213    (impl $op:tt for Record { fn $method:ident }) => {
1214        /**
1215         * Operation for a record and a constant of the same type.
1216         */
1217        impl<'a, T: Real + Primitive> $op<T> for Record<'a, T>
1218        where
1219            for<'t> &'t T: RealRef<T>,
1220        {
1221            type Output = Record<'a, T>;
1222            #[inline]
1223            fn $method(self, rhs: T) -> Self::Output {
1224                (&self).$method(&rhs)
1225            }
1226        }
1227    };
1228}
1229
1230macro_rules! record_real_number_operator_impl_value_reference {
1231    (impl $op:tt for Record { fn $method:ident }) => {
1232        /**
1233         * Operation for a record and a constant of the same type with the right referenced.
1234         */
1235        impl<'a, T: Real + Primitive> $op<&T> for Record<'a, T>
1236        where
1237            for<'t> &'t T: RealRef<T>,
1238        {
1239            type Output = Record<'a, T>;
1240            #[inline]
1241            fn $method(self, rhs: &T) -> Self::Output {
1242                (&self).$method(rhs)
1243            }
1244        }
1245    };
1246}
1247
1248macro_rules! record_real_number_operator_impl_reference_value {
1249    (impl $op:tt for Record { fn $method:ident }) => {
1250        /**
1251         * Operation for a record and a constant of the same type with the left referenced.
1252         */
1253        impl<'a, T: Real + Primitive> $op<T> for &Record<'a, T>
1254        where
1255            for<'t> &'t T: RealRef<T>,
1256        {
1257            type Output = Record<'a, T>;
1258            #[inline]
1259            fn $method(self, rhs: T) -> Self::Output {
1260                self.$method(&rhs)
1261            }
1262        }
1263    };
1264}
1265
1266record_real_number_operator_impl_value_value!(impl Pow for Record { fn pow });
1267record_real_number_operator_impl_reference_value!(impl Pow for Record { fn pow });
1268record_real_number_operator_impl_value_reference!(impl Pow for Record { fn pow });
1269
1270macro_rules! real_number_record_operator_impl_value_value {
1271    (impl $op:tt for Record { fn $method:ident }) => {
1272        /**
1273         * Operation for a constant and a record of the same type.
1274         */
1275        impl<'a, T: Real + Primitive> $op<Record<'a, T>> for T
1276        where
1277            for<'t> &'t T: RealRef<T>,
1278        {
1279            type Output = Record<'a, T>;
1280            #[inline]
1281            fn $method(self, rhs: Record<'a, T>) -> Self::Output {
1282                (&self).$method(&rhs)
1283            }
1284        }
1285    };
1286}
1287
1288macro_rules! real_number_record_operator_impl_value_reference {
1289    (impl $op:tt for Record { fn $method:ident }) => {
1290        /**
1291         * Operation for a constant and a record of the same type with the right referenced.
1292         */
1293        impl<'a, T: Real + Primitive> $op<&Record<'a, T>> for T
1294        where
1295            for<'t> &'t T: RealRef<T>,
1296        {
1297            type Output = Record<'a, T>;
1298            #[inline]
1299            fn $method(self, rhs: &Record<'a, T>) -> Self::Output {
1300                (&self).$method(rhs)
1301            }
1302        }
1303    };
1304}
1305
1306macro_rules! real_number_record_operator_impl_reference_value {
1307    (impl $op:tt for Record { fn $method:ident }) => {
1308        /**
1309         * Operation for a constant and a record of the same type with the left referenced.
1310         */
1311        impl<'a, T: Real + Primitive> $op<Record<'a, T>> for &T
1312        where
1313            for<'t> &'t T: RealRef<T>,
1314        {
1315            type Output = Record<'a, T>;
1316            #[inline]
1317            fn $method(self, rhs: Record<'a, T>) -> Self::Output {
1318                self.$method(&rhs)
1319            }
1320        }
1321    };
1322}
1323
1324real_number_record_operator_impl_value_value!(impl Pow for Record { fn pow });
1325real_number_record_operator_impl_reference_value!(impl Pow for Record { fn pow });
1326real_number_record_operator_impl_value_reference!(impl Pow for Record { fn pow });
1327
1328impl<'a, T: Real + Primitive> Pi for Record<'a, T> {
1329    #[inline]
1330    fn pi() -> Record<'a, T> {
1331        Record::constant(T::pi())
1332    }
1333}