use crate::core::{ArgminFloat, SerializeAlias};
use argmin_math::{ArgminDot, ArgminL2Norm, ArgminSub};
#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
pub trait NLCGBetaUpdate<G, P, F>: SerializeAlias {
fn update(&self, nabla_f_k: &G, nabla_f_k_p_1: &G, p_k: &P) -> F;
}
#[derive(Default, Copy, Clone, Debug)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct FletcherReeves {}
impl FletcherReeves {
pub fn new() -> Self {
FletcherReeves {}
}
}
impl<G, P, F> NLCGBetaUpdate<G, P, F> for FletcherReeves
where
G: ArgminDot<G, F>,
F: ArgminFloat,
{
fn update(&self, dfk: &G, dfk1: &G, _pk: &P) -> F {
dfk1.dot(dfk1) / dfk.dot(dfk)
}
}
#[derive(Default, Copy, Clone, Debug)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct PolakRibiere {}
impl PolakRibiere {
pub fn new() -> Self {
PolakRibiere {}
}
}
impl<G, P, F> NLCGBetaUpdate<G, P, F> for PolakRibiere
where
G: ArgminDot<G, F> + ArgminSub<G, G> + ArgminL2Norm<F>,
F: ArgminFloat,
{
fn update(&self, dfk: &G, dfk1: &G, _pk: &P) -> F {
let dfk_norm_sq = dfk.l2_norm().powi(2);
dfk1.dot(&dfk1.sub(dfk)) / dfk_norm_sq
}
}
#[derive(Default, Copy, Clone, Debug)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct PolakRibierePlus {}
impl PolakRibierePlus {
pub fn new() -> Self {
PolakRibierePlus {}
}
}
impl<G, P, F> NLCGBetaUpdate<G, P, F> for PolakRibierePlus
where
G: ArgminDot<G, F> + ArgminSub<G, G> + ArgminL2Norm<F>,
F: ArgminFloat,
{
fn update(&self, dfk: &G, dfk1: &G, _pk: &P) -> F {
let dfk_norm_sq = dfk.l2_norm().powi(2);
let beta = dfk1.dot(&dfk1.sub(dfk)) / dfk_norm_sq;
float!(0.0).max(beta)
}
}
#[derive(Default, Copy, Clone, Debug)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct HestenesStiefel {}
impl HestenesStiefel {
pub fn new() -> Self {
HestenesStiefel {}
}
}
impl<G, P, F> NLCGBetaUpdate<G, P, F> for HestenesStiefel
where
G: ArgminDot<G, F> + ArgminDot<P, F> + ArgminSub<G, G>,
F: ArgminFloat,
{
fn update(&self, dfk: &G, dfk1: &G, pk: &P) -> F {
let d = dfk1.sub(dfk);
dfk1.dot(&d) / d.dot(pk)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_trait_impl;
test_trait_impl!(fletcher_reeves, FletcherReeves);
test_trait_impl!(polak_ribiere, PolakRibiere);
test_trait_impl!(polak_ribiere_plus, PolakRibierePlus);
test_trait_impl!(hestenes_stiefel, HestenesStiefel);
}