Skip to main content

rs_math3d/
matrix.rs

1// Copyright 2020-Present (c) Raja Lehtihet & Wael El Oraiby
2//
3// Redistribution and use in source and binary forms, with or without
4// modification, are permitted provided that the following conditions are met:
5//
6// 1. Redistributions of source code must retain the above copyright notice,
7// this list of conditions and the following disclaimer.
8//
9// 2. Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12//
13// 3. Neither the name of the copyright holder nor the names of its contributors
14// may be used to endorse or promote products derived from this software without
15// specific prior written permission.
16//
17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
21// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27// POSSIBILITY OF SUCH DAMAGE.
28//! Matrix mathematics module providing 2x2, 3x3, and 4x4 matrices.
29//!
30//! This module provides square matrix types commonly used in computer graphics
31//! and linear algebra. Matrices are stored in column-major order for compatibility
32//! with graphics APIs like OpenGL.
33//!
34//! Integer matrices are supported for storage, addition, subtraction, and
35//! multiplication. Operations that require fractional results, such as inversion
36//! and axis-angle construction, are available only for floating-point scalars.
37//!
38//! # Examples
39//!
40//! ```
41//! use rs_math3d::matrix::Matrix4;
42//! use rs_math3d::vector::Vector4;
43//!
44//! let m = Matrix4::<f32>::identity();
45//! let v = Vector4::new(1.0, 2.0, 3.0, 1.0);
46//! let result = m * v; // Transform vector
47//! ```
48
49use crate::scalar::*;
50use crate::vector::*;
51use core::ops::*;
52use num_traits::{One, Zero};
53
54/// A 2x2 matrix stored in column-major order.
55///
56/// # Layout
57/// ```text
58/// [m₀₀ m₀₁]
59/// [m₁₀ m₁₁]
60/// ```
61/// where `col[j][i]` represents element at row i, column j.
62#[repr(C)]
63#[derive(Clone, Copy, Debug)]
64pub struct Matrix2<T: Scalar> {
65    /// Column vectors of the matrix
66    pub col: [Vector2<T>; 2],
67}
68
69/// A 3x3 matrix stored in column-major order.
70///
71/// Commonly used for 2D transformations (with homogeneous coordinates)
72/// and 3D rotations.
73///
74/// # Layout
75/// ```text
76/// [m₀₀ m₀₁ m₀₂]
77/// [m₁₀ m₁₁ m₁₂]
78/// [m₂₀ m₂₁ m₂₂]
79/// ```
80#[repr(C)]
81#[derive(Clone, Copy, Debug)]
82pub struct Matrix3<T: Scalar> {
83    /// Column vectors of the matrix
84    pub col: [Vector3<T>; 3],
85}
86
87/// A 4x4 matrix stored in column-major order.
88///
89/// The standard matrix for 3D transformations using homogeneous coordinates.
90///
91/// # Layout
92/// ```text
93/// [m₀₀ m₀₁ m₀₂ m₀₃]
94/// [m₁₀ m₁₁ m₁₂ m₁₃]
95/// [m₂₀ m₂₁ m₂₂ m₂₃]
96/// [m₃₀ m₃₁ m₃₂ m₃₃]
97/// ```
98#[repr(C)]
99#[derive(Clone, Copy, Debug)]
100pub struct Matrix4<T: Scalar> {
101    /// Column vectors of the matrix
102    pub col: [Vector4<T>; 4],
103}
104
105/******************************************************************************
106 * Matrix2
107 *
108 * i j ------------------->
109 * | [m0 = c0_x | m2 = c1_x]
110 * V [m1 = c0_y | m3 = c1_y]
111 *
112 *  aij => i = row, j = col (yx form)
113 *
114 *****************************************************************************/
115impl<T: Scalar> Matrix2<T> {
116    /// Creates a new 2x2 matrix from individual elements.
117    ///
118    /// Elements are provided in column-major order:
119    /// ```text
120    /// [m0 m2]
121    /// [m1 m3]
122    /// ```
123    pub fn new(m0: T, m1: T, m2: T, m3: T) -> Self {
124        Matrix2 {
125            col: [Vector2::new(m0, m1), Vector2::new(m2, m3)],
126        }
127    }
128
129    /// Returns the 2x2 identity matrix.
130    ///
131    /// ```text
132    /// [1 0]
133    /// [0 1]
134    /// ```
135    pub fn identity() -> Self {
136        Self::new(
137            <T as One>::one(),
138            <T as Zero>::zero(),
139            <T as Zero>::zero(),
140            <T as One>::one(),
141        )
142    }
143
144    /// Attempts to cast each element into another numeric type.
145    pub fn try_cast<U>(&self) -> Option<Matrix2<U>>
146    where
147        T: num_traits::ToPrimitive,
148        U: Scalar + num_traits::NumCast,
149    {
150        Some(Matrix2::new(
151            num_traits::NumCast::from(self.col[0].x)?,
152            num_traits::NumCast::from(self.col[0].y)?,
153            num_traits::NumCast::from(self.col[1].x)?,
154            num_traits::NumCast::from(self.col[1].y)?,
155        ))
156    }
157
158    /// Computes the determinant of the matrix.
159    ///
160    /// For a 2x2 matrix:
161    /// ```text
162    /// det(M) = m₀₀m₁₁ - m₀₁m₁₀
163    /// ```
164    pub fn determinant(&self) -> T {
165        let m00 = self.col[0].x;
166        let m10 = self.col[0].y;
167
168        let m01 = self.col[1].x;
169        let m11 = self.col[1].y;
170
171        m00 * m11 - m01 * m10
172    }
173
174    /// Returns the transpose of the matrix.
175    ///
176    /// ```text
177    /// Mᵀ[i,j] = M[j,i]
178    /// ```
179    pub fn transpose(&self) -> Self {
180        let m00 = self.col[0].x;
181        let m10 = self.col[0].y;
182
183        let m01 = self.col[1].x;
184        let m11 = self.col[1].y;
185
186        Self::new(m00, m01, m10, m11)
187    }
188
189    /// Computes the inverse of the matrix.
190    ///
191    /// For a 2x2 matrix:
192    /// ```text
193    /// M⁻¹ = (1/det(M)) * [m₁₁  -m₀₁]
194    ///                    [-m₁₀  m₀₀]
195    /// ```
196    ///
197    /// # Note
198    /// Returns NaN or Inf if the matrix is singular (determinant = 0).
199    pub fn inverse(&self) -> Self
200    where
201        T: FloatScalar,
202    {
203        let m00 = self.col[0].x;
204        let m10 = self.col[0].y;
205
206        let m01 = self.col[1].x;
207        let m11 = self.col[1].y;
208
209        let inv_det = <T as One>::one() / (m00 * m11 - m01 * m10);
210
211        let r00 = m11 * inv_det;
212        let r01 = -m01 * inv_det;
213        let r10 = -m10 * inv_det;
214        let r11 = m00 * inv_det;
215
216        Self::new(r00, r10, r01, r11)
217    }
218
219    /// Multiplies two 2x2 matrices.
220    ///
221    /// Matrix multiplication follows the rule:
222    /// ```text
223    /// C[i,j] = Σₖ A[i,k] * B[k,j]
224    /// ```
225    pub fn mul_matrix_matrix(l: &Self, r: &Self) -> Self {
226        let a00 = l.col[0].x;
227        let a10 = l.col[0].y;
228        let a01 = l.col[1].x;
229        let a11 = l.col[1].y;
230
231        let b00 = r.col[0].x;
232        let b10 = r.col[0].y;
233        let b01 = r.col[1].x;
234        let b11 = r.col[1].y;
235
236        let c00 = a00 * b00 + a01 * b10;
237        let c01 = a00 * b01 + a01 * b11;
238        let c10 = a10 * b00 + a11 * b10;
239        let c11 = a10 * b01 + a11 * b11;
240
241        Self::new(c00, c10, c01, c11)
242    }
243
244    /// Multiplies a 2x2 matrix by a 2D vector.
245    ///
246    /// Transforms the vector by the matrix:
247    /// ```text
248    /// v' = M * v
249    /// ```
250    pub fn mul_matrix_vector(l: &Self, r: &Vector2<T>) -> Vector2<T> {
251        Self::mul_vector_matrix(r, &l.transpose())
252    }
253
254    /// Multiplies a 2D vector by a 2x2 matrix (row vector).
255    ///
256    /// ```text
257    /// v' = vᵀ * M
258    /// ```
259    pub fn mul_vector_matrix(l: &Vector2<T>, r: &Self) -> Vector2<T> {
260        Vector2::new(Vector2::dot(l, &r.col[0]), Vector2::dot(l, &r.col[1]))
261    }
262
263    /// Adds two matrices element-wise.
264    ///
265    /// ```text
266    /// C[i,j] = A[i,j] + B[i,j]
267    /// ```
268    pub fn add_matrix_matrix(l: &Self, r: &Self) -> Self {
269        Matrix2 {
270            col: [l.col[0] + r.col[0], l.col[1] + r.col[1]],
271        }
272    }
273
274    /// Subtracts two matrices element-wise.
275    ///
276    /// ```text
277    /// C[i,j] = A[i,j] - B[i,j]
278    /// ```
279    pub fn sub_matrix_matrix(l: &Self, r: &Self) -> Self {
280        Matrix2 {
281            col: [l.col[0] - r.col[0], l.col[1] - r.col[1]],
282        }
283    }
284}
285
286/******************************************************************************
287 * Matrix3
288 *
289 * i j -------------------------------->
290 * | [m0 = c0_x | m3 = c1_x | m6 = c2_x]
291 * | [m1 = c0_y | m4 = c1_y | m7 = c2_y]
292 * V [m2 = c0_z | m5 = c1_z | m8 = c2_z]
293 *
294 *  aij => i = row, j = col (yx form)
295 *
296 *****************************************************************************/
297impl<T: Scalar> Matrix3<T> {
298    /// Creates a new 3x3 matrix from column-major elements.
299    #[allow(clippy::too_many_arguments)]
300    pub fn new(m0: T, m1: T, m2: T, m3: T, m4: T, m5: T, m6: T, m7: T, m8: T) -> Self {
301        Matrix3 {
302            col: [
303                Vector3::new(m0, m1, m2),
304                Vector3::new(m3, m4, m5),
305                Vector3::new(m6, m7, m8),
306            ],
307        }
308    }
309
310    /// Returns the 3x3 identity matrix.
311    ///
312    /// ```text
313    /// [1 0 0]
314    /// [0 1 0]
315    /// [0 0 1]
316    /// ```
317    pub fn identity() -> Self {
318        Self::new(
319            <T as One>::one(),
320            <T as Zero>::zero(),
321            <T as Zero>::zero(),
322            <T as Zero>::zero(),
323            <T as One>::one(),
324            <T as Zero>::zero(),
325            <T as Zero>::zero(),
326            <T as Zero>::zero(),
327            <T as One>::one(),
328        )
329    }
330
331    /// Attempts to cast each element into another numeric type.
332    pub fn try_cast<U>(&self) -> Option<Matrix3<U>>
333    where
334        T: num_traits::ToPrimitive,
335        U: Scalar + num_traits::NumCast,
336    {
337        Some(Matrix3::new(
338            num_traits::NumCast::from(self.col[0].x)?,
339            num_traits::NumCast::from(self.col[0].y)?,
340            num_traits::NumCast::from(self.col[0].z)?,
341            num_traits::NumCast::from(self.col[1].x)?,
342            num_traits::NumCast::from(self.col[1].y)?,
343            num_traits::NumCast::from(self.col[1].z)?,
344            num_traits::NumCast::from(self.col[2].x)?,
345            num_traits::NumCast::from(self.col[2].y)?,
346            num_traits::NumCast::from(self.col[2].z)?,
347        ))
348    }
349
350    /// Computes the determinant of the matrix.
351    pub fn determinant(&self) -> T {
352        let m00 = self.col[0].x;
353        let m10 = self.col[0].y;
354        let m20 = self.col[0].z;
355
356        let m01 = self.col[1].x;
357        let m11 = self.col[1].y;
358        let m21 = self.col[1].z;
359
360        let m02 = self.col[2].x;
361        let m12 = self.col[2].y;
362        let m22 = self.col[2].z;
363
364        m00 * m11 * m22 + m01 * m12 * m20 + m02 * m10 * m21
365            - m00 * m12 * m21
366            - m01 * m10 * m22
367            - m02 * m11 * m20
368    }
369
370    /// Returns the transpose of the matrix.
371    ///
372    /// ```text
373    /// Mᵀ[i,j] = M[j,i]
374    /// ```
375    pub fn transpose(&self) -> Self {
376        let m00 = self.col[0].x;
377        let m10 = self.col[0].y;
378        let m20 = self.col[0].z;
379
380        let m01 = self.col[1].x;
381        let m11 = self.col[1].y;
382        let m21 = self.col[1].z;
383
384        let m02 = self.col[2].x;
385        let m12 = self.col[2].y;
386        let m22 = self.col[2].z;
387
388        Self::new(m00, m01, m02, m10, m11, m12, m20, m21, m22)
389    }
390
391    /// Computes the inverse of the matrix.
392    ///
393    /// Uses the adjugate matrix method:
394    /// ```text
395    /// M⁻¹ = (1/det(M)) * adj(M)
396    /// ```
397    ///
398    /// # Note
399    /// Returns NaN or Inf if the matrix is singular (determinant = 0).
400    pub fn inverse(&self) -> Self
401    where
402        T: FloatScalar,
403    {
404        let m00 = self.col[0].x;
405        let m10 = self.col[0].y;
406        let m20 = self.col[0].z;
407
408        let m01 = self.col[1].x;
409        let m11 = self.col[1].y;
410        let m21 = self.col[1].z;
411
412        let m02 = self.col[2].x;
413        let m12 = self.col[2].y;
414        let m22 = self.col[2].z;
415
416        let inv_det = <T as One>::one()
417            / (m00 * m11 * m22 + m01 * m12 * m20 + m02 * m10 * m21
418                - m00 * m12 * m21
419                - m01 * m10 * m22
420                - m02 * m11 * m20);
421
422        let r00 = (m11 * m22 - m12 * m21) * inv_det;
423        let r01 = (m02 * m21 - m01 * m22) * inv_det;
424        let r02 = (m01 * m12 - m02 * m11) * inv_det;
425        let r10 = (m12 * m20 - m10 * m22) * inv_det;
426        let r11 = (m00 * m22 - m02 * m20) * inv_det;
427        let r12 = (m02 * m10 - m00 * m12) * inv_det;
428        let r20 = (m10 * m21 - m11 * m20) * inv_det;
429        let r21 = (m01 * m20 - m00 * m21) * inv_det;
430        let r22 = (m00 * m11 - m01 * m10) * inv_det;
431
432        Self::new(r00, r10, r20, r01, r11, r21, r02, r12, r22)
433    }
434
435    /// Multiplies two 3x3 matrices.
436    ///
437    /// Matrix multiplication follows the rule:
438    /// ```text
439    /// C[i,j] = Σₖ A[i,k] * B[k,j]
440    /// ```
441    pub fn mul_matrix_matrix(l: &Self, r: &Self) -> Self {
442        let a00 = l.col[0].x;
443        let a10 = l.col[0].y;
444        let a20 = l.col[0].z;
445
446        let a01 = l.col[1].x;
447        let a11 = l.col[1].y;
448        let a21 = l.col[1].z;
449
450        let a02 = l.col[2].x;
451        let a12 = l.col[2].y;
452        let a22 = l.col[2].z;
453
454        let b00 = r.col[0].x;
455        let b10 = r.col[0].y;
456        let b20 = r.col[0].z;
457
458        let b01 = r.col[1].x;
459        let b11 = r.col[1].y;
460        let b21 = r.col[1].z;
461
462        let b02 = r.col[2].x;
463        let b12 = r.col[2].y;
464        let b22 = r.col[2].z;
465
466        let c00 = a00 * b00 + a01 * b10 + a02 * b20;
467        let c01 = a00 * b01 + a01 * b11 + a02 * b21;
468        let c02 = a00 * b02 + a01 * b12 + a02 * b22;
469
470        let c10 = a10 * b00 + a11 * b10 + a12 * b20;
471        let c11 = a10 * b01 + a11 * b11 + a12 * b21;
472        let c12 = a10 * b02 + a11 * b12 + a12 * b22;
473
474        let c20 = a20 * b00 + a21 * b10 + a22 * b20;
475        let c21 = a20 * b01 + a21 * b11 + a22 * b21;
476        let c22 = a20 * b02 + a21 * b12 + a22 * b22;
477
478        Self::new(c00, c10, c20, c01, c11, c21, c02, c12, c22)
479    }
480
481    /// Multiplies a 3x3 matrix by a 3D vector.
482    ///
483    /// Transforms the vector by the matrix:
484    /// ```text
485    /// v' = M * v
486    /// ```
487    pub fn mul_matrix_vector(l: &Self, r: &Vector3<T>) -> Vector3<T> {
488        Self::mul_vector_matrix(r, &l.transpose())
489    }
490
491    /// Multiplies a 3D vector by a 3x3 matrix (row vector).
492    ///
493    /// ```text
494    /// v' = vᵀ * M
495    /// ```
496    pub fn mul_vector_matrix(l: &Vector3<T>, r: &Self) -> Vector3<T> {
497        Vector3::new(
498            Vector3::dot(l, &r.col[0]),
499            Vector3::dot(l, &r.col[1]),
500            Vector3::dot(l, &r.col[2]),
501        )
502    }
503
504    /// Adds two matrices element-wise.
505    ///
506    /// ```text
507    /// C[i,j] = A[i,j] + B[i,j]
508    /// ```
509    pub fn add_matrix_matrix(l: &Self, r: &Self) -> Self {
510        Matrix3 {
511            col: [
512                l.col[0] + r.col[0],
513                l.col[1] + r.col[1],
514                l.col[2] + r.col[2],
515            ],
516        }
517    }
518
519    /// Subtracts two matrices element-wise.
520    ///
521    /// ```text
522    /// C[i,j] = A[i,j] - B[i,j]
523    /// ```
524    pub fn sub_matrix_matrix(l: &Self, r: &Self) -> Self {
525        Matrix3 {
526            col: [
527                l.col[0] - r.col[0],
528                l.col[1] - r.col[1],
529                l.col[2] - r.col[2],
530            ],
531        }
532    }
533}
534
535impl<T: FloatScalar> Matrix3<T> {
536    /// Creates a 3x3 rotation matrix from an axis and angle.
537    ///
538    /// Uses Rodrigues' rotation formula:
539    /// ```text
540    /// R = I + sin(θ)K + (1 - cos(θ))K²
541    /// ```
542    /// where K is the cross-product matrix of the normalized axis.
543    ///
544    /// # Parameters
545    /// - `axis`: The rotation axis (will be normalized)
546    /// - `angle`: The rotation angle in radians
547    /// - `epsilon`: Minimum axis length to treat as valid
548    ///
549    /// # Returns
550    /// - `Some(matrix)` for a valid axis
551    /// - `None` if the axis length is too small
552    pub fn of_axis_angle(axis: &Vector3<T>, angle: T, epsilon: T) -> Option<Self> {
553        let len_sq = Vector3::dot(axis, axis);
554        if len_sq <= epsilon * epsilon {
555            return None;
556        }
557        let inv_len = <T as One>::one() / len_sq.tsqrt();
558        let n = *axis * inv_len;
559        let c = T::tcos(angle);
560        let s = T::tsin(angle);
561        let ux = n.x;
562        let uy = n.y;
563        let uz = n.z;
564        let uxx = ux * ux;
565        let uyy = uy * uy;
566        let uzz = uz * uz;
567
568        let oc = <T as One>::one() - c;
569
570        let m0 = c + uxx * oc;
571        let m1 = uy * ux * oc + uz * s;
572        let m2 = uz * ux * oc - uy * s;
573
574        let m3 = ux * uy * oc - uz * s;
575        let m4 = c + uyy * oc;
576        let m5 = uz * uy * oc + ux * s;
577
578        let m6 = ux * uz * oc + uy * s;
579        let m7 = uy * uz * oc - ux * s;
580        let m8 = c + uzz * oc;
581
582        Some(Self::new(m0, m1, m2, m3, m4, m5, m6, m7, m8))
583    }
584}
585
586/******************************************************************************
587 * Matrix4
588 *
589 * i j -------------------------------------------->
590 * | [m0 = c0_x | m4 = c1_x | m8 = c2_x | m12= c3_x]
591 * | [m1 = c0_y | m5 = c1_y | m9 = c2_y | m13= c3_y]
592 * | [m2 = c0_z | m6 = c1_z | m10= c2_z | m14= c3_z]
593 * V [m3 = c0_w | m7 = c1_w | m11= c2_w | m15= c3_w]
594 *
595 *  aij => i = row, j = col (yx form)
596 *
597 *****************************************************************************/
598impl<T: Scalar> Matrix4<T> {
599    /// Creates a new 4x4 matrix from column-major elements.
600    #[allow(clippy::too_many_arguments)]
601    pub fn new(
602        m0: T,
603        m1: T,
604        m2: T,
605        m3: T,
606        m4: T,
607        m5: T,
608        m6: T,
609        m7: T,
610        m8: T,
611        m9: T,
612        m10: T,
613        m11: T,
614        m12: T,
615        m13: T,
616        m14: T,
617        m15: T,
618    ) -> Self {
619        Matrix4 {
620            col: [
621                Vector4::new(m0, m1, m2, m3),
622                Vector4::new(m4, m5, m6, m7),
623                Vector4::new(m8, m9, m10, m11),
624                Vector4::new(m12, m13, m14, m15),
625            ],
626        }
627    }
628
629    /// Returns the 4x4 identity matrix.
630    pub fn identity() -> Self {
631        Self::new(
632            <T as One>::one(),
633            <T as Zero>::zero(),
634            <T as Zero>::zero(),
635            <T as Zero>::zero(),
636            <T as Zero>::zero(),
637            <T as One>::one(),
638            <T as Zero>::zero(),
639            <T as Zero>::zero(),
640            <T as Zero>::zero(),
641            <T as Zero>::zero(),
642            <T as One>::one(),
643            <T as Zero>::zero(),
644            <T as Zero>::zero(),
645            <T as Zero>::zero(),
646            <T as Zero>::zero(),
647            <T as One>::one(),
648        )
649    }
650
651    /// Attempts to cast each element into another numeric type.
652    pub fn try_cast<U>(&self) -> Option<Matrix4<U>>
653    where
654        T: num_traits::ToPrimitive,
655        U: Scalar + num_traits::NumCast,
656    {
657        Some(Matrix4::new(
658            num_traits::NumCast::from(self.col[0].x)?,
659            num_traits::NumCast::from(self.col[0].y)?,
660            num_traits::NumCast::from(self.col[0].z)?,
661            num_traits::NumCast::from(self.col[0].w)?,
662            num_traits::NumCast::from(self.col[1].x)?,
663            num_traits::NumCast::from(self.col[1].y)?,
664            num_traits::NumCast::from(self.col[1].z)?,
665            num_traits::NumCast::from(self.col[1].w)?,
666            num_traits::NumCast::from(self.col[2].x)?,
667            num_traits::NumCast::from(self.col[2].y)?,
668            num_traits::NumCast::from(self.col[2].z)?,
669            num_traits::NumCast::from(self.col[2].w)?,
670            num_traits::NumCast::from(self.col[3].x)?,
671            num_traits::NumCast::from(self.col[3].y)?,
672            num_traits::NumCast::from(self.col[3].z)?,
673            num_traits::NumCast::from(self.col[3].w)?,
674        ))
675    }
676
677    /// Computes the determinant of the matrix.
678    ///
679    /// Uses Laplace expansion along the first column.
680    /// A non-zero determinant indicates the matrix is invertible.
681    pub fn determinant(&self) -> T {
682        let m00 = self.col[0].x;
683        let m10 = self.col[0].y;
684        let m20 = self.col[0].z;
685        let m30 = self.col[0].w;
686
687        let m01 = self.col[1].x;
688        let m11 = self.col[1].y;
689        let m21 = self.col[1].z;
690        let m31 = self.col[1].w;
691
692        let m02 = self.col[2].x;
693        let m12 = self.col[2].y;
694        let m22 = self.col[2].z;
695        let m32 = self.col[2].w;
696
697        let m03 = self.col[3].x;
698        let m13 = self.col[3].y;
699        let m23 = self.col[3].z;
700        let m33 = self.col[3].w;
701
702        m03 * m12 * m21 * m30 - m02 * m13 * m21 * m30 - m03 * m11 * m22 * m30
703            + m01 * m13 * m22 * m30
704            + m02 * m11 * m23 * m30
705            - m01 * m12 * m23 * m30
706            - m03 * m12 * m20 * m31
707            + m02 * m13 * m20 * m31
708            + m03 * m10 * m22 * m31
709            - m00 * m13 * m22 * m31
710            - m02 * m10 * m23 * m31
711            + m00 * m12 * m23 * m31
712            + m03 * m11 * m20 * m32
713            - m01 * m13 * m20 * m32
714            - m03 * m10 * m21 * m32
715            + m00 * m13 * m21 * m32
716            + m01 * m10 * m23 * m32
717            - m00 * m11 * m23 * m32
718            - m02 * m11 * m20 * m33
719            + m01 * m12 * m20 * m33
720            + m02 * m10 * m21 * m33
721            - m00 * m12 * m21 * m33
722            - m01 * m10 * m22 * m33
723            + m00 * m11 * m22 * m33
724    }
725
726    /// Returns the transpose of the matrix.
727    ///
728    /// ```text
729    /// Mᵀ[i,j] = M[j,i]
730    /// ```
731    pub fn transpose(&self) -> Self {
732        let m00 = self.col[0].x;
733        let m10 = self.col[0].y;
734        let m20 = self.col[0].z;
735        let m30 = self.col[0].w;
736
737        let m01 = self.col[1].x;
738        let m11 = self.col[1].y;
739        let m21 = self.col[1].z;
740        let m31 = self.col[1].w;
741
742        let m02 = self.col[2].x;
743        let m12 = self.col[2].y;
744        let m22 = self.col[2].z;
745        let m32 = self.col[2].w;
746
747        let m03 = self.col[3].x;
748        let m13 = self.col[3].y;
749        let m23 = self.col[3].z;
750        let m33 = self.col[3].w;
751
752        Self::new(
753            m00, m01, m02, m03, m10, m11, m12, m13, m20, m21, m22, m23, m30, m31, m32, m33,
754        )
755    }
756
757    /// Returns true if the matrix is affine (last row equals [0, 0, 0, 1]).
758    pub fn is_affine(&self, epsilon: T) -> bool {
759        self.col[0].w.tabs() <= epsilon
760            && self.col[1].w.tabs() <= epsilon
761            && self.col[2].w.tabs() <= epsilon
762            && (self.col[3].w - <T as One>::one()).tabs() <= epsilon
763    }
764
765    /// Computes the inverse of the matrix.
766    ///
767    /// Uses the adjugate matrix method:
768    /// ```text
769    /// M⁻¹ = (1/det(M)) * adj(M)
770    /// ```
771    ///
772    /// # Note
773    /// Returns NaN or Inf if the matrix is singular (determinant = 0).
774    pub fn inverse(&self) -> Self
775    where
776        T: FloatScalar,
777    {
778        let m00 = self.col[0].x;
779        let m10 = self.col[0].y;
780        let m20 = self.col[0].z;
781        let m30 = self.col[0].w;
782
783        let m01 = self.col[1].x;
784        let m11 = self.col[1].y;
785        let m21 = self.col[1].z;
786        let m31 = self.col[1].w;
787
788        let m02 = self.col[2].x;
789        let m12 = self.col[2].y;
790        let m22 = self.col[2].z;
791        let m32 = self.col[2].w;
792
793        let m03 = self.col[3].x;
794        let m13 = self.col[3].y;
795        let m23 = self.col[3].z;
796        let m33 = self.col[3].w;
797
798        let denom = m03 * m12 * m21 * m30 - m02 * m13 * m21 * m30 - m03 * m11 * m22 * m30
799            + m01 * m13 * m22 * m30
800            + m02 * m11 * m23 * m30
801            - m01 * m12 * m23 * m30
802            - m03 * m12 * m20 * m31
803            + m02 * m13 * m20 * m31
804            + m03 * m10 * m22 * m31
805            - m00 * m13 * m22 * m31
806            - m02 * m10 * m23 * m31
807            + m00 * m12 * m23 * m31
808            + m03 * m11 * m20 * m32
809            - m01 * m13 * m20 * m32
810            - m03 * m10 * m21 * m32
811            + m00 * m13 * m21 * m32
812            + m01 * m10 * m23 * m32
813            - m00 * m11 * m23 * m32
814            - m02 * m11 * m20 * m33
815            + m01 * m12 * m20 * m33
816            + m02 * m10 * m21 * m33
817            - m00 * m12 * m21 * m33
818            - m01 * m10 * m22 * m33
819            + m00 * m11 * m22 * m33;
820        let inv_det = <T as One>::one() / denom;
821
822        let r00 = (m12 * m23 * m31 - m13 * m22 * m31 + m13 * m21 * m32
823            - m11 * m23 * m32
824            - m12 * m21 * m33
825            + m11 * m22 * m33)
826            * inv_det;
827
828        let r01 = (m03 * m22 * m31 - m02 * m23 * m31 - m03 * m21 * m32
829            + m01 * m23 * m32
830            + m02 * m21 * m33
831            - m01 * m22 * m33)
832            * inv_det;
833
834        let r02 = (m02 * m13 * m31 - m03 * m12 * m31 + m03 * m11 * m32
835            - m01 * m13 * m32
836            - m02 * m11 * m33
837            + m01 * m12 * m33)
838            * inv_det;
839
840        let r03 = (m03 * m12 * m21 - m02 * m13 * m21 - m03 * m11 * m22
841            + m01 * m13 * m22
842            + m02 * m11 * m23
843            - m01 * m12 * m23)
844            * inv_det;
845
846        let r10 = (m13 * m22 * m30 - m12 * m23 * m30 - m13 * m20 * m32
847            + m10 * m23 * m32
848            + m12 * m20 * m33
849            - m10 * m22 * m33)
850            * inv_det;
851
852        let r11 = (m02 * m23 * m30 - m03 * m22 * m30 + m03 * m20 * m32
853            - m00 * m23 * m32
854            - m02 * m20 * m33
855            + m00 * m22 * m33)
856            * inv_det;
857
858        let r12 = (m03 * m12 * m30 - m02 * m13 * m30 - m03 * m10 * m32
859            + m00 * m13 * m32
860            + m02 * m10 * m33
861            - m00 * m12 * m33)
862            * inv_det;
863
864        let r13 = (m02 * m13 * m20 - m03 * m12 * m20 + m03 * m10 * m22
865            - m00 * m13 * m22
866            - m02 * m10 * m23
867            + m00 * m12 * m23)
868            * inv_det;
869
870        let r20 = (m11 * m23 * m30 - m13 * m21 * m30 + m13 * m20 * m31
871            - m10 * m23 * m31
872            - m11 * m20 * m33
873            + m10 * m21 * m33)
874            * inv_det;
875
876        let r21 = (m03 * m21 * m30 - m01 * m23 * m30 - m03 * m20 * m31
877            + m00 * m23 * m31
878            + m01 * m20 * m33
879            - m00 * m21 * m33)
880            * inv_det;
881
882        let r22 = (m01 * m13 * m30 - m03 * m11 * m30 + m03 * m10 * m31
883            - m00 * m13 * m31
884            - m01 * m10 * m33
885            + m00 * m11 * m33)
886            * inv_det;
887
888        let r23 = (m03 * m11 * m20 - m01 * m13 * m20 - m03 * m10 * m21
889            + m00 * m13 * m21
890            + m01 * m10 * m23
891            - m00 * m11 * m23)
892            * inv_det;
893
894        let r30 = (m12 * m21 * m30 - m11 * m22 * m30 - m12 * m20 * m31
895            + m10 * m22 * m31
896            + m11 * m20 * m32
897            - m10 * m21 * m32)
898            * inv_det;
899
900        let r31 = (m01 * m22 * m30 - m02 * m21 * m30 + m02 * m20 * m31
901            - m00 * m22 * m31
902            - m01 * m20 * m32
903            + m00 * m21 * m32)
904            * inv_det;
905
906        let r32 = (m02 * m11 * m30 - m01 * m12 * m30 - m02 * m10 * m31
907            + m00 * m12 * m31
908            + m01 * m10 * m32
909            - m00 * m11 * m32)
910            * inv_det;
911
912        let r33 = (m01 * m12 * m20 - m02 * m11 * m20 + m02 * m10 * m21
913            - m00 * m12 * m21
914            - m01 * m10 * m22
915            + m00 * m11 * m22)
916            * inv_det;
917
918        Self::new(
919            r00, r10, r20, r30, r01, r11, r21, r31, r02, r12, r22, r32, r03, r13, r23, r33,
920        )
921    }
922
923    /// Computes the inverse of an affine matrix (rotation/scale + translation).
924    ///
925    /// Assumes the last row is `[0, 0, 0, 1]`.
926    pub fn inverse_affine(&self) -> Self
927    where
928        T: FloatScalar,
929    {
930        let rot = Matrix3::new(
931            self.col[0].x,
932            self.col[0].y,
933            self.col[0].z,
934            self.col[1].x,
935            self.col[1].y,
936            self.col[1].z,
937            self.col[2].x,
938            self.col[2].y,
939            self.col[2].z,
940        );
941        let rot_inv = rot.inverse();
942        let trans = Vector3::new(self.col[3].x, self.col[3].y, self.col[3].z);
943        let trans_inv = -(rot_inv * trans);
944
945        Self::new(
946            rot_inv.col[0].x,
947            rot_inv.col[0].y,
948            rot_inv.col[0].z,
949            <T as Zero>::zero(),
950            rot_inv.col[1].x,
951            rot_inv.col[1].y,
952            rot_inv.col[1].z,
953            <T as Zero>::zero(),
954            rot_inv.col[2].x,
955            rot_inv.col[2].y,
956            rot_inv.col[2].z,
957            <T as Zero>::zero(),
958            trans_inv.x,
959            trans_inv.y,
960            trans_inv.z,
961            <T as One>::one(),
962        )
963    }
964
965    /// Multiplies two 4x4 matrices.
966    ///
967    /// Matrix multiplication follows the rule:
968    /// ```text
969    /// C[i,j] = Σₖ A[i,k] * B[k,j]
970    /// ```
971    pub fn mul_matrix_matrix(l: &Self, r: &Self) -> Self {
972        let a00 = l.col[0].x;
973        let a10 = l.col[0].y;
974        let a20 = l.col[0].z;
975        let a30 = l.col[0].w;
976
977        let a01 = l.col[1].x;
978        let a11 = l.col[1].y;
979        let a21 = l.col[1].z;
980        let a31 = l.col[1].w;
981
982        let a02 = l.col[2].x;
983        let a12 = l.col[2].y;
984        let a22 = l.col[2].z;
985        let a32 = l.col[2].w;
986
987        let a03 = l.col[3].x;
988        let a13 = l.col[3].y;
989        let a23 = l.col[3].z;
990        let a33 = l.col[3].w;
991
992        let b00 = r.col[0].x;
993        let b10 = r.col[0].y;
994        let b20 = r.col[0].z;
995        let b30 = r.col[0].w;
996
997        let b01 = r.col[1].x;
998        let b11 = r.col[1].y;
999        let b21 = r.col[1].z;
1000        let b31 = r.col[1].w;
1001
1002        let b02 = r.col[2].x;
1003        let b12 = r.col[2].y;
1004        let b22 = r.col[2].z;
1005        let b32 = r.col[2].w;
1006
1007        let b03 = r.col[3].x;
1008        let b13 = r.col[3].y;
1009        let b23 = r.col[3].z;
1010        let b33 = r.col[3].w;
1011
1012        let c00 = a00 * b00 + a01 * b10 + a02 * b20 + a03 * b30;
1013        let c01 = a00 * b01 + a01 * b11 + a02 * b21 + a03 * b31;
1014        let c02 = a00 * b02 + a01 * b12 + a02 * b22 + a03 * b32;
1015        let c03 = a00 * b03 + a01 * b13 + a02 * b23 + a03 * b33;
1016
1017        let c10 = a10 * b00 + a11 * b10 + a12 * b20 + a13 * b30;
1018        let c11 = a10 * b01 + a11 * b11 + a12 * b21 + a13 * b31;
1019        let c12 = a10 * b02 + a11 * b12 + a12 * b22 + a13 * b32;
1020        let c13 = a10 * b03 + a11 * b13 + a12 * b23 + a13 * b33;
1021
1022        let c20 = a20 * b00 + a21 * b10 + a22 * b20 + a23 * b30;
1023        let c21 = a20 * b01 + a21 * b11 + a22 * b21 + a23 * b31;
1024        let c22 = a20 * b02 + a21 * b12 + a22 * b22 + a23 * b32;
1025        let c23 = a20 * b03 + a21 * b13 + a22 * b23 + a23 * b33;
1026
1027        let c30 = a30 * b00 + a31 * b10 + a32 * b20 + a33 * b30;
1028        let c31 = a30 * b01 + a31 * b11 + a32 * b21 + a33 * b31;
1029        let c32 = a30 * b02 + a31 * b12 + a32 * b22 + a33 * b32;
1030        let c33 = a30 * b03 + a31 * b13 + a32 * b23 + a33 * b33;
1031
1032        Self::new(
1033            c00, c10, c20, c30, c01, c11, c21, c31, c02, c12, c22, c32, c03, c13, c23, c33,
1034        )
1035    }
1036
1037    /// Multiplies a 4x4 matrix by a 4D vector.
1038    ///
1039    /// ```text
1040    /// v' = M * v
1041    /// ```
1042    pub fn mul_matrix_vector(l: &Self, r: &Vector4<T>) -> Vector4<T> {
1043        Self::mul_vector_matrix(r, &l.transpose())
1044    }
1045
1046    //
1047    //                     [m0 = c0_x | m4 = c1_x | m8 = c2_x | m12= c3_x]
1048    // [v_x v_y v_z v_w] * [m1 = c0_y | m5 = c1_y | m9 = c2_y | m13= c3_y] = [dot(v, c0) dot(v, c1) dot(v, c2) dot(v, c3)]
1049    //                     [m2 = c0_z | m6 = c1_z | m10= c2_z | m14= c3_z]
1050    //                     [m3 = c0_w | m7 = c1_w | m11= c2_w | m15= c3_w]
1051    //
1052    /// Multiplies a 4D vector by a 4x4 matrix (row vector).
1053    ///
1054    /// ```text
1055    /// v' = vᵀ * M
1056    /// ```
1057    pub fn mul_vector_matrix(l: &Vector4<T>, r: &Self) -> Vector4<T> {
1058        Vector4::new(
1059            Vector4::dot(l, &r.col[0]),
1060            Vector4::dot(l, &r.col[1]),
1061            Vector4::dot(l, &r.col[2]),
1062            Vector4::dot(l, &r.col[3]),
1063        )
1064    }
1065
1066    /// Adds two matrices element-wise.
1067    ///
1068    /// ```text
1069    /// C[i,j] = A[i,j] + B[i,j]
1070    /// ```
1071    pub fn add_matrix_matrix(l: &Self, r: &Self) -> Self {
1072        Matrix4 {
1073            col: [
1074                l.col[0] + r.col[0],
1075                l.col[1] + r.col[1],
1076                l.col[2] + r.col[2],
1077                l.col[3] + r.col[3],
1078            ],
1079        }
1080    }
1081
1082    /// Subtracts two matrices element-wise.
1083    ///
1084    /// ```text
1085    /// C[i,j] = A[i,j] - B[i,j]
1086    /// ```
1087    pub fn sub_matrix_matrix(l: &Self, r: &Self) -> Self {
1088        Matrix4 {
1089            col: [
1090                l.col[0] - r.col[0],
1091                l.col[1] - r.col[1],
1092                l.col[2] - r.col[2],
1093                l.col[3] - r.col[3],
1094            ],
1095        }
1096    }
1097}
1098
1099/******************************************************************************
1100 * Operator overloading
1101 *****************************************************************************/
1102macro_rules! implMatrixOps {
1103    ($mat:ident, $vec: ident) => {
1104        impl<T: Scalar> Mul<$mat<T>> for $vec<T> {
1105            type Output = $vec<T>;
1106            fn mul(self, rhs: $mat<T>) -> $vec<T> {
1107                $mat::mul_vector_matrix(&self, &rhs)
1108            }
1109        }
1110
1111        impl<T: Scalar> Mul<$vec<T>> for $mat<T> {
1112            type Output = $vec<T>;
1113            fn mul(self, rhs: $vec<T>) -> $vec<T> {
1114                $mat::mul_matrix_vector(&self, &rhs)
1115            }
1116        }
1117
1118        impl<T: Scalar> Mul<$mat<T>> for $mat<T> {
1119            type Output = $mat<T>;
1120            fn mul(self, rhs: $mat<T>) -> $mat<T> {
1121                $mat::mul_matrix_matrix(&self, &rhs)
1122            }
1123        }
1124
1125        impl<T: Scalar> Add<$mat<T>> for $mat<T> {
1126            type Output = $mat<T>;
1127            fn add(self, rhs: $mat<T>) -> $mat<T> {
1128                $mat::add_matrix_matrix(&self, &rhs)
1129            }
1130        }
1131
1132        impl<T: Scalar> Sub<$mat<T>> for $mat<T> {
1133            type Output = $mat<T>;
1134            fn sub(self, rhs: $mat<T>) -> $mat<T> {
1135                $mat::sub_matrix_matrix(&self, &rhs)
1136            }
1137        }
1138    };
1139}
1140
1141implMatrixOps!(Matrix2, Vector2);
1142implMatrixOps!(Matrix3, Vector3);
1143implMatrixOps!(Matrix4, Vector4);
1144
1145impl<T: Scalar> Mul<Matrix4<T>> for Vector3<T> {
1146    type Output = Vector3<T>;
1147    fn mul(self, rhs: Matrix4<T>) -> Vector3<T> {
1148        Matrix4::mul_vector_matrix(
1149            &Vector4::new(self.x, self.y, self.z, <T as One>::one()),
1150            &rhs,
1151        )
1152        .xyz()
1153    }
1154}
1155
1156impl<T: Scalar> Mul<Vector3<T>> for Matrix4<T> {
1157    type Output = Vector3<T>;
1158    fn mul(self, rhs: Vector3<T>) -> Vector3<T> {
1159        Matrix4::mul_matrix_vector(&self, &Vector4::new(rhs.x, rhs.y, rhs.z, <T as One>::one()))
1160            .xyz()
1161    }
1162}
1163
1164/// Convenience extension for extracting a 3x3 submatrix.
1165pub trait Matrix4Extension<T: Scalar> {
1166    /// Returns the upper-left 3x3 rotation/scale submatrix.
1167    fn mat3(&self) -> Matrix3<T>;
1168}
1169
1170impl<T: Scalar> Matrix4Extension<T> for Matrix4<T> {
1171    fn mat3(&self) -> Matrix3<T> {
1172        Matrix3::new(
1173            self.col[0].x,
1174            self.col[0].y,
1175            self.col[0].z,
1176            self.col[1].x,
1177            self.col[1].y,
1178            self.col[1].z,
1179            self.col[2].x,
1180            self.col[2].y,
1181            self.col[2].z,
1182        )
1183    }
1184}
1185
1186#[cfg(test)]
1187mod tests {
1188    use super::*;
1189
1190    #[test]
1191    fn test_matrix2_identity() {
1192        let m = Matrix2::<f32>::identity();
1193        assert_eq!(m.col[0].x, 1.0);
1194        assert_eq!(m.col[0].y, 0.0);
1195        assert_eq!(m.col[1].x, 0.0);
1196        assert_eq!(m.col[1].y, 1.0);
1197    }
1198
1199    #[test]
1200    fn test_matrix2_determinant() {
1201        let m = Matrix2::<f32>::new(1.0, 2.0, 3.0, 4.0);
1202        let det = m.determinant();
1203        assert_eq!(det, -2.0); // 1*4 - 3*2 = -2
1204
1205        // Test singular matrix
1206        let m_singular = Matrix2::<f32>::new(1.0, 2.0, 2.0, 4.0);
1207        let det_singular = m_singular.determinant();
1208        assert_eq!(det_singular, 0.0);
1209    }
1210
1211    #[test]
1212    fn test_matrix2_inverse() {
1213        let m = Matrix2::<f32>::new(1.0, 2.0, 3.0, 4.0);
1214        let m_inv = m.inverse();
1215        let product = Matrix2::mul_matrix_matrix(&m, &m_inv);
1216
1217        // Check if product is identity
1218        assert!((product.col[0].x - 1.0).abs() < 0.001);
1219        assert!((product.col[0].y).abs() < 0.001);
1220        assert!((product.col[1].x).abs() < 0.001);
1221        assert!((product.col[1].y - 1.0).abs() < 0.001);
1222    }
1223
1224    #[test]
1225    fn test_matrix2_transpose() {
1226        let m = Matrix2::<f32>::new(1.0, 2.0, 3.0, 4.0);
1227        let mt = m.transpose();
1228        assert_eq!(mt.col[0].x, 1.0);
1229        assert_eq!(mt.col[0].y, 3.0);
1230        assert_eq!(mt.col[1].x, 2.0);
1231        assert_eq!(mt.col[1].y, 4.0);
1232
1233        // Transpose of transpose should be original
1234        let mtt = mt.transpose();
1235        assert_eq!(mtt.col[0].x, m.col[0].x);
1236        assert_eq!(mtt.col[0].y, m.col[0].y);
1237        assert_eq!(mtt.col[1].x, m.col[1].x);
1238        assert_eq!(mtt.col[1].y, m.col[1].y);
1239    }
1240
1241    #[test]
1242    fn test_matrix3_identity() {
1243        let m = Matrix3::<f32>::identity();
1244        assert_eq!(m.col[0].x, 1.0);
1245        assert_eq!(m.col[1].y, 1.0);
1246        assert_eq!(m.col[2].z, 1.0);
1247        assert_eq!(m.col[0].y, 0.0);
1248        assert_eq!(m.col[0].z, 0.0);
1249    }
1250
1251    #[test]
1252    fn test_matrix3_determinant() {
1253        let m = Matrix3::<f32>::new(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0);
1254        assert_eq!(m.determinant(), 1.0);
1255
1256        let m2 = Matrix3::<f32>::new(2.0, 3.0, 1.0, 1.0, 0.0, 2.0, 1.0, 2.0, 1.0);
1257        let det = m2.determinant();
1258        assert!((det - -3.0).abs() < 0.001);
1259    }
1260
1261    #[test]
1262    fn test_matrix3_inverse() {
1263        let m = Matrix3::<f32>::new(2.0, 3.0, 1.0, 1.0, 0.0, 2.0, 1.0, 2.0, 1.0);
1264        let m_inv = m.inverse();
1265        let product = Matrix3::mul_matrix_matrix(&m, &m_inv);
1266
1267        // Check if product is close to identity
1268        for i in 0..3 {
1269            for j in 0..3 {
1270                let val = match (i, j) {
1271                    (0, 0) => product.col[0].x,
1272                    (1, 0) => product.col[0].y,
1273                    (2, 0) => product.col[0].z,
1274                    (0, 1) => product.col[1].x,
1275                    (1, 1) => product.col[1].y,
1276                    (2, 1) => product.col[1].z,
1277                    (0, 2) => product.col[2].x,
1278                    (1, 2) => product.col[2].y,
1279                    (2, 2) => product.col[2].z,
1280                    _ => 0.0,
1281                };
1282                let expected = if i == j { 1.0 } else { 0.0 };
1283                assert!((val - expected).abs() < 0.001);
1284            }
1285        }
1286    }
1287
1288    #[test]
1289    fn test_matrix3_axis_angle_zero_axis() {
1290        let axis = Vector3::<f32>::new(0.0, 0.0, 0.0);
1291        assert!(Matrix3::of_axis_angle(&axis, 1.0, EPS_F32).is_none());
1292    }
1293
1294    #[test]
1295    fn test_matrix4_identity() {
1296        let m = Matrix4::<f32>::identity();
1297        for i in 0..4 {
1298            for j in 0..4 {
1299                let val = match j {
1300                    0 => match i {
1301                        0 => m.col[0].x,
1302                        1 => m.col[0].y,
1303                        2 => m.col[0].z,
1304                        3 => m.col[0].w,
1305                        _ => 0.0,
1306                    },
1307                    1 => match i {
1308                        0 => m.col[1].x,
1309                        1 => m.col[1].y,
1310                        2 => m.col[1].z,
1311                        3 => m.col[1].w,
1312                        _ => 0.0,
1313                    },
1314                    2 => match i {
1315                        0 => m.col[2].x,
1316                        1 => m.col[2].y,
1317                        2 => m.col[2].z,
1318                        3 => m.col[2].w,
1319                        _ => 0.0,
1320                    },
1321                    3 => match i {
1322                        0 => m.col[3].x,
1323                        1 => m.col[3].y,
1324                        2 => m.col[3].z,
1325                        3 => m.col[3].w,
1326                        _ => 0.0,
1327                    },
1328                    _ => 0.0,
1329                };
1330                let expected = if i == j { 1.0 } else { 0.0 };
1331                assert_eq!(val, expected);
1332            }
1333        }
1334    }
1335
1336    #[test]
1337    fn test_matrix_vector_multiplication() {
1338        // Test Matrix2 * Vector2
1339        let m2 = Matrix2::<f32>::new(1.0, 2.0, 3.0, 4.0);
1340        let v2 = Vector2::<f32>::new(5.0, 6.0);
1341        let result2 = m2 * v2;
1342        assert_eq!(result2.x, 23.0); // 1*5 + 3*6 = 23
1343        assert_eq!(result2.y, 34.0); // 2*5 + 4*6 = 34
1344
1345        // Test Matrix3 * Vector3
1346        let m3 = Matrix3::<f32>::identity();
1347        let v3 = Vector3::<f32>::new(1.0, 2.0, 3.0);
1348        let result3 = m3 * v3;
1349        assert_eq!(result3.x, 1.0);
1350        assert_eq!(result3.y, 2.0);
1351        assert_eq!(result3.z, 3.0);
1352
1353        // Test Matrix4 * Vector4
1354        let m4 = Matrix4::<f32>::identity();
1355        let v4 = Vector4::<f32>::new(1.0, 2.0, 3.0, 4.0);
1356        let result4 = m4 * v4;
1357        assert_eq!(result4.x, 1.0);
1358        assert_eq!(result4.y, 2.0);
1359        assert_eq!(result4.z, 3.0);
1360        assert_eq!(result4.w, 4.0);
1361    }
1362
1363    #[test]
1364    fn test_matrix_multiplication() {
1365        // Test associativity: (A * B) * C == A * (B * C)
1366        let a = Matrix2::<f32>::new(1.0, 2.0, 3.0, 4.0);
1367        let b = Matrix2::<f32>::new(5.0, 6.0, 7.0, 8.0);
1368        let c = Matrix2::<f32>::new(9.0, 10.0, 11.0, 12.0);
1369
1370        let left = (a * b) * c;
1371        let right = a * (b * c);
1372
1373        assert!((left.col[0].x - right.col[0].x).abs() < 0.001);
1374        assert!((left.col[0].y - right.col[0].y).abs() < 0.001);
1375        assert!((left.col[1].x - right.col[1].x).abs() < 0.001);
1376        assert!((left.col[1].y - right.col[1].y).abs() < 0.001);
1377    }
1378
1379    #[test]
1380    fn test_matrix_addition_subtraction() {
1381        let m1 = Matrix2::<f32>::new(1.0, 2.0, 3.0, 4.0);
1382        let m2 = Matrix2::<f32>::new(5.0, 6.0, 7.0, 8.0);
1383
1384        let sum = m1 + m2;
1385        assert_eq!(sum.col[0].x, 6.0);
1386        assert_eq!(sum.col[0].y, 8.0);
1387        assert_eq!(sum.col[1].x, 10.0);
1388        assert_eq!(sum.col[1].y, 12.0);
1389
1390        let diff = m2 - m1;
1391        assert_eq!(diff.col[0].x, 4.0);
1392        assert_eq!(diff.col[0].y, 4.0);
1393        assert_eq!(diff.col[1].x, 4.0);
1394        assert_eq!(diff.col[1].y, 4.0);
1395    }
1396
1397    #[test]
1398    fn test_matrix4_inverse() {
1399        // Test with a known invertible matrix
1400        let m = Matrix4::<f32>::new(
1401            2.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.0, 1.0, 2.0, 3.0, 1.0,
1402        );
1403
1404        let m_inv = m.inverse();
1405        let product = m * m_inv;
1406
1407        // Check if product is close to identity
1408        for i in 0..4 {
1409            for j in 0..4 {
1410                let expected = if i == j { 1.0 } else { 0.0 };
1411                let val = match j {
1412                    0 => match i {
1413                        0 => product.col[0].x,
1414                        1 => product.col[0].y,
1415                        2 => product.col[0].z,
1416                        3 => product.col[0].w,
1417                        _ => 0.0,
1418                    },
1419                    1 => match i {
1420                        0 => product.col[1].x,
1421                        1 => product.col[1].y,
1422                        2 => product.col[1].z,
1423                        3 => product.col[1].w,
1424                        _ => 0.0,
1425                    },
1426                    2 => match i {
1427                        0 => product.col[2].x,
1428                        1 => product.col[2].y,
1429                        2 => product.col[2].z,
1430                        3 => product.col[2].w,
1431                        _ => 0.0,
1432                    },
1433                    3 => match i {
1434                        0 => product.col[3].x,
1435                        1 => product.col[3].y,
1436                        2 => product.col[3].z,
1437                        3 => product.col[3].w,
1438                        _ => 0.0,
1439                    },
1440                    _ => 0.0,
1441                };
1442                assert!(
1443                    (val - expected).abs() < 0.001,
1444                    "Matrix inverse failed at [{}, {}]: expected {}, got {}",
1445                    i,
1446                    j,
1447                    expected,
1448                    val
1449                );
1450            }
1451        }
1452    }
1453
1454    #[test]
1455    fn test_matrix4_inverse_affine() {
1456        let m = Matrix4::<f32>::new(
1457            0.0, 1.0, 0.0, 0.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.0, 0.0, 3.0, 4.0, 5.0, 1.0,
1458        );
1459        assert!(m.is_affine(EPS_F32));
1460
1461        let inv_affine = m.inverse_affine();
1462        let inv_full = m.inverse();
1463
1464        for i in 0..4 {
1465            let a = inv_affine.col[i];
1466            let b = inv_full.col[i];
1467            assert!((a.x - b.x).abs() < 0.001);
1468            assert!((a.y - b.y).abs() < 0.001);
1469            assert!((a.z - b.z).abs() < 0.001);
1470            assert!((a.w - b.w).abs() < 0.001);
1471        }
1472
1473        let non_affine = Matrix4::<f32>::new(
1474            1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, -1.0, 0.0, 0.0, 0.0, 0.0,
1475        );
1476        assert!(!non_affine.is_affine(EPS_F32));
1477    }
1478
1479    #[test]
1480    fn test_matrix_try_cast() {
1481        let mi = Matrix2::<i32>::new(1, 2, 3, 4);
1482        let mf = mi
1483            .try_cast::<f32>()
1484            .expect("integer matrix should cast to f32");
1485        assert_eq!(mf.col[0].x, 1.0);
1486        assert_eq!(mf.col[0].y, 2.0);
1487        assert_eq!(mf.col[1].x, 3.0);
1488        assert_eq!(mf.col[1].y, 4.0);
1489    }
1490}