use burn::{prelude::Backend, tensor::Tensor};
pub(crate) fn identity_in_last_two<B: Backend, const D: usize>(
example: &Tensor<B, D>,
) -> Tensor<B, D> {
let shape: [usize; D] = example.shape().dims();
debug_assert!(D >= 2);
debug_assert_eq!(shape[D - 1], shape[D - 2]);
let n = shape[D - 1];
let identity = Tensor::eye(n, &example.device());
identity.expand(example.shape())
}
#[allow(dead_code)]
pub(crate) fn diag_i<B: Backend, const D: usize>(
example: &Tensor<B, D>,
diag_fun: impl Fn(usize) -> f32,
) -> Tensor<B, D> {
let shape: [usize; D] = example.shape().dims();
debug_assert!(D >= 2);
debug_assert_eq!(shape[D - 1], shape[D - 2]);
let n = shape[D - 1];
let mut other = example.zeros_like();
let mut ones_shape = [1usize; D];
ones_shape[..(D - 2)].copy_from_slice(&shape[..(D - 2)]);
let ones_patch = Tensor::<B, D>::ones(ones_shape, &example.device());
for diag in 0..n {
let ranges: [_; D] = std::array::from_fn(|dim| {
if dim < D - 2 {
0..shape[dim]
} else {
diag..diag + 1
}
});
other = other.slice_assign(ranges, ones_patch.clone().mul_scalar(diag_fun(diag)));
}
other
}
pub(crate) fn ein_sum<B: Backend, const D: usize>(
mut t: Tensor<B, D>,
a: usize,
b: usize,
) -> Tensor<B, D> {
debug_assert!(a < D);
debug_assert!(b < D);
debug_assert_ne!(a, b);
let t_shape: [usize; D] = t.shape().dims();
debug_assert_eq!(t_shape[a], t_shape[b]);
if a != D - 2 || b != D - 1 {
t = t.swap_dims(a, D - 2);
t = t.swap_dims(b, D - 1);
}
let identity_last_two = identity_in_last_two(&t);
t = t.mul(identity_last_two);
t = t.sum_dim(D - 1).sum_dim(D - 2);
if a != D - 2 || b != D - 1 {
t = t.swap_dims(b, D - 1);
t = t.swap_dims(a, D - 2);
}
t
}
#[allow(dead_code)]
pub(crate) fn trace<B: Backend, const D: usize>(t: Tensor<B, D>) -> Tensor<B, D> {
ein_sum(t, D - 2, D - 1)
}
#[cfg(test)]
pub(crate) mod test {
use burn::{backend::NdArray, prelude::Backend, tensor::Tensor};
use crate::manifolds::utils::{diag_i, ein_sum, identity_in_last_two};
pub(crate) fn assert_matrix_close<TestBackend>(
a: &Tensor<TestBackend, 2>,
b: &Tensor<TestBackend, 2>,
tol: f32,
) where
TestBackend: Backend,
<TestBackend as Backend>::FloatElem: PartialOrd<f32>,
{
let diff = (a.clone() - b.clone()).abs();
let max_diff = diff.max().into_scalar();
assert!(
max_diff < tol,
"Tensors differ by {}, tolerance: {}",
max_diff,
tol
);
}
pub(crate) fn create_test_matrix<TestBackend: Backend>(
rows: usize,
cols: usize,
values: Vec<f32>,
) -> Tensor<TestBackend, 2> {
debug_assert_ne!(rows, 0);
debug_assert_ne!(cols, 0);
if rows < cols {
return create_test_matrix(cols, rows, values).transpose();
}
let device = Default::default();
let mut data = Vec::with_capacity(rows);
for chunk in values.chunks(cols) {
data.push(chunk.to_vec());
}
match (rows, cols) {
(3, 2) => {
if data.len() >= 3 && data[0].len() >= 2 && data[1].len() >= 2 && data[2].len() >= 2
{
Tensor::from_floats(
[
[data[0][0], data[0][1]],
[data[1][0], data[1][1]],
[data[2][0], data[2][1]],
],
&device,
)
} else {
panic!("Invalid 3x2 matrix data");
}
}
(3, 1) => {
if data.len() >= 3
&& !data[0].is_empty()
&& !data[1].is_empty()
&& !data[2].is_empty()
{
Tensor::from_floats([[data[0][0]], [data[1][0]], [data[2][0]]], &device)
} else {
panic!("Invalid 3x1 matrix data");
}
}
(3, 3) => {
if data.len() >= 3 && data[0].len() >= 3 && data[1].len() >= 3 && data[2].len() >= 3
{
Tensor::from_floats(
[
[data[0][0], data[0][1], data[0][2]],
[data[1][0], data[1][1], data[1][2]],
[data[2][0], data[2][1], data[2][2]],
],
&device,
)
} else {
panic!("Invalid 3x3 matrix data");
}
}
(4, 2) => {
if data.len() >= 4
&& data[0].len() >= 2
&& data[1].len() >= 2
&& data[2].len() >= 2
&& data[3].len() >= 2
{
Tensor::from_floats(
[
[data[0][0], data[0][1]],
[data[1][0], data[1][1]],
[data[2][0], data[2][1]],
[data[3][0], data[3][1]],
],
&device,
)
} else {
panic!("Invalid 4x2 matrix data");
}
}
(2, 2) => {
if data.len() >= 2 && data[0].len() >= 2 && data[1].len() >= 2 {
Tensor::from_floats(
[[data[0][0], data[0][1]], [data[1][0], data[1][1]]],
&device,
)
} else {
panic!("Invalid 2x2 matrix data");
}
}
(2, 1) => {
if data.len() >= 2 && !data[0].is_empty() && !data[1].is_empty() {
Tensor::from_floats([[data[0][0]], [data[1][0]]], &device)
} else {
panic!("Invalid 2x1 matrix data");
}
}
(1, 1) => {
if data.len() >= 1 && !data[0].is_empty() {
Tensor::from_floats([[data[0][0]]], &device)
} else {
panic!("Invalid 1x1 matrix data");
}
}
_ => panic!("Unsupported matrix dimensions: {}x{}", rows, cols),
}
}
#[test]
fn small_einsum() {
let mat = create_test_matrix::<NdArray>(
3,
3,
vec![3.0, 4.0, 5.0, 6.0, 7.0, 3.0, -10.0, -4.0, -1.0],
);
let ein_summed = ein_sum(mat.clone(), 0, 1);
assert_eq!(ein_summed.shape().dims(), [1, 1]);
let scalar = ein_summed.into_scalar();
assert!((scalar - 9.0).abs() <= 1e-6, "{}", scalar);
}
#[test]
fn identity_test() {
{
let mat = create_test_matrix::<NdArray>(
3,
3,
vec![3.0, 4.0, 5.0, 6.0, 7.0, 3.0, -10.0, -4.0, -1.0],
);
let mat = mat.expand([3, 3]);
let identity_mat = identity_in_last_two(&mat);
let expected = Tensor::eye(3, &identity_mat.device());
assert_matrix_close(&identity_mat, &expected, 1e-6);
}
let mat = create_test_matrix::<NdArray>(
3,
3,
vec![3.0, 4.0, 5.0, 6.0, 7.0, 3.0, -10.0, -4.0, -1.0],
);
let expanded_shape = [3, 3, 3, 3, 3];
let mat = mat.expand(expanded_shape);
let identity_mat = identity_in_last_two(&mat);
for idx in 0..expanded_shape[0] {
for jdx in 0..expanded_shape[1] {
for kdx in 0..expanded_shape[2] {
let slice = identity_mat
.clone()
.slice([idx..idx + 1, jdx..jdx + 1, kdx..kdx + 1, 0..3, 0..3])
.reshape([3, 3]);
let expected = Tensor::eye(3, &slice.device());
assert_matrix_close(&slice, &expected, 1e-6);
}
}
}
let mat = create_test_matrix::<NdArray>(
3,
3,
vec![3.0, 4.0, 5.0, 6.0, 7.0, 3.0, -10.0, -4.0, -1.0],
);
let expanded_shape = [29, 483, 2, 3, 3];
let mat = mat.expand(expanded_shape);
let identity_mat = identity_in_last_two(&mat);
for idx in 0..expanded_shape[0] {
for jdx in 0..expanded_shape[1] {
for kdx in 0..expanded_shape[2] {
let slice = identity_mat
.clone()
.slice([idx..idx + 1, jdx..jdx + 1, kdx..kdx + 1, 0..3, 0..3])
.reshape([3, 3]);
let expected = Tensor::eye(3, &slice.device());
assert_matrix_close(&slice, &expected, 1e-6);
}
}
}
}
#[test]
fn diag_test() {
let diag_entries = [2.0, 7.0, 9.0];
let expected =
create_test_matrix::<NdArray>(3, 3, vec![2.0, 0.0, 0.0, 0.0, 7.0, 0.0, 0.0, 0.0, 9.0]);
let expanded_shape = [3, 3, 3, 3, 3];
let mat = expected.clone().expand(expanded_shape);
let identity_mat = diag_i(&mat, |i| diag_entries[i]);
for idx in 0..expanded_shape[0] {
for jdx in 0..expanded_shape[1] {
for kdx in 0..expanded_shape[2] {
let slice = identity_mat
.clone()
.slice([idx..idx + 1, jdx..jdx + 1, kdx..kdx + 1, 0..3, 0..3])
.reshape([3, 3]);
assert_matrix_close(&slice, &expected, 1e-6);
}
}
}
let expanded_shape = [10, 9, 20, 3, 3];
let mat = expected.clone().expand(expanded_shape);
let identity_mat = diag_i(&mat, |i| diag_entries[i]);
for idx in 0..expanded_shape[0] {
for jdx in 0..expanded_shape[1] {
for kdx in 0..expanded_shape[2] {
let slice = identity_mat
.clone()
.slice([idx..idx + 1, jdx..jdx + 1, kdx..kdx + 1, 0..3, 0..3])
.reshape([3, 3]);
assert_matrix_close(&slice, &expected, 1e-6);
}
}
}
}
}