multicalc/numerical_derivative/
hessian.rs

1use crate::numerical_derivative::derivator::DerivatorMultiVariable;
2
3use num_complex::ComplexFloat;
4
5///computes the hessian matrix for a given function
6/// Can handle single and multivariable equations of any complexity or size
7pub struct Hessian<D: DerivatorMultiVariable>
8{
9    derivator: D
10}
11
12impl<D: DerivatorMultiVariable> Default for Hessian<D>
13{
14    ///the default constructor, optimal for most generic cases
15    fn default() -> Self 
16    {
17        return Hessian { derivator: D::default() };    
18    }
19}
20
21impl<D: DerivatorMultiVariable> Hessian<D>
22{
23    ///custom constructor, optimal for fine tuning
24    /// You can create a custom multivariable derivator from this crate
25    /// or supply your own by implementing the base traits yourself 
26    pub fn from_derivator(derivator: D) -> Self
27    {
28        return Hessian {derivator}
29    }
30
31    /// Returns the hessian matrix for a given function
32    /// Can handle multivariable functions of any order or complexity
33    /// 
34    /// The 2-D matrix returned has the structure [[d2f/d2var1, d2f/dvar1*dvar2, ... , d2f/dvar1*dvarN], 
35    ///                                            [                   ...                            ], 
36    ///                                            [d2f/dvar1*dvarN, d2f/dvar2*dvarN, ... , dfM/d2varN]]
37    /// where 'N' is the total number of variables
38    /// 
39    /// NOTE: Returns a Result<T, &'static str>
40    /// Possible &'static str are:
41    /// NUMBER_OF_DERIVATIVE_STEPS_CANNOT_BE_ZERO -> if the derivative step size is zero
42    /// 
43    /// assume our function is y*sin(x) + 2*x*e^y. First define the function
44    /// ```
45    /// use multicalc::numerical_derivative::finite_difference::MultiVariableSolver;
46    /// use multicalc::numerical_derivative::hessian::Hessian;
47    ///    let my_func = | args: &[f64; 2] | -> f64 
48    ///    { 
49    ///        return args[1]*args[0].sin() + 2.0*args[0]*args[1].exp();
50    ///    };
51    /// 
52    /// let points = [1.0, 2.0]; //the point around which we want the hessian matrix
53    /// let hessian = Hessian::<MultiVariableSolver>::default();
54    /// 
55    /// let result = hessian.get(&my_func, &points).unwrap();
56    /// ```
57    /// 
58    pub fn get<T: ComplexFloat, const NUM_VARS: usize>(&self, function: &dyn Fn(&[T; NUM_VARS]) -> T, vector_of_points: &[T; NUM_VARS]) -> Result<[[T; NUM_VARS]; NUM_VARS], &'static str>
59    {
60        let mut result = [[T::from(f64::NAN).unwrap(); NUM_VARS]; NUM_VARS];
61
62        for row_index in 0..NUM_VARS
63        {
64            for col_index in 0..NUM_VARS
65            {
66                if result[row_index][col_index].is_nan()
67                {
68                    result[row_index][col_index] = self.derivator.get_double_partial(function, &[row_index, col_index], vector_of_points)?;
69
70                    result[col_index][row_index] = result[row_index][col_index]; //exploit the fact that a hessian is a symmetric matrix
71                }
72            }
73        }
74
75        return Ok(result);
76    }
77
78}