autodiff/autodiffable.rs
1use crate::gradienttype::GradientType;
2
3// re-export Diffable<StaticArgs>
4pub use crate::diffable::Diffable as Diffable;
5
6pub trait AutoDiffable<StaticArgs>: Diffable<StaticArgs>
7where
8 <Self as Diffable<StaticArgs>>::Input: GradientType<<Self as Diffable<StaticArgs>>::Output>,
9{
10 /// Evaluate the function and its gradient for a given input and static arguments.
11 /// Returns `(f(x, static_args): <Self as Diffable<StaticArgs>>::Output, df/dx(x, static_args): <<Self as Diffable<StaticArgs>>::Input as GradientType<<Self as Diffable<StaticArgs>>::Output>>::GradientType)`
12 fn eval_grad(
13 &self,
14 x: &<Self as Diffable<StaticArgs>>::Input,
15 static_args: &StaticArgs,
16 ) -> (<Self as Diffable<StaticArgs>>::Output, <<Self as Diffable<StaticArgs>>::Input as GradientType<<Self as Diffable<StaticArgs>>::Output>>::GradientType);
17
18 /// Evaluate the function for a given input and static arguments.
19 /// Returns `f(x, static_args): <Self as Diffable<StaticArgs>>::Output`
20 fn eval(&self, x: &<Self as Diffable<StaticArgs>>::Input, static_args: &StaticArgs) -> <Self as Diffable<StaticArgs>>::Output
21 {
22 self.eval_grad(x, static_args).0
23 }
24
25 /// Evaluate the gradient for a given input and static arguments.
26 fn grad(&self, x: &<Self as Diffable<StaticArgs>>::Input, static_args: &StaticArgs) -> <<Self as Diffable<StaticArgs>>::Input as GradientType<<Self as Diffable<StaticArgs>>::Output>>::GradientType {
27 self.eval_grad(x, static_args).1
28 }
29}
30
31pub trait ForwardDiffable<StaticArgs>: Diffable<StaticArgs> {
32 /// Evaluate the function and its gradient in forward mode for a given input `x`, derivative `dx`, and static arguments
33 /// Returns `(f(x, static_args): <Self as Diffable<StaticArgs>>::Output, df(x, dx, static_args): <Self as Diffable<StaticArgs>>::Output)`
34 /// By default, `df = df/dx * dx`. However, this can be overridden in cases where this equality
35 /// does not hold (e.g. complex valued functions), or where a more efficient implementation is possible (e.g. functions whose arguments and return types are arrays)
36 /// NOTE: The multiplication here is not the same as normal multiplication. Instead in reality
37 /// `df = (df/dx).forward_mul(dx)`. For many types, this is equivalent to normal multiplication (all primitives which implement `Mul`). However, for arrays this is tensor contraction over the last few axes, such that the number of dimensions of `df` match that of `f`.
38 /// Similarly, this cannot be implemented for complex numbers, which will require a custom
39 /// eval_forward_grad implementation.
40 fn eval_forward_grad(
41 &self,
42 x: &<Self as Diffable<StaticArgs>>::Input,
43 dx: &<Self as Diffable<StaticArgs>>::Input,
44 static_args: &StaticArgs,
45 ) -> (<Self as Diffable<StaticArgs>>::Output, <Self as Diffable<StaticArgs>>::Output);
46
47 /// Evaluate the function for a given input `x` and static arguments
48 fn eval_forward(&self, x: &<Self as Diffable<StaticArgs>>::Input, static_args: &StaticArgs) -> <Self as Diffable<StaticArgs>>::Output {
49 self.eval_forward_grad(x, x, static_args).0
50 }
51
52 /// Evaluate the gradient in forward mode for a given input `x`, derivative `dx`, and static arguments
53 fn forward_grad(
54 &self,
55 x: &<Self as Diffable<StaticArgs>>::Input,
56 dx: &<Self as Diffable<StaticArgs>>::Input,
57 static_args: &StaticArgs,
58 ) -> <Self as Diffable<StaticArgs>>::Output {
59 self.eval_forward_grad(x, dx, static_args).1
60 }
61}