autodiff/
autodiffable.rs

1use crate::gradienttype::GradientType;
2
3// re-export Diffable<StaticArgs>
4pub use crate::diffable::Diffable as Diffable;
5
6pub trait AutoDiffable<StaticArgs>: Diffable<StaticArgs>
7where
8    <Self as Diffable<StaticArgs>>::Input: GradientType<<Self as Diffable<StaticArgs>>::Output>,
9{
10    /// Evaluate the function and its gradient for a given input and static arguments.
11    /// Returns `(f(x, static_args): <Self as Diffable<StaticArgs>>::Output, df/dx(x, static_args): <<Self as Diffable<StaticArgs>>::Input as GradientType<<Self as Diffable<StaticArgs>>::Output>>::GradientType)`
12    fn eval_grad(
13        &self,
14        x: &<Self as Diffable<StaticArgs>>::Input,
15        static_args: &StaticArgs,
16    ) -> (<Self as Diffable<StaticArgs>>::Output, <<Self as Diffable<StaticArgs>>::Input as GradientType<<Self as Diffable<StaticArgs>>::Output>>::GradientType);
17
18    /// Evaluate the function for a given input and static arguments.
19    /// Returns `f(x, static_args): <Self as Diffable<StaticArgs>>::Output`
20    fn eval(&self, x: &<Self as Diffable<StaticArgs>>::Input, static_args: &StaticArgs) -> <Self as Diffable<StaticArgs>>::Output
21    {
22        self.eval_grad(x, static_args).0
23    }
24
25    /// Evaluate the gradient for a given input and static arguments.
26    fn grad(&self, x: &<Self as Diffable<StaticArgs>>::Input, static_args: &StaticArgs) -> <<Self as Diffable<StaticArgs>>::Input as GradientType<<Self as Diffable<StaticArgs>>::Output>>::GradientType {
27        self.eval_grad(x, static_args).1
28    }
29}
30
31pub trait ForwardDiffable<StaticArgs>: Diffable<StaticArgs> {
32    /// Evaluate the function and its gradient in forward mode for a given input `x`, derivative `dx`, and static arguments
33    /// Returns `(f(x, static_args): <Self as Diffable<StaticArgs>>::Output, df(x, dx, static_args): <Self as Diffable<StaticArgs>>::Output)`
34    /// By default, `df = df/dx * dx`. However, this can be overridden in cases where this equality
35    /// does not hold (e.g. complex valued functions), or where a more efficient implementation is possible (e.g. functions whose arguments and return types are arrays)
36    /// NOTE: The multiplication here is not the same as normal multiplication. Instead in reality
37    /// `df = (df/dx).forward_mul(dx)`. For many types, this is equivalent to normal multiplication (all primitives which implement `Mul`). However, for arrays this is tensor contraction over the last few axes, such that the number of dimensions of `df` match that of `f`.
38    /// Similarly, this cannot be implemented for complex numbers, which will require a custom
39    /// eval_forward_grad implementation.
40    fn eval_forward_grad(
41        &self,
42        x: &<Self as Diffable<StaticArgs>>::Input,
43        dx: &<Self as Diffable<StaticArgs>>::Input,
44        static_args: &StaticArgs,
45    ) -> (<Self as Diffable<StaticArgs>>::Output, <Self as Diffable<StaticArgs>>::Output);
46
47    /// Evaluate the function for a given input `x` and static arguments
48    fn eval_forward(&self, x: &<Self as Diffable<StaticArgs>>::Input, static_args: &StaticArgs) -> <Self as Diffable<StaticArgs>>::Output {
49        self.eval_forward_grad(x, x, static_args).0
50    }
51
52    /// Evaluate the gradient in forward mode for a given input `x`, derivative `dx`, and static arguments
53    fn forward_grad(
54        &self,
55        x: &<Self as Diffable<StaticArgs>>::Input,
56        dx: &<Self as Diffable<StaticArgs>>::Input,
57        static_args: &StaticArgs,
58    ) -> <Self as Diffable<StaticArgs>>::Output {
59        self.eval_forward_grad(x, dx, static_args).1
60    }
61}