1use crate::Point;
11
12#[derive(Debug, Clone, Copy, PartialEq)]
14pub struct Matrix2D {
15 pub a: f64,
16 pub b: f64,
17 pub c: f64,
18 pub d: f64,
19 pub h: f64,
20 pub v: f64,
21}
22
23impl Matrix2D {
24 pub fn new(a: f64, b: f64, c: f64, d: f64, h: f64, v: f64) -> Self {
26 Self { a, b, c, d, h, v }
27 }
28
29 pub fn identity() -> Self {
31 Self::new(1.0, 0.0, 0.0, 1.0, 0.0, 0.0)
32 }
33
34 pub fn translation(h: f64, v: f64) -> Self {
36 Self::new(1.0, 0.0, 0.0, 1.0, h, v)
37 }
38
39 pub fn scale(sx: f64, sy: f64) -> Self {
41 Self::new(sx, 0.0, 0.0, sy, 0.0, 0.0)
42 }
43
44 pub fn rotation(angle: f64) -> Self {
46 let cos = angle.cos();
47 let sin = angle.sin();
48 Self::new(cos, sin, -sin, cos, 0.0, 0.0)
49 }
50
51 pub fn multiply(&self, other: &Matrix2D) -> Matrix2D {
54 Matrix2D {
55 a: self.a * other.a + self.b * other.c,
56 b: self.a * other.b + self.b * other.d,
57 c: self.c * other.a + self.d * other.c,
58 d: self.c * other.b + self.d * other.d,
59 h: self.h * other.a + self.v * other.c + other.h,
60 v: self.h * other.b + self.v * other.d + other.v,
61 }
62 }
63
64 pub fn concat(&self, other: &Matrix2D) -> Matrix2D {
66 other.multiply(self)
67 }
68
69 pub fn determinant(&self) -> f64 {
71 self.a * self.d - self.b * self.c
72 }
73
74 pub fn inverse(&self) -> Option<Matrix2D> {
76 let det = self.determinant();
77 if det.abs() < 1e-14 {
78 return None;
79 }
80 let inv_det = 1.0 / det;
81 Some(Matrix2D {
82 a: self.d * inv_det,
83 b: -self.b * inv_det,
84 c: -self.c * inv_det,
85 d: self.a * inv_det,
86 h: (self.c * self.v - self.d * self.h) * inv_det,
87 v: (self.b * self.h - self.a * self.v) * inv_det,
88 })
89 }
90
91 pub fn transform_point(&self, x: f64, y: f64) -> Point {
93 Point {
94 x: self.a * x + self.c * y + self.h,
95 y: self.b * x + self.d * y + self.v,
96 }
97 }
98}
99
100impl Default for Matrix2D {
101 fn default() -> Self {
102 Self::identity()
103 }
104}
105
106impl std::ops::Mul for Matrix2D {
107 type Output = Matrix2D;
108 fn mul(self, rhs: Matrix2D) -> Matrix2D {
109 self.multiply(&rhs)
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use super::*;
116
117 fn approx_eq(a: f64, b: f64) -> bool {
118 (a - b).abs() < 1e-10
119 }
120
121 #[test]
122 fn test_identity() {
123 let m = Matrix2D::identity();
124 let p = m.transform_point(3.0, 4.0);
125 assert!(approx_eq(p.x, 3.0));
126 assert!(approx_eq(p.y, 4.0));
127 }
128
129 #[test]
130 fn test_translation() {
131 let m = Matrix2D::translation(10.0, 20.0);
132 let p = m.transform_point(5.0, 5.0);
133 assert!(approx_eq(p.x, 15.0));
134 assert!(approx_eq(p.y, 25.0));
135 }
136
137 #[test]
138 fn test_scale() {
139 let m = Matrix2D::scale(2.0, 3.0);
140 let p = m.transform_point(5.0, 5.0);
141 assert!(approx_eq(p.x, 10.0));
142 assert!(approx_eq(p.y, 15.0));
143 }
144
145 #[test]
146 fn test_inverse() {
147 let m = Matrix2D::new(2.0, 1.0, 1.0, 3.0, 5.0, 7.0);
148 let inv = m.inverse().unwrap();
149 let product = m.multiply(&inv);
150 assert!(approx_eq(product.a, 1.0));
151 assert!(approx_eq(product.b, 0.0));
152 assert!(approx_eq(product.c, 0.0));
153 assert!(approx_eq(product.d, 1.0));
154 assert!(approx_eq(product.h, 0.0));
155 assert!(approx_eq(product.v, 0.0));
156 }
157
158 #[test]
159 fn test_singular_matrix() {
160 let m = Matrix2D::new(1.0, 2.0, 2.0, 4.0, 0.0, 0.0);
161 assert!(m.inverse().is_none());
162 }
163
164 #[test]
165 fn test_rotation() {
166 let m = Matrix2D::rotation(std::f64::consts::FRAC_PI_2); let p = m.transform_point(1.0, 0.0);
168 assert!(approx_eq(p.x, 0.0));
169 assert!(approx_eq(p.y, 1.0));
170 }
171
172 #[test]
173 fn test_multiply_associativity() {
174 let a = Matrix2D::translation(1.0, 2.0);
175 let b = Matrix2D::scale(2.0, 3.0);
176 let c = Matrix2D::rotation(0.5);
177 let ab_c = (a * b) * c;
178 let a_bc = a * (b * c);
179 assert!(approx_eq(ab_c.a, a_bc.a));
180 assert!(approx_eq(ab_c.b, a_bc.b));
181 assert!(approx_eq(ab_c.c, a_bc.c));
182 assert!(approx_eq(ab_c.d, a_bc.d));
183 assert!(approx_eq(ab_c.h, a_bc.h));
184 assert!(approx_eq(ab_c.v, a_bc.v));
185 }
186}