1use crate::prelude::*;
2
3#[derive(Debug, Clone, Default)]
4pub struct SteifielsManifold<B: Backend> {
5 _backend: std::marker::PhantomData<B>,
6}
7
8impl<B: Backend> Manifold<B> for SteifielsManifold<B> {
9 fn new() -> Self {
10 SteifielsManifold {
11 _backend: std::marker::PhantomData,
12 }
13 }
14
15 fn name() -> &'static str {
16 "Steifels"
17 }
18
19 fn project<const D: usize>(point: Tensor<B, D>, direction: Tensor<B, D>) -> Tensor<B, D> {
22 let xtd = point.clone().transpose().matmul(direction.clone());
23 let dtx = direction.clone().transpose().matmul(point.clone());
24 let symmetric_part = (xtd + dtx.transpose()) * 0.5;
25 direction - point.matmul(symmetric_part)
26 }
27
28 fn retract<const D: usize>(
29 point: Tensor<B, D>,
30 direction: Tensor<B, D>,
31 ) -> Tensor<B, D> {
32 let s = point + direction;
33 gram_schmidt(&s)
34 }
35
36 fn inner<const D: usize>(
37 _point: Tensor<B, D>,
38 u: Tensor<B, D>,
39 v: Tensor<B, D>,
40 ) -> Tensor<B, D> {
41 u * v
43 }
44}
45
46fn gram_schmidt<B: Backend, const D: usize>(v: &Tensor<B, D>) -> Tensor<B, D> {
47 let n = v.dims()[0];
48 let k = v.dims()[1];
49
50 let mut u = Tensor::zeros_like(v);
51 let v1 = v.clone().slice([0..n, 0..1]);
52 let norm = v1.clone().transpose().matmul(v1.clone()).sqrt();
53 u = u.slice_assign([0..n, 0..1], v1.clone() / norm);
54
55 for i in 1..k {
56 u = u.slice_assign([0..n, i..i + 1], v.clone().slice([0..n, i..i + 1]));
57 for j in 0..i {
58 let uj = u.clone().slice([0..n, j..j + 1]);
59 let ui = u.clone().slice([0..n, i..i + 1]);
60 let ui = ui.clone() - (uj.clone().transpose().matmul(ui.clone())) * uj;
61 u = u.slice_assign([0..n, i..i + 1], ui);
62 }
63 let ui = u.clone().slice([0..n, i..i + 1]);
65 let norm = ui.clone().transpose().matmul(ui.clone()).sqrt();
66 u = u.slice_assign([0..n, i..i + 1], ui / norm);
67 }
68 u
69}
70
71#[cfg(test)]
72mod test {
73 use super::*;
74 use burn::{
75 backend::{Autodiff, NdArray},
76 optim::SimpleOptimizer,
77 };
78
79 type TestBackend = Autodiff<NdArray>;
80 type TestTensor = Tensor<TestBackend, 2>;
81
82 const TOLERANCE: f32 = 1e-6;
83
84 fn assert_tensor_close(a: &TestTensor, b: &TestTensor, tol: f32) {
85 let diff = (a.clone() - b.clone()).abs();
86 let max_diff = diff.max().into_scalar();
87 assert!(
88 max_diff < tol,
89 "Tensors differ by {}, tolerance: {}",
90 max_diff,
91 tol
92 );
93 }
94
95 fn create_test_matrix(rows: usize, cols: usize, values: Vec<f32>) -> TestTensor {
96 let device = Default::default();
97 let mut data = Vec::with_capacity(rows);
99 for chunk in values.chunks(cols) {
100 data.push(chunk.to_vec());
101 }
102
103 match (rows, cols) {
105 (3, 2) => {
106 if data.len() >= 3 && data[0].len() >= 2 && data[1].len() >= 2 && data[2].len() >= 2
107 {
108 Tensor::from_floats(
109 [
110 [data[0][0], data[0][1]],
111 [data[1][0], data[1][1]],
112 [data[2][0], data[2][1]],
113 ],
114 &device,
115 )
116 } else {
117 panic!("Invalid 3x2 matrix data");
118 }
119 }
120 (3, 1) => {
121 if data.len() >= 3
122 && !data[0].is_empty()
123 && !data[1].is_empty()
124 && !data[2].is_empty()
125 {
126 Tensor::from_floats([[data[0][0]], [data[1][0]], [data[2][0]]], &device)
127 } else {
128 panic!("Invalid 3x1 matrix data");
129 }
130 }
131 (3, 3) => {
132 if data.len() >= 3 && data[0].len() >= 3 && data[1].len() >= 3 && data[2].len() >= 3
133 {
134 Tensor::from_floats(
135 [
136 [data[0][0], data[0][1], data[0][2]],
137 [data[1][0], data[1][1], data[1][2]],
138 [data[2][0], data[2][1], data[2][2]],
139 ],
140 &device,
141 )
142 } else {
143 panic!("Invalid 3x3 matrix data");
144 }
145 }
146 (4, 2) => {
147 if data.len() >= 4
148 && data[0].len() >= 2
149 && data[1].len() >= 2
150 && data[2].len() >= 2
151 && data[3].len() >= 2
152 {
153 Tensor::from_floats(
154 [
155 [data[0][0], data[0][1]],
156 [data[1][0], data[1][1]],
157 [data[2][0], data[2][1]],
158 [data[3][0], data[3][1]],
159 ],
160 &device,
161 )
162 } else {
163 panic!("Invalid 4x2 matrix data");
164 }
165 }
166 (2, 2) => {
167 if data.len() >= 2 && data[0].len() >= 2 && data[1].len() >= 2 {
168 Tensor::from_floats(
169 [[data[0][0], data[0][1]], [data[1][0], data[1][1]]],
170 &device,
171 )
172 } else {
173 panic!("Invalid 2x2 matrix data");
174 }
175 }
176 _ => panic!("Unsupported matrix dimensions: {}x{}", rows, cols),
177 }
178 }
179
180 #[test]
181 fn test_manifold_creation() {
182 let _manifold = SteifielsManifold::<TestBackend>::new();
183 assert_eq!(SteifielsManifold::<TestBackend>::name(), "Steifels");
184 }
185
186 #[test]
187 fn test_gram_schmidt_orthogonalization() {
188 let input = create_test_matrix(3, 2, vec![1.0, 1.0, 1.0, 0.0, 0.0, 1.0]);
190
191 let result = gram_schmidt(&input);
192
193 let q1 = result.clone().slice([0..3, 0..1]);
195 let q2 = result.clone().slice([0..3, 1..2]);
196
197 let dot_product = q1.clone().transpose().matmul(q2.clone());
199 let orthogonality_error = dot_product.abs().into_scalar();
200 assert!(
201 orthogonality_error < TOLERANCE,
202 "Columns are not orthogonal: dot product = {}",
203 orthogonality_error
204 );
205
206 let norm1 = q1
208 .clone()
209 .transpose()
210 .matmul(q1.clone())
211 .sqrt()
212 .into_scalar();
213 let norm2 = q2
214 .clone()
215 .transpose()
216 .matmul(q2.clone())
217 .sqrt()
218 .into_scalar();
219
220 assert!(
221 (norm1 - 1.0).abs() < TOLERANCE,
222 "First column not normalized: norm = {}",
223 norm1
224 );
225 assert!(
226 (norm2 - 1.0).abs() < TOLERANCE,
227 "Second column not normalized: norm = {}",
228 norm2
229 );
230 }
231
232 #[test]
233 fn test_gram_schmidt_single_column() {
234 let input = create_test_matrix(3, 1, vec![3.0, 4.0, 0.0]);
236 let result = gram_schmidt(&input);
237
238 let norm = result
240 .clone()
241 .transpose()
242 .matmul(result.clone())
243 .sqrt()
244 .into_scalar();
245 assert!(
246 (norm - 1.0).abs() < TOLERANCE,
247 "Single column not normalized: norm = {}",
248 norm
249 );
250
251 let expected = create_test_matrix(3, 1, vec![0.6, 0.8, 0.0]);
253 assert_tensor_close(&result, &expected, TOLERANCE);
254 }
255
256 #[test]
257 fn test_projection_tangent_space() {
258 let point = create_test_matrix(3, 2, vec![1.0, 0.0, 0.0, 1.0, 0.0, 0.0]);
260
261 let direction = create_test_matrix(3, 2, vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6]);
263
264 let projected = SteifielsManifold::<TestBackend>::project(point.clone(), direction.clone());
265
266 let product = point.clone().transpose().matmul(projected.clone());
269 let symmetric_part = (product.clone() + product.clone().transpose()) * 0.5;
270
271 let max_symmetric = symmetric_part.abs().max().into_scalar();
273 assert!(
274 max_symmetric < TOLERANCE,
275 "Projected direction not in tangent space: max symmetric component = {}",
276 max_symmetric
277 );
278 }
279
280 #[test]
281 fn test_projection_preserves_tangent_vectors() {
282 let point = create_test_matrix(3, 2, vec![1.0, 0.0, 0.0, 1.0, 0.0, 0.0]);
284 let tangent = create_test_matrix(3, 2, vec![0.0, 0.0, 0.0, 0.0, 1.0, -1.0]);
286 let projected = SteifielsManifold::<TestBackend>::project(point.clone(), tangent.clone());
288 assert_tensor_close(&projected, &tangent, 1e-6);
290 let xtv = point.clone().transpose().matmul(tangent.clone());
292 let vtx = tangent.clone().transpose().matmul(point.clone());
293 let skew = xtv + vtx.transpose();
294 let max_skew = skew.abs().max().into_scalar();
295 assert!(
296 max_skew < 1e-6,
297 "Tangent space property violated: max skew = {}",
298 max_skew
299 );
300 }
301
302 #[test]
303 fn test_retraction_preserves_stiefel_property() {
304 let point = create_test_matrix(3, 2, vec![1.0, 0.0, 0.0, 1.0, 0.0, 0.0]);
306
307 let direction = create_test_matrix(3, 2, vec![0.0, 0.1, 0.0, -0.1, 0.2, 0.3]);
309
310 let step = 0.1;
311 let retracted =
312 SteifielsManifold::<TestBackend>::retract(point.clone(), direction.clone()*step);
313
314 let q1 = retracted.clone().slice([0..3, 0..1]);
316 let q2 = retracted.clone().slice([0..3, 1..2]);
317
318 let dot_product = q1.clone().transpose().matmul(q2.clone()).into_scalar();
320 assert!(
321 dot_product.abs() < TOLERANCE,
322 "Retracted point columns not orthogonal: dot product = {}",
323 dot_product
324 );
325
326 let norm1 = q1
328 .clone()
329 .transpose()
330 .matmul(q1.clone())
331 .sqrt()
332 .into_scalar();
333 let norm2 = q2
334 .clone()
335 .transpose()
336 .matmul(q2.clone())
337 .sqrt()
338 .into_scalar();
339
340 assert!(
341 (norm1 - 1.0).abs() < TOLERANCE,
342 "First column not normalized after retraction: norm = {}",
343 norm1
344 );
345 assert!(
346 (norm2 - 1.0).abs() < TOLERANCE,
347 "Second column not normalized after retraction: norm = {}",
348 norm2
349 );
350 }
351
352 #[test]
353 fn test_gram_schmidt_identity_matrix() {
354 let identity = create_test_matrix(3, 3, vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]);
356
357 let result = gram_schmidt(&identity);
358 assert_tensor_close(&result, &identity, TOLERANCE);
359 }
360
361 #[test]
362 fn test_manifold_properties() {
363 let sqrt_half = (0.5_f32).sqrt();
365 let point = create_test_matrix(
366 4,
367 2,
368 vec![
369 sqrt_half, sqrt_half, sqrt_half, -sqrt_half, 0.0, 0.0, 0.0, 0.0,
370 ],
371 );
372
373 let gram_matrix = point.clone().transpose().matmul(point.clone());
375 let identity = create_test_matrix(2, 2, vec![1.0, 0.0, 0.0, 1.0]);
376
377 assert_tensor_close(&gram_matrix, &identity, TOLERANCE);
378
379 let direction = create_test_matrix(4, 2, vec![0.1, 0.0, 0.0, 0.1, 0.2, 0.3, -0.1, 0.2]);
381
382 let projected = SteifielsManifold::<TestBackend>::project(point.clone(), direction.clone());
383 let retracted = SteifielsManifold::<TestBackend>::retract(point.clone(), projected * 0.1);
384
385 let retracted_gram = retracted.clone().transpose().matmul(retracted.clone());
386 assert_tensor_close(&retracted_gram, &identity, TOLERANCE);
387 }
388
389 #[test]
390 fn test_optimiser() {
391 let optimiser = ManifoldRGD::<SteifielsManifold<TestBackend>, TestBackend>::default();
392
393 let a = create_test_matrix(3, 3, vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0]);
394
395 let mut x = Tensor::<TestBackend, 2>::random(
396 [3, 3],
397 burn::tensor::Distribution::Normal(1., 1.),
398 &a.device(),
399 )
400 .require_grad();
401 for _i in 0..100 {
402 let loss = x
403 .clone()
404 .transpose()
405 .matmul(a.clone())
406 .matmul(x.clone())
407 .sum();
408 let grads = loss.backward();
409 let x_grad = x.grad(&grads).unwrap();
410 let x_grad_data = x_grad.to_data();
412 let x_grad_ad = Tensor::<TestBackend, 2>::from_data(x_grad_data, &x.device());
413 let x_clone = x.clone();
415 let (new_x, _) = optimiser.step(0.1, x_clone, x_grad_ad, None);
416 x = new_x.detach().require_grad();
417 println!("Loss: {}", loss);
418 }
419 println!("Optimised tensor: {}", x);
420 }
421
422 #[test]
423 fn test_simple_optimizer_step() {
424 let optimiser = ManifoldRGD::<SteifielsManifold<TestBackend>, TestBackend>::default();
425
426 let point = create_test_matrix(3, 2, vec![1.0, 0.0, 0.0, 1.0, 0.0, 0.0]);
428
429 let grad = create_test_matrix(3, 2, vec![0.1, 0.1, 0.1, 0.1, 0.1, 0.1]);
430
431 let (_result, _) = optimiser.step(0.1, point, grad, None);
433 }
434}