use core::marker::PhantomData;
use core::ops::Mul;
use crate::ops::Rotation3;
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct FrameMatrix3<F, T = f64> {
data: [[T; 3]; 3],
_frame: PhantomData<F>,
}
impl<F, T: Copy + Default> FrameMatrix3<F, T> {
#[inline]
pub fn from_array(data: [[T; 3]; 3]) -> Self {
Self {
data,
_frame: PhantomData,
}
}
#[inline]
pub fn zero() -> Self {
Self::from_array([[T::default(); 3]; 3])
}
#[inline]
pub fn as_array(&self) -> &[[T; 3]; 3] {
&self.data
}
pub fn transpose(&self) -> Self {
let mut out = [[T::default(); 3]; 3];
for (i, row) in out.iter_mut().enumerate() {
for (j, slot) in row.iter_mut().enumerate() {
*slot = self.data[j][i];
}
}
Self::from_array(out)
}
#[inline]
pub fn relabel<G>(self) -> FrameMatrix3<G, T> {
FrameMatrix3 {
data: self.data,
_frame: PhantomData,
}
}
}
impl<F> FrameMatrix3<F> {
pub fn identity() -> Self {
let mut d = [[0.0_f64; 3]; 3];
for (i, row) in d.iter_mut().enumerate() {
row[i] = 1.0;
}
Self::from_array(d)
}
pub fn from_diagonal(diag: [f64; 3]) -> Self {
let mut d = [[0.0_f64; 3]; 3];
for i in 0..3 {
d[i][i] = diag[i];
}
Self::from_array(d)
}
pub fn mat_mul(&self, rhs: &Self) -> Self {
let a = &self.data;
let b = &rhs.data;
let mut out = [[0.0_f64; 3]; 3];
for (i, out_row) in out.iter_mut().enumerate() {
for (k, &aik) in a[i].iter().enumerate() {
if aik == 0.0 {
continue;
}
for (j, out_elt) in out_row.iter_mut().enumerate() {
*out_elt += aik * b[k][j];
}
}
}
FrameMatrix3::from_array(out)
}
pub fn add_in_place(&mut self, other: &FrameMatrix3<F>) {
for i in 0..3 {
for j in 0..3 {
self.data[i][j] += other.data[i][j];
}
}
}
pub fn scale_in_place(&mut self, s: f64) {
for row in &mut self.data {
for v in row.iter_mut() {
*v *= s;
}
}
}
pub fn add_outer_product_in_place(&mut self, a: [f64; 3], b: [f64; 3]) {
for (i, data_row) in self.data.iter_mut().enumerate() {
for (j, data_elt) in data_row.iter_mut().enumerate() {
*data_elt += a[i] * b[j];
}
}
}
pub fn similarity_general<G>(&self, m: &FrameMatrix3<F>) -> FrameMatrix3<G> {
let r = &self.data;
let mut tmp = [[0.0_f64; 3]; 3];
for (i, tmp_row) in tmp.iter_mut().enumerate() {
for (k, &rik) in r[i].iter().enumerate() {
if rik == 0.0 {
continue;
}
for (j, tmp_elt) in tmp_row.iter_mut().enumerate() {
*tmp_elt += rik * m.data[k][j];
}
}
}
let mut res = [[0.0_f64; 3]; 3];
for (i, res_row) in res.iter_mut().enumerate() {
for (k, &tik) in tmp[i].iter().enumerate() {
if tik == 0.0 {
continue;
}
for (j, res_elt) in res_row.iter_mut().enumerate() {
*res_elt += tik * r[j][k];
}
}
}
FrameMatrix3::from_array(res)
}
pub fn similarity<G>(&self, m: &SymmetricFrameMatrix3<F>) -> SymmetricFrameMatrix3<G> {
let r = &self.data;
let mut tmp = [[0.0_f64; 3]; 3];
for (i, tmp_row) in tmp.iter_mut().enumerate() {
for (k, &rik) in r[i].iter().enumerate() {
if rik == 0.0 {
continue;
}
for (j, tmp_elt) in tmp_row.iter_mut().enumerate() {
*tmp_elt += rik * m.data[k][j];
}
}
}
let mut raw = [[0.0_f64; 3]; 3];
for (i, raw_row) in raw.iter_mut().enumerate() {
for (k, &tik) in tmp[i].iter().enumerate() {
if tik == 0.0 {
continue;
}
for (j, raw_elt) in raw_row.iter_mut().enumerate() {
*raw_elt += tik * r[j][k];
}
}
}
let mut data = [[0.0_f64; 3]; 3];
for (i, row) in data.iter_mut().enumerate() {
for (j, slot) in row.iter_mut().enumerate() {
*slot = 0.5 * (raw[i][j] + raw[j][i]);
}
}
SymmetricFrameMatrix3 {
data,
_frame: PhantomData,
}
}
pub fn rotated_by<G>(&self, r: &Rotation3) -> FrameMatrix3<G> {
let rm = r.as_matrix();
let mut tmp = [[0.0_f64; 3]; 3];
for (i, tmp_row) in tmp.iter_mut().enumerate() {
for (k, &rik) in rm[i].iter().enumerate() {
if rik == 0.0 {
continue;
}
for (j, tmp_elt) in tmp_row.iter_mut().enumerate() {
*tmp_elt += rik * self.data[k][j];
}
}
}
let mut res = [[0.0_f64; 3]; 3];
for (i, res_row) in res.iter_mut().enumerate() {
for (k, &tik) in tmp[i].iter().enumerate() {
if tik == 0.0 {
continue;
}
for (j, res_elt) in res_row.iter_mut().enumerate() {
*res_elt += tik * rm[j][k];
}
}
}
FrameMatrix3::from_array(res)
}
}
impl<F> Mul<FrameMatrix3<F>> for FrameMatrix3<F> {
type Output = FrameMatrix3<F>;
#[inline]
fn mul(self, rhs: FrameMatrix3<F>) -> FrameMatrix3<F> {
self.mat_mul(&rhs)
}
}
impl<F> Mul<&FrameMatrix3<F>> for &FrameMatrix3<F> {
type Output = FrameMatrix3<F>;
#[inline]
fn mul(self, rhs: &FrameMatrix3<F>) -> FrameMatrix3<F> {
self.mat_mul(rhs)
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct SymmetricFrameMatrix3<F, T = f64> {
data: [[T; 3]; 3],
_frame: PhantomData<F>,
}
impl<F, T: Copy + Default> SymmetricFrameMatrix3<F, T> {
pub fn from_diagonal(diag: [T; 3]) -> Self {
let mut data = [[T::default(); 3]; 3];
for i in 0..3 {
data[i][i] = diag[i];
}
Self {
data,
_frame: PhantomData,
}
}
pub fn from_upper(upper: [[T; 3]; 3]) -> Self {
let mut out = [[T::default(); 3]; 3];
for i in 0..3 {
for j in i..3 {
out[i][j] = upper[i][j];
out[j][i] = upper[i][j];
}
}
Self {
data: out,
_frame: PhantomData,
}
}
#[inline]
pub fn as_array(&self) -> &[[T; 3]; 3] {
&self.data
}
#[inline]
pub fn diagonal(&self) -> [T; 3] {
[self.data[0][0], self.data[1][1], self.data[2][2]]
}
#[inline]
pub fn transpose(&self) -> Self {
Self {
data: self.data,
_frame: PhantomData,
}
}
#[inline]
pub fn relabel<G>(self) -> SymmetricFrameMatrix3<G, T> {
SymmetricFrameMatrix3 {
data: self.data,
_frame: PhantomData,
}
}
}
impl<F> SymmetricFrameMatrix3<F> {
pub fn identity() -> Self {
Self::from_diagonal([1.0, 1.0, 1.0])
}
pub fn add_in_place(&mut self, other: &SymmetricFrameMatrix3<F>) {
for i in 0..3 {
for j in 0..3 {
self.data[i][j] += other.data[i][j];
}
}
}
pub fn scale_in_place(&mut self, s: f64) {
for row in &mut self.data {
for v in row.iter_mut() {
*v *= s;
}
}
}
pub fn add_outer_product_in_place(&mut self, a: [f64; 3]) {
for i in 0..3 {
for j in 0..3 {
self.data[i][j] += a[i] * a[j];
}
}
}
pub fn rotated_by<G>(&self, r: &Rotation3) -> SymmetricFrameMatrix3<G> {
let rm = r.as_matrix();
let mut tmp = [[0.0_f64; 3]; 3];
for (i, tmp_row) in tmp.iter_mut().enumerate() {
for (k, &rik) in rm[i].iter().enumerate() {
if rik == 0.0 {
continue;
}
for (j, tmp_elt) in tmp_row.iter_mut().enumerate() {
*tmp_elt += rik * self.data[k][j];
}
}
}
let mut raw = [[0.0_f64; 3]; 3];
for (i, raw_row) in raw.iter_mut().enumerate() {
for (k, &tik) in tmp[i].iter().enumerate() {
if tik == 0.0 {
continue;
}
for (j, raw_elt) in raw_row.iter_mut().enumerate() {
*raw_elt += tik * rm[j][k];
}
}
}
let mut data = [[0.0_f64; 3]; 3];
for (i, row) in data.iter_mut().enumerate() {
for (j, slot) in row.iter_mut().enumerate() {
*slot = 0.5 * (raw[i][j] + raw[j][i]);
}
}
SymmetricFrameMatrix3 {
data,
_frame: PhantomData,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug, Copy, Clone)]
struct F1;
impl crate::frames::ReferenceFrame for F1 {
fn frame_name() -> &'static str {
"F1"
}
}
#[derive(Debug, Copy, Clone)]
struct F2;
impl crate::frames::ReferenceFrame for F2 {
fn frame_name() -> &'static str {
"F2"
}
}
#[test]
fn frame_matrix3_round_trip() {
let data = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
let m = FrameMatrix3::<F1>::from_array(data);
assert_eq!(m.as_array(), &data);
}
#[test]
fn frame_matrix3_zero() {
let m = FrameMatrix3::<F1>::zero();
for row in m.as_array() {
for v in row {
assert_eq!(*v, 0.0);
}
}
}
#[test]
fn frame_matrix3_identity() {
let m = FrameMatrix3::<F1>::identity();
for (i, row) in m.as_array().iter().enumerate() {
for (j, value) in row.iter().enumerate() {
let expected = if i == j { 1.0 } else { 0.0 };
assert_eq!(*value, expected);
}
}
}
#[test]
fn frame_matrix3_transpose() {
let data = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
let m = FrameMatrix3::<F1>::from_array(data);
let t = m.transpose();
for (i, row) in t.as_array().iter().enumerate() {
for (j, value) in row.iter().enumerate() {
assert_eq!(*value, data[j][i]);
}
}
}
#[test]
fn frame_matrix3_relabel() {
let data = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
let m1 = FrameMatrix3::<F1>::from_array(data);
let m2: FrameMatrix3<F2> = m1.relabel();
assert_eq!(m2.as_array(), &data);
}
#[test]
fn frame_matrix3_rotated_by_identity_is_noop() {
let data = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
let m = FrameMatrix3::<F1>::from_array(data);
let rotated: FrameMatrix3<F2> = m.rotated_by(&Rotation3::IDENTITY);
for (i, row) in rotated.as_array().iter().enumerate() {
for (j, value) in row.iter().enumerate() {
assert!((*value - data[i][j]).abs() < 1e-14);
}
}
}
#[test]
fn frame_matrix3_generic_element_type() {
let m = FrameMatrix3::<F1, i32>::from_array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]);
assert_eq!(m.as_array()[1][1], 1_i32);
assert_eq!(m.transpose().as_array()[0][0], 1_i32);
let z = FrameMatrix3::<F1, u8>::zero();
assert_eq!(z.as_array()[2][2], 0_u8);
}
#[test]
fn symmetric_from_diagonal_round_trip() {
let m = SymmetricFrameMatrix3::<F1>::from_diagonal([1.0, 2.0, 3.0]);
assert_eq!(m.diagonal(), [1.0, 2.0, 3.0]);
for (i, row) in m.as_array().iter().enumerate() {
for (j, value) in row.iter().enumerate() {
if i != j {
assert_eq!(*value, 0.0);
}
}
}
}
#[test]
fn symmetric_from_upper_mirrors_correctly() {
let upper = [[1.0, 2.0, 3.0], [99.0, 4.0, 5.0], [99.0, 99.0, 6.0]];
let m = SymmetricFrameMatrix3::<F1>::from_upper(upper);
let a = m.as_array();
assert_eq!(a[0][0], 1.0);
assert_eq!(a[1][1], 4.0);
assert_eq!(a[2][2], 6.0);
assert_eq!(a[0][1], 2.0);
assert_eq!(a[1][0], 2.0);
assert_eq!(a[0][2], 3.0);
assert_eq!(a[2][0], 3.0);
assert_eq!(a[1][2], 5.0);
assert_eq!(a[2][1], 5.0);
}
#[test]
fn symmetric_transpose_is_self() {
let m = SymmetricFrameMatrix3::<F1>::from_diagonal([1.0, 2.0, 3.0]);
assert_eq!(m.transpose().as_array(), m.as_array());
}
#[test]
fn symmetric_relabel() {
let m1 = SymmetricFrameMatrix3::<F1>::from_diagonal([1.0, 2.0, 3.0]);
let m2: SymmetricFrameMatrix3<F2> = m1.relabel();
assert_eq!(m2.diagonal(), [1.0, 2.0, 3.0]);
}
#[test]
fn symmetric_rotated_by_identity_is_noop() {
let m = SymmetricFrameMatrix3::<F1>::from_diagonal([1.0, 4.0, 9.0]);
let rotated: SymmetricFrameMatrix3<F2> = m.rotated_by(&Rotation3::IDENTITY);
assert!((rotated.as_array()[0][0] - 1.0).abs() < 1e-14);
assert!((rotated.as_array()[1][1] - 4.0).abs() < 1e-14);
assert!((rotated.as_array()[2][2] - 9.0).abs() < 1e-14);
}
#[test]
fn symmetric_rotated_by_preserves_symmetry() {
use qtty::angular::Radians;
let upper = [[4.0, 1.0, 0.0], [99.0, 9.0, 0.0], [99.0, 99.0, 1.0]];
let m = SymmetricFrameMatrix3::<F1>::from_upper(upper);
let r = Rotation3::rz(Radians::new(std::f64::consts::FRAC_PI_4));
let rotated: SymmetricFrameMatrix3<F2> = m.rotated_by(&r);
let a = rotated.as_array();
for (i, row) in a.iter().enumerate() {
for (j, value) in row.iter().enumerate() {
assert!(
(*value - a[j][i]).abs() < 1e-13,
"a[{i}][{j}] != a[{j}][{i}]"
);
}
}
let trace_in = 4.0 + 9.0 + 1.0;
let trace_out = a[0][0] + a[1][1] + a[2][2];
assert!(
(trace_out - trace_in).abs() < 1e-12,
"trace changed: {trace_out} != {trace_in}"
);
}
#[test]
fn frame_matrix3_from_diagonal() {
let m = FrameMatrix3::<F1>::from_diagonal([2.0, 3.0, 5.0]);
let a = m.as_array();
assert_eq!(a[0][0], 2.0);
assert_eq!(a[1][1], 3.0);
assert_eq!(a[2][2], 5.0);
assert_eq!(a[0][1], 0.0);
assert_eq!(a[1][2], 0.0);
}
#[test]
fn frame_matrix3_mat_mul_identity() {
let data = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
let a = FrameMatrix3::<F1>::from_array(data);
let eye = FrameMatrix3::<F1>::identity();
let result = a.mat_mul(&eye);
for (i, row) in result.as_array().iter().enumerate() {
for (j, v) in row.iter().enumerate() {
assert!(
(*v - data[i][j]).abs() < 1e-14,
"A*I mismatch at [{i}][{j}]: {v} != {}",
data[i][j]
);
}
}
}
#[test]
fn frame_matrix3_mat_mul_known() {
let a = FrameMatrix3::<F1>::from_array([[1.0, 2.0, 0.0], [3.0, 4.0, 0.0], [0.0, 0.0, 1.0]]);
let b = FrameMatrix3::<F1>::from_array([[5.0, 6.0, 0.0], [7.0, 8.0, 0.0], [0.0, 0.0, 1.0]]);
let c = a.mat_mul(&b);
assert!((c.as_array()[0][0] - 19.0).abs() < 1e-14);
assert!((c.as_array()[0][1] - 22.0).abs() < 1e-14);
assert!((c.as_array()[1][0] - 43.0).abs() < 1e-14);
assert!((c.as_array()[1][1] - 50.0).abs() < 1e-14);
assert!((c.as_array()[2][2] - 1.0).abs() < 1e-14);
}
#[test]
fn frame_matrix3_mul_operator() {
let eye = FrameMatrix3::<F1>::identity();
let result = eye * eye;
assert_eq!(result.as_array(), eye.as_array());
}
#[test]
fn frame_matrix3_similarity_general_identity_rotation() {
let data = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
let m = FrameMatrix3::<F1>::from_array(data);
let eye = FrameMatrix3::<F1>::identity();
let result: FrameMatrix3<F2> = eye.similarity_general(&m);
for (i, row) in result.as_array().iter().enumerate() {
for (j, v) in row.iter().enumerate() {
assert!(
(*v - data[i][j]).abs() < 1e-14,
"sim_general identity failed at [{i}][{j}]"
);
}
}
}
#[test]
fn frame_matrix3_similarity_round_trip() {
use qtty::angular::Radians;
let m = SymmetricFrameMatrix3::<F1>::from_diagonal([1.0, 4.0, 9.0]);
let r45 = Rotation3::rz(Radians::new(std::f64::consts::FRAC_PI_4));
let r45_mat = FrameMatrix3::<F1>::from_array(*r45.as_matrix());
let rotated: SymmetricFrameMatrix3<F2> = r45_mat.similarity(&m);
let r45_inv = r45.inverse();
let r45_inv_mat = FrameMatrix3::<F2>::from_array(*r45_inv.as_matrix());
let back: SymmetricFrameMatrix3<F1> = r45_inv_mat.similarity(&rotated);
let orig = m.as_array();
let result = back.as_array();
for i in 0..3 {
for j in 0..3 {
assert!(
(result[i][j] - orig[i][j]).abs() < 1e-12,
"round-trip similarity failed at [{i}][{j}]: {} != {}",
result[i][j],
orig[i][j]
);
}
}
}
#[test]
fn frame_matrix3_in_place_ops() {
let mut m = FrameMatrix3::<F1>::from_diagonal([1.0, 2.0, 3.0]);
let other = FrameMatrix3::<F1>::from_diagonal([1.0, 1.0, 1.0]);
m.add_in_place(&other);
assert_eq!(m.as_array()[0][0], 2.0);
assert_eq!(m.as_array()[1][1], 3.0);
assert_eq!(m.as_array()[2][2], 4.0);
m.scale_in_place(2.0);
assert_eq!(m.as_array()[0][0], 4.0);
m.add_outer_product_in_place([1.0, 0.0, 0.0], [0.0, 1.0, 0.0]);
assert_eq!(m.as_array()[0][1], 1.0);
}
#[test]
fn symmetric_in_place_ops() {
let mut m = SymmetricFrameMatrix3::<F1>::from_diagonal([1.0, 2.0, 3.0]);
let other = SymmetricFrameMatrix3::<F1>::from_diagonal([1.0, 1.0, 1.0]);
m.add_in_place(&other);
assert_eq!(m.diagonal(), [2.0, 3.0, 4.0]);
m.scale_in_place(0.5);
assert_eq!(m.diagonal(), [1.0, 1.5, 2.0]);
m.add_outer_product_in_place([1.0, 0.0, 0.0]);
assert_eq!(m.as_array()[0][0], 2.0);
assert_eq!(m.as_array()[0][1], 0.0); }
}