#[cfg(feature = "argmin")]
use ndarray::{Array1, array};
#[cfg(feature = "argmin")]
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "checkpointing", derive(serde::Serialize, serde::Deserialize))]
pub enum TrustRegionRadiusMethod {
Cauchy,
Steihaug,
}
#[cfg_attr(feature = "checkpointing", derive(serde::Serialize, serde::Deserialize))]
pub enum LocalSolverConfig {
#[cfg(feature = "argmin")]
LBFGS {
max_iter: u64,
tolerance_grad: f64,
tolerance_cost: f64,
history_size: usize,
l1_coefficient: Option<f64>,
line_search_params: LineSearchParams,
},
#[cfg(feature = "argmin")]
NelderMead {
simplex_delta: f64,
sd_tolerance: f64,
max_iter: u64,
alpha: f64,
gamma: f64,
rho: f64,
sigma: f64,
},
#[cfg(feature = "argmin")]
SteepestDescent {
max_iter: u64,
line_search_params: LineSearchParams,
},
#[cfg(feature = "argmin")]
TrustRegion {
trust_region_radius_method: TrustRegionRadiusMethod,
max_iter: u64,
radius: f64,
max_radius: f64,
eta: f64,
},
#[cfg(feature = "argmin")]
NewtonCG {
max_iter: u64,
curvature_threshold: f64,
tolerance: f64,
line_search_params: LineSearchParams,
},
COBYLA {
max_iter: u64,
initial_step_size: f64,
ftol_rel: f64,
ftol_abs: f64,
xtol_rel: f64,
xtol_abs: Vec<f64>,
},
}
impl std::fmt::Debug for LocalSolverConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
#[cfg(feature = "argmin")]
LocalSolverConfig::LBFGS { .. } => f.debug_struct("LBFGS").finish_non_exhaustive(),
#[cfg(feature = "argmin")]
LocalSolverConfig::NelderMead { .. } => {
f.debug_struct("NelderMead").finish_non_exhaustive()
}
#[cfg(feature = "argmin")]
LocalSolverConfig::SteepestDescent { .. } => {
f.debug_struct("SteepestDescent").finish_non_exhaustive()
}
#[cfg(feature = "argmin")]
LocalSolverConfig::TrustRegion { .. } => {
f.debug_struct("TrustRegion").finish_non_exhaustive()
}
#[cfg(feature = "argmin")]
LocalSolverConfig::NewtonCG { .. } => {
f.debug_struct("NewtonCG").finish_non_exhaustive()
}
LocalSolverConfig::COBYLA {
max_iter,
initial_step_size,
ftol_rel,
ftol_abs,
xtol_rel,
xtol_abs,
} => f
.debug_struct("COBYLA")
.field("max_iter", max_iter)
.field("initial_step_size", initial_step_size)
.field("ftol_rel", ftol_rel)
.field("ftol_abs", ftol_abs)
.field("xtol_rel", xtol_rel)
.field("xtol_abs", xtol_abs)
.finish(),
}
}
}
impl Clone for LocalSolverConfig {
fn clone(&self) -> Self {
match self {
#[cfg(feature = "argmin")]
LocalSolverConfig::LBFGS {
max_iter,
tolerance_grad,
tolerance_cost,
history_size,
l1_coefficient,
line_search_params,
} => LocalSolverConfig::LBFGS {
max_iter: *max_iter,
tolerance_grad: *tolerance_grad,
tolerance_cost: *tolerance_cost,
history_size: *history_size,
l1_coefficient: *l1_coefficient,
line_search_params: line_search_params.clone(),
},
#[cfg(feature = "argmin")]
LocalSolverConfig::NelderMead {
simplex_delta,
sd_tolerance,
max_iter,
alpha,
gamma,
rho,
sigma,
} => LocalSolverConfig::NelderMead {
simplex_delta: *simplex_delta,
sd_tolerance: *sd_tolerance,
max_iter: *max_iter,
alpha: *alpha,
gamma: *gamma,
rho: *rho,
sigma: *sigma,
},
#[cfg(feature = "argmin")]
LocalSolverConfig::SteepestDescent { max_iter, line_search_params } => {
LocalSolverConfig::SteepestDescent {
max_iter: *max_iter,
line_search_params: line_search_params.clone(),
}
}
#[cfg(feature = "argmin")]
LocalSolverConfig::TrustRegion {
trust_region_radius_method,
max_iter,
radius,
max_radius,
eta,
} => LocalSolverConfig::TrustRegion {
trust_region_radius_method: trust_region_radius_method.clone(),
max_iter: *max_iter,
radius: *radius,
max_radius: *max_radius,
eta: *eta,
},
#[cfg(feature = "argmin")]
LocalSolverConfig::NewtonCG {
max_iter,
curvature_threshold,
tolerance,
line_search_params,
} => LocalSolverConfig::NewtonCG {
max_iter: *max_iter,
curvature_threshold: *curvature_threshold,
tolerance: *tolerance,
line_search_params: line_search_params.clone(),
},
LocalSolverConfig::COBYLA {
max_iter,
initial_step_size,
ftol_rel,
ftol_abs,
xtol_rel,
xtol_abs,
} => LocalSolverConfig::COBYLA {
max_iter: *max_iter,
initial_step_size: *initial_step_size,
ftol_rel: *ftol_rel,
ftol_abs: *ftol_abs,
xtol_rel: *xtol_rel,
xtol_abs: xtol_abs.clone(),
},
}
}
}
impl LocalSolverConfig {
#[cfg(feature = "argmin")]
pub fn lbfgs() -> LBFGSBuilder {
LBFGSBuilder::default()
}
#[cfg(feature = "argmin")]
pub fn neldermead() -> NelderMeadBuilder {
NelderMeadBuilder::default()
}
#[cfg(feature = "argmin")]
pub fn steepestdescent() -> SteepestDescentBuilder {
SteepestDescentBuilder::default()
}
#[cfg(feature = "argmin")]
pub fn trustregion() -> TrustRegionBuilder {
TrustRegionBuilder::default()
}
#[cfg(feature = "argmin")]
pub fn newton_cg() -> NewtonCGBuilder {
NewtonCGBuilder::default()
}
pub fn cobyla() -> COBYLABuilder {
COBYLABuilder::default()
}
}
#[cfg(feature = "argmin")]
#[derive(Debug, Clone)]
pub struct LBFGSBuilder {
max_iter: u64,
tolerance_grad: f64,
tolerance_cost: f64,
history_size: usize,
l1_coefficient: Option<f64>,
line_search_params: LineSearchParams,
}
#[cfg(feature = "argmin")]
impl LBFGSBuilder {
pub fn new(
max_iter: u64,
tolerance_grad: f64,
tolerance_cost: f64,
history_size: usize,
l1_coefficient: Option<f64>,
line_search_params: LineSearchParams,
) -> Self {
LBFGSBuilder {
max_iter,
tolerance_grad,
tolerance_cost,
history_size,
l1_coefficient,
line_search_params,
}
}
pub fn build(self) -> LocalSolverConfig {
LocalSolverConfig::LBFGS {
max_iter: self.max_iter,
tolerance_grad: self.tolerance_grad,
tolerance_cost: self.tolerance_cost,
history_size: self.history_size,
l1_coefficient: self.l1_coefficient,
line_search_params: self.line_search_params,
}
}
pub fn max_iter(mut self, max_iter: u64) -> Self {
self.max_iter = max_iter;
self
}
pub fn tolerance_grad(mut self, tolerance_grad: f64) -> Self {
self.tolerance_grad = tolerance_grad;
self
}
pub fn tolerance_cost(mut self, tolerance_cost: f64) -> Self {
self.tolerance_cost = tolerance_cost;
self
}
pub fn history_size(mut self, history_size: usize) -> Self {
self.history_size = history_size;
self
}
pub fn line_search_params(mut self, line_search_params: LineSearchParams) -> Self {
self.line_search_params = line_search_params;
self
}
pub fn l1_coefficient(mut self, l1_coefficient: Option<f64>) -> Self {
self.l1_coefficient = l1_coefficient;
self
}
}
#[cfg(feature = "argmin")]
impl Default for LBFGSBuilder {
fn default() -> Self {
LBFGSBuilder {
max_iter: 300,
tolerance_grad: f64::EPSILON.sqrt(),
tolerance_cost: f64::EPSILON,
history_size: 10,
l1_coefficient: None,
line_search_params: LineSearchParams::default(),
}
}
}
#[cfg(feature = "argmin")]
#[derive(Debug, Clone)]
pub struct NelderMeadBuilder {
simplex_delta: f64,
sd_tolerance: f64,
max_iter: u64,
alpha: f64,
gamma: f64,
rho: f64,
sigma: f64,
}
#[cfg(feature = "argmin")]
impl NelderMeadBuilder {
pub fn new(
simplex_delta: f64,
sd_tolerance: f64,
max_iter: u64,
alpha: f64,
gamma: f64,
rho: f64,
sigma: f64,
) -> Self {
NelderMeadBuilder { simplex_delta, sd_tolerance, max_iter, alpha, gamma, rho, sigma }
}
pub fn build(self) -> LocalSolverConfig {
LocalSolverConfig::NelderMead {
simplex_delta: self.simplex_delta,
sd_tolerance: self.sd_tolerance,
max_iter: self.max_iter,
alpha: self.alpha,
gamma: self.gamma,
rho: self.rho,
sigma: self.sigma,
}
}
pub fn simplex_delta(mut self, simplex_delta: f64) -> Self {
self.simplex_delta = simplex_delta;
self
}
pub fn sd_tolerance(mut self, sd_tolerance: f64) -> Self {
self.sd_tolerance = sd_tolerance;
self
}
pub fn max_iter(mut self, max_iter: u64) -> Self {
self.max_iter = max_iter;
self
}
pub fn alpha(mut self, alpha: f64) -> Self {
self.alpha = alpha;
self
}
pub fn gamma(mut self, gamma: f64) -> Self {
self.gamma = gamma;
self
}
pub fn rho(mut self, rho: f64) -> Self {
self.rho = rho;
self
}
pub fn sigma(mut self, sigma: f64) -> Self {
self.sigma = sigma;
self
}
}
#[cfg(feature = "argmin")]
impl Default for NelderMeadBuilder {
fn default() -> Self {
NelderMeadBuilder {
simplex_delta: 0.1,
sd_tolerance: f64::EPSILON,
max_iter: 300,
alpha: 1.0,
gamma: 2.0,
rho: 0.5,
sigma: 0.5,
}
}
}
#[cfg(feature = "argmin")]
#[derive(Debug, Clone)]
pub struct SteepestDescentBuilder {
max_iter: u64,
line_search_params: LineSearchParams,
}
#[cfg(feature = "argmin")]
impl SteepestDescentBuilder {
pub fn new(max_iter: u64, line_search_params: LineSearchParams) -> Self {
SteepestDescentBuilder { max_iter, line_search_params }
}
pub fn build(self) -> LocalSolverConfig {
LocalSolverConfig::SteepestDescent {
max_iter: self.max_iter,
line_search_params: self.line_search_params,
}
}
pub fn max_iter(mut self, max_iter: u64) -> Self {
self.max_iter = max_iter;
self
}
pub fn line_search_params(mut self, line_search_params: LineSearchParams) -> Self {
self.line_search_params = line_search_params;
self
}
}
#[cfg(feature = "argmin")]
impl Default for SteepestDescentBuilder {
fn default() -> Self {
SteepestDescentBuilder { max_iter: 300, line_search_params: LineSearchParams::default() }
}
}
#[cfg(feature = "argmin")]
#[derive(Debug, Clone)]
pub struct TrustRegionBuilder {
trust_region_radius_method: TrustRegionRadiusMethod,
max_iter: u64,
radius: f64,
max_radius: f64,
eta: f64,
}
#[cfg(feature = "argmin")]
impl TrustRegionBuilder {
pub fn new(
trust_region_radius_method: TrustRegionRadiusMethod,
max_iter: u64,
radius: f64,
max_radius: f64,
eta: f64,
) -> Self {
TrustRegionBuilder { trust_region_radius_method, max_iter, radius, max_radius, eta }
}
pub fn build(self) -> LocalSolverConfig {
LocalSolverConfig::TrustRegion {
trust_region_radius_method: self.trust_region_radius_method,
max_iter: self.max_iter,
radius: self.radius,
max_radius: self.max_radius,
eta: self.eta,
}
}
pub fn method(mut self, method: TrustRegionRadiusMethod) -> Self {
self.trust_region_radius_method = method;
self
}
pub fn max_iter(mut self, max_iter: u64) -> Self {
self.max_iter = max_iter;
self
}
pub fn radius(mut self, radius: f64) -> Self {
self.radius = radius;
self
}
pub fn max_radius(mut self, max_radius: f64) -> Self {
self.max_radius = max_radius;
self
}
pub fn eta(mut self, eta: f64) -> Self {
self.eta = eta;
self
}
}
#[cfg(feature = "argmin")]
impl Default for TrustRegionBuilder {
fn default() -> Self {
TrustRegionBuilder {
trust_region_radius_method: TrustRegionRadiusMethod::Cauchy,
max_iter: 300,
radius: 1.0,
max_radius: 100.0,
eta: 0.125,
}
}
}
#[cfg(feature = "argmin")]
#[derive(Debug, Clone)]
pub struct NewtonCGBuilder {
max_iter: u64,
curvature_threshold: f64,
tolerance: f64,
line_search_params: LineSearchParams,
}
#[cfg(feature = "argmin")]
impl NewtonCGBuilder {
pub fn new(
max_iter: u64,
curvature_threshold: f64,
tolerance: f64,
line_search_params: LineSearchParams,
) -> Self {
NewtonCGBuilder { max_iter, curvature_threshold, tolerance, line_search_params }
}
pub fn build(self) -> LocalSolverConfig {
LocalSolverConfig::NewtonCG {
max_iter: self.max_iter,
curvature_threshold: self.curvature_threshold,
tolerance: self.tolerance,
line_search_params: self.line_search_params,
}
}
pub fn max_iter(mut self, max_iter: u64) -> Self {
self.max_iter = max_iter;
self
}
pub fn curvature_threshold(mut self, curvature_threshold: f64) -> Self {
self.curvature_threshold = curvature_threshold;
self
}
pub fn tolerance(mut self, tolerance: f64) -> Self {
self.tolerance = tolerance;
self
}
pub fn line_search_params(mut self, line_search_params: LineSearchParams) -> Self {
self.line_search_params = line_search_params;
self
}
}
#[cfg(feature = "argmin")]
impl Default for NewtonCGBuilder {
fn default() -> Self {
NewtonCGBuilder {
max_iter: 300,
curvature_threshold: 0.0,
tolerance: f64::EPSILON,
line_search_params: LineSearchParams::default(),
}
}
}
pub struct COBYLABuilder {
max_iter: u64,
initial_step_size: f64,
ftol_rel: Option<f64>,
ftol_abs: Option<f64>,
xtol_rel: Option<f64>,
xtol_abs: Option<Vec<f64>>,
}
impl COBYLABuilder {
pub fn new(max_iter: u64, initial_step_size: f64) -> Self {
COBYLABuilder {
max_iter,
initial_step_size,
ftol_rel: None,
ftol_abs: None,
xtol_rel: None,
xtol_abs: None,
}
}
pub fn build(self) -> LocalSolverConfig {
LocalSolverConfig::COBYLA {
max_iter: self.max_iter,
initial_step_size: self.initial_step_size,
ftol_rel: self.ftol_rel.unwrap_or(1e-6),
ftol_abs: self.ftol_abs.unwrap_or(1e-8),
xtol_rel: self.xtol_rel.unwrap_or(0.0), xtol_abs: self.xtol_abs.unwrap_or_default(), }
}
pub fn max_iter(mut self, max_iter: u64) -> Self {
self.max_iter = max_iter;
self
}
pub fn initial_step_size(mut self, initial_step_size: f64) -> Self {
self.initial_step_size = initial_step_size;
self
}
pub fn ftol_rel(mut self, ftol_rel: f64) -> Self {
self.ftol_rel = Some(ftol_rel);
self
}
pub fn ftol_abs(mut self, ftol_abs: f64) -> Self {
self.ftol_abs = Some(ftol_abs);
self
}
pub fn xtol_rel(mut self, xtol_rel: f64) -> Self {
self.xtol_rel = Some(xtol_rel);
self
}
pub fn xtol_abs(mut self, xtol_abs: Vec<f64>) -> Self {
self.xtol_abs = Some(xtol_abs);
self
}
}
impl Default for COBYLABuilder {
fn default() -> Self {
COBYLABuilder {
max_iter: 300,
initial_step_size: 0.5,
ftol_rel: Some(1e-6),
ftol_abs: Some(1e-8),
xtol_rel: None,
xtol_abs: None,
}
}
}
#[cfg(feature = "argmin")]
#[derive(Debug, Clone)]
#[cfg_attr(feature = "checkpointing", derive(serde::Serialize, serde::Deserialize))]
pub enum LineSearchMethod {
MoreThuente {
c1: f64,
c2: f64,
width_tolerance: f64,
bounds: Array1<f64>,
},
HagerZhang {
delta: f64,
sigma: f64,
epsilon: f64,
theta: f64,
gamma: f64,
eta: f64,
bounds: Array1<f64>,
},
}
#[cfg(feature = "argmin")]
#[derive(Debug, Clone)]
#[cfg_attr(feature = "checkpointing", derive(serde::Serialize, serde::Deserialize))]
pub struct LineSearchParams {
pub method: LineSearchMethod,
}
#[cfg(feature = "argmin")]
impl LineSearchParams {
pub fn morethuente() -> MoreThuenteBuilder {
MoreThuenteBuilder::default()
}
pub fn hagerzhang() -> HagerZhangBuilder {
HagerZhangBuilder::default()
}
}
#[cfg(feature = "argmin")]
impl Default for LineSearchParams {
fn default() -> Self {
LineSearchParams {
method: LineSearchMethod::MoreThuente {
c1: 1e-4,
c2: 0.9,
width_tolerance: 1e-10,
bounds: array![f64::EPSILON.sqrt(), f64::INFINITY],
},
}
}
}
#[cfg(feature = "argmin")]
#[derive(Debug, Clone)]
pub struct MoreThuenteBuilder {
c1: f64,
c2: f64,
width_tolerance: f64,
bounds: Array1<f64>,
}
#[cfg(feature = "argmin")]
impl MoreThuenteBuilder {
pub fn new(c1: f64, c2: f64, width_tolerance: f64, bounds: Array1<f64>) -> Self {
MoreThuenteBuilder { c1, c2, width_tolerance, bounds }
}
pub fn build(self) -> LineSearchParams {
LineSearchParams {
method: LineSearchMethod::MoreThuente {
c1: self.c1,
c2: self.c2,
width_tolerance: self.width_tolerance,
bounds: self.bounds,
},
}
}
pub fn c1(mut self, c1: f64) -> Self {
self.c1 = c1;
self
}
pub fn c2(mut self, c2: f64) -> Self {
self.c2 = c2;
self
}
pub fn width_tolerance(mut self, width_tolerance: f64) -> Self {
self.width_tolerance = width_tolerance;
self
}
pub fn bounds(mut self, bounds: Array1<f64>) -> Self {
self.bounds = bounds;
self
}
}
#[cfg(feature = "argmin")]
impl Default for MoreThuenteBuilder {
fn default() -> Self {
MoreThuenteBuilder {
c1: 1e-4,
c2: 0.9,
width_tolerance: 1e-10,
bounds: array![f64::EPSILON.sqrt(), f64::INFINITY],
}
}
}
#[cfg(feature = "argmin")]
#[derive(Debug, Clone)]
pub struct HagerZhangBuilder {
delta: f64,
sigma: f64,
epsilon: f64,
theta: f64,
gamma: f64,
eta: f64,
bounds: Array1<f64>,
}
#[cfg(feature = "argmin")]
impl HagerZhangBuilder {
pub fn new(
delta: f64,
sigma: f64,
epsilon: f64,
theta: f64,
gamma: f64,
eta: f64,
bounds: Array1<f64>,
) -> Self {
HagerZhangBuilder { delta, sigma, epsilon, theta, gamma, eta, bounds }
}
pub fn build(self) -> LineSearchParams {
LineSearchParams {
method: LineSearchMethod::HagerZhang {
delta: self.delta,
sigma: self.sigma,
epsilon: self.epsilon,
theta: self.theta,
gamma: self.gamma,
eta: self.eta,
bounds: self.bounds,
},
}
}
pub fn delta(mut self, delta: f64) -> Self {
self.delta = delta;
self
}
pub fn sigma(mut self, sigma: f64) -> Self {
self.sigma = sigma;
self
}
pub fn epsilon(mut self, epsilon: f64) -> Self {
self.epsilon = epsilon;
self
}
pub fn theta(mut self, theta: f64) -> Self {
self.theta = theta;
self
}
pub fn gamma(mut self, gamma: f64) -> Self {
self.gamma = gamma;
self
}
pub fn eta(mut self, eta: f64) -> Self {
self.eta = eta;
self
}
pub fn bounds(mut self, bounds: Array1<f64>) -> Self {
self.bounds = bounds;
self
}
}
#[cfg(feature = "argmin")]
impl Default for HagerZhangBuilder {
fn default() -> Self {
HagerZhangBuilder {
delta: 0.1,
sigma: 0.9,
epsilon: 1e-6,
theta: 0.5,
gamma: 0.66,
eta: 0.01,
bounds: array![f64::EPSILON, 1e5],
}
}
}
#[cfg(test)]
mod tests_builders {
use super::*;
#[cfg(feature = "argmin")]
mod argmin_tests {
use super::*;
#[test]
fn test_default_lbfgs() {
let lbfgs: LocalSolverConfig = LBFGSBuilder::default().build();
match lbfgs {
LocalSolverConfig::LBFGS {
max_iter,
tolerance_grad,
tolerance_cost,
history_size,
l1_coefficient,
line_search_params,
} => {
assert_eq!(max_iter, 300);
assert_eq!(tolerance_grad, f64::EPSILON.sqrt());
assert_eq!(tolerance_cost, f64::EPSILON);
assert_eq!(history_size, 10);
assert_eq!(l1_coefficient, None);
match line_search_params.method {
LineSearchMethod::MoreThuente { c1, c2, width_tolerance, bounds } => {
assert_eq!(c1, 1e-4);
assert_eq!(c2, 0.9);
assert_eq!(width_tolerance, 1e-10);
assert_eq!(bounds, array![f64::EPSILON.sqrt(), f64::INFINITY]);
}
_ => panic!("Expected MoreThuente line search method"),
}
}
_ => panic!("Expected L-BFGS local solver"),
}
}
#[test]
fn test_default_neldermead() {
let neldermead: LocalSolverConfig = NelderMeadBuilder::default().build();
match neldermead {
LocalSolverConfig::NelderMead {
simplex_delta,
sd_tolerance,
max_iter,
alpha,
gamma,
rho,
sigma,
} => {
assert_eq!(simplex_delta, 0.1);
assert_eq!(sd_tolerance, f64::EPSILON);
assert_eq!(max_iter, 300);
assert_eq!(alpha, 1.0);
assert_eq!(gamma, 2.0);
assert_eq!(rho, 0.5);
assert_eq!(sigma, 0.5);
}
_ => panic!("Expected Nelder-Mead local solver"),
}
}
#[test]
fn test_default_steepestdescent() {
let steepestdescent: LocalSolverConfig = SteepestDescentBuilder::default().build();
match steepestdescent {
LocalSolverConfig::SteepestDescent { max_iter, line_search_params } => {
assert_eq!(max_iter, 300);
match line_search_params.method {
LineSearchMethod::MoreThuente { c1, c2, width_tolerance, bounds } => {
assert_eq!(c1, 1e-4);
assert_eq!(c2, 0.9);
assert_eq!(width_tolerance, 1e-10);
assert_eq!(bounds, array![f64::EPSILON.sqrt(), f64::INFINITY]);
}
_ => panic!("Expected MoreThuente line search method"),
}
}
_ => panic!("Expected Steepest Descent local solver"),
}
}
#[test]
fn test_default_trustregion() {
let trustregion: LocalSolverConfig = TrustRegionBuilder::default().build();
match trustregion {
LocalSolverConfig::TrustRegion {
trust_region_radius_method,
max_iter,
radius,
max_radius,
eta,
} => {
assert_eq!(trust_region_radius_method, TrustRegionRadiusMethod::Cauchy);
assert_eq!(max_iter, 300);
assert_eq!(radius, 1.0);
assert_eq!(max_radius, 100.0);
assert_eq!(eta, 0.125);
}
_ => panic!("Expected Trust Region local solver"),
}
}
#[test]
fn test_default_newton_cg() {
let newtoncg: LocalSolverConfig = NewtonCGBuilder::default().build();
match newtoncg {
LocalSolverConfig::NewtonCG {
max_iter,
curvature_threshold,
tolerance,
line_search_params,
} => {
assert_eq!(max_iter, 300);
assert_eq!(curvature_threshold, 0.0);
assert_eq!(tolerance, f64::EPSILON);
match line_search_params.method {
LineSearchMethod::MoreThuente { c1, c2, width_tolerance, bounds } => {
assert_eq!(c1, 1e-4);
assert_eq!(c2, 0.9);
assert_eq!(width_tolerance, 1e-10);
assert_eq!(bounds, array![f64::EPSILON.sqrt(), f64::INFINITY]);
}
_ => panic!("Expected MoreThuente line search method"),
}
}
_ => panic!("Expected Newton-CG local solver"),
}
}
#[test]
fn test_default_morethuente() {
let morethuente: LineSearchParams = MoreThuenteBuilder::default().build();
match morethuente.method {
LineSearchMethod::MoreThuente { c1, c2, width_tolerance, bounds } => {
assert_eq!(c1, 1e-4);
assert_eq!(c2, 0.9);
assert_eq!(width_tolerance, 1e-10);
assert_eq!(bounds, array![f64::EPSILON.sqrt(), f64::INFINITY]);
}
_ => panic!("Expected MoreThuente line search method"),
}
}
#[test]
fn test_default_hagerzhang() {
let hagerzhang: LineSearchParams = HagerZhangBuilder::default().build();
match hagerzhang.method {
LineSearchMethod::HagerZhang {
delta,
sigma,
epsilon,
theta,
gamma,
eta,
bounds,
} => {
assert_eq!(delta, 0.1);
assert_eq!(sigma, 0.9);
assert_eq!(epsilon, 1e-6);
assert_eq!(theta, 0.5);
assert_eq!(gamma, 0.66);
assert_eq!(eta, 0.01);
assert_eq!(bounds, array![f64::EPSILON, 1e5]);
}
_ => panic!("Expected HagerZhang line search method"),
}
}
#[test]
fn change_params_lbfgs() {
let linesearch: LineSearchParams =
MoreThuenteBuilder::default().c1(1e-5).c2(0.8).build();
let lbfgs: LocalSolverConfig = LBFGSBuilder::default()
.max_iter(500)
.tolerance_grad(1e-8)
.tolerance_cost(1e-8)
.history_size(5)
.line_search_params(linesearch)
.build();
match lbfgs {
LocalSolverConfig::LBFGS {
max_iter,
tolerance_grad,
tolerance_cost,
history_size,
l1_coefficient,
line_search_params,
} => {
assert_eq!(max_iter, 500);
assert_eq!(tolerance_grad, 1e-8);
assert_eq!(tolerance_cost, 1e-8);
assert_eq!(history_size, 5);
assert_eq!(l1_coefficient, None);
match line_search_params.method {
LineSearchMethod::MoreThuente { c1, c2, width_tolerance, bounds } => {
assert_eq!(c1, 1e-5);
assert_eq!(c2, 0.8);
assert_eq!(width_tolerance, 1e-10);
assert_eq!(bounds, array![f64::EPSILON.sqrt(), f64::INFINITY]);
}
_ => panic!("Expected MoreThuente line search method"),
}
}
_ => panic!("Expected L-BFGS local solver"),
}
}
#[test]
fn change_params_neldermead() {
let neldermead: LocalSolverConfig = NelderMeadBuilder::default()
.simplex_delta(0.5)
.sd_tolerance(1e-5)
.max_iter(1000)
.alpha(1.5)
.gamma(3.0)
.rho(0.6)
.sigma(0.6)
.build();
match neldermead {
LocalSolverConfig::NelderMead {
simplex_delta,
sd_tolerance,
max_iter,
alpha,
gamma,
rho,
sigma,
} => {
assert_eq!(simplex_delta, 0.5);
assert_eq!(sd_tolerance, 1e-5);
assert_eq!(max_iter, 1000);
assert_eq!(alpha, 1.5);
assert_eq!(gamma, 3.0);
assert_eq!(rho, 0.6);
assert_eq!(sigma, 0.6);
}
_ => panic!("Expected Nelder-Mead local solver"),
}
}
#[test]
fn change_params_steepestdescent() {
let linesearch: LineSearchParams =
MoreThuenteBuilder::default().c1(1e-5).c2(0.8).build();
let steepestdescent: LocalSolverConfig = SteepestDescentBuilder::default()
.max_iter(500)
.line_search_params(linesearch)
.build();
match steepestdescent {
LocalSolverConfig::SteepestDescent { max_iter, line_search_params } => {
assert_eq!(max_iter, 500);
match line_search_params.method {
LineSearchMethod::MoreThuente { c1, c2, width_tolerance, bounds } => {
assert_eq!(c1, 1e-5);
assert_eq!(c2, 0.8);
assert_eq!(width_tolerance, 1e-10);
assert_eq!(bounds, array![f64::EPSILON.sqrt(), f64::INFINITY]);
}
_ => panic!("Expected MoreThuente line search method"),
}
}
_ => panic!("Expected Steepest Descent local solver"),
}
}
#[test]
fn change_params_trustregion() {
let trustregion: LocalSolverConfig = TrustRegionBuilder::default()
.method(TrustRegionRadiusMethod::Steihaug)
.max_iter(500)
.radius(2.0)
.max_radius(200.0)
.eta(0.1)
.build();
match trustregion {
LocalSolverConfig::TrustRegion {
trust_region_radius_method,
max_iter,
radius,
max_radius,
eta,
} => {
assert_eq!(trust_region_radius_method, TrustRegionRadiusMethod::Steihaug);
assert_eq!(max_iter, 500);
assert_eq!(radius, 2.0);
assert_eq!(max_radius, 200.0);
assert_eq!(eta, 0.1);
}
_ => panic!("Expected Trust Region local solver"),
}
}
#[test]
fn change_params_newton_cg() {
let linesearch: LineSearchParams =
MoreThuenteBuilder::default().c1(1e-5).c2(0.8).build();
let newtoncg: LocalSolverConfig = NewtonCGBuilder::default()
.max_iter(500)
.curvature_threshold(0.1)
.tolerance(1e-7)
.line_search_params(linesearch)
.build();
match newtoncg {
LocalSolverConfig::NewtonCG {
max_iter,
curvature_threshold,
tolerance,
line_search_params,
} => {
assert_eq!(max_iter, 500);
assert_eq!(curvature_threshold, 0.1);
assert_eq!(tolerance, 1e-7);
match line_search_params.method {
LineSearchMethod::MoreThuente { c1, c2, width_tolerance, bounds } => {
assert_eq!(c1, 1e-5);
assert_eq!(c2, 0.8);
assert_eq!(width_tolerance, 1e-10);
assert_eq!(bounds, array![f64::EPSILON.sqrt(), f64::INFINITY]);
}
_ => panic!("Expected MoreThuente line search method"),
}
}
_ => panic!("Expected Newton-CG local solver"),
}
}
#[test]
fn change_params_morethuente() {
let morethuente: LineSearchParams = MoreThuenteBuilder::default()
.c1(1e-5)
.c2(0.8)
.width_tolerance(1e-8)
.bounds(array![1e-5, 1e5])
.build();
match morethuente.method {
LineSearchMethod::MoreThuente { c1, c2, width_tolerance, bounds } => {
assert_eq!(c1, 1e-5);
assert_eq!(c2, 0.8);
assert_eq!(width_tolerance, 1e-8);
assert_eq!(bounds, array![1e-5, 1e5]);
}
_ => panic!("Expected MoreThuente line search method"),
}
}
#[test]
fn change_params_hagerzhang() {
let hagerzhang = HagerZhangBuilder::default()
.delta(0.2)
.sigma(0.8)
.epsilon(1e-7)
.theta(0.6)
.gamma(0.7)
.eta(0.05)
.bounds(array![1e-6, 1e6])
.build();
match hagerzhang.method {
LineSearchMethod::HagerZhang {
delta,
sigma,
epsilon,
theta,
gamma,
eta,
bounds,
} => {
assert_eq!(delta, 0.2);
assert_eq!(sigma, 0.8);
assert_eq!(epsilon, 1e-7);
assert_eq!(theta, 0.6);
assert_eq!(gamma, 0.7);
assert_eq!(eta, 0.05);
assert_eq!(bounds, array![1e-6, 1e6]);
}
_ => panic!("Expected HagerZhang line search method"),
}
}
#[test]
fn test_lbfgs_new() {
let ls = LineSearchParams::morethuente().c1(1e-5).c2(0.8).build();
let lbfgs = LBFGSBuilder::new(500, 1e-8, 1e-8, 5, None, ls).build();
match lbfgs {
LocalSolverConfig::LBFGS {
max_iter,
tolerance_grad,
tolerance_cost,
history_size,
l1_coefficient,
line_search_params,
} => {
assert_eq!(max_iter, 500);
assert_eq!(tolerance_grad, 1e-8);
assert_eq!(tolerance_cost, 1e-8);
assert_eq!(history_size, 5);
assert_eq!(l1_coefficient, None);
match line_search_params.method {
LineSearchMethod::MoreThuente { c1, c2, width_tolerance, bounds } => {
assert_eq!(c1, 1e-5);
assert_eq!(c2, 0.8);
assert_eq!(width_tolerance, 1e-10);
assert_eq!(bounds, array![f64::EPSILON.sqrt(), f64::INFINITY]);
}
_ => panic!("Expected MoreThuente line search method"),
}
}
_ => panic!("Expected L-BFGS local solver"),
}
}
#[test]
fn test_neldermead_new() {
let nm = NelderMeadBuilder::new(0.5, 1e-5, 1000, 1.5, 3.0, 0.6, 0.6).build();
match nm {
LocalSolverConfig::NelderMead {
simplex_delta,
sd_tolerance,
max_iter,
alpha,
gamma,
rho,
sigma,
} => {
assert_eq!(simplex_delta, 0.5);
assert_eq!(sd_tolerance, 1e-5);
assert_eq!(max_iter, 1000);
assert_eq!(alpha, 1.5);
assert_eq!(gamma, 3.0);
assert_eq!(rho, 0.6);
assert_eq!(sigma, 0.6);
}
_ => panic!("Expected Nelder-Mead local solver"),
}
}
#[test]
fn test_steepestdescent_new() {
let ls = LineSearchParams::morethuente().c1(1e-5).c2(0.8).build();
let sd = SteepestDescentBuilder::new(500, ls).build();
match sd {
LocalSolverConfig::SteepestDescent { max_iter, line_search_params } => {
assert_eq!(max_iter, 500);
match line_search_params.method {
LineSearchMethod::MoreThuente { c1, c2, width_tolerance, bounds } => {
assert_eq!(c1, 1e-5);
assert_eq!(c2, 0.8);
assert_eq!(width_tolerance, 1e-10);
assert_eq!(bounds, array![f64::EPSILON.sqrt(), f64::INFINITY]);
}
_ => panic!("Expected MoreThuente line search method"),
}
}
_ => panic!("Expected Steepest Descent local solver"),
}
}
#[test]
fn test_trustregion_new() {
let tr =
TrustRegionBuilder::new(TrustRegionRadiusMethod::Steihaug, 500, 2.0, 200.0, 0.1)
.build();
match tr {
LocalSolverConfig::TrustRegion {
trust_region_radius_method,
max_iter,
radius,
max_radius,
eta,
} => {
assert_eq!(trust_region_radius_method, TrustRegionRadiusMethod::Steihaug);
assert_eq!(max_iter, 500);
assert_eq!(radius, 2.0);
assert_eq!(max_radius, 200.0);
assert_eq!(eta, 0.1);
}
_ => panic!("Expected Trust Region local solver"),
}
}
#[test]
fn test_newtoncg_new() {
let ls = LineSearchParams::morethuente().c1(1e-5).c2(0.8).build();
let ncg = NewtonCGBuilder::new(500, 0.1, 1e-7, ls).build();
match ncg {
LocalSolverConfig::NewtonCG {
max_iter,
curvature_threshold,
tolerance,
line_search_params,
} => {
assert_eq!(max_iter, 500);
assert_eq!(curvature_threshold, 0.1);
assert_eq!(tolerance, 1e-7);
match line_search_params.method {
LineSearchMethod::MoreThuente { c1, c2, width_tolerance, bounds } => {
assert_eq!(c1, 1e-5);
assert_eq!(c2, 0.8);
assert_eq!(width_tolerance, 1e-10);
assert_eq!(bounds, array![f64::EPSILON.sqrt(), f64::INFINITY]);
}
_ => panic!("Expected MoreThuente line search method"),
}
}
_ => panic!("Expected Newton-CG local solver"),
}
}
#[test]
fn test_morethuente_new() {
let mt = MoreThuenteBuilder::new(1e-5, 0.8, 1e-8, array![1e-5, 1e5]).build();
match mt.method {
LineSearchMethod::MoreThuente { c1, c2, width_tolerance, bounds } => {
assert_eq!(c1, 1e-5);
assert_eq!(c2, 0.8);
assert_eq!(width_tolerance, 1e-8);
assert_eq!(bounds, array![1e-5, 1e5]);
}
_ => panic!("Expected MoreThuente line search method"),
}
}
#[test]
fn test_hagerzhang_new() {
let hz =
HagerZhangBuilder::new(0.2, 0.8, 1e-7, 0.6, 0.7, 0.05, array![1e-6, 1e6]).build();
match hz.method {
LineSearchMethod::HagerZhang {
delta,
sigma,
epsilon,
theta,
gamma,
eta,
bounds,
} => {
assert_eq!(delta, 0.2);
assert_eq!(sigma, 0.8);
assert_eq!(epsilon, 1e-7);
assert_eq!(theta, 0.6);
assert_eq!(gamma, 0.7);
assert_eq!(eta, 0.05);
assert_eq!(bounds, array![1e-6, 1e6]);
}
_ => panic!("Expected HagerZhang line search method"),
}
}
}
#[test]
fn test_default_cobyla() {
let cobyla: LocalSolverConfig = COBYLABuilder::default().build();
match cobyla {
LocalSolverConfig::COBYLA {
max_iter,
initial_step_size,
ftol_rel,
ftol_abs,
xtol_rel,
xtol_abs,
} => {
assert_eq!(max_iter, 300);
assert_eq!(initial_step_size, 0.5);
assert_eq!(ftol_rel, 1e-6); assert_eq!(ftol_abs, 1e-8); assert_eq!(xtol_rel, 0.0); assert_eq!(xtol_abs, Vec::<f64>::new()); }
#[cfg_attr(not(feature = "argmin"), allow(unreachable_patterns))]
_ => panic!("Expected COBYLA local solver"),
}
}
#[test]
fn change_params_cobyla() {
let cobyla: LocalSolverConfig =
COBYLABuilder::default().max_iter(500).initial_step_size(0.1).ftol_rel(1e-10).build();
match cobyla {
LocalSolverConfig::COBYLA {
max_iter,
initial_step_size,
ftol_rel,
ftol_abs,
xtol_rel,
xtol_abs,
} => {
assert_eq!(max_iter, 500);
assert_eq!(initial_step_size, 0.1);
assert_eq!(ftol_rel, 1e-10);
assert_eq!(ftol_abs, 1e-8);
assert_eq!(xtol_rel, 0.0);
assert_eq!(xtol_abs, Vec::<f64>::new());
}
#[cfg_attr(not(feature = "argmin"), allow(unreachable_patterns))]
_ => panic!("Expected COBYLA local solver"),
}
}
#[test]
fn test_cobyla_new() {
let cobyla = COBYLABuilder::new(500, 0.5).build();
match cobyla {
LocalSolverConfig::COBYLA {
max_iter,
initial_step_size,
ftol_rel,
ftol_abs,
xtol_rel,
xtol_abs,
} => {
assert_eq!(max_iter, 500);
assert_eq!(initial_step_size, 0.5);
assert_eq!(ftol_rel, 1e-6); assert_eq!(ftol_abs, 1e-8); assert_eq!(xtol_rel, 0.0); assert_eq!(xtol_abs, Vec::<f64>::new()); }
#[cfg_attr(not(feature = "argmin"), allow(unreachable_patterns))]
_ => panic!("Expected COBYLA local solver"),
}
}
#[test]
fn test_cobyla_vector_xtol_abs() {
let xtol_vec = vec![1e-6, 1e-8, 1e-10];
let cobyla: LocalSolverConfig = COBYLABuilder::default()
.max_iter(1000)
.initial_step_size(0.1)
.ftol_rel(1e-8)
.xtol_abs(xtol_vec.clone())
.build();
match cobyla {
LocalSolverConfig::COBYLA {
max_iter,
initial_step_size,
ftol_rel,
ftol_abs,
xtol_rel,
xtol_abs,
} => {
assert_eq!(max_iter, 1000);
assert_eq!(initial_step_size, 0.1);
assert_eq!(ftol_rel, 1e-8);
assert_eq!(ftol_abs, 1e-8); assert_eq!(xtol_rel, 0.0); assert_eq!(xtol_abs, xtol_vec); }
#[cfg_attr(not(feature = "argmin"), allow(unreachable_patterns))]
_ => panic!("Expected COBYLA local solver"),
}
}
}