1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
//! Automatic differentation at compile time.
//!
//! Dependencies:
//! ```text
//!        *Grad
//!       /     \
//! grad::Own grad::Ref
//!       \     /
//!     grad::Typed
//!          |
//!        *Eval
//! ```

use crate::eval::Eval;

/// Output type.
#[const_trait]
pub trait Typed: ~const Eval {
    /// Output of evaluation (i.e. `grad(&x) -> ???`).
    type Differentiated: ~const Eval;
}

/// Implementation taking `self` (moved).
#[const_trait]
pub trait Own: ~const Typed {
    /// Fold an expression into a value.
    fn grad<U>(self, x: &U) -> Self::Differentiated;
}

/// Implementation taking `&self` (not moved).
#[const_trait]
pub trait Ref: ~const Typed {
    /// Fold an expression into a value without consuming the expression.
    fn grad<U>(&self, x: &U) -> Self::Differentiated;
}

/// Automatically differentiate an expression, optionally at compile time (if evaluated into a `const`).
#[const_trait]
pub trait Grad: ~const Own + ~const Ref {}

/// Automagically implement `Grad`.
#[macro_export]
macro_rules! implement_grad {
    ($name:ty >-> $output:ty: |$self:ident, $x:ident| $body:expr) => {
        impl const $crate::grad::Typed for $name {
            type Differentiated = $output;
        }
        impl const $crate::grad::Own for $name {
            #[inline(always)]
            fn grad<U>($self, $x: &U) -> $output {
                $body
            }
        }
        impl const $crate::grad::Ref for $name {
            #[inline(always)]
            fn grad<U>(&$self, $x: &U) -> $output {
                $body
            }
        }
        impl const $crate::grad::Grad for $name {}
    };
    ($($t:ident: $const_trait:path),+ => $name:ty >-> $output:ty: |$self:ident, $x:ident| $body:expr) => {
        impl<$($t: ~const $const_trait),+> const $crate::grad::Typed for $name {
            type Differentiated = $output;
        }
        impl<$($t: ~const $const_trait),+> const $crate::grad::Own for $name {
            #[inline(always)]
            fn grad<U>($self, $x: &U) -> $output {
                $body
            }
        }
        impl<$($t: ~const $const_trait),+> const $crate::grad::Ref for $name {
            #[inline(always)]
            fn grad<U>(&$self, $x: &U) -> $output {
                $body
            }
        }
        impl<$($t: ~const $const_trait),+> const $crate::grad::Grad for $name {}
    };
    // It would be so nice to have `where` as a parallel to C++'s `if constexpr`...
    ($($t:ident: $const_trait:path),+ => $name:ty >-> $output:ty: |$self:ident, $x:ident| where own { $own:expr } else { $ref:expr }) => {
        impl<$($t: ~const $const_trait),+> const $crate::grad::Typed for $name {
            type Differentiated = $output;
        }
        impl<$($t: ~const $const_trait),+> const $crate::grad::Own for $name {
            #[inline(always)]
            fn grad<U>($self, $x: &U) -> $output {
                $own
            }
        }
        impl<$($t: ~const $const_trait),+> const $crate::grad::Ref for $name {
            #[inline(always)]
            fn grad<U>(&$self, $x: &U) -> $output {
                $ref
            }
        }
        impl<$($t: ~const $const_trait),+> const $crate::grad::Grad for $name {}
    };
}