use cartan_core::{Manifold, ParallelTransport, Real, Retraction};
use crate::result::OptResult;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum CgVariant {
FletcherReeves,
#[default]
PolakRibiere,
}
#[derive(Debug, Clone)]
pub struct RCGConfig {
pub max_iters: usize,
pub grad_tol: Real,
pub init_step: Real,
pub armijo_c: Real,
pub armijo_beta: Real,
pub max_ls_iters: usize,
pub variant: CgVariant,
pub restart_every: usize,
}
impl Default for RCGConfig {
fn default() -> Self {
Self {
max_iters: 1000,
grad_tol: 1e-6,
init_step: 1.0,
armijo_c: 1e-4,
armijo_beta: 0.5,
max_ls_iters: 50,
variant: CgVariant::PolakRibiere,
restart_every: 0,
}
}
}
pub fn minimize_rcg<M, F, G>(
manifold: &M,
cost: F,
rgrad: G,
x0: M::Point,
config: &RCGConfig,
) -> OptResult<M::Point>
where
M: Manifold + Retraction + ParallelTransport,
F: Fn(&M::Point) -> Real,
G: Fn(&M::Point) -> M::Tangent,
{
let mut x = x0;
let mut f_x = cost(&x);
let mut g = rgrad(&x);
let mut g_sq = manifold.inner(&x, &g, &g);
let mut g_norm = {
#[cfg(feature = "std")]
{
g_sq.sqrt()
}
#[cfg(not(feature = "std"))]
{
libm::sqrt(g_sq)
}
};
let mut p = -g.clone();
for iter in 0..config.max_iters {
if g_norm < config.grad_tol {
return OptResult {
point: x,
value: f_x,
grad_norm: g_norm,
iterations: iter,
converged: true,
};
}
if manifold.inner(&x, &g, &p) >= 0.0 {
p = -g.clone();
}
let slope = manifold.inner(&x, &g, &p);
let mut t = config.init_step;
let mut x_new = manifold.retract(&x, &(p.clone() * t));
let mut f_new = cost(&x_new);
for _ in 0..config.max_ls_iters {
if f_new <= f_x + config.armijo_c * t * slope {
break;
}
t *= config.armijo_beta;
x_new = manifold.retract(&x, &(p.clone() * t));
f_new = cost(&x_new);
}
let x_prev = x.clone();
let g_prev = g.clone();
let g_sq_prev = g_sq;
let p_prev = p.clone();
x = x_new;
f_x = f_new;
g = rgrad(&x);
g_sq = manifold.inner(&x, &g, &g);
g_norm = {
#[cfg(feature = "std")]
{
g_sq.sqrt()
}
#[cfg(not(feature = "std"))]
{
libm::sqrt(g_sq)
}
};
let force_restart = config.restart_every > 0 && (iter + 1) % config.restart_every == 0;
let beta = if force_restart || g_sq_prev < 1e-30 {
0.0
} else {
match config.variant {
CgVariant::FletcherReeves => g_sq / g_sq_prev,
CgVariant::PolakRibiere => {
let g_pt = manifold
.transport(&x_prev, &x, &g_prev)
.unwrap_or_else(|_| g.clone());
let diff = g.clone() - g_pt; let num = manifold.inner(&x, &g, &diff);
(num / g_sq_prev).max(0.0)
}
}
};
let p_pt = if beta.abs() < 1e-30 {
manifold.zero_tangent(&x)
} else {
manifold
.transport(&x_prev, &x, &p_prev)
.unwrap_or_else(|_| manifold.zero_tangent(&x))
};
p = -g.clone() + p_pt * beta;
}
OptResult {
point: x,
value: f_x,
grad_norm: g_norm,
iterations: config.max_iters,
converged: false,
}
}