1use ad_trait::AD;
2use ad_trait::function_engine::FunctionEngine;
3use ad_trait::differentiable_function::{DifferentiableFunctionTrait, FiniteDifferencing, ForwardAD, ForwardADMulti, ReverseAD};
4use ad_trait::forward_ad::adfn::adfn;
5use ad_trait::reverse_ad::adr::adr;
6
7#[derive(Clone)]
8pub struct Test<T: AD> {
9 coeff: T
10}
11impl<T: AD> DifferentiableFunctionTrait<T> for Test<T> {
12 const NAME: &'static str = "Test";
13
14 fn call(&self, inputs: &[T], _freeze: bool) -> Vec<T> {
15 vec![ self.coeff*inputs[0].sin() + inputs[1].cos() ]
16 }
17
18 fn num_inputs(&self) -> usize {
19 2
20 }
21
22 fn num_outputs(&self) -> usize {
23 1
24 }
25}
26impl<T: AD> Test<T> {
27 pub fn to_other_ad_type<T2: AD>(&self) -> Test<T2> {
28 Test { coeff: self.coeff.to_other_ad_type::<T2>() }
29 }
30}
31
32
33fn main() {
34 let inputs = vec![1., 2.];
35
36 let function_standard = Test { coeff: 2.0 };
38 let function_derivative = function_standard.to_other_ad_type::<adr>();
39 let differentiable_block = FunctionEngine::new(function_standard, function_derivative, ReverseAD::new());
40
41 let (f_res, derivative_res) = differentiable_block.derivative(&inputs);
42 println!("Reverse AD: ");
43 println!(" f_res: {}", f_res[0]);
44 println!(" derivative: {}", derivative_res);
45 println!("//////////////");
46 println!();
47
48 let function_standard = Test { coeff: 2.0 };
50 let function_derivative = function_standard.to_other_ad_type::<adfn<1>>();
51 let differentiable_block = FunctionEngine::new(function_standard, function_derivative, ForwardAD::new());
52
53 let (f_res, derivative_res) = differentiable_block.derivative(&inputs);
54 println!("Forward AD: ");
55 println!(" f_res: {}", f_res[0]);
56 println!(" derivative: {}", derivative_res);
57 println!("//////////////");
58 println!();
59
60 let function_standard = Test { coeff: 2.0 };
62 let function_derivative = function_standard.to_other_ad_type::<adfn<2>>();
63 let differentiable_block = FunctionEngine::new(function_standard, function_derivative, ForwardADMulti::new());
64
65 let (f_res, derivative_res) = differentiable_block.derivative(&inputs);
66 println!("Forward AD Multi: ");
67 println!(" f_res: {}", f_res[0]);
68 println!(" derivative: {}", derivative_res);
69 println!("//////////////");
70 println!();
71
72 let function_standard = Test { coeff: 2.0 };
74 let function_derivative = function_standard.clone();
75 let differentiable_block = FunctionEngine::new(function_standard, function_derivative, FiniteDifferencing::new());
76
77 let (f_res, derivative_res) = differentiable_block.derivative(&inputs);
78 println!("Finite Differencing: ");
79 println!(" f_res: {}", f_res[0]);
80 println!(" derivative: {}", derivative_res);
81 println!("//////////////");
82 println!();
83
84}