1use crate::CoordinateType;
9
10use crate::vector::Vector;
11
12#[derive(Clone, Hash, PartialEq, Eq, Debug)]
18#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
19pub struct Matrix2d<T: CoordinateType> {
20 pub(crate) m11: T,
22 pub(crate) m12: T,
24 pub(crate) m21: T,
26 pub(crate) m22: T,
28}
29
30impl<T> Matrix2d<T>
31where
32 T: CoordinateType,
33{
34 pub fn new(m11: T, m12: T, m21: T, m22: T) -> Self {
40 Matrix2d { m11, m12, m21, m22 }
41 }
42
43 pub fn identity() -> Self {
45 Self::new(T::one(), T::zero(), T::zero(), T::one())
46 }
47
48 pub fn mul_scalar(&self, rhs: T) -> Self {
50 Matrix2d::new(
51 self.m11 * rhs,
52 self.m12 * rhs,
53 self.m21 * rhs,
54 self.m22 * rhs,
55 )
56 }
57
58 pub fn mul_column_vector(&self, rhs: Vector<T>) -> Vector<T> {
61 Vector::new(
62 rhs.x * self.m11 + rhs.y * self.m12,
63 rhs.x * self.m21 + rhs.y * self.m22,
64 )
65 }
66
67 pub fn mul_matrix(&self, rhs: &Self) -> Self {
69 let a = self;
70 let b = rhs;
71 let c11 = a.m11 * b.m11 + a.m12 * b.m21;
72 let c12 = a.m11 * b.m12 + a.m12 * b.m22;
73 let c21 = a.m21 * b.m11 + a.m22 * b.m21;
74 let c22 = a.m21 * b.m12 + a.m22 * b.m22;
75 Self::new(c11, c12, c21, c22)
76 }
77
78 pub fn transpose(&self) -> Self {
80 Self::new(self.m11, self.m21, self.m12, self.m22)
81 }
82
83 pub fn determinant(&self) -> T {
85 self.m11 * self.m22 - self.m12 * self.m21
86 }
87
88 pub fn is_identity(&self) -> bool {
90 self == &Self::identity()
91 }
92
93 pub fn is_unitary(&self) -> bool {
95 self.mul_matrix(&self.transpose()).is_identity()
96 }
97
98 pub fn try_inverse(&self) -> Option<Self> {
102 let det = self.determinant();
104 if !det.is_zero() {
105 let z = T::zero();
106 Some(Self::new(
107 self.m22 / det,
108 z - self.m12 / det,
109 z - self.m21 / det,
110 self.m11 / det,
111 ))
112 } else {
113 None
114 }
115 }
116}
117
118impl<T: CoordinateType> Default for Matrix2d<T> {
119 fn default() -> Self {
120 Self::identity()
121 }
122}
123
124#[test]
125fn test_matrix_multiplication() {
126 let a = Matrix2d::new(1.0, 2.0, 3.0, 4.0);
127 let b = Matrix2d::new(5.0, 6.0, 7.0, 8.0);
128 let id = Matrix2d::identity();
129 assert_eq!(id.mul_matrix(&id), id);
130 assert_eq!(b.mul_matrix(&id), b);
131 assert_eq!(id.mul_matrix(&b), b);
132 assert_eq!(
133 a.mul_matrix(&b),
134 Matrix2d::new(19.0, 22.0, 15.0 + 28.0, 18.0 + 32.0)
135 );
136}
137
138#[test]
139fn test_inverse() {
140 let m = Matrix2d::new(2.0, 1.0, 4.0, 8.0);
141 let i = m.try_inverse().unwrap();
142 assert_eq!(m.mul_matrix(&i), Matrix2d::identity());
143 assert_eq!(i.mul_matrix(&m), Matrix2d::identity());
144}