manopt_rs/manifolds/
steifiel.rs

1use crate::prelude::*;
2
3#[derive(Debug, Clone, Default)]
4pub struct SteifielsManifold<B: Backend> {
5    _backend: std::marker::PhantomData<B>,
6}
7
8impl<B: Backend> Manifold<B> for SteifielsManifold<B> {
9    fn new() -> Self {
10        SteifielsManifold {
11            _backend: std::marker::PhantomData,
12        }
13    }
14
15    fn name() -> &'static str {
16        "Steifels"
17    }
18
19    /// Project direction onto tangent space at point
20    /// For Stiefel manifold: P_X(Z) = Z - X(X^T Z + Z^T X)/2
21    fn project<const D: usize>(point: Tensor<B, D>, direction: Tensor<B, D>) -> Tensor<B, D> {
22        let xtd = point.clone().transpose().matmul(direction.clone());
23        let dtx = direction.clone().transpose().matmul(point.clone());
24        let symmetric_part = (xtd + dtx.transpose()) * 0.5;
25        direction - point.matmul(symmetric_part)
26    }
27
28    fn retract<const D: usize>(
29        point: Tensor<B, D>,
30        direction: Tensor<B, D>,
31    ) -> Tensor<B, D> {
32        let s = point + direction;
33        gram_schmidt(&s)
34    }
35
36    fn inner<const D: usize>(
37        _point: Tensor<B, D>,
38        u: Tensor<B, D>,
39        v: Tensor<B, D>,
40    ) -> Tensor<B, D> {
41        // For Stiefel manifold, we use the standard Euclidean inner product
42        u * v
43    }
44}
45
46fn gram_schmidt<B: Backend, const D: usize>(v: &Tensor<B, D>) -> Tensor<B, D> {
47    let n = v.dims()[0];
48    let k = v.dims()[1];
49
50    let mut u = Tensor::zeros_like(v);
51    let v1 = v.clone().slice([0..n, 0..1]);
52    let norm = v1.clone().transpose().matmul(v1.clone()).sqrt();
53    u = u.slice_assign([0..n, 0..1], v1.clone() / norm);
54
55    for i in 1..k {
56        u = u.slice_assign([0..n, i..i + 1], v.clone().slice([0..n, i..i + 1]));
57        for j in 0..i {
58            let uj = u.clone().slice([0..n, j..j + 1]);
59            let ui = u.clone().slice([0..n, i..i + 1]);
60            let ui = ui.clone() - (uj.clone().transpose().matmul(ui.clone())) * uj;
61            u = u.slice_assign([0..n, i..i + 1], ui);
62        }
63        // Normalize the vector
64        let ui = u.clone().slice([0..n, i..i + 1]);
65        let norm = ui.clone().transpose().matmul(ui.clone()).sqrt();
66        u = u.slice_assign([0..n, i..i + 1], ui / norm);
67    }
68    u
69}
70
71#[cfg(test)]
72mod test {
73    use super::*;
74    use burn::{
75        backend::{Autodiff, NdArray},
76        optim::SimpleOptimizer,
77    };
78
79    type TestBackend = Autodiff<NdArray>;
80    type TestTensor = Tensor<TestBackend, 2>;
81
82    const TOLERANCE: f32 = 1e-6;
83
84    fn assert_tensor_close(a: &TestTensor, b: &TestTensor, tol: f32) {
85        let diff = (a.clone() - b.clone()).abs();
86        let max_diff = diff.max().into_scalar();
87        assert!(
88            max_diff < tol,
89            "Tensors differ by {}, tolerance: {}",
90            max_diff,
91            tol
92        );
93    }
94
95    fn create_test_matrix(rows: usize, cols: usize, values: Vec<f32>) -> TestTensor {
96        let device = Default::default();
97        // Reshape the flat vector into a 2D array
98        let mut data = Vec::with_capacity(rows);
99        for chunk in values.chunks(cols) {
100            data.push(chunk.to_vec());
101        }
102
103        // Create tensor from nested arrays
104        match (rows, cols) {
105            (3, 2) => {
106                if data.len() >= 3 && data[0].len() >= 2 && data[1].len() >= 2 && data[2].len() >= 2
107                {
108                    Tensor::from_floats(
109                        [
110                            [data[0][0], data[0][1]],
111                            [data[1][0], data[1][1]],
112                            [data[2][0], data[2][1]],
113                        ],
114                        &device,
115                    )
116                } else {
117                    panic!("Invalid 3x2 matrix data");
118                }
119            }
120            (3, 1) => {
121                if data.len() >= 3
122                    && !data[0].is_empty()
123                    && !data[1].is_empty()
124                    && !data[2].is_empty()
125                {
126                    Tensor::from_floats([[data[0][0]], [data[1][0]], [data[2][0]]], &device)
127                } else {
128                    panic!("Invalid 3x1 matrix data");
129                }
130            }
131            (3, 3) => {
132                if data.len() >= 3 && data[0].len() >= 3 && data[1].len() >= 3 && data[2].len() >= 3
133                {
134                    Tensor::from_floats(
135                        [
136                            [data[0][0], data[0][1], data[0][2]],
137                            [data[1][0], data[1][1], data[1][2]],
138                            [data[2][0], data[2][1], data[2][2]],
139                        ],
140                        &device,
141                    )
142                } else {
143                    panic!("Invalid 3x3 matrix data");
144                }
145            }
146            (4, 2) => {
147                if data.len() >= 4
148                    && data[0].len() >= 2
149                    && data[1].len() >= 2
150                    && data[2].len() >= 2
151                    && data[3].len() >= 2
152                {
153                    Tensor::from_floats(
154                        [
155                            [data[0][0], data[0][1]],
156                            [data[1][0], data[1][1]],
157                            [data[2][0], data[2][1]],
158                            [data[3][0], data[3][1]],
159                        ],
160                        &device,
161                    )
162                } else {
163                    panic!("Invalid 4x2 matrix data");
164                }
165            }
166            (2, 2) => {
167                if data.len() >= 2 && data[0].len() >= 2 && data[1].len() >= 2 {
168                    Tensor::from_floats(
169                        [[data[0][0], data[0][1]], [data[1][0], data[1][1]]],
170                        &device,
171                    )
172                } else {
173                    panic!("Invalid 2x2 matrix data");
174                }
175            }
176            _ => panic!("Unsupported matrix dimensions: {}x{}", rows, cols),
177        }
178    }
179
180    #[test]
181    fn test_manifold_creation() {
182        let _manifold = SteifielsManifold::<TestBackend>::new();
183        assert_eq!(SteifielsManifold::<TestBackend>::name(), "Steifels");
184    }
185
186    #[test]
187    fn test_gram_schmidt_orthogonalization() {
188        // Test with a simple 3x2 matrix
189        let input = create_test_matrix(3, 2, vec![1.0, 1.0, 1.0, 0.0, 0.0, 1.0]);
190
191        let result = gram_schmidt(&input);
192
193        // Check that the result has orthonormal columns
194        let q1 = result.clone().slice([0..3, 0..1]);
195        let q2 = result.clone().slice([0..3, 1..2]);
196
197        // Check orthogonality: q1^T * q2 should be close to 0
198        let dot_product = q1.clone().transpose().matmul(q2.clone());
199        let orthogonality_error = dot_product.abs().into_scalar();
200        assert!(
201            orthogonality_error < TOLERANCE,
202            "Columns are not orthogonal: dot product = {}",
203            orthogonality_error
204        );
205
206        // Check normalization: ||q1|| = ||q2|| = 1
207        let norm1 = q1
208            .clone()
209            .transpose()
210            .matmul(q1.clone())
211            .sqrt()
212            .into_scalar();
213        let norm2 = q2
214            .clone()
215            .transpose()
216            .matmul(q2.clone())
217            .sqrt()
218            .into_scalar();
219
220        assert!(
221            (norm1 - 1.0).abs() < TOLERANCE,
222            "First column not normalized: norm = {}",
223            norm1
224        );
225        assert!(
226            (norm2 - 1.0).abs() < TOLERANCE,
227            "Second column not normalized: norm = {}",
228            norm2
229        );
230    }
231
232    #[test]
233    fn test_gram_schmidt_single_column() {
234        // Test with a single column vector
235        let input = create_test_matrix(3, 1, vec![3.0, 4.0, 0.0]);
236        let result = gram_schmidt(&input);
237
238        // Should be normalized to unit length
239        let norm = result
240            .clone()
241            .transpose()
242            .matmul(result.clone())
243            .sqrt()
244            .into_scalar();
245        assert!(
246            (norm - 1.0).abs() < TOLERANCE,
247            "Single column not normalized: norm = {}",
248            norm
249        );
250
251        // Should be proportional to original vector
252        let expected = create_test_matrix(3, 1, vec![0.6, 0.8, 0.0]);
253        assert_tensor_close(&result, &expected, TOLERANCE);
254    }
255
256    #[test]
257    fn test_projection_tangent_space() {
258        // Create a point on the Steifel manifold (orthonormal matrix)
259        let point = create_test_matrix(3, 2, vec![1.0, 0.0, 0.0, 1.0, 0.0, 0.0]);
260
261        // Create a direction vector
262        let direction = create_test_matrix(3, 2, vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6]);
263
264        let projected = SteifielsManifold::<TestBackend>::project(point.clone(), direction.clone());
265
266        // The projection should be orthogonal to the point
267        // i.e., point^T * projected should be skew-symmetric
268        let product = point.clone().transpose().matmul(projected.clone());
269        let symmetric_part = (product.clone() + product.clone().transpose()) * 0.5;
270
271        // The symmetric part should be close to zero
272        let max_symmetric = symmetric_part.abs().max().into_scalar();
273        assert!(
274            max_symmetric < TOLERANCE,
275            "Projected direction not in tangent space: max symmetric component = {}",
276            max_symmetric
277        );
278    }
279
280    #[test]
281    fn test_projection_preserves_tangent_vectors() {
282        // Use a true tangent vector at the identity block
283        let point = create_test_matrix(3, 2, vec![1.0, 0.0, 0.0, 1.0, 0.0, 0.0]);
284        // Tangent vector: only the (3,1) and (3,2) entries are nonzero
285        let tangent = create_test_matrix(3, 2, vec![0.0, 0.0, 0.0, 0.0, 1.0, -1.0]);
286        // Project the tangent vector again
287        let projected = SteifielsManifold::<TestBackend>::project(point.clone(), tangent.clone());
288        // Should be unchanged (idempotent)
289        assert_tensor_close(&projected, &tangent, 1e-6);
290        // Check the tangent space property: X^T V + V^T X = 0
291        let xtv = point.clone().transpose().matmul(tangent.clone());
292        let vtx = tangent.clone().transpose().matmul(point.clone());
293        let skew = xtv + vtx.transpose();
294        let max_skew = skew.abs().max().into_scalar();
295        assert!(
296            max_skew < 1e-6,
297            "Tangent space property violated: max skew = {}",
298            max_skew
299        );
300    }
301
302    #[test]
303    fn test_retraction_preserves_stiefel_property() {
304        // Start with a point on the Steifel manifold
305        let point = create_test_matrix(3, 2, vec![1.0, 0.0, 0.0, 1.0, 0.0, 0.0]);
306
307        // Create a tangent direction
308        let direction = create_test_matrix(3, 2, vec![0.0, 0.1, 0.0, -0.1, 0.2, 0.3]);
309
310        let step = 0.1;
311        let retracted =
312            SteifielsManifold::<TestBackend>::retract(point.clone(), direction.clone()*step);
313
314        // Check that the result has orthonormal columns
315        let q1 = retracted.clone().slice([0..3, 0..1]);
316        let q2 = retracted.clone().slice([0..3, 1..2]);
317
318        // Check orthogonality
319        let dot_product = q1.clone().transpose().matmul(q2.clone()).into_scalar();
320        assert!(
321            dot_product.abs() < TOLERANCE,
322            "Retracted point columns not orthogonal: dot product = {}",
323            dot_product
324        );
325
326        // Check normalization
327        let norm1 = q1
328            .clone()
329            .transpose()
330            .matmul(q1.clone())
331            .sqrt()
332            .into_scalar();
333        let norm2 = q2
334            .clone()
335            .transpose()
336            .matmul(q2.clone())
337            .sqrt()
338            .into_scalar();
339
340        assert!(
341            (norm1 - 1.0).abs() < TOLERANCE,
342            "First column not normalized after retraction: norm = {}",
343            norm1
344        );
345        assert!(
346            (norm2 - 1.0).abs() < TOLERANCE,
347            "Second column not normalized after retraction: norm = {}",
348            norm2
349        );
350    }
351
352    #[test]
353    fn test_gram_schmidt_identity_matrix() {
354        // Identity matrix should remain unchanged
355        let identity = create_test_matrix(3, 3, vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]);
356
357        let result = gram_schmidt(&identity);
358        assert_tensor_close(&result, &identity, TOLERANCE);
359    }
360
361    #[test]
362    fn test_manifold_properties() {
363        // Test that the manifold preserves the Stiefel property: X^T * X = I
364        let sqrt_half = (0.5_f32).sqrt();
365        let point = create_test_matrix(
366            4,
367            2,
368            vec![
369                sqrt_half, sqrt_half, sqrt_half, -sqrt_half, 0.0, 0.0, 0.0, 0.0,
370            ],
371        );
372
373        // Verify it's on the manifold
374        let gram_matrix = point.clone().transpose().matmul(point.clone());
375        let identity = create_test_matrix(2, 2, vec![1.0, 0.0, 0.0, 1.0]);
376
377        assert_tensor_close(&gram_matrix, &identity, TOLERANCE);
378
379        // Test projection and retraction preserve this property
380        let direction = create_test_matrix(4, 2, vec![0.1, 0.0, 0.0, 0.1, 0.2, 0.3, -0.1, 0.2]);
381
382        let projected = SteifielsManifold::<TestBackend>::project(point.clone(), direction.clone());
383        let retracted = SteifielsManifold::<TestBackend>::retract(point.clone(), projected * 0.1);
384
385        let retracted_gram = retracted.clone().transpose().matmul(retracted.clone());
386        assert_tensor_close(&retracted_gram, &identity, TOLERANCE);
387    }
388
389    #[test]
390    fn test_optimiser() {
391        let optimiser = ManifoldRGD::<SteifielsManifold<TestBackend>, TestBackend>::default();
392
393        let a = create_test_matrix(3, 3, vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0]);
394
395        let mut x = Tensor::<TestBackend, 2>::random(
396            [3, 3],
397            burn::tensor::Distribution::Normal(1., 1.),
398            &a.device(),
399        )
400        .require_grad();
401        for _i in 0..100 {
402            let loss = x
403                .clone()
404                .transpose()
405                .matmul(a.clone())
406                .matmul(x.clone())
407                .sum();
408            let grads = loss.backward();
409            let x_grad = x.grad(&grads).unwrap();
410            // Convert gradient to autodiff backend and ensure independent tensor
411            let x_grad_data = x_grad.to_data();
412            let x_grad_ad = Tensor::<TestBackend, 2>::from_data(x_grad_data, &x.device());
413            // Clone x to ensure independent tensor for optimizer
414            let x_clone = x.clone();
415            let (new_x, _) = optimiser.step(0.1, x_clone, x_grad_ad, None);
416            x = new_x.detach().require_grad();
417            println!("Loss: {}", loss);
418        }
419        println!("Optimised tensor: {}", x);
420    }
421
422    #[test]
423    fn test_simple_optimizer_step() {
424        let optimiser = ManifoldRGD::<SteifielsManifold<TestBackend>, TestBackend>::default();
425
426        // Create simple test tensors
427        let point = create_test_matrix(3, 2, vec![1.0, 0.0, 0.0, 1.0, 0.0, 0.0]);
428
429        let grad = create_test_matrix(3, 2, vec![0.1, 0.1, 0.1, 0.1, 0.1, 0.1]);
430
431        // Test one optimizer step
432        let (_result, _) = optimiser.step(0.1, point, grad, None);
433    }
434}