use ndarray::{Array1, Array3, ArrayD, ArrayView1, ArrayView2, ArrayViewD, IxDyn};
pub fn rho_so2(theta: ArrayView1<'_, f64>) -> Array3<f64> {
let n = theta.len();
let mut out = Array3::<f64>::zeros((n, 2, 2));
for (i, &t) in theta.iter().enumerate() {
let (s, c) = t.sin_cos();
out[[i, 0, 0]] = c;
out[[i, 0, 1]] = -s;
out[[i, 1, 0]] = s;
out[[i, 1, 1]] = c;
}
out
}
pub fn rho_so2_jvp(theta: ArrayView1<'_, f64>) -> Array3<f64> {
let n = theta.len();
let mut out = Array3::<f64>::zeros((n, 2, 2));
for (i, &t) in theta.iter().enumerate() {
let (s, c) = t.sin_cos();
out[[i, 0, 0]] = -s;
out[[i, 0, 1]] = -c;
out[[i, 1, 0]] = c;
out[[i, 1, 1]] = -s;
}
out
}
pub fn rho_so3_single(ox: f64, oy: f64, oz: f64) -> [[f64; 3]; 3] {
let angle = (ox * ox + oy * oy + oz * oz).sqrt().max(1.0e-12);
let ax = ox / angle;
let ay = oy / angle;
let az = oz / angle;
let k = [[0.0, -az, ay], [az, 0.0, -ax], [-ay, ax, 0.0]];
let mut kk = [[0.0_f64; 3]; 3];
for i in 0..3 {
for j in 0..3 {
let mut acc = 0.0;
for r in 0..3 {
acc += k[i][r] * k[r][j];
}
kk[i][j] = acc;
}
}
let s = angle.sin();
let one_minus_c = 1.0 - angle.cos();
let mut out = [[0.0_f64; 3]; 3];
for i in 0..3 {
for j in 0..3 {
let id = if i == j { 1.0 } else { 0.0 };
out[i][j] = id + s * k[i][j] + one_minus_c * kk[i][j];
}
}
out
}
pub fn rho_so3(omega: ArrayView2<'_, f64>) -> Result<Array3<f64>, String> {
if omega.ncols() != 3 {
return Err(format!(
"SO(3) rep input must have shape (N, 3); got {}",
omega.ncols()
));
}
let n = omega.nrows();
let mut out = Array3::<f64>::zeros((n, 3, 3));
for row in 0..n {
let r = rho_so3_single(omega[[row, 0]], omega[[row, 1]], omega[[row, 2]]);
for i in 0..3 {
for j in 0..3 {
out[[row, i, j]] = r[i][j];
}
}
}
Ok(out)
}
pub fn so3_right_jacobian_times_vec(
ox: f64,
oy: f64,
oz: f64,
vx: f64,
vy: f64,
vz: f64,
) -> [f64; 3] {
let theta2 = ox * ox + oy * oy + oz * oz;
let theta = theta2.sqrt();
let ox_v = [oy * vz - oz * vy, oz * vx - ox * vz, ox * vy - oy * vx];
let omega_dot_v = ox * vx + oy * vy + oz * vz;
let ox2_v = [
ox * omega_dot_v - theta2 * vx,
oy * omega_dot_v - theta2 * vy,
oz * omega_dot_v - theta2 * vz,
];
let (alpha, beta) = if theta < 1.0e-6 {
(
-0.5 + theta2 / 24.0 - theta2 * theta2 / 720.0,
1.0 / 6.0 - theta2 / 120.0 + theta2 * theta2 / 5040.0,
)
} else {
let s = theta.sin();
let c = theta.cos();
(-(1.0 - c) / theta2, (theta - s) / (theta2 * theta))
};
[
vx + alpha * ox_v[0] + beta * ox2_v[0],
vy + alpha * ox_v[1] + beta * ox2_v[1],
vz + alpha * ox_v[2] + beta * ox2_v[2],
]
}
pub fn rho_so3_jvp(
omega: ArrayView2<'_, f64>,
domega: ArrayView2<'_, f64>,
) -> Result<Array3<f64>, String> {
if omega.ncols() != 3 || domega.ncols() != 3 {
return Err("SO(3) rep JVP requires (N, 3) inputs".to_string());
}
if omega.nrows() != domega.nrows() {
return Err("SO(3) rep JVP omega/domega must agree in row count".to_string());
}
let n = omega.nrows();
let mut out = Array3::<f64>::zeros((n, 3, 3));
for row in 0..n {
let ox = omega[[row, 0]];
let oy = omega[[row, 1]];
let oz = omega[[row, 2]];
let rg = rho_so3_single(ox, oy, oz);
let jr_dw = so3_right_jacobian_times_vec(
ox,
oy,
oz,
domega[[row, 0]],
domega[[row, 1]],
domega[[row, 2]],
);
let sx = jr_dw[0];
let sy = jr_dw[1];
let sz = jr_dw[2];
let kd = [[0.0, -sz, sy], [sz, 0.0, -sx], [-sy, sx, 0.0]];
for i in 0..3 {
for j in 0..3 {
let mut acc = 0.0;
for r in 0..3 {
acc += rg[i][r] * kd[r][j];
}
out[[row, i, j]] = acc;
}
}
}
Ok(out)
}
pub fn rho(group: &str, g: ArrayViewD<'_, f64>) -> Result<ArrayD<f64>, String> {
match group {
"SO2" => {
let mut out_shape = g.shape().to_vec();
out_shape.push(2);
out_shape.push(2);
let flat = Array1::from_vec(g.iter().copied().collect());
rho_so2(flat.view())
.into_shape_with_order(IxDyn(&out_shape))
.map_err(|err| format!("failed to reshape SO(2) representation: {err}"))
}
"SO3" => {
let shape = g.shape().to_vec();
if shape.last().copied() != Some(3) {
return Err("SO(3) rep input requires last axis of length 3".to_string());
}
let n = g.len() / 3;
let flat = g
.to_owned()
.into_shape_with_order((n, 3))
.map_err(|err| format!("failed to flatten SO(3) representation input: {err}"))?;
let mut out_shape = shape[..shape.len() - 1].to_vec();
out_shape.push(3);
out_shape.push(3);
rho_so3(flat.view())?
.into_shape_with_order(IxDyn(&out_shape))
.map_err(|err| format!("failed to reshape SO(3) representation: {err}"))
}
"R1" => {
let mut out_shape = g.shape().to_vec();
out_shape.push(1);
out_shape.push(1);
Ok(ArrayD::<f64>::from_elem(IxDyn(&out_shape), 1.0))
}
"Trivial" => {
let shape = g.shape();
let mut out_shape = if shape.is_empty() {
Vec::new()
} else {
shape[..shape.len() - 1].to_vec()
};
out_shape.push(1);
out_shape.push(1);
Ok(ArrayD::<f64>::from_elem(IxDyn(&out_shape), 1.0))
}
other => Err(format!("unknown equivariant group {other:?}")),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn so3_jvp_matches_finite_difference() {
fn expm_hat(wx: f64, wy: f64, wz: f64) -> [[f64; 3]; 3] {
rho_so3_single(wx, wy, wz)
}
fn fd_central_4pt(
wx: f64,
wy: f64,
wz: f64,
dx: f64,
dy: f64,
dz: f64,
h: f64,
) -> [[f64; 3]; 3] {
let r_p2 = expm_hat(wx + 2.0 * h * dx, wy + 2.0 * h * dy, wz + 2.0 * h * dz);
let r_p1 = expm_hat(wx + h * dx, wy + h * dy, wz + h * dz);
let r_m1 = expm_hat(wx - h * dx, wy - h * dy, wz - h * dz);
let r_m2 = expm_hat(wx - 2.0 * h * dx, wy - 2.0 * h * dy, wz - 2.0 * h * dz);
let mut out = [[0.0_f64; 3]; 3];
for i in 0..3 {
for j in 0..3 {
out[i][j] = (-r_p2[i][j] + 8.0 * r_p1[i][j] - 8.0 * r_m1[i][j] + r_m2[i][j])
/ (12.0 * h);
}
}
out
}
let cases: [([f64; 3], [f64; 3]); 3] = [
([0.7, -0.4, 0.3], [1.1, 0.5, -0.8]),
([0.2, 0.9, -0.5], [-0.6, 1.2, 0.4]),
([-1.0, 0.3, 0.6], [0.5, -0.7, 0.9]),
];
let omega = ndarray::Array2::from_shape_vec(
(cases.len(), 3),
cases
.iter()
.flat_map(|(w, _)| w.iter().copied())
.collect::<Vec<_>>(),
)
.expect("omega array");
let domega = ndarray::Array2::from_shape_vec(
(cases.len(), 3),
cases
.iter()
.flat_map(|(_, dw)| dw.iter().copied())
.collect::<Vec<_>>(),
)
.expect("domega array");
let jvp = rho_so3_jvp(omega.view(), domega.view())
.expect("SO(3) JVP must succeed on (N,3) input");
for (row, (w, dw)) in cases.iter().enumerate() {
let fd = fd_central_4pt(w[0], w[1], w[2], dw[0], dw[1], dw[2], 1.0e-4);
let mut max_err = 0.0_f64;
for i in 0..3 {
for j in 0..3 {
let diff = (jvp[[row, i, j]] - fd[i][j]).abs();
if diff > max_err {
max_err = diff;
}
}
}
assert!(
max_err < 1.0e-7,
"row {row}: SO(3) JVP - 4pt-FD has max |err| = {max_err:.3e} (omega={w:?}, domega={dw:?})"
);
}
}
#[test]
fn so3_jvp_parallel_direction_unaffected_by_right_jacobian_fix() {
let omega = ndarray::Array2::from_shape_vec((1, 3), vec![0.3, -0.6, 0.4]).expect("omega");
let domega = {
let scale = 1.7_f64;
ndarray::Array2::from_shape_vec((1, 3), vec![scale * 0.3, scale * -0.6, scale * 0.4])
.expect("domega")
};
let jvp = rho_so3_jvp(omega.view(), domega.view()).expect("JVP");
let rg = rho_so3_single(omega[[0, 0]], omega[[0, 1]], omega[[0, 2]]);
let dx = domega[[0, 0]];
let dy = domega[[0, 1]];
let dz = domega[[0, 2]];
let kd = [[0.0, -dz, dy], [dz, 0.0, -dx], [-dy, dx, 0.0]];
let mut expected = [[0.0_f64; 3]; 3];
for i in 0..3 {
for j in 0..3 {
for r in 0..3 {
expected[i][j] += rg[i][r] * kd[r][j];
}
}
}
let mut max_err = 0.0_f64;
for i in 0..3 {
for j in 0..3 {
let diff = (jvp[[0, i, j]] - expected[i][j]).abs();
if diff > max_err {
max_err = diff;
}
}
}
assert!(
max_err < 1.0e-13,
"parallel-dω boundary case shifted: max err = {max_err:.3e}"
);
}
}