1use burn::module::AutodiffModule;
7use burn::optim::adaptor::OptimizerAdaptor;
8use burn::optim::{LrDecayState, SimpleOptimizer};
9use burn::record::Record;
10use burn::tensor::backend::AutodiffBackend;
11use burn::LearningRate;
12use std::marker::PhantomData;
13
14use crate::manifolds::Manifold;
15use crate::prelude::*;
16
17#[derive(Debug)]
18pub struct ManifoldRGDConfig<M, B> {
19 _manifold: PhantomData<M>,
20 _backend: PhantomData<B>,
21}
22
23impl<M, B> Default for ManifoldRGDConfig<M, B>
24where
25 M: Manifold<B>,
26 B: Backend,
27{
28 fn default() -> Self {
29 Self {
30 _manifold: PhantomData,
31 _backend: PhantomData,
32 }
33 }
34}
35
36#[derive(Debug, Clone)]
37pub struct ManifoldRGD<M: Manifold<B>, B: Backend> {
38 _manifold: PhantomData<M>,
39 _backend: PhantomData<B>,
40}
41
42impl<M, B> Default for ManifoldRGD<M, B>
43where
44 M: Manifold<B>,
45 B: Backend,
46{
47 fn default() -> Self {
48 Self {
49 _manifold: PhantomData,
50 _backend: PhantomData,
51 }
52 }
53}
54
55#[derive(Record, Clone)]
56pub struct ManifoldRGDState<B: Backend, const D: usize> {
57 lr_decay: LrDecayState<B, D>,
58}
59
60impl<M, B> SimpleOptimizer<B> for ManifoldRGD<M, B>
61where
62 M: Manifold<B>,
63 B: Backend,
64{
65 type State<const D: usize> = ManifoldRGDState<B, D>;
66
67 fn step<const D: usize>(
68 &self,
69 lr: LearningRate,
70 tensor: Tensor<B, D>,
71 grad: Tensor<B, D>,
72 state: Option<Self::State<D>>,
73 ) -> (Tensor<B, D>, Option<Self::State<D>>) {
74 let direction = M::project(tensor.clone(), -grad);
75 let result = M::retract(tensor, direction * lr);
76 (result, state)
77 }
78
79 fn to_device<const D: usize>(
80 _state: Self::State<D>,
81 _device: &<B as Backend>::Device,
82 ) -> Self::State<D> {
83 _state
84 }
85}
86
87impl<M, B> ManifoldRGDConfig<M, B>
88where
89 M: Manifold<B>,
90 B: Backend,
91{
92 pub fn init<Back: AutodiffBackend, Mod: AutodiffModule<Back>>(
93 &self,
94 ) -> OptimizerAdaptor<ManifoldRGD<M, Back::InnerBackend>, Mod, Back>
95 where
96 M: Manifold<Back::InnerBackend>,
97 {
98 let optim = ManifoldRGD::<M, Back::InnerBackend>::default();
99
100 OptimizerAdaptor::from(optim)
101 }
102}
103
104#[derive(Debug, Clone)]
123pub struct RiemannianAdamConfig<M, B> {
124 pub lr: f64,
125 pub beta1: f64,
126 pub beta2: f64,
127 pub eps: f64,
128 pub weight_decay: f64,
129 pub amsgrad: bool,
130 pub stabilize: Option<usize>,
131 _manifold: PhantomData<M>,
132 _backend: PhantomData<B>,
133}
134
135impl<M, B> Default for RiemannianAdamConfig<M, B>
136where
137 M: Manifold<B>,
138 B: Backend,
139{
140 fn default() -> Self {
141 Self {
142 lr: 1e-3,
143 beta1: 0.9,
144 beta2: 0.999,
145 eps: 1e-8,
146 weight_decay: 0.0,
147 amsgrad: false,
148 stabilize: None,
149 _manifold: PhantomData,
150 _backend: PhantomData,
151 }
152 }
153}
154
155impl<M, B> RiemannianAdamConfig<M, B>
156where
157 M: Manifold<B>,
158 B: Backend,
159{
160 pub fn new() -> Self {
161 Self::default()
162 }
163
164 pub fn with_lr(mut self, lr: f64) -> Self {
165 self.lr = lr;
166 self
167 }
168
169 pub fn with_beta1(mut self, beta1: f64) -> Self {
170 self.beta1 = beta1;
171 self
172 }
173
174 pub fn with_beta2(mut self, beta2: f64) -> Self {
175 self.beta2 = beta2;
176 self
177 }
178
179 pub fn with_eps(mut self, eps: f64) -> Self {
180 self.eps = eps;
181 self
182 }
183
184 pub fn with_weight_decay(mut self, weight_decay: f64) -> Self {
185 self.weight_decay = weight_decay;
186 self
187 }
188
189 pub fn with_amsgrad(mut self, amsgrad: bool) -> Self {
190 self.amsgrad = amsgrad;
191 self
192 }
193
194 pub fn with_stabilize(mut self, stabilize: Option<usize>) -> Self {
195 self.stabilize = stabilize;
196 self
197 }
198}
199
200#[derive(Debug, Clone)]
202pub struct RiemannianAdam<M: Manifold<B>, B: Backend> {
203 config: RiemannianAdamConfig<M, B>,
204}
205
206impl<M, B> RiemannianAdam<M, B>
207where
208 M: Manifold<B>,
209 B: Backend,
210{
211 pub fn new(config: RiemannianAdamConfig<M, B>) -> Self {
212 Self { config }
213 }
214}
215
216#[derive(Record, Clone)]
218pub struct RiemannianAdamState<B: Backend, const D: usize> {
219 pub step: usize,
220 pub exp_avg: Tensor<B, D>,
221 pub exp_avg_sq: Tensor<B, D>,
222 pub max_exp_avg_sq: Option<Tensor<B, D>>,
223 lr_decay: LrDecayState<B, D>,
224}
225
226impl<M, B> SimpleOptimizer<B> for RiemannianAdam<M, B>
227where
228 M: Manifold<B>,
229 B: Backend,
230{
231 type State<const D: usize> = RiemannianAdamState<B, D>;
232
233 fn step<const D: usize>(
234 &self,
235 _lr: LearningRate,
236 tensor: Tensor<B, D>,
237 grad: Tensor<B, D>,
238 state: Option<Self::State<D>>,
239 ) -> (Tensor<B, D>, Option<Self::State<D>>) {
240 let learning_rate = self.config.lr;
241
242 let grad = if self.config.weight_decay > 0.0 {
244 grad + tensor.clone() * self.config.weight_decay
245 } else {
246 grad
247 };
248
249 let rgrad = M::egrad2rgrad(tensor.clone(), grad);
251
252 let mut state = match state {
253 Some(mut state) => {
254 state.step += 1;
255 state
256 }
257 None => RiemannianAdamState {
258 step: 1,
259 exp_avg: Tensor::zeros_like(&tensor),
260 exp_avg_sq: Tensor::zeros_like(&tensor),
261 max_exp_avg_sq: if self.config.amsgrad {
262 Some(Tensor::zeros_like(&tensor))
263 } else {
264 None
265 },
266 lr_decay: LrDecayState::new(0, tensor.clone()),
267 },
268 };
269
270 state.exp_avg =
272 state.exp_avg.clone() * self.config.beta1 + rgrad.clone() * (1.0 - self.config.beta1);
273
274 let inner_product = M::inner(tensor.clone(), rgrad.clone(), rgrad.clone());
275 state.exp_avg_sq = state.exp_avg_sq.clone() * self.config.beta2 + inner_product * (1.0 - self.config.beta2);
276
277 let denom = if self.config.amsgrad {
279 let max_exp_avg_sq = state.max_exp_avg_sq.as_ref().unwrap();
280 let new_max = Tensor::max_pair(max_exp_avg_sq.clone(), state.exp_avg_sq.clone());
281 state.max_exp_avg_sq = Some(new_max.clone());
282 new_max.sqrt() + self.config.eps
283 } else {
284 state.exp_avg_sq.clone().sqrt() + self.config.eps
285 };
286
287 let bias_correction1 = 1.0 - self.config.beta1.powi(state.step as i32);
289 let bias_correction2 = 1.0 - self.config.beta2.powi(state.step as i32);
290 let step_size = learning_rate * bias_correction2.sqrt() / bias_correction1;
291
292 let direction = state.exp_avg.clone() / denom;
294
295 let new_point = M::expmap(tensor.clone(), direction.clone() * (-step_size));
297 let new_point = M::proj(new_point);
298
299 let exp_avg_new = M::parallel_transport(tensor, new_point.clone(), state.exp_avg);
301 state.exp_avg = exp_avg_new;
302
303 (new_point, Some(state))
304 }
305
306 fn to_device<const D: usize>(
307 mut state: Self::State<D>,
308 device: &<B as Backend>::Device,
309 ) -> Self::State<D> {
310 state.exp_avg = state.exp_avg.to_device(device);
311 state.exp_avg_sq = state.exp_avg_sq.to_device(device);
312 if let Some(ref max_exp_avg_sq) = state.max_exp_avg_sq {
313 state.max_exp_avg_sq = Some(max_exp_avg_sq.clone().to_device(device));
314 }
315 state.lr_decay = LrDecayState::to_device(state.lr_decay, device);
316 state
317 }
318}
319
320impl<M, B> RiemannianAdamConfig<M, B>
321where
322 M: Manifold<B>,
323 B: Backend,
324{
325 pub fn init<Back: AutodiffBackend, Mod: AutodiffModule<Back>>(
326 &self,
327 ) -> OptimizerAdaptor<RiemannianAdam<M, Back::InnerBackend>, Mod, Back>
328 where
329 M: Manifold<Back::InnerBackend>,
330 {
331 let optim = RiemannianAdam::<M, Back::InnerBackend>::new(RiemannianAdamConfig {
332 lr: self.lr,
333 beta1: self.beta1,
334 beta2: self.beta2,
335 eps: self.eps,
336 weight_decay: self.weight_decay,
337 amsgrad: self.amsgrad,
338 stabilize: self.stabilize,
339 _manifold: PhantomData,
340 _backend: PhantomData,
341 });
342
343 OptimizerAdaptor::from(optim)
344 }
345}
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350 use burn::backend::NdArray;
351 use burn::optim::SimpleOptimizer;
352
353 type TestBackend = NdArray;
354
355 #[test]
356 fn test_riemannian_adam_basic() {
357 let config = RiemannianAdamConfig::<Euclidean, TestBackend>::new()
358 .with_lr(0.1)
359 .with_beta1(0.9)
360 .with_beta2(0.999);
361
362 let optimizer = RiemannianAdam::new(config);
363
364 let tensor = Tensor::<TestBackend, 1>::zeros([3], &Default::default());
366 let grad = Tensor::<TestBackend, 1>::ones([3], &Default::default());
367
368 let (new_tensor, state) = optimizer.step(1.0, tensor.clone(), grad, None);
370
371 let scalar_value = new_tensor.slice([0; 1]).into_scalar();
373 assert!(
374 scalar_value < 0.0,
375 "Should move in negative gradient direction"
376 );
377 assert!(state.is_some(), "State should be initialized");
378 }
379
380 #[test]
381 fn test_riemannian_adam_convergence() {
382 let config = RiemannianAdamConfig::<Euclidean, TestBackend>::new().with_lr(0.1);
383
384 let optimizer = RiemannianAdam::new(config);
385
386 let target = Tensor::<TestBackend, 1>::from_floats([1.0, -0.5, 2.0], &Default::default());
388 let mut x = Tensor::<TestBackend, 1>::zeros([3], &Default::default());
389 let mut state = None;
390
391 let initial_loss = (x.clone() - target.clone()).powf_scalar(2.0).sum();
392
393 for _ in 0..50 {
395 let grad = (x.clone() - target.clone()) * 2.0;
396 let (new_x, new_state) = optimizer.step(1.0, x, grad, state);
397 x = new_x;
398 state = new_state;
399 }
400
401 let final_loss = (x.clone() - target.clone()).powf_scalar(2.0).sum();
402
403 assert!(
405 final_loss.into_scalar() < initial_loss.into_scalar(),
406 "Loss should decrease"
407 );
408 }
409
410 #[test]
411 fn test_riemannian_adam_amsgrad() {
412 let config = RiemannianAdamConfig::<Euclidean, TestBackend>::new()
413 .with_lr(0.1)
414 .with_amsgrad(true);
415
416 let optimizer = RiemannianAdam::new(config);
417
418 let tensor = Tensor::<TestBackend, 1>::zeros([2], &Default::default());
419 let grad = Tensor::<TestBackend, 1>::ones([2], &Default::default());
420
421 let (_, state) = optimizer.step(1.0, tensor, grad, None);
422
423 assert!(state.is_some());
425 let state = state.unwrap();
426 assert!(
427 state.max_exp_avg_sq.is_some(),
428 "AMSGrad should initialize max_exp_avg_sq"
429 );
430 }
431
432 #[test]
433 fn test_riemannian_adam_weight_decay() {
434 let config = RiemannianAdamConfig::<Euclidean, TestBackend>::new()
435 .with_lr(0.1)
436 .with_weight_decay(0.1);
437
438 let optimizer = RiemannianAdam::new(config);
439
440 let tensor = Tensor::<TestBackend, 1>::ones([2], &Default::default());
441 let grad = Tensor::<TestBackend, 1>::zeros([2], &Default::default());
442
443 let (new_tensor, _) = optimizer.step(1.0, tensor.clone(), grad, None);
444
445 let original_norm = tensor.powf_scalar(2.0).sum().sqrt();
447 let new_norm = new_tensor.powf_scalar(2.0).sum().sqrt();
448
449 assert!(
450 new_norm.into_scalar() < original_norm.into_scalar(),
451 "Weight decay should reduce tensor magnitude"
452 );
453 }
454
455 #[test]
456 fn test_riemannian_adam_state_persistence() {
457 let config = RiemannianAdamConfig::<Euclidean, TestBackend>::new().with_lr(0.1);
458
459 let optimizer = RiemannianAdam::new(config);
460
461 let tensor = Tensor::<TestBackend, 1>::zeros([2], &Default::default());
462 let grad = Tensor::<TestBackend, 1>::ones([2], &Default::default());
463
464 let (tensor1, state1) = optimizer.step(1.0, tensor, grad.clone(), None);
466 assert!(state1.is_some());
467 let state1 = state1.unwrap();
468 assert_eq!(state1.step, 1);
469
470 let (_, state2) = optimizer.step(1.0, tensor1, grad, Some(state1));
472 assert!(state2.is_some());
473 let state2 = state2.unwrap();
474 assert_eq!(state2.step, 2);
475 }
476}