1use super::{Fixed, Point};
4use core::ops::{Add, Mul, MulAssign};
5
6#[derive(Copy, Clone, PartialEq, Eq, Debug)]
35#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
36#[cfg_attr(feature = "bytemuck", derive(bytemuck::AnyBitPattern))]
37#[repr(C)]
38pub struct Matrix<T> {
39 pub xx: T,
40 pub yx: T,
41 pub xy: T,
42 pub yy: T,
43 pub dx: T,
44 pub dy: T,
45}
46
47impl<T: MatrixElement> Default for Matrix<T> {
48 fn default() -> Self {
49 Self::IDENTITY
50 }
51}
52
53impl<T: Copy> Matrix<T> {
54 pub const fn from_elements(elements: [T; 6]) -> Self {
57 Self {
58 xx: elements[0],
59 yx: elements[1],
60 xy: elements[2],
61 yy: elements[3],
62 dx: elements[4],
63 dy: elements[5],
64 }
65 }
66
67 pub const fn elements(&self) -> [T; 6] {
70 [self.xx, self.yx, self.xy, self.yy, self.dx, self.dy]
71 }
72
73 #[inline(always)]
75 pub fn map<U: Copy>(self, f: impl FnMut(T) -> U) -> Matrix<U> {
76 Matrix::from_elements(self.elements().map(f))
77 }
78}
79
80impl<T: MatrixElement> Matrix<T> {
81 pub const IDENTITY: Self = Self {
83 xx: T::ONE,
84 yx: T::ZERO,
85 xy: T::ZERO,
86 yy: T::ONE,
87 dx: T::ZERO,
88 dy: T::ZERO,
89 };
90
91 pub fn transform(&self, x: T, y: T) -> (T, T) {
93 (
94 self.xx * x + self.xy * y + self.dx,
95 self.yx * x + self.yy * y + self.dy,
96 )
97 }
98}
99
100impl<T: MatrixElement> Mul for Matrix<T> {
101 type Output = Self;
102
103 fn mul(self, rhs: Self) -> Self::Output {
104 Self {
105 xx: self.xx * rhs.xx + self.xy * rhs.yx,
106 yx: self.yx * rhs.xx + self.yy * rhs.yx,
107 xy: self.xx * rhs.xy + self.xy * rhs.yy,
108 yy: self.yx * rhs.xy + self.yy * rhs.yy,
109 dx: self.xx * rhs.dx + self.xy * rhs.dy + self.dx,
110 dy: self.yx * rhs.dx + self.yy * rhs.dy + self.dy,
111 }
112 }
113}
114
115impl<T: MatrixElement> MulAssign for Matrix<T> {
116 fn mul_assign(&mut self, rhs: Self) {
117 *self = *self * rhs;
118 }
119}
120
121impl<T: MatrixElement> Mul<Point<T>> for Matrix<T> {
122 type Output = Point<T>;
123
124 fn mul(self, rhs: Point<T>) -> Self::Output {
125 let (x, y) = self.transform(rhs.x, rhs.y);
126 Point::new(x, y)
127 }
128}
129
130pub trait MatrixElement: Copy + Add<Output = Self> + Mul<Output = Self> {
132 const ZERO: Self;
133 const ONE: Self;
134}
135
136impl MatrixElement for Fixed {
137 const ONE: Self = Fixed::ONE;
138 const ZERO: Self = Fixed::ZERO;
139}
140
141impl MatrixElement for f32 {
142 const ONE: Self = 1.0;
143 const ZERO: Self = 0.0;
144}
145
146impl MatrixElement for f64 {
147 const ONE: Self = 1.0;
148 const ZERO: Self = 0.0;
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154
155 #[test]
156 fn mul_matrix_identity_and_known_product() {
157 let a = Matrix::from_elements([0.5f32, 1.0, -2.0, 0.25, 7.0, -3.0]);
158 assert_eq!(Matrix::IDENTITY * a, a);
159 assert_eq!(a * Matrix::IDENTITY, a);
160 let translate = Matrix::from_elements([1.0, 0.0, 0.0, 1.0, 10.0, 20.0]);
161 let scale = Matrix::from_elements([2.0, 0.0, 0.0, 3.0, 0.0, 0.0]);
162 assert_eq!(
163 (translate * scale).elements(),
164 [2.0, 0.0, 0.0, 3.0, 10.0, 20.0]
165 );
166 assert_eq!(
167 (scale * translate).elements(),
168 [2.0, 0.0, 0.0, 3.0, 20.0, 60.0]
169 );
170 }
171
172 #[test]
173 fn transform_points() {
174 let translate = Matrix::from_elements([1.0f32, 0.0, 0.0, 1.0, 10.0, 20.0]);
175 let scale = Matrix::from_elements([2.0, 0.0, 0.0, 3.0, 0.0, 0.0]);
176 let translate_scale = translate * scale;
177 let scale_translate = scale * translate;
178 let p = Point::new(5.0, -22.0);
179 assert_eq!(translate * p, Point::new(15.0, -2.0));
180 assert_eq!(scale * p, Point::new(10.0, -66.0));
181 assert_eq!(translate_scale * p, Point::new(20.0, -46.0));
182 assert_eq!(scale_translate * p, Point::new(30.0, -6.0));
183 }
184}