manopt_rs/
optimisers.rs

1//! Riemannian optimizers for manifold-constrained optimization.
2//!
3//! This module provides Riemannian optimization algorithms that work on manifolds,
4//! extending classical optimization methods to handle geometric constraints.
5
6use burn::module::AutodiffModule;
7use burn::optim::adaptor::OptimizerAdaptor;
8use burn::optim::{LrDecayState, SimpleOptimizer};
9use burn::record::Record;
10use burn::tensor::backend::AutodiffBackend;
11use burn::LearningRate;
12use std::marker::PhantomData;
13
14use crate::manifolds::Manifold;
15use crate::prelude::*;
16
17#[derive(Debug)]
18pub struct ManifoldRGDConfig<M, B> {
19    _manifold: PhantomData<M>,
20    _backend: PhantomData<B>,
21}
22
23impl<M, B> Default for ManifoldRGDConfig<M, B>
24where
25    M: Manifold<B>,
26    B: Backend,
27{
28    fn default() -> Self {
29        Self {
30            _manifold: PhantomData,
31            _backend: PhantomData,
32        }
33    }
34}
35
36#[derive(Debug, Clone)]
37pub struct ManifoldRGD<M: Manifold<B>, B: Backend> {
38    _manifold: PhantomData<M>,
39    _backend: PhantomData<B>,
40}
41
42impl<M, B> Default for ManifoldRGD<M, B>
43where
44    M: Manifold<B>,
45    B: Backend,
46{
47    fn default() -> Self {
48        Self {
49            _manifold: PhantomData,
50            _backend: PhantomData,
51        }
52    }
53}
54
55#[derive(Record, Clone)]
56pub struct ManifoldRGDState<B: Backend, const D: usize> {
57    lr_decay: LrDecayState<B, D>,
58}
59
60impl<M, B> SimpleOptimizer<B> for ManifoldRGD<M, B>
61where
62    M: Manifold<B>,
63    B: Backend,
64{
65    type State<const D: usize> = ManifoldRGDState<B, D>;
66
67    fn step<const D: usize>(
68        &self,
69        lr: LearningRate,
70        tensor: Tensor<B, D>,
71        grad: Tensor<B, D>,
72        state: Option<Self::State<D>>,
73    ) -> (Tensor<B, D>, Option<Self::State<D>>) {
74        let direction = M::project(tensor.clone(), -grad);
75        let result = M::retract(tensor, direction * lr);
76        (result, state)
77    }
78
79    fn to_device<const D: usize>(
80        _state: Self::State<D>,
81        _device: &<B as Backend>::Device,
82    ) -> Self::State<D> {
83        _state
84    }
85}
86
87impl<M, B> ManifoldRGDConfig<M, B>
88where
89    M: Manifold<B>,
90    B: Backend,
91{
92    pub fn init<Back: AutodiffBackend, Mod: AutodiffModule<Back>>(
93        &self,
94    ) -> OptimizerAdaptor<ManifoldRGD<M, Back::InnerBackend>, Mod, Back>
95    where
96        M: Manifold<Back::InnerBackend>,
97    {
98        let optim = ManifoldRGD::<M, Back::InnerBackend>::default();
99
100        OptimizerAdaptor::from(optim)
101    }
102}
103
104/// Configuration for the Riemannian Adam optimizer.
105///
106/// This optimizer extends the Adam algorithm to work on Riemannian manifolds,
107/// following the approach described in "Riemannian adaptive optimization methods"
108/// (Bécigneul & Ganea, 2018).
109///
110/// # Example
111///
112/// ```rust
113/// use manopt_rs::prelude::*;
114///
115/// let config = RiemannianAdamConfig::<Euclidean, burn::backend::NdArray>::new()
116///     .with_lr(0.001)
117///     .with_beta1(0.9)
118///     .with_beta2(0.999)
119///     .with_eps(1e-8)
120///     .with_amsgrad(true);
121/// ```
122#[derive(Debug, Clone)]
123pub struct RiemannianAdamConfig<M, B> {
124    pub lr: f64,
125    pub beta1: f64,
126    pub beta2: f64,
127    pub eps: f64,
128    pub weight_decay: f64,
129    pub amsgrad: bool,
130    pub stabilize: Option<usize>,
131    _manifold: PhantomData<M>,
132    _backend: PhantomData<B>,
133}
134
135impl<M, B> Default for RiemannianAdamConfig<M, B>
136where
137    M: Manifold<B>,
138    B: Backend,
139{
140    fn default() -> Self {
141        Self {
142            lr: 1e-3,
143            beta1: 0.9,
144            beta2: 0.999,
145            eps: 1e-8,
146            weight_decay: 0.0,
147            amsgrad: false,
148            stabilize: None,
149            _manifold: PhantomData,
150            _backend: PhantomData,
151        }
152    }
153}
154
155impl<M, B> RiemannianAdamConfig<M, B>
156where
157    M: Manifold<B>,
158    B: Backend,
159{
160    pub fn new() -> Self {
161        Self::default()
162    }
163
164    pub fn with_lr(mut self, lr: f64) -> Self {
165        self.lr = lr;
166        self
167    }
168
169    pub fn with_beta1(mut self, beta1: f64) -> Self {
170        self.beta1 = beta1;
171        self
172    }
173
174    pub fn with_beta2(mut self, beta2: f64) -> Self {
175        self.beta2 = beta2;
176        self
177    }
178
179    pub fn with_eps(mut self, eps: f64) -> Self {
180        self.eps = eps;
181        self
182    }
183
184    pub fn with_weight_decay(mut self, weight_decay: f64) -> Self {
185        self.weight_decay = weight_decay;
186        self
187    }
188
189    pub fn with_amsgrad(mut self, amsgrad: bool) -> Self {
190        self.amsgrad = amsgrad;
191        self
192    }
193
194    pub fn with_stabilize(mut self, stabilize: Option<usize>) -> Self {
195        self.stabilize = stabilize;
196        self
197    }
198}
199
200/// Riemannian Adam optimizer
201#[derive(Debug, Clone)]
202pub struct RiemannianAdam<M: Manifold<B>, B: Backend> {
203    config: RiemannianAdamConfig<M, B>,
204}
205
206impl<M, B> RiemannianAdam<M, B>
207where
208    M: Manifold<B>,
209    B: Backend,
210{
211    pub fn new(config: RiemannianAdamConfig<M, B>) -> Self {
212        Self { config }
213    }
214}
215
216/// State for Riemannian Adam optimizer
217#[derive(Record, Clone)]
218pub struct RiemannianAdamState<B: Backend, const D: usize> {
219    pub step: usize,
220    pub exp_avg: Tensor<B, D>,
221    pub exp_avg_sq: Tensor<B, D>,
222    pub max_exp_avg_sq: Option<Tensor<B, D>>,
223    lr_decay: LrDecayState<B, D>,
224}
225
226impl<M, B> SimpleOptimizer<B> for RiemannianAdam<M, B>
227where
228    M: Manifold<B>,
229    B: Backend,
230{
231    type State<const D: usize> = RiemannianAdamState<B, D>;
232
233    fn step<const D: usize>(
234        &self,
235        _lr: LearningRate,
236        tensor: Tensor<B, D>,
237        grad: Tensor<B, D>,
238        state: Option<Self::State<D>>,
239    ) -> (Tensor<B, D>, Option<Self::State<D>>) {
240        let learning_rate = self.config.lr;
241
242        // Apply weight decay if specified
243        let grad = if self.config.weight_decay > 0.0 {
244            grad + tensor.clone() * self.config.weight_decay
245        } else {
246            grad
247        };
248
249        // Convert Euclidean gradient to Riemannian gradient
250        let rgrad = M::egrad2rgrad(tensor.clone(), grad);
251
252        let mut state = match state {
253            Some(mut state) => {
254                state.step += 1;
255                state
256            }
257            None => RiemannianAdamState {
258                step: 1,
259                exp_avg: Tensor::zeros_like(&tensor),
260                exp_avg_sq: Tensor::zeros_like(&tensor),
261                max_exp_avg_sq: if self.config.amsgrad {
262                    Some(Tensor::zeros_like(&tensor))
263                } else {
264                    None
265                },
266                lr_decay: LrDecayState::new(0, tensor.clone()),
267            },
268        };
269
270        // Update exponential moving averages
271        state.exp_avg =
272            state.exp_avg.clone() * self.config.beta1 + rgrad.clone() * (1.0 - self.config.beta1);
273
274        let inner_product = M::inner(tensor.clone(), rgrad.clone(), rgrad.clone());
275        state.exp_avg_sq = state.exp_avg_sq.clone() * self.config.beta2 + inner_product * (1.0 - self.config.beta2);
276
277        // Compute denominator
278        let denom = if self.config.amsgrad {
279            let max_exp_avg_sq = state.max_exp_avg_sq.as_ref().unwrap();
280            let new_max = Tensor::max_pair(max_exp_avg_sq.clone(), state.exp_avg_sq.clone());
281            state.max_exp_avg_sq = Some(new_max.clone());
282            new_max.sqrt() + self.config.eps
283        } else {
284            state.exp_avg_sq.clone().sqrt() + self.config.eps
285        };
286
287        // Bias correction
288        let bias_correction1 = 1.0 - self.config.beta1.powi(state.step as i32);
289        let bias_correction2 = 1.0 - self.config.beta2.powi(state.step as i32);
290        let step_size = learning_rate * bias_correction2.sqrt() / bias_correction1;
291
292        // Compute direction
293        let direction = state.exp_avg.clone() / denom;
294
295        // Move on manifold using exponential map
296        let new_point = M::expmap(tensor.clone(), direction.clone() * (-step_size));
297        let new_point = M::proj(new_point);
298
299        // Parallel transport the exponential average to the new point
300        let exp_avg_new = M::parallel_transport(tensor, new_point.clone(), state.exp_avg);
301        state.exp_avg = exp_avg_new;
302
303        (new_point, Some(state))
304    }
305
306    fn to_device<const D: usize>(
307        mut state: Self::State<D>,
308        device: &<B as Backend>::Device,
309    ) -> Self::State<D> {
310        state.exp_avg = state.exp_avg.to_device(device);
311        state.exp_avg_sq = state.exp_avg_sq.to_device(device);
312        if let Some(ref max_exp_avg_sq) = state.max_exp_avg_sq {
313            state.max_exp_avg_sq = Some(max_exp_avg_sq.clone().to_device(device));
314        }
315        state.lr_decay = LrDecayState::to_device(state.lr_decay, device);
316        state
317    }
318}
319
320impl<M, B> RiemannianAdamConfig<M, B>
321where
322    M: Manifold<B>,
323    B: Backend,
324{
325    pub fn init<Back: AutodiffBackend, Mod: AutodiffModule<Back>>(
326        &self,
327    ) -> OptimizerAdaptor<RiemannianAdam<M, Back::InnerBackend>, Mod, Back>
328    where
329        M: Manifold<Back::InnerBackend>,
330    {
331        let optim = RiemannianAdam::<M, Back::InnerBackend>::new(RiemannianAdamConfig {
332            lr: self.lr,
333            beta1: self.beta1,
334            beta2: self.beta2,
335            eps: self.eps,
336            weight_decay: self.weight_decay,
337            amsgrad: self.amsgrad,
338            stabilize: self.stabilize,
339            _manifold: PhantomData,
340            _backend: PhantomData,
341        });
342
343        OptimizerAdaptor::from(optim)
344    }
345}
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350    use burn::backend::NdArray;
351    use burn::optim::SimpleOptimizer;
352
353    type TestBackend = NdArray;
354
355    #[test]
356    fn test_riemannian_adam_basic() {
357        let config = RiemannianAdamConfig::<Euclidean, TestBackend>::new()
358            .with_lr(0.1)
359            .with_beta1(0.9)
360            .with_beta2(0.999);
361
362        let optimizer = RiemannianAdam::new(config);
363
364        // Create test tensors
365        let tensor = Tensor::<TestBackend, 1>::zeros([3], &Default::default());
366        let grad = Tensor::<TestBackend, 1>::ones([3], &Default::default());
367
368        // Perform one step
369        let (new_tensor, state) = optimizer.step(1.0, tensor.clone(), grad, None);
370
371        // Check that the tensor moved in the negative gradient direction
372        let scalar_value = new_tensor.slice([0; 1]).into_scalar();
373        assert!(
374            scalar_value < 0.0,
375            "Should move in negative gradient direction"
376        );
377        assert!(state.is_some(), "State should be initialized");
378    }
379
380    #[test]
381    fn test_riemannian_adam_convergence() {
382        let config = RiemannianAdamConfig::<Euclidean, TestBackend>::new().with_lr(0.1);
383
384        let optimizer = RiemannianAdam::new(config);
385
386        // Target optimization: minimize ||x - target||^2
387        let target = Tensor::<TestBackend, 1>::from_floats([1.0, -0.5, 2.0], &Default::default());
388        let mut x = Tensor::<TestBackend, 1>::zeros([3], &Default::default());
389        let mut state = None;
390
391        let initial_loss = (x.clone() - target.clone()).powf_scalar(2.0).sum();
392
393        // Run optimization for several steps
394        for _ in 0..50 {
395            let grad = (x.clone() - target.clone()) * 2.0;
396            let (new_x, new_state) = optimizer.step(1.0, x, grad, state);
397            x = new_x;
398            state = new_state;
399        }
400
401        let final_loss = (x.clone() - target.clone()).powf_scalar(2.0).sum();
402
403        // Check convergence
404        assert!(
405            final_loss.into_scalar() < initial_loss.into_scalar(),
406            "Loss should decrease"
407        );
408    }
409
410    #[test]
411    fn test_riemannian_adam_amsgrad() {
412        let config = RiemannianAdamConfig::<Euclidean, TestBackend>::new()
413            .with_lr(0.1)
414            .with_amsgrad(true);
415
416        let optimizer = RiemannianAdam::new(config);
417
418        let tensor = Tensor::<TestBackend, 1>::zeros([2], &Default::default());
419        let grad = Tensor::<TestBackend, 1>::ones([2], &Default::default());
420
421        let (_, state) = optimizer.step(1.0, tensor, grad, None);
422
423        // Check that AMSGrad state is initialized
424        assert!(state.is_some());
425        let state = state.unwrap();
426        assert!(
427            state.max_exp_avg_sq.is_some(),
428            "AMSGrad should initialize max_exp_avg_sq"
429        );
430    }
431
432    #[test]
433    fn test_riemannian_adam_weight_decay() {
434        let config = RiemannianAdamConfig::<Euclidean, TestBackend>::new()
435            .with_lr(0.1)
436            .with_weight_decay(0.1);
437
438        let optimizer = RiemannianAdam::new(config);
439
440        let tensor = Tensor::<TestBackend, 1>::ones([2], &Default::default());
441        let grad = Tensor::<TestBackend, 1>::zeros([2], &Default::default());
442
443        let (new_tensor, _) = optimizer.step(1.0, tensor.clone(), grad, None);
444
445        // With weight decay and zero gradient, the tensor should shrink
446        let original_norm = tensor.powf_scalar(2.0).sum().sqrt();
447        let new_norm = new_tensor.powf_scalar(2.0).sum().sqrt();
448
449        assert!(
450            new_norm.into_scalar() < original_norm.into_scalar(),
451            "Weight decay should reduce tensor magnitude"
452        );
453    }
454
455    #[test]
456    fn test_riemannian_adam_state_persistence() {
457        let config = RiemannianAdamConfig::<Euclidean, TestBackend>::new().with_lr(0.1);
458
459        let optimizer = RiemannianAdam::new(config);
460
461        let tensor = Tensor::<TestBackend, 1>::zeros([2], &Default::default());
462        let grad = Tensor::<TestBackend, 1>::ones([2], &Default::default());
463
464        // First step
465        let (tensor1, state1) = optimizer.step(1.0, tensor, grad.clone(), None);
466        assert!(state1.is_some());
467        let state1 = state1.unwrap();
468        assert_eq!(state1.step, 1);
469
470        // Second step with state
471        let (_, state2) = optimizer.step(1.0, tensor1, grad, Some(state1));
472        assert!(state2.is_some());
473        let state2 = state2.unwrap();
474        assert_eq!(state2.step, 2);
475    }
476}