use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct RotationMatrix3 {
elements: [[f64; 3]; 3],
}
impl RotationMatrix3 {
pub fn identity() -> Self {
Self {
elements: [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
}
}
pub fn from_array(elements: [[f64; 3]; 3]) -> Self {
Self { elements }
}
pub fn get(&self, row: usize, col: usize) -> f64 {
self.elements[row][col]
}
pub fn set(&mut self, row: usize, col: usize, value: f64) {
self.elements[row][col] = value;
}
pub fn elements(&self) -> &[[f64; 3]; 3] {
&self.elements
}
pub fn rotate_x(&mut self, phi: f64) {
let (s, c) = phi.sin_cos();
let a10 = c * self.elements[1][0] + s * self.elements[2][0];
let a11 = c * self.elements[1][1] + s * self.elements[2][1];
let a12 = c * self.elements[1][2] + s * self.elements[2][2];
let a20 = -s * self.elements[1][0] + c * self.elements[2][0];
let a21 = -s * self.elements[1][1] + c * self.elements[2][1];
let a22 = -s * self.elements[1][2] + c * self.elements[2][2];
self.elements[1][0] = a10;
self.elements[1][1] = a11;
self.elements[1][2] = a12;
self.elements[2][0] = a20;
self.elements[2][1] = a21;
self.elements[2][2] = a22;
}
pub fn rotate_z(&mut self, psi: f64) {
let (s, c) = psi.sin_cos();
let a00 = c * self.elements[0][0] + s * self.elements[1][0];
let a01 = c * self.elements[0][1] + s * self.elements[1][1];
let a02 = c * self.elements[0][2] + s * self.elements[1][2];
let a10 = -s * self.elements[0][0] + c * self.elements[1][0];
let a11 = -s * self.elements[0][1] + c * self.elements[1][1];
let a12 = -s * self.elements[0][2] + c * self.elements[1][2];
self.elements[0][0] = a00;
self.elements[0][1] = a01;
self.elements[0][2] = a02;
self.elements[1][0] = a10;
self.elements[1][1] = a11;
self.elements[1][2] = a12;
}
pub fn rotate_y(&mut self, theta: f64) {
let (s, c) = theta.sin_cos();
let a00 = c * self.elements[0][0] - s * self.elements[2][0];
let a01 = c * self.elements[0][1] - s * self.elements[2][1];
let a02 = c * self.elements[0][2] - s * self.elements[2][2];
let a20 = s * self.elements[0][0] + c * self.elements[2][0];
let a21 = s * self.elements[0][1] + c * self.elements[2][1];
let a22 = s * self.elements[0][2] + c * self.elements[2][2];
self.elements[0][0] = a00;
self.elements[0][1] = a01;
self.elements[0][2] = a02;
self.elements[2][0] = a20;
self.elements[2][1] = a21;
self.elements[2][2] = a22;
}
pub fn multiply(&self, other: &Self) -> Self {
let mut result = [[0.0; 3]; 3];
for (i, row) in result.iter_mut().enumerate() {
for (j, cell) in row.iter_mut().enumerate() {
for k in 0..3 {
*cell += self.elements[i][k] * other.elements[k][j];
}
}
}
Self::from_array(result)
}
pub fn apply_to_vector(&self, vector: [f64; 3]) -> [f64; 3] {
[
self.elements[0][0] * vector[0]
+ self.elements[0][1] * vector[1]
+ self.elements[0][2] * vector[2],
self.elements[1][0] * vector[0]
+ self.elements[1][1] * vector[1]
+ self.elements[1][2] * vector[2],
self.elements[2][0] * vector[0]
+ self.elements[2][1] * vector[1]
+ self.elements[2][2] * vector[2],
]
}
pub fn determinant(&self) -> f64 {
let m = &self.elements;
m[0][0] * (m[1][1] * m[2][2] - m[1][2] * m[2][1])
- m[0][1] * (m[1][0] * m[2][2] - m[1][2] * m[2][0])
+ m[0][2] * (m[1][0] * m[2][1] - m[1][1] * m[2][0])
}
pub fn transpose(&self) -> Self {
Self::from_array([
[
self.elements[0][0],
self.elements[1][0],
self.elements[2][0],
],
[
self.elements[0][1],
self.elements[1][1],
self.elements[2][1],
],
[
self.elements[0][2],
self.elements[1][2],
self.elements[2][2],
],
])
}
pub fn is_rotation_matrix(&self, tolerance: f64) -> bool {
let det = self.determinant();
if (det - 1.0).abs() > tolerance {
return false;
}
let rt = self.transpose();
let product = self.multiply(&rt);
let identity = Self::identity();
for i in 0..3 {
for j in 0..3 {
if (product.elements[i][j] - identity.elements[i][j]).abs() > tolerance {
return false;
}
}
}
true
}
pub fn max_difference(&self, other: &Self) -> f64 {
let mut max_diff: f64 = 0.0;
for i in 0..3 {
for j in 0..3 {
let diff = (self.elements[i][j] - other.elements[i][j]).abs();
max_diff = max_diff.max(diff);
}
}
max_diff
}
pub fn transform_spherical(&self, ra: f64, dec: f64) -> (f64, f64) {
let (sin_ra, cos_ra) = ra.sin_cos();
let (sin_dec, cos_dec) = dec.sin_cos();
let vector = [cos_dec * cos_ra, cos_dec * sin_ra, sin_dec];
let transformed = self.apply_to_vector(vector);
let new_ra = transformed[1].atan2(transformed[0]);
let norm = (transformed[0] * transformed[0]
+ transformed[1] * transformed[1]
+ transformed[2] * transformed[2])
.sqrt();
let z = if norm == 0.0 {
0.0
} else {
(transformed[2] / norm).clamp(-1.0, 1.0)
};
let new_dec = z.asin();
(new_ra, new_dec)
}
}
impl std::ops::Mul for RotationMatrix3 {
type Output = Self;
fn mul(self, rhs: Self) -> Self {
self.multiply(&rhs)
}
}
impl std::ops::Mul<&RotationMatrix3> for RotationMatrix3 {
type Output = RotationMatrix3;
fn mul(self, rhs: &RotationMatrix3) -> RotationMatrix3 {
self.multiply(rhs)
}
}
impl std::ops::Mul<RotationMatrix3> for &RotationMatrix3 {
type Output = RotationMatrix3;
fn mul(self, rhs: RotationMatrix3) -> RotationMatrix3 {
self.multiply(&rhs)
}
}
impl std::ops::Mul<&RotationMatrix3> for &RotationMatrix3 {
type Output = RotationMatrix3;
fn mul(self, rhs: &RotationMatrix3) -> RotationMatrix3 {
self.multiply(rhs)
}
}
impl std::ops::Index<(usize, usize)> for RotationMatrix3 {
type Output = f64;
fn index(&self, (row, col): (usize, usize)) -> &f64 {
&self.elements[row][col]
}
}
impl std::ops::IndexMut<(usize, usize)> for RotationMatrix3 {
fn index_mut(&mut self, (row, col): (usize, usize)) -> &mut f64 {
&mut self.elements[row][col]
}
}
impl std::ops::Mul<super::Vector3> for RotationMatrix3 {
type Output = super::Vector3;
fn mul(self, vec: super::Vector3) -> super::Vector3 {
let result = self.apply_to_vector([vec.x, vec.y, vec.z]);
super::Vector3::from_array(result)
}
}
impl std::ops::Mul<super::Vector3> for &RotationMatrix3 {
type Output = super::Vector3;
fn mul(self, vec: super::Vector3) -> super::Vector3 {
let result = self.apply_to_vector([vec.x, vec.y, vec.z]);
super::Vector3::from_array(result)
}
}
impl fmt::Display for RotationMatrix3 {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "RotationMatrix3:")?;
for row in &self.elements {
writeln!(f, " [{:12.9} {:12.9} {:12.9}]", row[0], row[1], row[2])?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::constants::HALF_PI;
#[test]
fn test_identity_and_get() {
let m = RotationMatrix3::identity();
assert_eq!(m.get(0, 0), 1.0);
assert_eq!(m.get(1, 1), 1.0);
assert_eq!(m.get(2, 2), 1.0);
assert_eq!(m.get(0, 1), 0.0);
}
#[test]
fn test_set() {
let mut m = RotationMatrix3::identity();
m.set(0, 1, 0.5);
assert_eq!(m.get(0, 1), 0.5);
}
#[test]
fn test_rotate_z() {
let mut m = RotationMatrix3::identity();
m.rotate_z(HALF_PI);
let result = m.apply_to_vector([1.0, 0.0, 0.0]);
assert!(result[0].abs() < 1e-15);
assert!((result[1] + 1.0).abs() < 1e-15);
assert!(result[2].abs() < 1e-15);
}
#[test]
fn test_rotate_x() {
let mut m = RotationMatrix3::identity();
m.rotate_x(HALF_PI);
let result = m.apply_to_vector([0.0, 1.0, 0.0]);
assert!(result[0].abs() < 1e-15);
assert!(result[1].abs() < 1e-15);
assert!((result[2] + 1.0).abs() < 1e-15);
}
#[test]
fn test_rotate_y() {
let mut m = RotationMatrix3::identity();
m.rotate_y(HALF_PI);
let result = m.apply_to_vector([0.0, 0.0, 1.0]);
assert!((result[0] + 1.0).abs() < 1e-15);
assert!(result[1].abs() < 1e-15);
assert!(result[2].abs() < 1e-15);
}
#[test]
fn test_is_rotation_matrix_valid() {
let mut m = RotationMatrix3::identity();
m.rotate_z(0.5);
assert!(m.is_rotation_matrix(1e-14));
}
#[test]
fn test_is_rotation_matrix_bad_determinant() {
let m = RotationMatrix3::from_array([[2.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]);
assert!(!m.is_rotation_matrix(1e-15));
}
#[test]
fn test_is_rotation_matrix_not_orthogonal() {
let m = RotationMatrix3::from_array([[1.0, 0.1, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]);
assert!(!m.is_rotation_matrix(1e-15));
}
#[test]
fn test_transform_spherical_identity() {
let m = RotationMatrix3::identity();
let (ra, dec) = (1.0, 0.5);
let (new_ra, new_dec) = m.transform_spherical(ra, dec);
assert!((new_ra - ra).abs() < 1e-14);
assert!((new_dec - dec).abs() < 1e-14);
}
#[test]
fn test_transform_spherical_rotation() {
let mut m = RotationMatrix3::identity();
m.rotate_z(HALF_PI);
let (new_ra, new_dec) = m.transform_spherical(0.0, 0.0);
assert!((new_ra + HALF_PI).abs() < 1e-14);
assert!(new_dec.abs() < 1e-14);
}
#[test]
fn test_transform_spherical_zero_norm() {
let zero_matrix =
RotationMatrix3::from_array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]);
let (_, dec) = zero_matrix.transform_spherical(0.0, 0.0);
assert!(dec.is_finite());
}
#[test]
fn test_mul_matrix_matrix() {
let mut a = RotationMatrix3::identity();
a.rotate_x(0.1);
let mut b = RotationMatrix3::identity();
b.rotate_y(0.2);
let r1 = a * b;
let r2 = a * &b;
let r3 = &a * b;
let r4 = &a * &b;
assert_eq!(r1.get(0, 0), r2.get(0, 0));
assert_eq!(r2.get(0, 0), r3.get(0, 0));
assert_eq!(r3.get(0, 0), r4.get(0, 0));
}
#[test]
fn test_index_operators() {
let mut m = RotationMatrix3::identity();
assert_eq!(m[(0, 0)], 1.0);
assert_eq!(m[(0, 1)], 0.0);
m[(0, 1)] = 0.5;
assert_eq!(m[(0, 1)], 0.5);
}
#[test]
fn test_mul_matrix_vector() {
use crate::Vector3;
let m = RotationMatrix3::identity();
let v = Vector3::new(1.0, 2.0, 3.0);
let r1 = m * v;
let r2 = &m * v;
assert_eq!(r1, v);
assert_eq!(r2, v);
}
#[test]
fn test_display() {
let mut m = RotationMatrix3::identity();
m.rotate_z(0.1);
let s = format!("{}", m);
assert!(s.contains("RotationMatrix3:"));
assert!(s.contains("["));
}
#[test]
fn test_max_difference() {
let a = RotationMatrix3::identity();
let b = RotationMatrix3::from_array([[1.0, 0.1, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]);
assert!((a.max_difference(&b) - 0.1).abs() < 1e-15);
}
#[test]
fn test_elements() {
let m = RotationMatrix3::identity();
let e = m.elements();
assert_eq!(e[0][0], 1.0);
assert_eq!(e[1][1], 1.0);
}
}