example1/
example1.rs

1use ad_trait::AD;
2use ad_trait::function_engine::FunctionEngine;
3use ad_trait::differentiable_function::{DifferentiableFunctionTrait, FiniteDifferencing, ForwardAD, ForwardADMulti, ReverseAD};
4use ad_trait::forward_ad::adfn::adfn;
5use ad_trait::reverse_ad::adr::adr;
6
7#[derive(Clone)]
8pub struct Test<T: AD> {
9    coeff: T
10}
11impl<T: AD> DifferentiableFunctionTrait<T> for Test<T> {
12    const NAME: &'static str = "Test";
13
14    fn call(&self, inputs: &[T], _freeze: bool) -> Vec<T> {
15        vec![ self.coeff*inputs[0].sin() + inputs[1].cos() ]
16    }
17
18    fn num_inputs(&self) -> usize {
19        2
20    }
21
22    fn num_outputs(&self) -> usize {
23        1
24    }
25}
26impl<T: AD> Test<T> {
27    pub fn to_other_ad_type<T2: AD>(&self) -> Test<T2> {
28        Test { coeff: self.coeff.to_other_ad_type::<T2>() }
29    }
30}
31
32
33fn main() {
34    let inputs = vec![1., 2.];
35
36    // Reverse AD //////////////////////////////////////////////////////////////////////////////////
37    let function_standard = Test { coeff: 2.0 };
38    let function_derivative = function_standard.to_other_ad_type::<adr>();
39    let differentiable_block = FunctionEngine::new(function_standard, function_derivative, ReverseAD::new());
40
41    let (f_res, derivative_res) = differentiable_block.derivative(&inputs);
42    println!("Reverse AD: ");
43    println!("  f_res: {}", f_res[0]);
44    println!("  derivative: {}", derivative_res);
45    println!("//////////////");
46    println!();
47
48    // Forward AD //////////////////////////////////////////////////////////////////////////////////
49    let function_standard = Test { coeff: 2.0 };
50    let function_derivative = function_standard.to_other_ad_type::<adfn<1>>();
51    let differentiable_block = FunctionEngine::new(function_standard, function_derivative, ForwardAD::new());
52
53    let (f_res, derivative_res) = differentiable_block.derivative(&inputs);
54    println!("Forward AD: ");
55    println!("  f_res: {}", f_res[0]);
56    println!("  derivative: {}", derivative_res);
57    println!("//////////////");
58    println!();
59
60    // Forward AD Multi ////////////////////////////////////////////////////////////////////////////
61    let function_standard = Test { coeff: 2.0 };
62    let function_derivative = function_standard.to_other_ad_type::<adfn<2>>();
63    let differentiable_block = FunctionEngine::new(function_standard, function_derivative, ForwardADMulti::new());
64
65    let (f_res, derivative_res) = differentiable_block.derivative(&inputs);
66    println!("Forward AD Multi: ");
67    println!("  f_res: {}", f_res[0]);
68    println!("  derivative: {}", derivative_res);
69    println!("//////////////");
70    println!();
71
72    // Finite Differencing /////////////////////////////////////////////////////////////////////////
73    let function_standard = Test { coeff: 2.0 };
74    let function_derivative = function_standard.clone();
75    let differentiable_block = FunctionEngine::new(function_standard, function_derivative, FiniteDifferencing::new());
76
77    let (f_res, derivative_res) = differentiable_block.derivative(&inputs);
78    println!("Finite Differencing: ");
79    println!("  f_res: {}", f_res[0]);
80    println!("  derivative: {}", derivative_res);
81    println!("//////////////");
82    println!();
83
84}