use crate::tensor::DenseTensor;
use crate::tensor::TensorBase;
use crate::tensor::TensorError;
pub fn lie_exponential(algebra: &DenseTensor) -> Result<DenseTensor, TensorError> {
let shape = algebra.shape();
if shape.len() != 2 || shape[0] != shape[1] {
return Err(TensorError::DimensionMismatch {
expected: 2,
got: shape.len(),
});
}
let n = shape[0];
let data = algebra.data();
let norm: f64 = data.iter().map(|x| x.abs()).fold(0.0, f64::max);
let s = if norm > 0.5 {
((norm.ln() / 2.0_f64.ln()).ceil() as i32) + 1
} else {
0
};
let scale = 2.0_f64.powi(-s);
let scaled: Vec<f64> = data.iter().map(|x| x * scale).collect();
let mut m_mat = vec![0.0; n * n];
for i in 0..n {
for j in 0..n {
let idx = i * n + j;
m_mat[idx] = if i == j { 1.0 - scaled[idx] / 2.0 } else { -scaled[idx] / 2.0 };
}
}
let mut n_mat = vec![0.0; n * n];
for i in 0..n {
for j in 0..n {
let idx = i * n + j;
n_mat[idx] = if i == j { 1.0 + scaled[idx] / 2.0 } else { scaled[idx] / 2.0 };
}
}
let result = solve_matrix_equation(&m_mat, &n_mat, n)?;
let mut exp_a = result;
for _ in 0..s {
exp_a = matrix_multiply(&exp_a, &exp_a, n);
}
Ok(DenseTensor::from_vec(exp_a, vec![n, n]))
}
pub fn lie_logarithm(group: &DenseTensor) -> Result<DenseTensor, TensorError> {
let shape = group.shape();
if shape.len() != 2 || shape[0] != shape[1] {
return Err(TensorError::DimensionMismatch {
expected: 2,
got: shape.len(),
});
}
let n = shape[0];
let data = group.data();
let mut algebra = vec![0.0; n * n];
for i in 0..n {
for j in 0..n {
let idx = i * n + j;
algebra[idx] = if i == j {
data[idx] - 1.0
} else {
data[idx]
};
}
}
skew_symmetric_projection_inplace(&mut algebra, n);
Ok(DenseTensor::from_vec(algebra, vec![n, n]))
}
pub fn skew_symmetric_projection(matrix: &DenseTensor) -> Result<DenseTensor, TensorError> {
let shape = matrix.shape();
if shape.len() != 2 || shape[0] != shape[1] {
return Err(TensorError::DimensionMismatch {
expected: 2,
got: shape.len(),
});
}
let n = shape[0];
let data = matrix.data();
let mut result = vec![0.0; n * n];
skew_symmetric_projection_with(data, &mut result, n);
Ok(DenseTensor::from_vec(result, vec![n, n]))
}
fn skew_symmetric_projection_inplace(data: &mut [f64], n: usize) {
for i in 0..n {
for j in 0..n {
let idx_ij = i * n + j;
let idx_ji = j * n + i;
data[idx_ij] = (data[idx_ij] - data[idx_ji]) / 2.0;
}
}
}
fn skew_symmetric_projection_with(input: &[f64], output: &mut [f64], n: usize) {
for i in 0..n {
for j in 0..n {
let idx_ij = i * n + j;
let idx_ji = j * n + i;
output[idx_ij] = (input[idx_ij] - input[idx_ji]) / 2.0;
}
}
}
fn matrix_multiply(a: &[f64], b: &[f64], n: usize) -> Vec<f64> {
let mut result = vec![0.0; n * n];
for i in 0..n {
for j in 0..n {
for k in 0..n {
result[i * n + j] += a[i * n + k] * b[k * n + j];
}
}
}
result
}
fn solve_matrix_equation(a: &[f64], b: &[f64], n: usize) -> Result<Vec<f64>, TensorError> {
let mut x = vec![0.0; n * n];
for col in 0..n {
let mut rhs = vec![0.0; n];
for i in 0..n {
rhs[i] = b[i * n + col];
}
let mut aug = vec![0.0; n * (n + 1)];
for i in 0..n {
for j in 0..n {
aug[i * (n + 1) + j] = a[i * n + j];
}
aug[i * (n + 1) + n] = rhs[i];
}
for k in 0..n {
let mut max_row = k;
let mut max_val = aug[k * (n + 1) + k].abs();
for row in (k + 1)..n {
let val = aug[row * (n + 1) + k].abs();
if val > max_val {
max_val = val;
max_row = row;
}
}
if max_val < 1e-12 {
return Err(TensorError::BlasError {
code: 0,
description: "Singular or near-singular matrix".to_string(),
});
}
if max_row != k {
for j in 0..(n + 1) {
aug.swap(k * (n + 1) + j, max_row * (n + 1) + j);
}
}
let pivot = aug[k * (n + 1) + k];
for row in (k + 1)..n {
let factor = aug[row * (n + 1) + k] / pivot;
for j in k..(n + 1) {
aug[row * (n + 1) + j] -= factor * aug[k * (n + 1) + j];
}
}
}
for i in (0..n).rev() {
let mut sum = aug[i * (n + 1) + n];
for j in (i + 1)..n {
sum -= aug[i * (n + 1) + j] * x[j * n + col];
}
x[i * n + col] = sum / aug[i * (n + 1) + i];
}
}
Ok(x)
}
pub fn so_n_generator(n: usize, i: usize, j: usize) -> Result<DenseTensor, TensorError> {
if i >= n || j >= n || i == j {
return Err(TensorError::SliceError {
description: format!("Invalid indices ({}, {}) for SO({})", i, j, n),
});
}
let mut data = vec![0.0; n * n];
data[i * n + j] = 1.0;
data[j * n + i] = -1.0;
Ok(DenseTensor::from_vec(data, vec![n, n]))
}
pub fn is_skew_symmetric(matrix: &DenseTensor, tolerance: f64) -> bool {
let shape = matrix.shape();
if shape.len() != 2 || shape[0] != shape[1] {
return false;
}
let n = shape[0];
let data = matrix.data();
for i in 0..n {
for j in 0..n {
let idx_ij = i * n + j;
let idx_ji = j * n + i;
if (data[idx_ij] + data[idx_ji]).abs() > tolerance {
return false;
}
}
}
true
}
pub fn is_orthogonal(matrix: &DenseTensor, tolerance: f64) -> bool {
let shape = matrix.shape();
if shape.len() != 2 || shape[0] != shape[1] {
return false;
}
let n = shape[0];
let data = matrix.data();
for i in 0..n {
for j in 0..n {
let mut dot = 0.0;
for k in 0..n {
dot += data[k * n + i] * data[k * n + j];
}
let expected = if i == j { 1.0 } else { 0.0 };
if (dot - expected).abs() > tolerance {
return false;
}
}
}
true
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_skew_symmetric_projection() {
let matrix = DenseTensor::from_vec(
vec![1.0, 2.0, 3.0, 4.0],
vec![2, 2],
);
let skew = skew_symmetric_projection(&matrix).unwrap();
assert!(is_skew_symmetric(&skew, 1e-6));
}
#[test]
fn test_so2_generator() {
let gen = so_n_generator(2, 0, 1).unwrap();
assert!(is_skew_symmetric(&gen, 1e-6));
}
#[test]
fn test_lie_exponential_logarithm() {
let algebra = DenseTensor::from_vec(
vec![0.0, -0.1, 0.1, 0.0],
vec![2, 2],
);
let group = lie_exponential(&algebra).unwrap();
assert!(is_orthogonal(&group, 1e-5));
let algebra_back = lie_logarithm(&group).unwrap();
assert!(is_skew_symmetric(&algebra_back, 1e-5));
}
#[test]
fn test_lie_exponential_rotation() {
let theta = 0.1;
let algebra = DenseTensor::from_vec(
vec![0.0, -theta, theta, 0.0],
vec![2, 2],
);
let group = lie_exponential(&algebra).unwrap();
let data = group.data();
assert!((data[0] - theta.cos()).abs() < 1e-3, "Expected {}, got {}", theta.cos(), data[0]);
assert!((data[1] + theta.sin()).abs() < 1e-3, "Expected {}, got {}", -theta.sin(), data[1]);
assert!((data[2] - theta.sin()).abs() < 1e-3, "Expected {}, got {}", theta.sin(), data[2]);
assert!((data[3] - theta.cos()).abs() < 1e-3, "Expected {}, got {}", theta.cos(), data[3]);
assert!(is_orthogonal(&group, 1e-5));
}
#[test]
fn test_so3_generator() {
let gen_xy = so_n_generator(3, 0, 1).unwrap();
let gen_yz = so_n_generator(3, 1, 2).unwrap();
let gen_xz = so_n_generator(3, 0, 2).unwrap();
assert!(is_skew_symmetric(&gen_xy, 1e-6));
assert!(is_skew_symmetric(&gen_yz, 1e-6));
assert!(is_skew_symmetric(&gen_xz, 1e-6));
}
}