rs_math3d/
matrix.rs

1// Copyright 2020-Present (c) Raja Lehtihet & Wael El Oraiby
2//
3// Redistribution and use in source and binary forms, with or without
4// modification, are permitted provided that the following conditions are met:
5//
6// 1. Redistributions of source code must retain the above copyright notice,
7// this list of conditions and the following disclaimer.
8//
9// 2. Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12//
13// 3. Neither the name of the copyright holder nor the names of its contributors
14// may be used to endorse or promote products derived from this software without
15// specific prior written permission.
16//
17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
21// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27// POSSIBILITY OF SUCH DAMAGE.
28//! Matrix mathematics module providing 2x2, 3x3, and 4x4 matrices.
29//!
30//! This module provides square matrix types commonly used in computer graphics
31//! and linear algebra. Matrices are stored in column-major order for compatibility
32//! with graphics APIs like OpenGL.
33//!
34//! # Examples
35//!
36//! ```
37//! use rs_math3d::matrix::Matrix4;
38//! use rs_math3d::vector::Vector4;
39//! 
40//! let m = Matrix4::<f32>::identity();
41//! let v = Vector4::new(1.0, 2.0, 3.0, 1.0);
42//! let result = m * v; // Transform vector
43//! ```
44
45use crate::scalar::*;
46use crate::vector::*;
47use num_traits::{Zero, One};
48use core::ops::*;
49
50/// A 2x2 matrix stored in column-major order.
51///
52/// # Layout
53/// ```text
54/// [m₀₀ m₀₁]
55/// [m₁₀ m₁₁]
56/// ```
57/// where `col[j][i]` represents element at row i, column j.
58#[repr(C)]
59#[derive(Clone, Copy, Debug)]
60pub struct Matrix2<T: Scalar> {
61    /// Column vectors of the matrix
62    pub col: [Vector2<T>; 2],
63}
64
65/// A 3x3 matrix stored in column-major order.
66///
67/// Commonly used for 2D transformations (with homogeneous coordinates)
68/// and 3D rotations.
69///
70/// # Layout
71/// ```text
72/// [m₀₀ m₀₁ m₀₂]
73/// [m₁₀ m₁₁ m₁₂]
74/// [m₂₀ m₂₁ m₂₂]
75/// ```
76#[repr(C)]
77#[derive(Clone, Copy, Debug)]
78pub struct Matrix3<T: Scalar> {
79    /// Column vectors of the matrix
80    pub col: [Vector3<T>; 3],
81}
82
83/// A 4x4 matrix stored in column-major order.
84///
85/// The standard matrix for 3D transformations using homogeneous coordinates.
86///
87/// # Layout
88/// ```text
89/// [m₀₀ m₀₁ m₀₂ m₀₃]
90/// [m₁₀ m₁₁ m₁₂ m₁₃]
91/// [m₂₀ m₂₁ m₂₂ m₂₃]
92/// [m₃₀ m₃₁ m₃₂ m₃₃]
93/// ```
94#[repr(C)]
95#[derive(Clone, Copy, Debug)]
96pub struct Matrix4<T: Scalar> {
97    /// Column vectors of the matrix
98    pub col: [Vector4<T>; 4],
99}
100
101/******************************************************************************
102 * Matrix2
103 *
104 * i j ------------------->
105 * | [m0 = c0_x | m2 = c1_x]
106 * V [m1 = c0_y | m3 = c1_y]
107 *
108 *  aij => i = row, j = col (yx form)
109 *
110 *****************************************************************************/
111impl<T: Scalar> Matrix2<T> {
112    /// Creates a new 2x2 matrix from individual elements.
113    ///
114    /// Elements are provided in column-major order:
115    /// ```text
116    /// [m0 m2]
117    /// [m1 m3]
118    /// ```
119    pub fn new(m0: T, m1: T, m2: T, m3: T) -> Self {
120        Matrix2 {
121            col: [Vector2::new(m0, m1), Vector2::new(m2, m3)],
122        }
123    }
124
125    /// Returns the 2x2 identity matrix.
126    ///
127    /// ```text
128    /// [1 0]
129    /// [0 1]
130    /// ```
131    pub fn identity() -> Self {
132        Self::new(<T as One>::one(), <T as Zero>::zero(), <T as Zero>::zero(), <T as One>::one())
133    }
134
135    /// Computes the determinant of the matrix.
136    ///
137    /// For a 2x2 matrix:
138    /// ```text
139    /// det(M) = m₀₀m₁₁ - m₀₁m₁₀
140    /// ```
141    pub fn determinant(&self) -> T {
142        let m00 = self.col[0].x;
143        let m10 = self.col[0].y;
144
145        let m01 = self.col[1].x;
146        let m11 = self.col[1].y;
147
148        m00 * m11 - m01 * m10
149    }
150
151    /// Returns the transpose of the matrix.
152    ///
153    /// ```text
154    /// Mᵀ[i,j] = M[j,i]
155    /// ```
156    pub fn transpose(&self) -> Self {
157        let m00 = self.col[0].x;
158        let m10 = self.col[0].y;
159
160        let m01 = self.col[1].x;
161        let m11 = self.col[1].y;
162
163        Self::new(m00, m01, m10, m11)
164    }
165
166    /// Computes the inverse of the matrix.
167    ///
168    /// For a 2x2 matrix:
169    /// ```text
170    /// M⁻¹ = (1/det(M)) * [m₁₁  -m₀₁]
171    ///                    [-m₁₀  m₀₀]
172    /// ```
173    ///
174    /// # Note
175    /// Returns NaN or Inf if the matrix is singular (determinant = 0).
176    pub fn inverse(&self) -> Self {
177        let m00 = self.col[0].x;
178        let m10 = self.col[0].y;
179
180        let m01 = self.col[1].x;
181        let m11 = self.col[1].y;
182
183        let inv_det = <T as One>::one() / (m00 * m11 - m01 * m10);
184
185        let r00 = m11 * inv_det;
186        let r01 = -m01 * inv_det;
187        let r10 = -m10 * inv_det;
188        let r11 = m00 * inv_det;
189
190        Self::new(r00, r10, r01, r11)
191    }
192
193    /// Multiplies two 2x2 matrices.
194    ///
195    /// Matrix multiplication follows the rule:
196    /// ```text
197    /// C[i,j] = Σₖ A[i,k] * B[k,j]
198    /// ```
199    pub fn mul_matrix_matrix(l: &Self, r: &Self) -> Self {
200        let a00 = l.col[0].x;
201        let a10 = l.col[0].y;
202        let a01 = l.col[1].x;
203        let a11 = l.col[1].y;
204
205        let b00 = r.col[0].x;
206        let b10 = r.col[0].y;
207        let b01 = r.col[1].x;
208        let b11 = r.col[1].y;
209
210        let c00 = a00 * b00 + a01 * b10;
211        let c01 = a00 * b01 + a01 * b11;
212        let c10 = a10 * b00 + a11 * b10;
213        let c11 = a10 * b01 + a11 * b11;
214
215        Self::new(c00, c10, c01, c11)
216    }
217
218    /// Multiplies a 2x2 matrix by a 2D vector.
219    ///
220    /// Transforms the vector by the matrix:
221    /// ```text
222    /// v' = M * v
223    /// ```
224    pub fn mul_matrix_vector(l: &Self, r: &Vector2<T>) -> Vector2<T> {
225        Self::mul_vector_matrix(r, &l.transpose())
226    }
227
228    /// Multiplies a 2D vector by a 2x2 matrix (row vector).
229    ///
230    /// ```text
231    /// v' = vᵀ * M
232    /// ```
233    pub fn mul_vector_matrix(l: &Vector2<T>, r: &Self) -> Vector2<T> {
234        Vector2::new(Vector2::dot(l, &r.col[0]), Vector2::dot(l, &r.col[1]))
235    }
236
237    /// Adds two matrices element-wise.
238    ///
239    /// ```text
240    /// C[i,j] = A[i,j] + B[i,j]
241    /// ```
242    pub fn add_matrix_matrix(l: &Self, r: &Self) -> Self {
243        Matrix2 {
244            col: [l.col[0] + r.col[0], l.col[1] + r.col[1]],
245        }
246    }
247
248    /// Subtracts two matrices element-wise.
249    ///
250    /// ```text
251    /// C[i,j] = A[i,j] - B[i,j]
252    /// ```
253    pub fn sub_matrix_matrix(l: &Self, r: &Self) -> Self {
254        Matrix2 {
255            col: [l.col[0] - r.col[0], l.col[1] - r.col[1]],
256        }
257    }
258}
259
260/******************************************************************************
261 * Matrix3
262 *
263 * i j -------------------------------->
264 * | [m0 = c0_x | m3 = c1_x | m6 = c2_x]
265 * | [m1 = c0_y | m4 = c1_y | m7 = c2_y]
266 * V [m2 = c0_z | m5 = c1_z | m8 = c2_z]
267 *
268 *  aij => i = row, j = col (yx form)
269 *
270 *****************************************************************************/
271impl<T: Scalar> Matrix3<T> {
272    /// Creates a new 3x3 matrix from column-major elements.
273    pub fn new(m0: T, m1: T, m2: T, m3: T, m4: T, m5: T, m6: T, m7: T, m8: T) -> Self {
274        Matrix3 {
275            col: [
276                Vector3::new(m0, m1, m2),
277                Vector3::new(m3, m4, m5),
278                Vector3::new(m6, m7, m8),
279            ],
280        }
281    }
282
283    /// Returns the 3x3 identity matrix.
284    ///
285    /// ```text
286    /// [1 0 0]
287    /// [0 1 0]
288    /// [0 0 1]
289    /// ```
290    pub fn identity() -> Self {
291        Self::new(
292            <T as One>::one(),
293            <T as Zero>::zero(),
294            <T as Zero>::zero(),
295            <T as Zero>::zero(),
296            <T as One>::one(),
297            <T as Zero>::zero(),
298            <T as Zero>::zero(),
299            <T as Zero>::zero(),
300            <T as One>::one(),
301        )
302    }
303
304    /// Computes the determinant of the matrix.
305    pub fn determinant(&self) -> T {
306        let m00 = self.col[0].x;
307        let m10 = self.col[0].y;
308        let m20 = self.col[0].z;
309
310        let m01 = self.col[1].x;
311        let m11 = self.col[1].y;
312        let m21 = self.col[1].z;
313
314        let m02 = self.col[2].x;
315        let m12 = self.col[2].y;
316        let m22 = self.col[2].z;
317
318        m00 * m11 * m22 + m01 * m12 * m20 + m02 * m10 * m21
319            - m00 * m12 * m21
320            - m01 * m10 * m22
321            - m02 * m11 * m20
322    }
323
324    /// Returns the transpose of the matrix.
325    ///
326    /// ```text
327    /// Mᵀ[i,j] = M[j,i]
328    /// ```
329    pub fn transpose(&self) -> Self {
330        let m00 = self.col[0].x;
331        let m10 = self.col[0].y;
332        let m20 = self.col[0].z;
333
334        let m01 = self.col[1].x;
335        let m11 = self.col[1].y;
336        let m21 = self.col[1].z;
337
338        let m02 = self.col[2].x;
339        let m12 = self.col[2].y;
340        let m22 = self.col[2].z;
341
342        Self::new(m00, m01, m02, m10, m11, m12, m20, m21, m22)
343    }
344
345    /// Computes the inverse of the matrix.
346    ///
347    /// Uses the adjugate matrix method:
348    /// ```text
349    /// M⁻¹ = (1/det(M)) * adj(M)
350    /// ```
351    ///
352    /// # Note
353    /// Returns NaN or Inf if the matrix is singular (determinant = 0).
354    pub fn inverse(&self) -> Self {
355        let m00 = self.col[0].x;
356        let m10 = self.col[0].y;
357        let m20 = self.col[0].z;
358
359        let m01 = self.col[1].x;
360        let m11 = self.col[1].y;
361        let m21 = self.col[1].z;
362
363        let m02 = self.col[2].x;
364        let m12 = self.col[2].y;
365        let m22 = self.col[2].z;
366
367        let inv_det = <T as One>::one()
368            / (m00 * m11 * m22 + m01 * m12 * m20 + m02 * m10 * m21
369                - m00 * m12 * m21
370                - m01 * m10 * m22
371                - m02 * m11 * m20);
372
373        let r00 = (m11 * m22 - m12 * m21) * inv_det;
374        let r01 = (m02 * m21 - m01 * m22) * inv_det;
375        let r02 = (m01 * m12 - m02 * m11) * inv_det;
376        let r10 = (m12 * m20 - m10 * m22) * inv_det;
377        let r11 = (m00 * m22 - m02 * m20) * inv_det;
378        let r12 = (m02 * m10 - m00 * m12) * inv_det;
379        let r20 = (m10 * m21 - m11 * m20) * inv_det;
380        let r21 = (m01 * m20 - m00 * m21) * inv_det;
381        let r22 = (m00 * m11 - m01 * m10) * inv_det;
382
383        Self::new(r00, r10, r20, r01, r11, r21, r02, r12, r22)
384    }
385
386    /// Multiplies two 3x3 matrices.
387    ///
388    /// Matrix multiplication follows the rule:
389    /// ```text
390    /// C[i,j] = Σₖ A[i,k] * B[k,j]
391    /// ```
392    pub fn mul_matrix_matrix(l: &Self, r: &Self) -> Self {
393        let a00 = l.col[0].x;
394        let a10 = l.col[0].y;
395        let a20 = l.col[0].z;
396
397        let a01 = l.col[1].x;
398        let a11 = l.col[1].y;
399        let a21 = l.col[1].z;
400
401        let a02 = l.col[2].x;
402        let a12 = l.col[2].y;
403        let a22 = l.col[2].z;
404
405        let b00 = r.col[0].x;
406        let b10 = r.col[0].y;
407        let b20 = r.col[0].z;
408
409        let b01 = r.col[1].x;
410        let b11 = r.col[1].y;
411        let b21 = r.col[1].z;
412
413        let b02 = r.col[2].x;
414        let b12 = r.col[2].y;
415        let b22 = r.col[2].z;
416
417        let c00 = a00 * b00 + a01 * b10 + a02 * b20;
418        let c01 = a00 * b01 + a01 * b11 + a02 * b21;
419        let c02 = a00 * b02 + a01 * b12 + a02 * b22;
420
421        let c10 = a10 * b00 + a11 * b10 + a12 * b20;
422        let c11 = a10 * b01 + a11 * b11 + a12 * b21;
423        let c12 = a10 * b02 + a11 * b12 + a12 * b22;
424
425        let c20 = a20 * b00 + a21 * b10 + a22 * b20;
426        let c21 = a20 * b01 + a21 * b11 + a22 * b21;
427        let c22 = a20 * b02 + a21 * b12 + a22 * b22;
428
429        Self::new(c00, c10, c20, c01, c11, c21, c02, c12, c22)
430    }
431
432    /// Multiplies a 3x3 matrix by a 3D vector.
433    ///
434    /// Transforms the vector by the matrix:
435    /// ```text
436    /// v' = M * v
437    /// ```
438    pub fn mul_matrix_vector(l: &Self, r: &Vector3<T>) -> Vector3<T> {
439        Self::mul_vector_matrix(r, &l.transpose())
440    }
441
442    /// Multiplies a 3D vector by a 3x3 matrix (row vector).
443    ///
444    /// ```text
445    /// v' = vᵀ * M
446    /// ```
447    pub fn mul_vector_matrix(l: &Vector3<T>, r: &Self) -> Vector3<T> {
448        Vector3::new(
449            Vector3::dot(l, &r.col[0]),
450            Vector3::dot(l, &r.col[1]),
451            Vector3::dot(l, &r.col[2]),
452        )
453    }
454
455    /// Adds two matrices element-wise.
456    ///
457    /// ```text
458    /// C[i,j] = A[i,j] + B[i,j]
459    /// ```
460    pub fn add_matrix_matrix(l: &Self, r: &Self) -> Self {
461        Matrix3 {
462            col: [
463                l.col[0] + r.col[0],
464                l.col[1] + r.col[1],
465                l.col[2] + r.col[2],
466            ],
467        }
468    }
469
470    /// Subtracts two matrices element-wise.
471    ///
472    /// ```text
473    /// C[i,j] = A[i,j] - B[i,j]
474    /// ```
475    pub fn sub_matrix_matrix(l: &Self, r: &Self) -> Self {
476        Matrix3 {
477            col: [
478                l.col[0] - r.col[0],
479                l.col[1] - r.col[1],
480                l.col[2] - r.col[2],
481            ],
482        }
483    }
484}
485
486impl<T: FloatScalar> Matrix3<T> {
487    /// Creates a 3x3 rotation matrix from an axis and angle.
488    ///
489    /// Uses Rodrigues' rotation formula:
490    /// ```text
491    /// R = I + sin(θ)K + (1 - cos(θ))K²
492    /// ```
493    /// where K is the cross-product matrix of the normalized axis.
494    ///
495    /// # Parameters
496    /// - `axis`: The rotation axis (will be normalized)
497    /// - `angle`: The rotation angle in radians
498    /// - `epsilon`: Minimum axis length to treat as valid
499    ///
500    /// # Returns
501    /// - `Some(matrix)` for a valid axis
502    /// - `None` if the axis length is too small
503    pub fn of_axis_angle(axis: &Vector3<T>, angle: T, epsilon: T) -> Option<Self> {
504        let len_sq = Vector3::dot(axis, axis);
505        if len_sq <= epsilon * epsilon {
506            return None;
507        }
508        let inv_len = <T as One>::one() / len_sq.tsqrt();
509        let n = *axis * inv_len;
510        let c = T::tcos(angle);
511        let s = T::tsin(angle);
512        let ux = n.x;
513        let uy = n.y;
514        let uz = n.z;
515        let uxx = ux * ux;
516        let uyy = uy * uy;
517        let uzz = uz * uz;
518
519        let oc = <T as One>::one() - c;
520
521        let m0 = c + uxx * oc;
522        let m1 = uy * ux * oc + uz * s;
523        let m2 = uz * ux * oc - uy * s;
524
525        let m3 = ux * uy * oc - uz * s;
526        let m4 = c + uyy * oc;
527        let m5 = uz * uy * oc + ux * s;
528
529        let m6 = ux * uz * oc + uy * s;
530        let m7 = uy * uz * oc - ux * s;
531        let m8 = c + uzz * oc;
532
533        Some(Self::new(m0, m1, m2, m3, m4, m5, m6, m7, m8))
534    }
535}
536
537/******************************************************************************
538 * Matrix4
539 *
540 * i j -------------------------------------------->
541 * | [m0 = c0_x | m4 = c1_x | m8 = c2_x | m12= c3_x]
542 * | [m1 = c0_y | m5 = c1_y | m9 = c2_y | m13= c3_y]
543 * | [m2 = c0_z | m6 = c1_z | m10= c2_z | m14= c3_z]
544 * V [m3 = c0_w | m7 = c1_w | m11= c2_w | m15= c3_w]
545 *
546 *  aij => i = row, j = col (yx form)
547 *
548 *****************************************************************************/
549impl<T: Scalar> Matrix4<T> {
550    /// Creates a new 4x4 matrix from column-major elements.
551    pub fn new(
552        m0: T,
553        m1: T,
554        m2: T,
555        m3: T,
556        m4: T,
557        m5: T,
558        m6: T,
559        m7: T,
560        m8: T,
561        m9: T,
562        m10: T,
563        m11: T,
564        m12: T,
565        m13: T,
566        m14: T,
567        m15: T,
568    ) -> Self {
569        Matrix4 {
570            col: [
571                Vector4::new(m0, m1, m2, m3),
572                Vector4::new(m4, m5, m6, m7),
573                Vector4::new(m8, m9, m10, m11),
574                Vector4::new(m12, m13, m14, m15),
575            ],
576        }
577    }
578
579    /// Returns the 4x4 identity matrix.
580    pub fn identity() -> Self {
581        Self::new(
582            <T as One>::one(),
583            <T as Zero>::zero(),
584            <T as Zero>::zero(),
585            <T as Zero>::zero(),
586            <T as Zero>::zero(),
587            <T as One>::one(),
588            <T as Zero>::zero(),
589            <T as Zero>::zero(),
590            <T as Zero>::zero(),
591            <T as Zero>::zero(),
592            <T as One>::one(),
593            <T as Zero>::zero(),
594            <T as Zero>::zero(),
595            <T as Zero>::zero(),
596            <T as Zero>::zero(),
597            <T as One>::one(),
598        )
599    }
600
601    /// Computes the determinant of the matrix.
602    ///
603    /// Uses Laplace expansion along the first column.
604    /// A non-zero determinant indicates the matrix is invertible.
605    pub fn determinant(&self) -> T {
606        let m00 = self.col[0].x;
607        let m10 = self.col[0].y;
608        let m20 = self.col[0].z;
609        let m30 = self.col[0].w;
610
611        let m01 = self.col[1].x;
612        let m11 = self.col[1].y;
613        let m21 = self.col[1].z;
614        let m31 = self.col[1].w;
615
616        let m02 = self.col[2].x;
617        let m12 = self.col[2].y;
618        let m22 = self.col[2].z;
619        let m32 = self.col[2].w;
620
621        let m03 = self.col[3].x;
622        let m13 = self.col[3].y;
623        let m23 = self.col[3].z;
624        let m33 = self.col[3].w;
625
626        m03 * m12 * m21 * m30 - m02 * m13 * m21 * m30 - m03 * m11 * m22 * m30
627            + m01 * m13 * m22 * m30
628            + m02 * m11 * m23 * m30
629            - m01 * m12 * m23 * m30
630            - m03 * m12 * m20 * m31
631            + m02 * m13 * m20 * m31
632            + m03 * m10 * m22 * m31
633            - m00 * m13 * m22 * m31
634            - m02 * m10 * m23 * m31
635            + m00 * m12 * m23 * m31
636            + m03 * m11 * m20 * m32
637            - m01 * m13 * m20 * m32
638            - m03 * m10 * m21 * m32
639            + m00 * m13 * m21 * m32
640            + m01 * m10 * m23 * m32
641            - m00 * m11 * m23 * m32
642            - m02 * m11 * m20 * m33
643            + m01 * m12 * m20 * m33
644            + m02 * m10 * m21 * m33
645            - m00 * m12 * m21 * m33
646            - m01 * m10 * m22 * m33
647            + m00 * m11 * m22 * m33
648    }
649
650    /// Returns the transpose of the matrix.
651    ///
652    /// ```text
653    /// Mᵀ[i,j] = M[j,i]
654    /// ```
655    pub fn transpose(&self) -> Self {
656        let m00 = self.col[0].x;
657        let m10 = self.col[0].y;
658        let m20 = self.col[0].z;
659        let m30 = self.col[0].w;
660
661        let m01 = self.col[1].x;
662        let m11 = self.col[1].y;
663        let m21 = self.col[1].z;
664        let m31 = self.col[1].w;
665
666        let m02 = self.col[2].x;
667        let m12 = self.col[2].y;
668        let m22 = self.col[2].z;
669        let m32 = self.col[2].w;
670
671        let m03 = self.col[3].x;
672        let m13 = self.col[3].y;
673        let m23 = self.col[3].z;
674        let m33 = self.col[3].w;
675
676        Self::new(
677            m00, m01, m02, m03, m10, m11, m12, m13, m20, m21, m22, m23, m30, m31, m32, m33,
678        )
679    }
680
681    /// Returns true if the matrix is affine (last row equals [0, 0, 0, 1]).
682    pub fn is_affine(&self, epsilon: T) -> bool {
683        self.col[0].w.tabs() <= epsilon
684            && self.col[1].w.tabs() <= epsilon
685            && self.col[2].w.tabs() <= epsilon
686            && (self.col[3].w - <T as One>::one()).tabs() <= epsilon
687    }
688
689    /// Computes the inverse of the matrix.
690    ///
691    /// Uses the adjugate matrix method:
692    /// ```text
693    /// M⁻¹ = (1/det(M)) * adj(M)
694    /// ```
695    ///
696    /// # Note
697    /// Returns NaN or Inf if the matrix is singular (determinant = 0).
698    pub fn inverse(&self) -> Self {
699        let m00 = self.col[0].x;
700        let m10 = self.col[0].y;
701        let m20 = self.col[0].z;
702        let m30 = self.col[0].w;
703
704        let m01 = self.col[1].x;
705        let m11 = self.col[1].y;
706        let m21 = self.col[1].z;
707        let m31 = self.col[1].w;
708
709        let m02 = self.col[2].x;
710        let m12 = self.col[2].y;
711        let m22 = self.col[2].z;
712        let m32 = self.col[2].w;
713
714        let m03 = self.col[3].x;
715        let m13 = self.col[3].y;
716        let m23 = self.col[3].z;
717        let m33 = self.col[3].w;
718
719        let denom = m03 * m12 * m21 * m30 - m02 * m13 * m21 * m30 - m03 * m11 * m22 * m30
720            + m01 * m13 * m22 * m30
721            + m02 * m11 * m23 * m30
722            - m01 * m12 * m23 * m30
723            - m03 * m12 * m20 * m31
724            + m02 * m13 * m20 * m31
725            + m03 * m10 * m22 * m31
726            - m00 * m13 * m22 * m31
727            - m02 * m10 * m23 * m31
728            + m00 * m12 * m23 * m31
729            + m03 * m11 * m20 * m32
730            - m01 * m13 * m20 * m32
731            - m03 * m10 * m21 * m32
732            + m00 * m13 * m21 * m32
733            + m01 * m10 * m23 * m32
734            - m00 * m11 * m23 * m32
735            - m02 * m11 * m20 * m33
736            + m01 * m12 * m20 * m33
737            + m02 * m10 * m21 * m33
738            - m00 * m12 * m21 * m33
739            - m01 * m10 * m22 * m33
740            + m00 * m11 * m22 * m33;
741        let inv_det = <T as One>::one() / denom;
742
743        let r00 = (m12 * m23 * m31 - m13 * m22 * m31 + m13 * m21 * m32
744            - m11 * m23 * m32
745            - m12 * m21 * m33
746            + m11 * m22 * m33)
747            * inv_det;
748
749        let r01 = (m03 * m22 * m31 - m02 * m23 * m31 - m03 * m21 * m32
750            + m01 * m23 * m32
751            + m02 * m21 * m33
752            - m01 * m22 * m33)
753            * inv_det;
754
755        let r02 = (m02 * m13 * m31 - m03 * m12 * m31 + m03 * m11 * m32
756            - m01 * m13 * m32
757            - m02 * m11 * m33
758            + m01 * m12 * m33)
759            * inv_det;
760
761        let r03 = (m03 * m12 * m21 - m02 * m13 * m21 - m03 * m11 * m22
762            + m01 * m13 * m22
763            + m02 * m11 * m23
764            - m01 * m12 * m23)
765            * inv_det;
766
767        let r10 = (m13 * m22 * m30 - m12 * m23 * m30 - m13 * m20 * m32
768            + m10 * m23 * m32
769            + m12 * m20 * m33
770            - m10 * m22 * m33)
771            * inv_det;
772
773        let r11 = (m02 * m23 * m30 - m03 * m22 * m30 + m03 * m20 * m32
774            - m00 * m23 * m32
775            - m02 * m20 * m33
776            + m00 * m22 * m33)
777            * inv_det;
778
779        let r12 = (m03 * m12 * m30 - m02 * m13 * m30 - m03 * m10 * m32
780            + m00 * m13 * m32
781            + m02 * m10 * m33
782            - m00 * m12 * m33)
783            * inv_det;
784
785        let r13 = (m02 * m13 * m20 - m03 * m12 * m20 + m03 * m10 * m22
786            - m00 * m13 * m22
787            - m02 * m10 * m23
788            + m00 * m12 * m23)
789            * inv_det;
790
791        let r20 = (m11 * m23 * m30 - m13 * m21 * m30 + m13 * m20 * m31
792            - m10 * m23 * m31
793            - m11 * m20 * m33
794            + m10 * m21 * m33)
795            * inv_det;
796
797        let r21 = (m03 * m21 * m30 - m01 * m23 * m30 - m03 * m20 * m31
798            + m00 * m23 * m31
799            + m01 * m20 * m33
800            - m00 * m21 * m33)
801            * inv_det;
802
803        let r22 = (m01 * m13 * m30 - m03 * m11 * m30 + m03 * m10 * m31
804            - m00 * m13 * m31
805            - m01 * m10 * m33
806            + m00 * m11 * m33)
807            * inv_det;
808
809        let r23 = (m03 * m11 * m20 - m01 * m13 * m20 - m03 * m10 * m21
810            + m00 * m13 * m21
811            + m01 * m10 * m23
812            - m00 * m11 * m23)
813            * inv_det;
814
815        let r30 = (m12 * m21 * m30 - m11 * m22 * m30 - m12 * m20 * m31
816            + m10 * m22 * m31
817            + m11 * m20 * m32
818            - m10 * m21 * m32)
819            * inv_det;
820
821        let r31 = (m01 * m22 * m30 - m02 * m21 * m30 + m02 * m20 * m31
822            - m00 * m22 * m31
823            - m01 * m20 * m32
824            + m00 * m21 * m32)
825            * inv_det;
826
827        let r32 = (m02 * m11 * m30 - m01 * m12 * m30 - m02 * m10 * m31
828            + m00 * m12 * m31
829            + m01 * m10 * m32
830            - m00 * m11 * m32)
831            * inv_det;
832
833        let r33 = (m01 * m12 * m20 - m02 * m11 * m20 + m02 * m10 * m21
834            - m00 * m12 * m21
835            - m01 * m10 * m22
836            + m00 * m11 * m22)
837            * inv_det;
838
839        Self::new(
840            r00, r10, r20, r30, r01, r11, r21, r31, r02, r12, r22, r32, r03, r13, r23, r33,
841        )
842    }
843
844    /// Computes the inverse of an affine matrix (rotation/scale + translation).
845    ///
846    /// Assumes the last row is `[0, 0, 0, 1]`.
847    pub fn inverse_affine(&self) -> Self {
848        let rot = Matrix3::new(
849            self.col[0].x,
850            self.col[0].y,
851            self.col[0].z,
852            self.col[1].x,
853            self.col[1].y,
854            self.col[1].z,
855            self.col[2].x,
856            self.col[2].y,
857            self.col[2].z,
858        );
859        let rot_inv = rot.inverse();
860        let trans = Vector3::new(self.col[3].x, self.col[3].y, self.col[3].z);
861        let trans_inv = -(rot_inv * trans);
862
863        Self::new(
864            rot_inv.col[0].x,
865            rot_inv.col[0].y,
866            rot_inv.col[0].z,
867            <T as Zero>::zero(),
868            rot_inv.col[1].x,
869            rot_inv.col[1].y,
870            rot_inv.col[1].z,
871            <T as Zero>::zero(),
872            rot_inv.col[2].x,
873            rot_inv.col[2].y,
874            rot_inv.col[2].z,
875            <T as Zero>::zero(),
876            trans_inv.x,
877            trans_inv.y,
878            trans_inv.z,
879            <T as One>::one(),
880        )
881    }
882
883    /// Multiplies two 4x4 matrices.
884    ///
885    /// Matrix multiplication follows the rule:
886    /// ```text
887    /// C[i,j] = Σₖ A[i,k] * B[k,j]
888    /// ```
889    pub fn mul_matrix_matrix(l: &Self, r: &Self) -> Self {
890        let a00 = l.col[0].x;
891        let a10 = l.col[0].y;
892        let a20 = l.col[0].z;
893        let a30 = l.col[0].w;
894
895        let a01 = l.col[1].x;
896        let a11 = l.col[1].y;
897        let a21 = l.col[1].z;
898        let a31 = l.col[1].w;
899
900        let a02 = l.col[2].x;
901        let a12 = l.col[2].y;
902        let a22 = l.col[2].z;
903        let a32 = l.col[2].w;
904
905        let a03 = l.col[3].x;
906        let a13 = l.col[3].y;
907        let a23 = l.col[3].z;
908        let a33 = l.col[3].w;
909
910        let b00 = r.col[0].x;
911        let b10 = r.col[0].y;
912        let b20 = r.col[0].z;
913        let b30 = r.col[0].w;
914
915        let b01 = r.col[1].x;
916        let b11 = r.col[1].y;
917        let b21 = r.col[1].z;
918        let b31 = r.col[1].w;
919
920        let b02 = r.col[2].x;
921        let b12 = r.col[2].y;
922        let b22 = r.col[2].z;
923        let b32 = r.col[2].w;
924
925        let b03 = r.col[3].x;
926        let b13 = r.col[3].y;
927        let b23 = r.col[3].z;
928        let b33 = r.col[3].w;
929
930        let c00 = a00 * b00 + a01 * b10 + a02 * b20 + a03 * b30;
931        let c01 = a00 * b01 + a01 * b11 + a02 * b21 + a03 * b31;
932        let c02 = a00 * b02 + a01 * b12 + a02 * b22 + a03 * b32;
933        let c03 = a00 * b03 + a01 * b13 + a02 * b23 + a03 * b33;
934
935        let c10 = a10 * b00 + a11 * b10 + a12 * b20 + a13 * b30;
936        let c11 = a10 * b01 + a11 * b11 + a12 * b21 + a13 * b31;
937        let c12 = a10 * b02 + a11 * b12 + a12 * b22 + a13 * b32;
938        let c13 = a10 * b03 + a11 * b13 + a12 * b23 + a13 * b33;
939
940        let c20 = a20 * b00 + a21 * b10 + a22 * b20 + a23 * b30;
941        let c21 = a20 * b01 + a21 * b11 + a22 * b21 + a23 * b31;
942        let c22 = a20 * b02 + a21 * b12 + a22 * b22 + a23 * b32;
943        let c23 = a20 * b03 + a21 * b13 + a22 * b23 + a23 * b33;
944
945        let c30 = a30 * b00 + a31 * b10 + a32 * b20 + a33 * b30;
946        let c31 = a30 * b01 + a31 * b11 + a32 * b21 + a33 * b31;
947        let c32 = a30 * b02 + a31 * b12 + a32 * b22 + a33 * b32;
948        let c33 = a30 * b03 + a31 * b13 + a32 * b23 + a33 * b33;
949
950        Self::new(
951            c00, c10, c20, c30, c01, c11, c21, c31, c02, c12, c22, c32, c03, c13, c23, c33,
952        )
953    }
954
955    /// Multiplies a 4x4 matrix by a 4D vector.
956    ///
957    /// ```text
958    /// v' = M * v
959    /// ```
960    pub fn mul_matrix_vector(l: &Self, r: &Vector4<T>) -> Vector4<T> {
961        Self::mul_vector_matrix(r, &l.transpose())
962    }
963
964    //
965    //                     [m0 = c0_x | m4 = c1_x | m8 = c2_x | m12= c3_x]
966    // [v_x v_y v_z v_w] * [m1 = c0_y | m5 = c1_y | m9 = c2_y | m13= c3_y] = [dot(v, c0) dot(v, c1) dot(v, c2) dot(v, c3)]
967    //                     [m2 = c0_z | m6 = c1_z | m10= c2_z | m14= c3_z]
968    //                     [m3 = c0_w | m7 = c1_w | m11= c2_w | m15= c3_w]
969    //
970    /// Multiplies a 4D vector by a 4x4 matrix (row vector).
971    ///
972    /// ```text
973    /// v' = vᵀ * M
974    /// ```
975    pub fn mul_vector_matrix(l: &Vector4<T>, r: &Self) -> Vector4<T> {
976        Vector4::new(
977            Vector4::dot(l, &r.col[0]),
978            Vector4::dot(l, &r.col[1]),
979            Vector4::dot(l, &r.col[2]),
980            Vector4::dot(l, &r.col[3]),
981        )
982    }
983
984    /// Adds two matrices element-wise.
985    ///
986    /// ```text
987    /// C[i,j] = A[i,j] + B[i,j]
988    /// ```
989    pub fn add_matrix_matrix(l: &Self, r: &Self) -> Self {
990        Matrix4 {
991            col: [
992                l.col[0] + r.col[0],
993                l.col[1] + r.col[1],
994                l.col[2] + r.col[2],
995                l.col[3] + r.col[3],
996            ],
997        }
998    }
999
1000    /// Subtracts two matrices element-wise.
1001    ///
1002    /// ```text
1003    /// C[i,j] = A[i,j] - B[i,j]
1004    /// ```
1005    pub fn sub_matrix_matrix(l: &Self, r: &Self) -> Self {
1006        Matrix4 {
1007            col: [
1008                l.col[0] - r.col[0],
1009                l.col[1] - r.col[1],
1010                l.col[2] - r.col[2],
1011                l.col[3] - r.col[3],
1012            ],
1013        }
1014    }
1015}
1016
1017/******************************************************************************
1018 * Operator overloading
1019 *****************************************************************************/
1020macro_rules! implMatrixOps {
1021    ($mat:ident, $vec: ident) => {
1022        impl<T: Scalar> Mul<$mat<T>> for $vec<T> {
1023            type Output = $vec<T>;
1024            fn mul(self, rhs: $mat<T>) -> $vec<T> {
1025                $mat::mul_vector_matrix(&self, &rhs)
1026            }
1027        }
1028
1029        impl<T: Scalar> Mul<$vec<T>> for $mat<T> {
1030            type Output = $vec<T>;
1031            fn mul(self, rhs: $vec<T>) -> $vec<T> {
1032                $mat::mul_matrix_vector(&self, &rhs)
1033            }
1034        }
1035
1036        impl<T: Scalar> Mul<$mat<T>> for $mat<T> {
1037            type Output = $mat<T>;
1038            fn mul(self, rhs: $mat<T>) -> $mat<T> {
1039                $mat::mul_matrix_matrix(&self, &rhs)
1040            }
1041        }
1042
1043        impl<T: Scalar> Add<$mat<T>> for $mat<T> {
1044            type Output = $mat<T>;
1045            fn add(self, rhs: $mat<T>) -> $mat<T> {
1046                $mat::add_matrix_matrix(&self, &rhs)
1047            }
1048        }
1049
1050        impl<T: Scalar> Sub<$mat<T>> for $mat<T> {
1051            type Output = $mat<T>;
1052            fn sub(self, rhs: $mat<T>) -> $mat<T> {
1053                $mat::sub_matrix_matrix(&self, &rhs)
1054            }
1055        }
1056    };
1057}
1058
1059implMatrixOps!(Matrix2, Vector2);
1060implMatrixOps!(Matrix3, Vector3);
1061implMatrixOps!(Matrix4, Vector4);
1062
1063impl<T: Scalar> Mul<Matrix4<T>> for Vector3<T> {
1064    type Output = Vector3<T>;
1065    fn mul(self, rhs: Matrix4<T>) -> Vector3<T> {
1066        Matrix4::mul_vector_matrix(&Vector4::new(self.x, self.y, self.z, <T as One>::one()), &rhs).xyz()
1067    }
1068}
1069
1070impl<T: Scalar> Mul<Vector3<T>> for Matrix4<T> {
1071    type Output = Vector3<T>;
1072    fn mul(self, rhs: Vector3<T>) -> Vector3<T> {
1073        Matrix4::mul_matrix_vector(&self, &Vector4::new(rhs.x, rhs.y, rhs.z, <T as One>::one())).xyz()
1074    }
1075}
1076
1077/// Convenience extension for extracting a 3x3 submatrix.
1078pub trait Matrix4Extension<T: Scalar> {
1079    /// Returns the upper-left 3x3 rotation/scale submatrix.
1080    fn mat3(&self) -> Matrix3<T>;
1081}
1082
1083impl<T: Scalar> Matrix4Extension<T> for Matrix4<T> {
1084    fn mat3(&self) -> Matrix3<T> {
1085        Matrix3::new(
1086            self.col[0].x,
1087            self.col[0].y,
1088            self.col[0].z,
1089            self.col[1].x,
1090            self.col[1].y,
1091            self.col[1].z,
1092            self.col[2].x,
1093            self.col[2].y,
1094            self.col[2].z,
1095        )
1096    }
1097}
1098
1099#[cfg(test)]
1100mod tests {
1101    use super::*;
1102    use crate::vector::*;
1103
1104    #[test]
1105    fn test_matrix2_identity() {
1106        let m = Matrix2::<f32>::identity();
1107        assert_eq!(m.col[0].x, 1.0);
1108        assert_eq!(m.col[0].y, 0.0);
1109        assert_eq!(m.col[1].x, 0.0);
1110        assert_eq!(m.col[1].y, 1.0);
1111    }
1112
1113    #[test]
1114    fn test_matrix2_determinant() {
1115        let m = Matrix2::<f32>::new(1.0, 2.0, 3.0, 4.0);
1116        let det = m.determinant();
1117        assert_eq!(det, -2.0); // 1*4 - 3*2 = -2
1118        
1119        // Test singular matrix
1120        let m_singular = Matrix2::<f32>::new(1.0, 2.0, 2.0, 4.0);
1121        let det_singular = m_singular.determinant();
1122        assert_eq!(det_singular, 0.0);
1123    }
1124
1125    #[test]
1126    fn test_matrix2_inverse() {
1127        let m = Matrix2::<f32>::new(1.0, 2.0, 3.0, 4.0);
1128        let m_inv = m.inverse();
1129        let product = Matrix2::mul_matrix_matrix(&m, &m_inv);
1130        
1131        // Check if product is identity
1132        assert!((product.col[0].x - 1.0).abs() < 0.001);
1133        assert!((product.col[0].y).abs() < 0.001);
1134        assert!((product.col[1].x).abs() < 0.001);
1135        assert!((product.col[1].y - 1.0).abs() < 0.001);
1136    }
1137
1138    #[test]
1139    fn test_matrix2_transpose() {
1140        let m = Matrix2::<f32>::new(1.0, 2.0, 3.0, 4.0);
1141        let mt = m.transpose();
1142        assert_eq!(mt.col[0].x, 1.0);
1143        assert_eq!(mt.col[0].y, 3.0);
1144        assert_eq!(mt.col[1].x, 2.0);
1145        assert_eq!(mt.col[1].y, 4.0);
1146        
1147        // Transpose of transpose should be original
1148        let mtt = mt.transpose();
1149        assert_eq!(mtt.col[0].x, m.col[0].x);
1150        assert_eq!(mtt.col[0].y, m.col[0].y);
1151        assert_eq!(mtt.col[1].x, m.col[1].x);
1152        assert_eq!(mtt.col[1].y, m.col[1].y);
1153    }
1154
1155    #[test]
1156    fn test_matrix3_identity() {
1157        let m = Matrix3::<f32>::identity();
1158        assert_eq!(m.col[0].x, 1.0);
1159        assert_eq!(m.col[1].y, 1.0);
1160        assert_eq!(m.col[2].z, 1.0);
1161        assert_eq!(m.col[0].y, 0.0);
1162        assert_eq!(m.col[0].z, 0.0);
1163    }
1164
1165    #[test]
1166    fn test_matrix3_determinant() {
1167        let m = Matrix3::<f32>::new(
1168            1.0, 0.0, 0.0,
1169            0.0, 1.0, 0.0,
1170            0.0, 0.0, 1.0
1171        );
1172        assert_eq!(m.determinant(), 1.0);
1173        
1174        let m2 = Matrix3::<f32>::new(
1175            2.0, 3.0, 1.0,
1176            1.0, 0.0, 2.0,
1177            1.0, 2.0, 1.0
1178        );
1179        let det = m2.determinant();
1180        assert!((det - -3.0).abs() < 0.001);
1181    }
1182
1183    #[test]
1184    fn test_matrix3_inverse() {
1185        let m = Matrix3::<f32>::new(
1186            2.0, 3.0, 1.0,
1187            1.0, 0.0, 2.0,
1188            1.0, 2.0, 1.0
1189        );
1190        let m_inv = m.inverse();
1191        let product = Matrix3::mul_matrix_matrix(&m, &m_inv);
1192        
1193        // Check if product is close to identity
1194        for i in 0..3 {
1195            for j in 0..3 {
1196                let val = match (i, j) {
1197                    (0, 0) => product.col[0].x,
1198                    (1, 0) => product.col[0].y,
1199                    (2, 0) => product.col[0].z,
1200                    (0, 1) => product.col[1].x,
1201                    (1, 1) => product.col[1].y,
1202                    (2, 1) => product.col[1].z,
1203                    (0, 2) => product.col[2].x,
1204                    (1, 2) => product.col[2].y,
1205                    (2, 2) => product.col[2].z,
1206                    _ => 0.0,
1207                };
1208                let expected = if i == j { 1.0 } else { 0.0 };
1209                assert!((val - expected).abs() < 0.001);
1210            }
1211        }
1212    }
1213
1214    #[test]
1215    fn test_matrix3_axis_angle_zero_axis() {
1216        let axis = Vector3::<f32>::new(0.0, 0.0, 0.0);
1217        assert!(Matrix3::of_axis_angle(&axis, 1.0, EPS_F32).is_none());
1218    }
1219
1220    #[test]
1221    fn test_matrix4_identity() {
1222        let m = Matrix4::<f32>::identity();
1223        for i in 0..4 {
1224            for j in 0..4 {
1225                let val = match j {
1226                    0 => match i {
1227                        0 => m.col[0].x,
1228                        1 => m.col[0].y,
1229                        2 => m.col[0].z,
1230                        3 => m.col[0].w,
1231                        _ => 0.0,
1232                    },
1233                    1 => match i {
1234                        0 => m.col[1].x,
1235                        1 => m.col[1].y,
1236                        2 => m.col[1].z,
1237                        3 => m.col[1].w,
1238                        _ => 0.0,
1239                    },
1240                    2 => match i {
1241                        0 => m.col[2].x,
1242                        1 => m.col[2].y,
1243                        2 => m.col[2].z,
1244                        3 => m.col[2].w,
1245                        _ => 0.0,
1246                    },
1247                    3 => match i {
1248                        0 => m.col[3].x,
1249                        1 => m.col[3].y,
1250                        2 => m.col[3].z,
1251                        3 => m.col[3].w,
1252                        _ => 0.0,
1253                    },
1254                    _ => 0.0,
1255                };
1256                let expected = if i == j { 1.0 } else { 0.0 };
1257                assert_eq!(val, expected);
1258            }
1259        }
1260    }
1261
1262    #[test]
1263    fn test_matrix_vector_multiplication() {
1264        // Test Matrix2 * Vector2
1265        let m2 = Matrix2::<f32>::new(1.0, 2.0, 3.0, 4.0);
1266        let v2 = Vector2::<f32>::new(5.0, 6.0);
1267        let result2 = m2 * v2;
1268        assert_eq!(result2.x, 23.0); // 1*5 + 3*6 = 23
1269        assert_eq!(result2.y, 34.0); // 2*5 + 4*6 = 34
1270        
1271        // Test Matrix3 * Vector3
1272        let m3 = Matrix3::<f32>::identity();
1273        let v3 = Vector3::<f32>::new(1.0, 2.0, 3.0);
1274        let result3 = m3 * v3;
1275        assert_eq!(result3.x, 1.0);
1276        assert_eq!(result3.y, 2.0);
1277        assert_eq!(result3.z, 3.0);
1278        
1279        // Test Matrix4 * Vector4
1280        let m4 = Matrix4::<f32>::identity();
1281        let v4 = Vector4::<f32>::new(1.0, 2.0, 3.0, 4.0);
1282        let result4 = m4 * v4;
1283        assert_eq!(result4.x, 1.0);
1284        assert_eq!(result4.y, 2.0);
1285        assert_eq!(result4.z, 3.0);
1286        assert_eq!(result4.w, 4.0);
1287    }
1288
1289    #[test]
1290    fn test_matrix_multiplication() {
1291        // Test associativity: (A * B) * C == A * (B * C)
1292        let a = Matrix2::<f32>::new(1.0, 2.0, 3.0, 4.0);
1293        let b = Matrix2::<f32>::new(5.0, 6.0, 7.0, 8.0);
1294        let c = Matrix2::<f32>::new(9.0, 10.0, 11.0, 12.0);
1295        
1296        let left = (a * b) * c;
1297        let right = a * (b * c);
1298        
1299        assert!((left.col[0].x - right.col[0].x).abs() < 0.001);
1300        assert!((left.col[0].y - right.col[0].y).abs() < 0.001);
1301        assert!((left.col[1].x - right.col[1].x).abs() < 0.001);
1302        assert!((left.col[1].y - right.col[1].y).abs() < 0.001);
1303    }
1304
1305    #[test]
1306    fn test_matrix_addition_subtraction() {
1307        let m1 = Matrix2::<f32>::new(1.0, 2.0, 3.0, 4.0);
1308        let m2 = Matrix2::<f32>::new(5.0, 6.0, 7.0, 8.0);
1309        
1310        let sum = m1 + m2;
1311        assert_eq!(sum.col[0].x, 6.0);
1312        assert_eq!(sum.col[0].y, 8.0);
1313        assert_eq!(sum.col[1].x, 10.0);
1314        assert_eq!(sum.col[1].y, 12.0);
1315        
1316        let diff = m2 - m1;
1317        assert_eq!(diff.col[0].x, 4.0);
1318        assert_eq!(diff.col[0].y, 4.0);
1319        assert_eq!(diff.col[1].x, 4.0);
1320        assert_eq!(diff.col[1].y, 4.0);
1321    }
1322
1323    #[test]
1324    fn test_matrix4_inverse() {
1325        // Test with a known invertible matrix
1326        let m = Matrix4::<f32>::new(
1327            2.0, 0.0, 0.0, 0.0,
1328            0.0, 1.0, 0.0, 0.0,
1329            0.0, 0.0, 0.5, 0.0,
1330            1.0, 2.0, 3.0, 1.0
1331        );
1332        
1333        let m_inv = m.inverse();
1334        let product = m * m_inv;
1335        
1336        // Check if product is close to identity
1337        for i in 0..4 {
1338            for j in 0..4 {
1339                let expected = if i == j { 1.0 } else { 0.0 };
1340                let val = match j {
1341                    0 => match i {
1342                        0 => product.col[0].x,
1343                        1 => product.col[0].y,
1344                        2 => product.col[0].z,
1345                        3 => product.col[0].w,
1346                        _ => 0.0,
1347                    },
1348                    1 => match i {
1349                        0 => product.col[1].x,
1350                        1 => product.col[1].y,
1351                        2 => product.col[1].z,
1352                        3 => product.col[1].w,
1353                        _ => 0.0,
1354                    },
1355                    2 => match i {
1356                        0 => product.col[2].x,
1357                        1 => product.col[2].y,
1358                        2 => product.col[2].z,
1359                        3 => product.col[2].w,
1360                        _ => 0.0,
1361                    },
1362                    3 => match i {
1363                        0 => product.col[3].x,
1364                        1 => product.col[3].y,
1365                        2 => product.col[3].z,
1366                        3 => product.col[3].w,
1367                        _ => 0.0,
1368                    },
1369                    _ => 0.0,
1370                };
1371                assert!((val - expected).abs() < 0.001, 
1372                    "Matrix inverse failed at [{}, {}]: expected {}, got {}", i, j, expected, val);
1373            }
1374        }
1375    }
1376
1377    #[test]
1378    fn test_matrix4_inverse_affine() {
1379        let m = Matrix4::<f32>::new(
1380            0.0, 1.0, 0.0, 0.0,
1381            -1.0, 0.0, 0.0, 0.0,
1382            0.0, 0.0, 2.0, 0.0,
1383            3.0, 4.0, 5.0, 1.0,
1384        );
1385        assert!(m.is_affine(EPS_F32));
1386
1387        let inv_affine = m.inverse_affine();
1388        let inv_full = m.inverse();
1389
1390        for i in 0..4 {
1391            let a = inv_affine.col[i];
1392            let b = inv_full.col[i];
1393            assert!((a.x - b.x).abs() < 0.001);
1394            assert!((a.y - b.y).abs() < 0.001);
1395            assert!((a.z - b.z).abs() < 0.001);
1396            assert!((a.w - b.w).abs() < 0.001);
1397        }
1398
1399        let non_affine = Matrix4::<f32>::new(
1400            1.0, 0.0, 0.0, 0.0,
1401            0.0, 1.0, 0.0, 0.0,
1402            0.0, 0.0, 1.0, -1.0,
1403            0.0, 0.0, 0.0, 0.0,
1404        );
1405        assert!(!non_affine.is_affine(EPS_F32));
1406    }
1407}