use crate::prelude::*;
use burn::module::AutodiffModule;
use burn::optim::{adaptor::OptimizerAdaptor, LrDecayState, SimpleOptimizer};
use burn::record::Record;
use burn::tensor::backend::AutodiffBackend;
use burn::optim::LearningRate;
use std::marker::PhantomData;
pub mod hessian_optimizer;
pub mod many_steps;
pub mod multiple;
pub use many_steps::LessSimpleOptimizer;
#[derive(Debug)]
pub struct ManifoldRGDConfig<M: Manifold<B>, B: Backend> {
_manifold: PhantomData<M>,
_backend: PhantomData<B>,
}
impl<M, B> Default for ManifoldRGDConfig<M, B>
where
M: Manifold<B>,
B: Backend,
{
fn default() -> Self {
Self {
_manifold: PhantomData,
_backend: PhantomData,
}
}
}
#[derive(Debug, Clone)]
pub struct ManifoldRGD<M: Manifold<B>, B: Backend> {
_manifold: PhantomData<M>,
_backend: PhantomData<B>,
}
impl<M, B> Default for ManifoldRGD<M, B>
where
M: Manifold<B>,
B: Backend,
{
fn default() -> Self {
Self {
_manifold: PhantomData,
_backend: PhantomData,
}
}
}
#[derive(Record, Clone)]
pub struct ManifoldRGDState<B: Backend, const D: usize> {
lr_decay: LrDecayState<B, D>,
}
impl<M, B> SimpleOptimizer<B> for ManifoldRGD<M, B>
where
M: Manifold<B>,
B: Backend,
{
type State<const D: usize> = ManifoldRGDState<B, D>;
fn step<const D: usize>(
&self,
lr: LearningRate,
tensor: Tensor<B, D>,
grad: Tensor<B, D>,
state: Option<Self::State<D>>,
) -> (Tensor<B, D>, Option<Self::State<D>>) {
let direction = M::project(tensor.clone(), -grad);
let result = M::retract(tensor, direction * lr);
(result, state)
}
fn to_device<const D: usize>(
state: Self::State<D>,
device: &<B as Backend>::Device,
) -> Self::State<D> {
const DECAY_STATE_TO_DEVICE: bool = false;
if DECAY_STATE_TO_DEVICE {
ManifoldRGDState {
lr_decay: state.lr_decay.to_device(device),
}
} else {
state
}
}
}
impl<M, B> ManifoldRGDConfig<M, B>
where
M: Manifold<B>,
B: Backend,
{
#[must_use]
pub fn init<Back: AutodiffBackend, Mod: AutodiffModule<Back>>(
&self,
) -> OptimizerAdaptor<ManifoldRGD<M, Back::InnerBackend>, Mod, Back>
where
M: Manifold<Back::InnerBackend>,
{
let optim = ManifoldRGD::<M, Back::InnerBackend>::default();
OptimizerAdaptor::from(optim)
}
}
#[derive(Debug, Clone)]
pub struct RiemannianAdamConfig<M: Manifold<B>, B: Backend> {
pub lr: f64,
pub beta1: f64,
pub beta2: f64,
pub eps: f64,
pub weight_decay: f64,
pub amsgrad: bool,
pub stabilize: Option<usize>,
_manifold: PhantomData<M>,
_backend: PhantomData<B>,
}
impl<M, B> Default for RiemannianAdamConfig<M, B>
where
M: Manifold<B>,
B: Backend,
{
fn default() -> Self {
Self {
lr: 1e-3,
beta1: 0.9,
beta2: 0.999,
eps: 1e-8,
weight_decay: 0.0,
amsgrad: false,
stabilize: None,
_manifold: PhantomData,
_backend: PhantomData,
}
}
}
impl<M, B> RiemannianAdamConfig<M, B>
where
M: Manifold<B>,
B: Backend,
{
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_lr(mut self, lr: f64) -> Self {
self.lr = lr;
self
}
#[must_use]
pub fn with_beta1(mut self, beta1: f64) -> Self {
self.beta1 = beta1;
self
}
#[must_use]
pub fn with_beta2(mut self, beta2: f64) -> Self {
self.beta2 = beta2;
self
}
#[must_use]
pub fn with_eps(mut self, eps: f64) -> Self {
self.eps = eps;
self
}
#[must_use]
pub fn with_weight_decay(mut self, weight_decay: f64) -> Self {
self.weight_decay = weight_decay;
self
}
#[must_use]
pub fn with_amsgrad(mut self, amsgrad: bool) -> Self {
self.amsgrad = amsgrad;
self
}
#[must_use]
pub fn with_stabilize(mut self, stabilize: Option<usize>) -> Self {
self.stabilize = stabilize;
self
}
}
#[derive(Debug, Clone)]
pub struct RiemannianAdam<M: Manifold<B>, B: Backend> {
config: RiemannianAdamConfig<M, B>,
}
impl<M, B> RiemannianAdam<M, B>
where
M: Manifold<B>,
B: Backend,
{
#[must_use]
pub fn new(config: RiemannianAdamConfig<M, B>) -> Self {
Self { config }
}
}
#[derive(Record, Clone)]
pub struct RiemannianAdamState<B: Backend, const D: usize> {
pub step: usize,
pub exp_avg: Tensor<B, D>,
pub exp_avg_sq: Tensor<B, D>,
pub max_exp_avg_sq: Option<Tensor<B, D>>,
lr_decay: LrDecayState<B, D>,
}
impl<M, B> SimpleOptimizer<B> for RiemannianAdam<M, B>
where
M: Manifold<B>,
B: Backend,
{
type State<const D: usize> = RiemannianAdamState<B, D>;
fn step<const D: usize>(
&self,
_lr: LearningRate,
tensor: Tensor<B, D>,
grad: Tensor<B, D>,
state: Option<Self::State<D>>,
) -> (Tensor<B, D>, Option<Self::State<D>>) {
let learning_rate = self.config.lr;
let grad = if self.config.weight_decay > 0.0 {
grad + tensor.clone() * self.config.weight_decay
} else {
grad
};
let rgrad = M::egrad2rgrad(tensor.clone(), grad);
let mut state = match state {
Some(mut state) => {
state.step += 1;
state
}
None => RiemannianAdamState {
step: 1,
exp_avg: Tensor::zeros_like(&tensor),
exp_avg_sq: Tensor::zeros_like(&tensor),
max_exp_avg_sq: if self.config.amsgrad {
Some(Tensor::zeros_like(&tensor))
} else {
None
},
lr_decay: LrDecayState::new(0, tensor.clone()),
},
};
state.exp_avg =
state.exp_avg.clone() * self.config.beta1 + rgrad.clone() * (1.0 - self.config.beta1);
let inner_product = M::inner::<D>(tensor.clone(), rgrad.clone(), rgrad.clone());
state.exp_avg_sq = state.exp_avg_sq.clone() * self.config.beta2
+ inner_product * (1.0 - self.config.beta2);
let denom = if self.config.amsgrad {
let max_exp_avg_sq = state.max_exp_avg_sq.as_ref().expect(
"On an initial None state, having config.amsgrad be True makes this maximum field set to 0. \
If there was an input state then it will be present because of earlier steps");
let new_max = Tensor::max_pair(max_exp_avg_sq.clone(), state.exp_avg_sq.clone());
state.max_exp_avg_sq = Some(new_max.clone());
new_max.sqrt() + self.config.eps
} else {
state.exp_avg_sq.clone().sqrt() + self.config.eps
};
#[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
let bias_correction1 = 1.0 - self.config.beta1.powi(state.step as i32);
#[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
let bias_correction2 = 1.0 - self.config.beta2.powi(state.step as i32);
let step_size = learning_rate * bias_correction2.sqrt() / bias_correction1;
let direction = state.exp_avg.clone() / denom;
let new_point = M::expmap(tensor.clone(), direction.clone() * (-step_size));
let new_point = M::proj(new_point);
let exp_avg_new = M::parallel_transport(tensor, new_point.clone(), state.exp_avg);
state.exp_avg = exp_avg_new;
(new_point, Some(state))
}
fn to_device<const D: usize>(
mut state: Self::State<D>,
device: &<B as Backend>::Device,
) -> Self::State<D> {
state.exp_avg = state.exp_avg.to_device(device);
state.exp_avg_sq = state.exp_avg_sq.to_device(device);
if let Some(ref max_exp_avg_sq) = state.max_exp_avg_sq {
state.max_exp_avg_sq = Some(max_exp_avg_sq.clone().to_device(device));
}
state.lr_decay = LrDecayState::to_device(state.lr_decay, device);
state
}
}
impl<M, B> RiemannianAdamConfig<M, B>
where
M: Manifold<B>,
B: Backend,
{
#[must_use]
pub fn init<Back: AutodiffBackend, Mod: AutodiffModule<Back>>(
&self,
) -> OptimizerAdaptor<RiemannianAdam<M, Back::InnerBackend>, Mod, Back>
where
M: Manifold<Back::InnerBackend>,
{
let optim = RiemannianAdam::<M, Back::InnerBackend>::new(RiemannianAdamConfig {
lr: self.lr,
beta1: self.beta1,
beta2: self.beta2,
eps: self.eps,
weight_decay: self.weight_decay,
amsgrad: self.amsgrad,
stabilize: self.stabilize,
_manifold: PhantomData,
_backend: PhantomData,
});
OptimizerAdaptor::from(optim)
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::NdArray;
use burn::optim::SimpleOptimizer;
type TestBackend = NdArray;
#[test]
fn test_riemannian_adam_basic() {
let config = RiemannianAdamConfig::<Euclidean, TestBackend>::new()
.with_lr(0.1)
.with_beta1(0.9)
.with_beta2(0.999);
let optimizer = RiemannianAdam::new(config);
let tensor = Tensor::<TestBackend, 1>::zeros([3], &Default::default());
let grad = Tensor::<TestBackend, 1>::ones([3], &Default::default());
let (new_tensor, state) = optimizer.step(1.0, tensor.clone(), grad, None);
let scalar_value = new_tensor.slice([0; 1]).into_scalar();
assert!(
scalar_value < 0.0,
"Should move in negative gradient direction"
);
assert!(state.is_some(), "State should be initialized");
}
#[test]
fn test_riemannian_adam_convergence() {
let config = RiemannianAdamConfig::<Euclidean, TestBackend>::new().with_lr(0.1);
let optimizer = RiemannianAdam::new(config);
let target = Tensor::<TestBackend, 1>::from_floats([1.0, -0.5, 2.0], &Default::default());
let mut x = Tensor::<TestBackend, 1>::zeros([3], &Default::default());
let mut state = None;
let initial_loss = (x.clone() - target.clone()).powf_scalar(2.0).sum();
for _ in 0..50 {
let grad = (x.clone() - target.clone()) * 2.0;
let (new_x, new_state) = optimizer.step(1.0, x, grad, state);
x = new_x;
state = new_state;
}
let final_loss = (x.clone() - target.clone()).powf_scalar(2.0).sum();
assert!(
final_loss.into_scalar() < initial_loss.into_scalar(),
"Loss should decrease"
);
}
#[test]
fn test_riemannian_adam_amsgrad() {
let config = RiemannianAdamConfig::<Euclidean, TestBackend>::new()
.with_lr(0.1)
.with_amsgrad(true);
let optimizer = RiemannianAdam::new(config);
let tensor = Tensor::<TestBackend, 1>::zeros([2], &Default::default());
let grad = Tensor::<TestBackend, 1>::ones([2], &Default::default());
let (_, state) = optimizer.step(1.0, tensor, grad, None);
assert!(state.is_some());
let state =
state.expect("RiemannianAdam optimizer always gives back an initialized state on step");
assert!(
state.max_exp_avg_sq.is_some(),
"AMSGrad should initialize max_exp_avg_sq. See the explanation around the compute denominator part of step"
);
}
#[test]
fn test_riemannian_adam_weight_decay() {
let config = RiemannianAdamConfig::<Euclidean, TestBackend>::new()
.with_lr(0.1)
.with_weight_decay(0.1);
let optimizer = RiemannianAdam::new(config);
let tensor = Tensor::<TestBackend, 1>::ones([2], &Default::default());
let grad = Tensor::<TestBackend, 1>::zeros([2], &Default::default());
let (new_tensor, _) = optimizer.step(1.0, tensor.clone(), grad, None);
let original_norm = tensor.powf_scalar(2.0).sum().sqrt();
let new_norm = new_tensor.powf_scalar(2.0).sum().sqrt();
assert!(
new_norm.into_scalar() < original_norm.into_scalar(),
"Weight decay should reduce tensor magnitude"
);
}
#[test]
fn test_riemannian_adam_state_persistence() {
let config = RiemannianAdamConfig::<Euclidean, TestBackend>::new().with_lr(0.1);
let optimizer = RiemannianAdam::new(config);
let tensor = Tensor::<TestBackend, 1>::zeros([2], &Default::default());
let grad = Tensor::<TestBackend, 1>::ones([2], &Default::default());
let (tensor1, state1) = optimizer.step(1.0, tensor, grad.clone(), None);
assert!(state1.is_some());
let state1 = state1.expect(
"RiemannianAdam optimizer always gives back an initialized state on step even with None initial state");
assert_eq!(state1.step, 1);
let (_, state2) = optimizer.step(1.0, tensor1, grad, Some(state1));
assert!(state2.is_some());
let state2 = state2.expect(
"There was an input state so RiemannianAdam optimizer's step modifies that and returns it"
);
assert_eq!(state2.step, 2);
}
}