macro_rules! get_gradient {
($f:ident, $func_name:ident) => { ... };
}Expand description
Get a function that returns the gradient of the provided multivariate, scalar-valued function.
The gradient is computed using forward-mode automatic differentiation.
§Arguments
f- Multivariate, scalar-valued function, $f:\mathbb{R}^{n}\to\mathbb{R}$.func_name- Name of the function that will return the gradient of $f(\mathbf{x})$ at any point $\mathbf{x}\in\mathbb{R}^{n}$.
§Warning
f cannot be defined as closure. It must be defined as a function.
§Note
The function produced by this macro will perform $n$ evaluations of $f(\mathbf{x})$ to evaluate its gradient.
§Examples
§Basic Example
Compute the gradient of
$$f(\mathbf{x})=x_{0}^{5}+\sin^{3}{x_{1}}$$
at $\mathbf{x}=(5,8)^{T}$, and compare the result to the true result of
$$ \nabla f\left((5,8)^{T}\right)= \begin{bmatrix} 3125 \\ 3\sin^{2}{(8)}\cos{(8)} \end{bmatrix} $$
§Using standard vectors
use linalg_traits::{Scalar, Vector};
use numtest::*;
use numdiff::{get_gradient, Dual, DualVector};
// Define the function, f(x).
fn f<S: Scalar, V: Vector<S>>(x: &V, _p: &[f64]) -> S {
x.vget(0).powi(5) + x.vget(1).sin().powi(3)
}
// Define the evaluation point.
let x0 = vec![5.0, 8.0];
// Parameter vector (empty for this example).
let p = [];
// Autogenerate the function "g" that can be used to compute the gradient of f(x) at any point
// x.
get_gradient!(f, g);
// Function defining the true gradient of f(x).
let g_true = |x: &Vec<f64>| vec![5.0 * x[0].powi(4), 3.0 * x[1].sin().powi(2) * x[1].cos()];
// Evaluate the gradient using "g".
let g_eval: Vec<f64> = g(&x0, &p);
// Verify that the gradient function obtained using get_gradient! computes the gradient
// correctly.
assert_arrays_equal_to_decimal!(g(&x0, &p), g_true(&x0), 15);§Using other vector types
The function produced by get_gradient! can accept any type for x0, as long as it
implements the linalg_traits::Vector trait.
use faer::Mat;
use linalg_traits::{Scalar, Vector};
use nalgebra::{dvector, DVector, SVector};
use ndarray::{array, Array1};
use numdiff::{get_gradient, Dual, DualVector};
// Define the function, f(x).
fn f<S: Scalar, V: Vector<S>>(x: &V, _p: &[f64]) -> S {
x.vget(0).powi(5) + x.vget(1).sin().powi(3)
}
// Parameter vector (empty for this example).
let p = [];
// Autogenerate the function "g" that can be used to compute the gradient of f(x) at any point
// x.
get_gradient!(f, g);
// nalgebra::DVector
let x0: DVector<f64> = dvector![5.0, 8.0];
let g_eval: DVector<f64> = g(&x0, &p);
// nalgebra::SVector
let x0: SVector<f64, 2> = SVector::from_slice(&[5.0, 8.0]);
let g_eval: SVector<f64, 2> = g(&x0, &p);
// ndarray::Array1
let x0: Array1<f64> = array![5.0, 8.0];
let g_eval: Array1<f64> = g(&x0, &p);
// faer::Mat
let x0: Mat<f64> = Mat::from_slice(&[5.0, 8.0]);
let g_eval: Mat<f64> = g(&x0, &p);§Example Passing Runtime Parameters
Compute the gradient of a parameterized function
$$f(\mathbf{x})=ax_{0}^{2}+bx_{1}^{2}+cx_{0}x_{1}+d$$
where $a$, $b$, $c$, and $d$ are runtime parameters. Compare the result against the true gradient of
$$\nabla f=\begin{bmatrix}2ax_{0}+cx_{1}\\2bx_{1}+cx_{0}\end{bmatrix}$$
use linalg_traits::{Scalar, Vector};
use numtest::*;
use numdiff::{get_gradient, Dual, DualVector};
// Define the function, f(x).
fn f<S: Scalar, V: Vector<S>>(x: &V, p: &[f64]) -> S {
let a = S::new(p[0]);
let b = S::new(p[1]);
let c = S::new(p[2]);
let d = S::new(p[3]);
a * x.vget(0).powi(2) + b * x.vget(1).powi(2) + c * x.vget(0) * x.vget(1) + d
}
// Parameter vector.
let a = 2.0;
let b = 1.5;
let c = 0.8;
let d = -3.0;
let p = [a, b, c, d];
// Evaluation point.
let x0 = vec![1.0, -2.0];
// Autogenerate the gradient function.
get_gradient!(f, g);
// True gradient function.
let g_true = |x: &Vec<f64>| vec![2.0 * a * x[0] + c * x[1], 2.0 * b * x[1] + c * x[0]];
// Compute the gradient using both the automatically generated gradient function and the true
// gradient function, and compare the results.
let g_eval: Vec<f64> = g(&x0, &p);
let g_eval_true: Vec<f64> = g_true(&x0);
assert_arrays_equal_to_decimal!(g_eval, g_eval_true, 15);