use crate::{manifolds::utils::identity_in_last_two, prelude::*};
#[derive(Debug, Clone, Default)]
pub struct SteifielsManifold<B: Backend> {
_backend: std::marker::PhantomData<B>,
}
impl<B: Backend> Manifold<B> for SteifielsManifold<B> {
const RANK_PER_POINT: usize = 2;
fn new() -> Self {
SteifielsManifold {
_backend: std::marker::PhantomData,
}
}
fn name() -> &'static str {
"Steifels"
}
fn project<const D: usize>(point: Tensor<B, D>, direction: Tensor<B, D>) -> Tensor<B, D> {
let xtd = point.clone().transpose().matmul(direction.clone());
let dtx = direction.clone().transpose().matmul(point.clone());
let symmetric_part = (xtd + dtx.transpose()) * 0.5;
direction - point.matmul(symmetric_part)
}
fn retract<const D: usize>(point: Tensor<B, D>, direction: Tensor<B, D>) -> Tensor<B, D> {
debug_assert!(point.dims().len() >= Self::RANK_PER_POINT);
debug_assert!(direction.dims().len() >= Self::RANK_PER_POINT);
let mut s = point + direction;
if s.dims().len() > Self::RANK_PER_POINT {
s = s.swap_dims(0, D - 2);
s = s.swap_dims(1, D - 1);
s = gram_schmidt(&s);
s = s.swap_dims(1, D - 1);
s = s.swap_dims(0, D - 2);
s
} else {
gram_schmidt(&s)
}
}
fn inner<const D: usize>(
_point: Tensor<B, D>,
u: Tensor<B, D>,
v: Tensor<B, D>,
) -> Tensor<B, D> {
(u * v).sum_dim(D - 1).sum_dim(D - 2)
}
fn is_tangent_at<const D: usize>(
point: Tensor<B, D>,
vector: Tensor<B, D>,
) -> Tensor<B, D, burn::tensor::Bool> {
let xtv = point.clone().transpose().matmul(vector.clone());
let vtx = vector.clone().transpose().matmul(point.clone());
let skew = xtv + vtx.transpose();
let max_skew = skew.clone().abs().max_dim(D - 1).max_dim(D - 2);
max_skew.lower_elem(1e-6)
}
fn proj<const D: usize>(mut point: Tensor<B, D>) -> Tensor<B, D> {
debug_assert!(point.dims().len() >= Self::RANK_PER_POINT);
if point.dims().len() > Self::RANK_PER_POINT {
point = point.swap_dims(0, D - 2);
point = point.swap_dims(1, D - 1);
point = gram_schmidt(&point);
point = point.swap_dims(1, D - 1);
point = point.swap_dims(0, D - 2);
point
} else {
gram_schmidt(&point)
}
}
fn is_in_manifold<const D: usize>(point: Tensor<B, D>) -> Tensor<B, D, burn::tensor::Bool> {
let a_transpose_times_a = point.clone().transpose().matmul(point);
let all_dims = a_transpose_times_a.shape();
debug_assert!(all_dims.num_dims() >= 2);
let other = identity_in_last_two(&a_transpose_times_a);
a_transpose_times_a
.is_close(other, None, None)
.all_dim(D - 1)
.all_dim(D - 2)
}
fn acceptable_dims(a_is: &[usize]) -> bool {
let n = a_is[0];
let k = a_is[1];
n > 0 && k > 0 && k <= n
}
}
fn gram_schmidt<B: Backend, const D: usize>(v: &Tensor<B, D>) -> Tensor<B, D> {
let n = v.dims()[0];
let k = v.dims()[1];
let mut u = Tensor::zeros_like(v);
let v1 = v.clone().slice([0..n, 0..1]);
let norm = v1.clone().transpose().matmul(v1.clone()).sqrt();
u = u.slice_assign([0..n, 0..1], v1.clone() / norm);
for i in 1..k {
u = u.slice_assign([0..n, i..i + 1], v.clone().slice([0..n, i..i + 1]));
for j in 0..i {
let uj = u.clone().slice([0..n, j..j + 1]);
let ui = u.clone().slice([0..n, i..i + 1]);
let ui = ui.clone() - (uj.clone().transpose().matmul(ui.clone())) * uj;
u = u.slice_assign([0..n, i..i + 1], ui);
}
let ui = u.clone().slice([0..n, i..i + 1]);
let norm = ui.clone().transpose().matmul(ui.clone()).sqrt();
u = u.slice_assign([0..n, i..i + 1], ui / norm);
}
u
}
#[cfg(test)]
mod test {
use crate::manifolds::utils::test::{assert_matrix_close, create_test_matrix};
use crate::optimizers::LessSimpleOptimizer;
use super::*;
use burn::{
backend::{Autodiff, NdArray},
optim::SimpleOptimizer,
};
type TestBackend = Autodiff<NdArray>;
const TOLERANCE: f32 = 1e-6;
#[test]
fn test_manifold_creation() {
let _manifold = SteifielsManifold::<TestBackend>::new();
assert_eq!(SteifielsManifold::<TestBackend>::name(), "Steifels");
}
#[test]
fn test_gram_schmidt_orthogonalization() {
let input = create_test_matrix::<TestBackend>(3, 2, vec![1.0, 1.0, 1.0, 0.0, 0.0, 1.0]);
let result = gram_schmidt(&input);
let q1 = result.clone().slice([0..3, 0..1]);
let q2 = result.clone().slice([0..3, 1..2]);
let dot_product = q1.clone().transpose().matmul(q2.clone());
let orthogonality_error = dot_product.abs().into_scalar();
assert!(
orthogonality_error < TOLERANCE,
"Columns are not orthogonal: dot product = {}",
orthogonality_error
);
let norm1 = q1
.clone()
.transpose()
.matmul(q1.clone())
.sqrt()
.into_scalar();
let norm2 = q2
.clone()
.transpose()
.matmul(q2.clone())
.sqrt()
.into_scalar();
assert!(
(norm1 - 1.0).abs() < TOLERANCE,
"First column not normalized: norm = {}",
norm1
);
assert!(
(norm2 - 1.0).abs() < TOLERANCE,
"Second column not normalized: norm = {}",
norm2
);
}
#[test]
fn test_gram_schmidt_single_column() {
let input = create_test_matrix::<TestBackend>(3, 1, vec![3.0, 4.0, 0.0]);
let result = gram_schmidt(&input);
let norm = result
.clone()
.transpose()
.matmul(result.clone())
.sqrt()
.into_scalar();
assert!(
(norm - 1.0).abs() < TOLERANCE,
"Single column not normalized: norm = {}",
norm
);
let expected = create_test_matrix::<TestBackend>(3, 1, vec![0.6, 0.8, 0.0]);
assert_matrix_close(&result, &expected, TOLERANCE);
}
#[test]
fn test_projection_tangent_space() {
let point = create_test_matrix(3, 2, vec![1.0, 0.0, 0.0, 1.0, 0.0, 0.0]);
let direction = create_test_matrix(3, 2, vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6]);
let projected = SteifielsManifold::<TestBackend>::project(point.clone(), direction.clone());
let product = point.clone().transpose().matmul(projected.clone());
let symmetric_part = (product.clone() + product.clone().transpose()) * 0.5;
let max_symmetric = symmetric_part.abs().max().into_scalar();
assert!(
max_symmetric < TOLERANCE,
"Projected direction not in tangent space: max symmetric component = {}",
max_symmetric
);
}
#[test]
fn test_projection_preserves_tangent_vectors() {
let point = create_test_matrix(3, 2, vec![1.0, 0.0, 0.0, 1.0, 0.0, 0.0]);
let tangent = create_test_matrix(3, 2, vec![0.0, 0.0, 0.0, 0.0, 1.0, -1.0]);
let projected = SteifielsManifold::<TestBackend>::project(point.clone(), tangent.clone());
assert_matrix_close(&projected, &tangent, 1e-6);
let xtv = point.clone().transpose().matmul(tangent.clone());
let vtx = tangent.clone().transpose().matmul(point.clone());
let skew = xtv + vtx.transpose();
let max_skew = skew.abs().max().into_scalar();
assert!(
max_skew < 1e-6,
"Tangent space property violated: max skew = {}",
max_skew
);
assert!(
SteifielsManifold::is_tangent_at(point, tangent).into_scalar(),
"Tangent space property violated: max skew unknown"
)
}
#[test]
fn test_retraction_preserves_stiefel_property() {
let point = create_test_matrix(3, 2, vec![1.0, 0.0, 0.0, 1.0, 0.0, 0.0]);
let direction = create_test_matrix(3, 2, vec![0.0, 0.1, 0.0, -0.1, 0.2, 0.3]);
let step = 0.1;
let retracted =
SteifielsManifold::<TestBackend>::retract(point.clone(), direction.clone() * step);
let q1 = retracted.clone().slice([0..3, 0..1]);
let q2 = retracted.clone().slice([0..3, 1..2]);
let dot_product = q1.clone().transpose().matmul(q2.clone()).into_scalar();
assert!(
dot_product.abs() < TOLERANCE,
"Retracted point columns not orthogonal: dot product = {}",
dot_product
);
let norm1 = q1
.clone()
.transpose()
.matmul(q1.clone())
.sqrt()
.into_scalar();
let norm2 = q2
.clone()
.transpose()
.matmul(q2.clone())
.sqrt()
.into_scalar();
assert!(
(norm1 - 1.0).abs() < TOLERANCE,
"First column not normalized after retraction: norm = {}",
norm1
);
assert!(
(norm2 - 1.0).abs() < TOLERANCE,
"Second column not normalized after retraction: norm = {}",
norm2
);
assert!(SteifielsManifold::<TestBackend>::is_in_manifold(retracted)
.all()
.into_scalar());
}
#[test]
fn test_gram_schmidt_identity_matrix() {
let identity = create_test_matrix::<TestBackend>(
3,
3,
vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0],
);
let result = gram_schmidt(&identity);
assert_matrix_close(&result, &identity, TOLERANCE);
}
#[test]
fn test_manifold_properties() {
let sqrt_half = (0.5_f32).sqrt();
let point = create_test_matrix(
4,
2,
vec![
sqrt_half, sqrt_half, sqrt_half, -sqrt_half, 0.0, 0.0, 0.0, 0.0,
],
);
let gram_matrix = point.clone().transpose().matmul(point.clone());
let identity = create_test_matrix(2, 2, vec![1.0, 0.0, 0.0, 1.0]);
assert_matrix_close(&gram_matrix, &identity, TOLERANCE);
let direction = create_test_matrix(4, 2, vec![0.1, 0.0, 0.0, 0.1, 0.2, 0.3, -0.1, 0.2]);
let projected = SteifielsManifold::<TestBackend>::project(point.clone(), direction.clone());
let retracted = SteifielsManifold::<TestBackend>::retract(point.clone(), projected * 0.1);
let retracted_gram = retracted.clone().transpose().matmul(retracted.clone());
assert_matrix_close(&retracted_gram, &identity, TOLERANCE);
}
#[test]
fn test_optimiser() {
let optimiser = ManifoldRGD::<SteifielsManifold<TestBackend>, TestBackend>::default();
let a = create_test_matrix(3, 3, vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0]);
let mut x = Tensor::<TestBackend, 2>::random(
[3, 3],
burn::tensor::Distribution::Normal(1., 1.),
&a.device(),
)
.require_grad();
for _i in 0..100 {
let loss = x
.clone()
.transpose()
.matmul(a.clone())
.matmul(x.clone())
.sum();
let grads = loss.backward();
let x_grad = x
.grad(&grads)
.expect("The gradients do exist we just did loss.backwards()");
let x_grad_data = x_grad.to_data();
let x_grad_ad = Tensor::<TestBackend, 2>::from_data(x_grad_data, &x.device());
let x_clone = x.clone();
let (new_x, _) = optimiser.step(0.1, x_clone, x_grad_ad, None);
x = new_x.detach().require_grad();
println!("Loss: {}", loss);
}
println!("Optimised tensor: {}", x);
}
#[test]
fn test_optimiser_remove() {
let optimiser = ManifoldRGD::<SteifielsManifold<TestBackend>, TestBackend>::default();
let a = create_test_matrix(3, 3, vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0]);
let mut x = Tensor::<TestBackend, 2>::random(
[3, 3],
burn::tensor::Distribution::Normal(1., 1.),
&a.device(),
)
.require_grad();
for _i in 0..100 {
let loss = x
.clone()
.transpose()
.matmul(a.clone())
.matmul(x.clone())
.sum();
let mut grads = loss.backward();
let x_grad = x
.grad_remove(&mut grads)
.expect("The gradients do exist we just did loss.backwards()");
let x_grad_data = x_grad.to_data();
let x_grad_ad = Tensor::<TestBackend, 2>::from_data(x_grad_data, &x.device());
let x_clone = x.clone();
let (new_x, _) = optimiser.step(0.1, x_clone, x_grad_ad, None);
x = new_x.detach().require_grad();
println!("Loss: {}", loss);
}
println!("Optimised tensor: {}", x);
}
#[test]
fn test_optimiser_many() {
let optimiser = ManifoldRGD::<SteifielsManifold<TestBackend>, TestBackend>::default();
let a = create_test_matrix(3, 3, vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0]);
let mut x = Tensor::<TestBackend, 2>::random(
[3, 3],
burn::tensor::Distribution::Normal(1., 1.),
&a.device(),
)
.require_grad();
fn grad_fn(
x: Tensor<Autodiff<NdArray>, 2>,
a: Tensor<Autodiff<NdArray>, 2>,
) -> Tensor<Autodiff<NdArray>, 2> {
let loss = x.clone().transpose().matmul(a).matmul(x.clone()).sum();
let mut grads = loss.backward();
let x_grad = x
.grad_remove(&mut grads)
.expect("The gradients do exist we just did loss.backwards()");
let x_grad_ad = Tensor::<TestBackend, 2>::from_data(x_grad.to_data(), &x.device());
x_grad_ad
}
let mut state = None;
let x_original: Tensor<NdArray, 2> =
Tensor::<NdArray, 2>::from_data(x.to_data(), &Default::default());
let a_original: Tensor<NdArray, 2> =
Tensor::<NdArray, 2>::from_data(a.to_data(), &Default::default());
let unoptimised_loss = x_original
.clone()
.transpose()
.matmul(a_original.clone())
.matmul(x_original.clone())
.sum()
.into_scalar();
println!(
"Unoptimised tensor: {} with loss {}",
x_original, unoptimised_loss
);
(x, state) = optimiser.many_steps(|_| 0.1, 100, |x| grad_fn(x, a.clone()), x, state);
assert!(state.is_none());
let x_optimised: Tensor<NdArray, 2> =
Tensor::<NdArray, 2>::from_data(x.to_data(), &Default::default());
let optimised_loss = x_optimised
.clone()
.transpose()
.matmul(a_original)
.matmul(x_optimised.clone())
.sum()
.into_scalar();
println!(
"Optimised tensor: {} with loss {}",
x_optimised, optimised_loss
);
assert!(optimised_loss <= unoptimised_loss,
"The optimimisation should have lowered the loss function. It was {unoptimised_loss} before and {optimised_loss} after");
}
#[test]
fn test_simple_optimizer_step() {
let optimiser = ManifoldRGD::<SteifielsManifold<TestBackend>, TestBackend>::default();
let point = create_test_matrix(3, 2, vec![1.0, 0.0, 0.0, 1.0, 0.0, 0.0]);
let grad = create_test_matrix(3, 2, vec![0.1, 0.1, 0.1, 0.1, 0.1, 0.1]);
let (_result, _) = optimiser.step(0.1, point, grad, None);
}
}