rs_math3d/
matrix.rs

1// Copyright 2020-Present (c) Raja Lehtihet & Wael El Oraiby
2//
3// Redistribution and use in source and binary forms, with or without
4// modification, are permitted provided that the following conditions are met:
5//
6// 1. Redistributions of source code must retain the above copyright notice,
7// this list of conditions and the following disclaimer.
8//
9// 2. Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12//
13// 3. Neither the name of the copyright holder nor the names of its contributors
14// may be used to endorse or promote products derived from this software without
15// specific prior written permission.
16//
17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
21// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27// POSSIBILITY OF SUCH DAMAGE.
28//! Matrix mathematics module providing 2x2, 3x3, and 4x4 matrices.
29//!
30//! This module provides square matrix types commonly used in computer graphics
31//! and linear algebra. Matrices are stored in column-major order for compatibility
32//! with graphics APIs like OpenGL.
33//!
34//! # Examples
35//!
36//! ```
37//! use rs_math3d::matrix::Matrix4;
38//! use rs_math3d::vector::Vector4;
39//! 
40//! let m = Matrix4::<f32>::identity();
41//! let v = Vector4::new(1.0, 2.0, 3.0, 1.0);
42//! let result = m * v; // Transform vector
43//! ```
44
45use crate::scalar::*;
46use crate::vector::*;
47use num_traits::{Zero, One};
48use core::ops::*;
49
50/// A 2x2 matrix stored in column-major order.
51///
52/// # Layout
53/// ```text
54/// [m₀₀ m₀₁]
55/// [m₁₀ m₁₁]
56/// ```
57/// where `col[j][i]` represents element at row i, column j.
58#[repr(C)]
59#[derive(Clone, Copy, Debug)]
60pub struct Matrix2<T: Scalar> {
61    /// Column vectors of the matrix
62    pub col: [Vector2<T>; 2],
63}
64
65/// A 3x3 matrix stored in column-major order.
66///
67/// Commonly used for 2D transformations (with homogeneous coordinates)
68/// and 3D rotations.
69///
70/// # Layout
71/// ```text
72/// [m₀₀ m₀₁ m₀₂]
73/// [m₁₀ m₁₁ m₁₂]
74/// [m₂₀ m₂₁ m₂₂]
75/// ```
76#[repr(C)]
77#[derive(Clone, Copy, Debug)]
78pub struct Matrix3<T: Scalar> {
79    /// Column vectors of the matrix
80    pub col: [Vector3<T>; 3],
81}
82
83/// A 4x4 matrix stored in column-major order.
84///
85/// The standard matrix for 3D transformations using homogeneous coordinates.
86///
87/// # Layout
88/// ```text
89/// [m₀₀ m₀₁ m₀₂ m₀₃]
90/// [m₁₀ m₁₁ m₁₂ m₁₃]
91/// [m₂₀ m₂₁ m₂₂ m₂₃]
92/// [m₃₀ m₃₁ m₃₂ m₃₃]
93/// ```
94#[repr(C)]
95#[derive(Clone, Copy, Debug)]
96pub struct Matrix4<T: Scalar> {
97    /// Column vectors of the matrix
98    pub col: [Vector4<T>; 4],
99}
100
101/******************************************************************************
102 * Matrix2
103 *
104 * i j ------------------->
105 * | [m0 = c0_x | m2 = c1_x]
106 * V [m1 = c0_y | m3 = c1_y]
107 *
108 *  aij => i = row, j = col (yx form)
109 *
110 *****************************************************************************/
111impl<T: Scalar> Matrix2<T> {
112    /// Creates a new 2x2 matrix from individual elements.
113    ///
114    /// Elements are provided in column-major order:
115    /// ```text
116    /// [m0 m2]
117    /// [m1 m3]
118    /// ```
119    pub fn new(m0: T, m1: T, m2: T, m3: T) -> Self {
120        Matrix2 {
121            col: [Vector2::new(m0, m1), Vector2::new(m2, m3)],
122        }
123    }
124
125    /// Returns the 2x2 identity matrix.
126    ///
127    /// ```text
128    /// [1 0]
129    /// [0 1]
130    /// ```
131    pub fn identity() -> Self {
132        Self::new(<T as One>::one(), <T as Zero>::zero(), <T as Zero>::zero(), <T as One>::one())
133    }
134
135    /// Computes the determinant of the matrix.
136    ///
137    /// For a 2x2 matrix:
138    /// ```text
139    /// det(M) = m₀₀m₁₁ - m₀₁m₁₀
140    /// ```
141    pub fn determinant(&self) -> T {
142        let m00 = self.col[0].x;
143        let m10 = self.col[0].y;
144
145        let m01 = self.col[1].x;
146        let m11 = self.col[1].y;
147
148        m00 * m11 - m01 * m10
149    }
150
151    /// Returns the transpose of the matrix.
152    ///
153    /// ```text
154    /// Mᵀ[i,j] = M[j,i]
155    /// ```
156    pub fn transpose(&self) -> Self {
157        let m00 = self.col[0].x;
158        let m10 = self.col[0].y;
159
160        let m01 = self.col[1].x;
161        let m11 = self.col[1].y;
162
163        Self::new(m00, m01, m10, m11)
164    }
165
166    /// Computes the inverse of the matrix.
167    ///
168    /// For a 2x2 matrix:
169    /// ```text
170    /// M⁻¹ = (1/det(M)) * [m₁₁  -m₀₁]
171    ///                    [-m₁₀  m₀₀]
172    /// ```
173    ///
174    /// # Note
175    /// Returns NaN or Inf if the matrix is singular (determinant = 0).
176    pub fn inverse(&self) -> Self {
177        let m00 = self.col[0].x;
178        let m10 = self.col[0].y;
179
180        let m01 = self.col[1].x;
181        let m11 = self.col[1].y;
182
183        let inv_det = <T as One>::one() / (m00 * m11 - m01 * m10);
184
185        let r00 = m11 * inv_det;
186        let r01 = -m01 * inv_det;
187        let r10 = -m10 * inv_det;
188        let r11 = m00 * inv_det;
189
190        Self::new(r00, r10, r01, r11)
191    }
192
193    /// Multiplies two 2x2 matrices.
194    ///
195    /// Matrix multiplication follows the rule:
196    /// ```text
197    /// C[i,j] = Σₖ A[i,k] * B[k,j]
198    /// ```
199    pub fn mul_matrix_matrix(l: &Self, r: &Self) -> Self {
200        let a00 = l.col[0].x;
201        let a10 = l.col[0].y;
202        let a01 = l.col[1].x;
203        let a11 = l.col[1].y;
204
205        let b00 = r.col[0].x;
206        let b10 = r.col[0].y;
207        let b01 = r.col[1].x;
208        let b11 = r.col[1].y;
209
210        let c00 = a00 * b00 + a01 * b10;
211        let c01 = a00 * b01 + a01 * b11;
212        let c10 = a10 * b00 + a11 * b10;
213        let c11 = a10 * b01 + a11 * b11;
214
215        Self::new(c00, c10, c01, c11)
216    }
217
218    /// Multiplies a 2x2 matrix by a 2D vector.
219    ///
220    /// Transforms the vector by the matrix:
221    /// ```text
222    /// v' = M * v
223    /// ```
224    pub fn mul_matrix_vector(l: &Self, r: &Vector2<T>) -> Vector2<T> {
225        Self::mul_vector_matrix(r, &l.transpose())
226    }
227
228    /// Multiplies a 2D vector by a 2x2 matrix (row vector).
229    ///
230    /// ```text
231    /// v' = vᵀ * M
232    /// ```
233    pub fn mul_vector_matrix(l: &Vector2<T>, r: &Self) -> Vector2<T> {
234        Vector2::new(Vector2::dot(l, &r.col[0]), Vector2::dot(l, &r.col[1]))
235    }
236
237    /// Adds two matrices element-wise.
238    ///
239    /// ```text
240    /// C[i,j] = A[i,j] + B[i,j]
241    /// ```
242    pub fn add_matrix_matrix(l: &Self, r: &Self) -> Self {
243        Matrix2 {
244            col: [l.col[0] + r.col[0], l.col[1] + r.col[1]],
245        }
246    }
247
248    /// Subtracts two matrices element-wise.
249    ///
250    /// ```text
251    /// C[i,j] = A[i,j] - B[i,j]
252    /// ```
253    pub fn sub_matrix_matrix(l: &Self, r: &Self) -> Self {
254        Matrix2 {
255            col: [l.col[0] - r.col[0], l.col[1] - r.col[1]],
256        }
257    }
258}
259
260/******************************************************************************
261 * Matrix3
262 *
263 * i j -------------------------------->
264 * | [m0 = c0_x | m3 = c1_x | m6 = c2_x]
265 * | [m1 = c0_y | m4 = c1_y | m7 = c2_y]
266 * V [m2 = c0_z | m5 = c1_z | m8 = c2_z]
267 *
268 *  aij => i = row, j = col (yx form)
269 *
270 *****************************************************************************/
271impl<T: Scalar> Matrix3<T> {
272    pub fn new(m0: T, m1: T, m2: T, m3: T, m4: T, m5: T, m6: T, m7: T, m8: T) -> Self {
273        Matrix3 {
274            col: [
275                Vector3::new(m0, m1, m2),
276                Vector3::new(m3, m4, m5),
277                Vector3::new(m6, m7, m8),
278            ],
279        }
280    }
281
282    pub fn identity() -> Self {
283        Self::new(
284            <T as One>::one(),
285            <T as Zero>::zero(),
286            <T as Zero>::zero(),
287            <T as Zero>::zero(),
288            <T as One>::one(),
289            <T as Zero>::zero(),
290            <T as Zero>::zero(),
291            <T as Zero>::zero(),
292            <T as One>::one(),
293        )
294    }
295
296    pub fn determinant(&self) -> T {
297        let m00 = self.col[0].x;
298        let m10 = self.col[0].y;
299        let m20 = self.col[0].z;
300
301        let m01 = self.col[1].x;
302        let m11 = self.col[1].y;
303        let m21 = self.col[1].z;
304
305        let m02 = self.col[2].x;
306        let m12 = self.col[2].y;
307        let m22 = self.col[2].z;
308
309        m00 * m11 * m22 + m01 * m12 * m20 + m02 * m10 * m21
310            - m00 * m12 * m21
311            - m01 * m10 * m22
312            - m02 * m11 * m20
313    }
314
315    /// Returns the transpose of the matrix.
316    ///
317    /// ```text
318    /// Mᵀ[i,j] = M[j,i]
319    /// ```
320    pub fn transpose(&self) -> Self {
321        let m00 = self.col[0].x;
322        let m10 = self.col[0].y;
323        let m20 = self.col[0].z;
324
325        let m01 = self.col[1].x;
326        let m11 = self.col[1].y;
327        let m21 = self.col[1].z;
328
329        let m02 = self.col[2].x;
330        let m12 = self.col[2].y;
331        let m22 = self.col[2].z;
332
333        Self::new(m00, m01, m02, m10, m11, m12, m20, m21, m22)
334    }
335
336    /// Computes the inverse of the matrix.
337    ///
338    /// Uses the adjugate matrix method:
339    /// ```text
340    /// M⁻¹ = (1/det(M)) * adj(M)
341    /// ```
342    ///
343    /// # Note
344    /// Returns NaN or Inf if the matrix is singular (determinant = 0).
345    pub fn inverse(&self) -> Self {
346        let m00 = self.col[0].x;
347        let m10 = self.col[0].y;
348        let m20 = self.col[0].z;
349
350        let m01 = self.col[1].x;
351        let m11 = self.col[1].y;
352        let m21 = self.col[1].z;
353
354        let m02 = self.col[2].x;
355        let m12 = self.col[2].y;
356        let m22 = self.col[2].z;
357
358        let inv_det = <T as One>::one()
359            / (m00 * m11 * m22 + m01 * m12 * m20 + m02 * m10 * m21
360                - m00 * m12 * m21
361                - m01 * m10 * m22
362                - m02 * m11 * m20);
363
364        let r00 = (m11 * m22 - m12 * m21) * inv_det;
365        let r01 = (m02 * m21 - m01 * m22) * inv_det;
366        let r02 = (m01 * m12 - m02 * m11) * inv_det;
367        let r10 = (m12 * m20 - m10 * m22) * inv_det;
368        let r11 = (m00 * m22 - m02 * m20) * inv_det;
369        let r12 = (m02 * m10 - m00 * m12) * inv_det;
370        let r20 = (m10 * m21 - m11 * m20) * inv_det;
371        let r21 = (m01 * m20 - m00 * m21) * inv_det;
372        let r22 = (m00 * m11 - m01 * m10) * inv_det;
373
374        Self::new(r00, r10, r20, r01, r11, r21, r02, r12, r22)
375    }
376
377    /// Multiplies two 3x3 matrices.
378    ///
379    /// Matrix multiplication follows the rule:
380    /// ```text
381    /// C[i,j] = Σₖ A[i,k] * B[k,j]
382    /// ```
383    pub fn mul_matrix_matrix(l: &Self, r: &Self) -> Self {
384        let a00 = l.col[0].x;
385        let a10 = l.col[0].y;
386        let a20 = l.col[0].z;
387
388        let a01 = l.col[1].x;
389        let a11 = l.col[1].y;
390        let a21 = l.col[1].z;
391
392        let a02 = l.col[2].x;
393        let a12 = l.col[2].y;
394        let a22 = l.col[2].z;
395
396        let b00 = r.col[0].x;
397        let b10 = r.col[0].y;
398        let b20 = r.col[0].z;
399
400        let b01 = r.col[1].x;
401        let b11 = r.col[1].y;
402        let b21 = r.col[1].z;
403
404        let b02 = r.col[2].x;
405        let b12 = r.col[2].y;
406        let b22 = r.col[2].z;
407
408        let c00 = a00 * b00 + a01 * b10 + a02 * b20;
409        let c01 = a00 * b01 + a01 * b11 + a02 * b21;
410        let c02 = a00 * b02 + a01 * b12 + a02 * b22;
411
412        let c10 = a10 * b00 + a11 * b10 + a12 * b20;
413        let c11 = a10 * b01 + a11 * b11 + a12 * b21;
414        let c12 = a10 * b02 + a11 * b12 + a12 * b22;
415
416        let c20 = a20 * b00 + a21 * b10 + a22 * b20;
417        let c21 = a20 * b01 + a21 * b11 + a22 * b21;
418        let c22 = a20 * b02 + a21 * b12 + a22 * b22;
419
420        Self::new(c00, c10, c20, c01, c11, c21, c02, c12, c22)
421    }
422
423    /// Multiplies a 3x3 matrix by a 3D vector.
424    ///
425    /// Transforms the vector by the matrix:
426    /// ```text
427    /// v' = M * v
428    /// ```
429    pub fn mul_matrix_vector(l: &Self, r: &Vector3<T>) -> Vector3<T> {
430        Self::mul_vector_matrix(r, &l.transpose())
431    }
432
433    /// Multiplies a 3D vector by a 3x3 matrix (row vector).
434    ///
435    /// ```text
436    /// v' = vᵀ * M
437    /// ```
438    pub fn mul_vector_matrix(l: &Vector3<T>, r: &Self) -> Vector3<T> {
439        Vector3::new(
440            Vector3::dot(l, &r.col[0]),
441            Vector3::dot(l, &r.col[1]),
442            Vector3::dot(l, &r.col[2]),
443        )
444    }
445
446    /// Adds two matrices element-wise.
447    ///
448    /// ```text
449    /// C[i,j] = A[i,j] + B[i,j]
450    /// ```
451    pub fn add_matrix_matrix(l: &Self, r: &Self) -> Self {
452        Matrix3 {
453            col: [
454                l.col[0] + r.col[0],
455                l.col[1] + r.col[1],
456                l.col[2] + r.col[2],
457            ],
458        }
459    }
460
461    /// Subtracts two matrices element-wise.
462    ///
463    /// ```text
464    /// C[i,j] = A[i,j] - B[i,j]
465    /// ```
466    pub fn sub_matrix_matrix(l: &Self, r: &Self) -> Self {
467        Matrix3 {
468            col: [
469                l.col[0] - r.col[0],
470                l.col[1] - r.col[1],
471                l.col[2] - r.col[2],
472            ],
473        }
474    }
475}
476
477impl<T: FloatScalar> Matrix3<T> {
478    /// Creates a 3x3 rotation matrix from an axis and angle.
479    ///
480    /// Uses Rodrigues' rotation formula:
481    /// ```text
482    /// R = I + sin(θ)K + (1 - cos(θ))K²
483    /// ```
484    /// where K is the cross-product matrix of the normalized axis.
485    ///
486    /// # Parameters
487    /// - `axis`: The rotation axis (will be normalized)
488    /// - `angle`: The rotation angle in radians
489    pub fn of_axis_angle(axis: &Vector3<T>, angle: T) -> Self {
490        let c = T::tcos(angle);
491        let s = T::tsin(angle);
492        let n = Vector3::normalize(axis);
493        let ux = n.x;
494        let uy = n.y;
495        let uz = n.z;
496        let uxx = ux * ux;
497        let uyy = uy * uy;
498        let uzz = uz * uz;
499
500        let oc = <T as One>::one() - c;
501
502        let m0 = c + uxx * oc;
503        let m1 = uy * ux * oc + uz * s;
504        let m2 = uz * ux * oc - uy * s;
505
506        let m3 = ux * uy * oc - uz * s;
507        let m4 = c + uyy * oc;
508        let m5 = uz * uy * oc + ux * s;
509
510        let m6 = ux * uz * oc + uy * s;
511        let m7 = uy * uz * oc - ux * s;
512        let m8 = c + uzz * oc;
513
514        Self::new(m0, m1, m2, m3, m4, m5, m6, m7, m8)
515    }
516}
517
518/******************************************************************************
519 * Matrix4
520 *
521 * i j -------------------------------------------->
522 * | [m0 = c0_x | m4 = c1_x | m8 = c2_x | m12= c3_x]
523 * | [m1 = c0_y | m5 = c1_y | m9 = c2_y | m13= c3_y]
524 * | [m2 = c0_z | m6 = c1_z | m10= c2_z | m14= c3_z]
525 * V [m3 = c0_w | m7 = c1_w | m11= c2_w | m15= c3_w]
526 *
527 *  aij => i = row, j = col (yx form)
528 *
529 *****************************************************************************/
530impl<T: Scalar> Matrix4<T> {
531    pub fn new(
532        m0: T,
533        m1: T,
534        m2: T,
535        m3: T,
536        m4: T,
537        m5: T,
538        m6: T,
539        m7: T,
540        m8: T,
541        m9: T,
542        m10: T,
543        m11: T,
544        m12: T,
545        m13: T,
546        m14: T,
547        m15: T,
548    ) -> Self {
549        Matrix4 {
550            col: [
551                Vector4::new(m0, m1, m2, m3),
552                Vector4::new(m4, m5, m6, m7),
553                Vector4::new(m8, m9, m10, m11),
554                Vector4::new(m12, m13, m14, m15),
555            ],
556        }
557    }
558
559    pub fn identity() -> Self {
560        Self::new(
561            <T as One>::one(),
562            <T as Zero>::zero(),
563            <T as Zero>::zero(),
564            <T as Zero>::zero(),
565            <T as Zero>::zero(),
566            <T as One>::one(),
567            <T as Zero>::zero(),
568            <T as Zero>::zero(),
569            <T as Zero>::zero(),
570            <T as Zero>::zero(),
571            <T as One>::one(),
572            <T as Zero>::zero(),
573            <T as Zero>::zero(),
574            <T as Zero>::zero(),
575            <T as Zero>::zero(),
576            <T as One>::one(),
577        )
578    }
579
580    /// Computes the determinant of the matrix.
581    ///
582    /// Uses Laplace expansion along the first column.
583    /// A non-zero determinant indicates the matrix is invertible.
584    pub fn determinant(&self) -> T {
585        let m00 = self.col[0].x;
586        let m10 = self.col[0].y;
587        let m20 = self.col[0].z;
588        let m30 = self.col[0].w;
589
590        let m01 = self.col[1].x;
591        let m11 = self.col[1].y;
592        let m21 = self.col[1].z;
593        let m31 = self.col[1].w;
594
595        let m02 = self.col[2].x;
596        let m12 = self.col[2].y;
597        let m22 = self.col[2].z;
598        let m32 = self.col[2].w;
599
600        let m03 = self.col[3].x;
601        let m13 = self.col[3].y;
602        let m23 = self.col[3].z;
603        let m33 = self.col[3].w;
604
605        m03 * m12 * m21 * m30 - m02 * m13 * m21 * m30 - m03 * m11 * m22 * m30
606            + m01 * m13 * m22 * m30
607            + m02 * m11 * m23 * m30
608            - m01 * m12 * m23 * m30
609            - m03 * m12 * m20 * m31
610            + m02 * m13 * m20 * m31
611            + m03 * m10 * m22 * m31
612            - m00 * m13 * m22 * m31
613            - m02 * m10 * m23 * m31
614            + m00 * m12 * m23 * m31
615            + m03 * m11 * m20 * m32
616            - m01 * m13 * m20 * m32
617            - m03 * m10 * m21 * m32
618            + m00 * m13 * m21 * m32
619            + m01 * m10 * m23 * m32
620            - m00 * m11 * m23 * m32
621            - m02 * m11 * m20 * m33
622            + m01 * m12 * m20 * m33
623            + m02 * m10 * m21 * m33
624            - m00 * m12 * m21 * m33
625            - m01 * m10 * m22 * m33
626            + m00 * m11 * m22 * m33
627    }
628
629    /// Returns the transpose of the matrix.
630    ///
631    /// ```text
632    /// Mᵀ[i,j] = M[j,i]
633    /// ```
634    pub fn transpose(&self) -> Self {
635        let m00 = self.col[0].x;
636        let m10 = self.col[0].y;
637        let m20 = self.col[0].z;
638        let m30 = self.col[0].w;
639
640        let m01 = self.col[1].x;
641        let m11 = self.col[1].y;
642        let m21 = self.col[1].z;
643        let m31 = self.col[1].w;
644
645        let m02 = self.col[2].x;
646        let m12 = self.col[2].y;
647        let m22 = self.col[2].z;
648        let m32 = self.col[2].w;
649
650        let m03 = self.col[3].x;
651        let m13 = self.col[3].y;
652        let m23 = self.col[3].z;
653        let m33 = self.col[3].w;
654
655        Self::new(
656            m00, m01, m02, m03, m10, m11, m12, m13, m20, m21, m22, m23, m30, m31, m32, m33,
657        )
658    }
659
660    /// Computes the inverse of the matrix.
661    ///
662    /// Uses the adjugate matrix method:
663    /// ```text
664    /// M⁻¹ = (1/det(M)) * adj(M)
665    /// ```
666    ///
667    /// # Note
668    /// Returns NaN or Inf if the matrix is singular (determinant = 0).
669    pub fn inverse(&self) -> Self {
670        let m00 = self.col[0].x;
671        let m10 = self.col[0].y;
672        let m20 = self.col[0].z;
673        let m30 = self.col[0].w;
674
675        let m01 = self.col[1].x;
676        let m11 = self.col[1].y;
677        let m21 = self.col[1].z;
678        let m31 = self.col[1].w;
679
680        let m02 = self.col[2].x;
681        let m12 = self.col[2].y;
682        let m22 = self.col[2].z;
683        let m32 = self.col[2].w;
684
685        let m03 = self.col[3].x;
686        let m13 = self.col[3].y;
687        let m23 = self.col[3].z;
688        let m33 = self.col[3].w;
689
690        let denom = m03 * m12 * m21 * m30 - m02 * m13 * m21 * m30 - m03 * m11 * m22 * m30
691            + m01 * m13 * m22 * m30
692            + m02 * m11 * m23 * m30
693            - m01 * m12 * m23 * m30
694            - m03 * m12 * m20 * m31
695            + m02 * m13 * m20 * m31
696            + m03 * m10 * m22 * m31
697            - m00 * m13 * m22 * m31
698            - m02 * m10 * m23 * m31
699            + m00 * m12 * m23 * m31
700            + m03 * m11 * m20 * m32
701            - m01 * m13 * m20 * m32
702            - m03 * m10 * m21 * m32
703            + m00 * m13 * m21 * m32
704            + m01 * m10 * m23 * m32
705            - m00 * m11 * m23 * m32
706            - m02 * m11 * m20 * m33
707            + m01 * m12 * m20 * m33
708            + m02 * m10 * m21 * m33
709            - m00 * m12 * m21 * m33
710            - m01 * m10 * m22 * m33
711            + m00 * m11 * m22 * m33;
712        let inv_det = <T as One>::one() / denom;
713
714        let r00 = (m12 * m23 * m31 - m13 * m22 * m31 + m13 * m21 * m32
715            - m11 * m23 * m32
716            - m12 * m21 * m33
717            + m11 * m22 * m33)
718            * inv_det;
719
720        let r01 = (m03 * m22 * m31 - m02 * m23 * m31 - m03 * m21 * m32
721            + m01 * m23 * m32
722            + m02 * m21 * m33
723            - m01 * m22 * m33)
724            * inv_det;
725
726        let r02 = (m02 * m13 * m31 - m03 * m12 * m31 + m03 * m11 * m32
727            - m01 * m13 * m32
728            - m02 * m11 * m33
729            + m01 * m12 * m33)
730            * inv_det;
731
732        let r03 = (m03 * m12 * m21 - m02 * m13 * m21 - m03 * m11 * m22
733            + m01 * m13 * m22
734            + m02 * m11 * m23
735            - m01 * m12 * m23)
736            * inv_det;
737
738        let r10 = (m13 * m22 * m30 - m12 * m23 * m30 - m13 * m20 * m32
739            + m10 * m23 * m32
740            + m12 * m20 * m33
741            - m10 * m22 * m33)
742            * inv_det;
743
744        let r11 = (m02 * m23 * m30 - m03 * m22 * m30 + m03 * m20 * m32
745            - m00 * m23 * m32
746            - m02 * m20 * m33
747            + m00 * m22 * m33)
748            * inv_det;
749
750        let r12 = (m03 * m12 * m30 - m02 * m13 * m30 - m03 * m10 * m32
751            + m00 * m13 * m32
752            + m02 * m10 * m33
753            - m00 * m12 * m33)
754            * inv_det;
755
756        let r13 = (m02 * m13 * m20 - m03 * m12 * m20 + m03 * m10 * m22
757            - m00 * m13 * m22
758            - m02 * m10 * m23
759            + m00 * m12 * m23)
760            * inv_det;
761
762        let r20 = (m11 * m23 * m30 - m13 * m21 * m30 + m13 * m20 * m31
763            - m10 * m23 * m31
764            - m11 * m20 * m33
765            + m10 * m21 * m33)
766            * inv_det;
767
768        let r21 = (m03 * m21 * m30 - m01 * m23 * m30 - m03 * m20 * m31
769            + m00 * m23 * m31
770            + m01 * m20 * m33
771            - m00 * m21 * m33)
772            * inv_det;
773
774        let r22 = (m01 * m13 * m30 - m03 * m11 * m30 + m03 * m10 * m31
775            - m00 * m13 * m31
776            - m01 * m10 * m33
777            + m00 * m11 * m33)
778            * inv_det;
779
780        let r23 = (m03 * m11 * m20 - m01 * m13 * m20 - m03 * m10 * m21
781            + m00 * m13 * m21
782            + m01 * m10 * m23
783            - m00 * m11 * m23)
784            * inv_det;
785
786        let r30 = (m12 * m21 * m30 - m11 * m22 * m30 - m12 * m20 * m31
787            + m10 * m22 * m31
788            + m11 * m20 * m32
789            - m10 * m21 * m32)
790            * inv_det;
791
792        let r31 = (m01 * m22 * m30 - m02 * m21 * m30 + m02 * m20 * m31
793            - m00 * m22 * m31
794            - m01 * m20 * m32
795            + m00 * m21 * m32)
796            * inv_det;
797
798        let r32 = (m02 * m11 * m30 - m01 * m12 * m30 - m02 * m10 * m31
799            + m00 * m12 * m31
800            + m01 * m10 * m32
801            - m00 * m11 * m32)
802            * inv_det;
803
804        let r33 = (m01 * m12 * m20 - m02 * m11 * m20 + m02 * m10 * m21
805            - m00 * m12 * m21
806            - m01 * m10 * m22
807            + m00 * m11 * m22)
808            * inv_det;
809
810        Self::new(
811            r00, r10, r20, r30, r01, r11, r21, r31, r02, r12, r22, r32, r03, r13, r23, r33,
812        )
813    }
814
815    /// Multiplies two 4x4 matrices.
816    ///
817    /// Matrix multiplication follows the rule:
818    /// ```text
819    /// C[i,j] = Σₖ A[i,k] * B[k,j]
820    /// ```
821    pub fn mul_matrix_matrix(l: &Self, r: &Self) -> Self {
822        let a00 = l.col[0].x;
823        let a10 = l.col[0].y;
824        let a20 = l.col[0].z;
825        let a30 = l.col[0].w;
826
827        let a01 = l.col[1].x;
828        let a11 = l.col[1].y;
829        let a21 = l.col[1].z;
830        let a31 = l.col[1].w;
831
832        let a02 = l.col[2].x;
833        let a12 = l.col[2].y;
834        let a22 = l.col[2].z;
835        let a32 = l.col[2].w;
836
837        let a03 = l.col[3].x;
838        let a13 = l.col[3].y;
839        let a23 = l.col[3].z;
840        let a33 = l.col[3].w;
841
842        let b00 = r.col[0].x;
843        let b10 = r.col[0].y;
844        let b20 = r.col[0].z;
845        let b30 = r.col[0].w;
846
847        let b01 = r.col[1].x;
848        let b11 = r.col[1].y;
849        let b21 = r.col[1].z;
850        let b31 = r.col[1].w;
851
852        let b02 = r.col[2].x;
853        let b12 = r.col[2].y;
854        let b22 = r.col[2].z;
855        let b32 = r.col[2].w;
856
857        let b03 = r.col[3].x;
858        let b13 = r.col[3].y;
859        let b23 = r.col[3].z;
860        let b33 = r.col[3].w;
861
862        let c00 = a00 * b00 + a01 * b10 + a02 * b20 + a03 * b30;
863        let c01 = a00 * b01 + a01 * b11 + a02 * b21 + a03 * b31;
864        let c02 = a00 * b02 + a01 * b12 + a02 * b22 + a03 * b32;
865        let c03 = a00 * b03 + a01 * b13 + a02 * b23 + a03 * b33;
866
867        let c10 = a10 * b00 + a11 * b10 + a12 * b20 + a13 * b30;
868        let c11 = a10 * b01 + a11 * b11 + a12 * b21 + a13 * b31;
869        let c12 = a10 * b02 + a11 * b12 + a12 * b22 + a13 * b32;
870        let c13 = a10 * b03 + a11 * b13 + a12 * b23 + a13 * b33;
871
872        let c20 = a20 * b00 + a21 * b10 + a22 * b20 + a23 * b30;
873        let c21 = a20 * b01 + a21 * b11 + a22 * b21 + a23 * b31;
874        let c22 = a20 * b02 + a21 * b12 + a22 * b22 + a23 * b32;
875        let c23 = a20 * b03 + a21 * b13 + a22 * b23 + a23 * b33;
876
877        let c30 = a30 * b00 + a31 * b10 + a32 * b20 + a33 * b30;
878        let c31 = a30 * b01 + a31 * b11 + a32 * b21 + a33 * b31;
879        let c32 = a30 * b02 + a31 * b12 + a32 * b22 + a33 * b32;
880        let c33 = a30 * b03 + a31 * b13 + a32 * b23 + a33 * b33;
881
882        Self::new(
883            c00, c10, c20, c30, c01, c11, c21, c31, c02, c12, c22, c32, c03, c13, c23, c33,
884        )
885    }
886
887    pub fn mul_matrix_vector(l: &Self, r: &Vector4<T>) -> Vector4<T> {
888        Self::mul_vector_matrix(r, &l.transpose())
889    }
890
891    //
892    //                     [m0 = c0_x | m4 = c1_x | m8 = c2_x | m12= c3_x]
893    // [v_x v_y v_z v_w] * [m1 = c0_y | m5 = c1_y | m9 = c2_y | m13= c3_y] = [dot(v, c0) dot(v, c1) dot(v, c2) dot(v, c3)]
894    //                     [m2 = c0_z | m6 = c1_z | m10= c2_z | m14= c3_z]
895    //                     [m3 = c0_w | m7 = c1_w | m11= c2_w | m15= c3_w]
896    //
897    pub fn mul_vector_matrix(l: &Vector4<T>, r: &Self) -> Vector4<T> {
898        Vector4::new(
899            Vector4::dot(l, &r.col[0]),
900            Vector4::dot(l, &r.col[1]),
901            Vector4::dot(l, &r.col[2]),
902            Vector4::dot(l, &r.col[3]),
903        )
904    }
905
906    /// Adds two matrices element-wise.
907    ///
908    /// ```text
909    /// C[i,j] = A[i,j] + B[i,j]
910    /// ```
911    pub fn add_matrix_matrix(l: &Self, r: &Self) -> Self {
912        Matrix4 {
913            col: [
914                l.col[0] + r.col[0],
915                l.col[1] + r.col[1],
916                l.col[2] + r.col[2],
917                l.col[3] + r.col[3],
918            ],
919        }
920    }
921
922    /// Subtracts two matrices element-wise.
923    ///
924    /// ```text
925    /// C[i,j] = A[i,j] - B[i,j]
926    /// ```
927    pub fn sub_matrix_matrix(l: &Self, r: &Self) -> Self {
928        Matrix4 {
929            col: [
930                l.col[0] - r.col[0],
931                l.col[1] - r.col[1],
932                l.col[2] - r.col[2],
933                l.col[3] - r.col[3],
934            ],
935        }
936    }
937}
938
939/******************************************************************************
940 * Operator overloading
941 *****************************************************************************/
942macro_rules! implMatrixOps {
943    ($mat:ident, $vec: ident) => {
944        impl<T: Scalar> Mul<$mat<T>> for $vec<T> {
945            type Output = $vec<T>;
946            fn mul(self, rhs: $mat<T>) -> $vec<T> {
947                $mat::mul_vector_matrix(&self, &rhs)
948            }
949        }
950
951        impl<T: Scalar> Mul<$vec<T>> for $mat<T> {
952            type Output = $vec<T>;
953            fn mul(self, rhs: $vec<T>) -> $vec<T> {
954                $mat::mul_matrix_vector(&self, &rhs)
955            }
956        }
957
958        impl<T: Scalar> Mul<$mat<T>> for $mat<T> {
959            type Output = $mat<T>;
960            fn mul(self, rhs: $mat<T>) -> $mat<T> {
961                $mat::mul_matrix_matrix(&self, &rhs)
962            }
963        }
964
965        impl<T: Scalar> Add<$mat<T>> for $mat<T> {
966            type Output = $mat<T>;
967            fn add(self, rhs: $mat<T>) -> $mat<T> {
968                $mat::add_matrix_matrix(&self, &rhs)
969            }
970        }
971
972        impl<T: Scalar> Sub<$mat<T>> for $mat<T> {
973            type Output = $mat<T>;
974            fn sub(self, rhs: $mat<T>) -> $mat<T> {
975                $mat::sub_matrix_matrix(&self, &rhs)
976            }
977        }
978    };
979}
980
981implMatrixOps!(Matrix2, Vector2);
982implMatrixOps!(Matrix3, Vector3);
983implMatrixOps!(Matrix4, Vector4);
984
985impl<T: Scalar> Mul<Matrix4<T>> for Vector3<T> {
986    type Output = Vector3<T>;
987    fn mul(self, rhs: Matrix4<T>) -> Vector3<T> {
988        Matrix4::mul_vector_matrix(&Vector4::new(self.x, self.y, self.z, <T as One>::one()), &rhs).xyz()
989    }
990}
991
992impl<T: Scalar> Mul<Vector3<T>> for Matrix4<T> {
993    type Output = Vector3<T>;
994    fn mul(self, rhs: Vector3<T>) -> Vector3<T> {
995        Matrix4::mul_matrix_vector(&self, &Vector4::new(rhs.x, rhs.y, rhs.z, <T as One>::one())).xyz()
996    }
997}
998
999pub trait Matrix4Extension<T: Scalar> {
1000    fn mat3(&self) -> Matrix3<T>;
1001}
1002
1003impl<T: Scalar> Matrix4Extension<T> for Matrix4<T> {
1004    fn mat3(&self) -> Matrix3<T> {
1005        Matrix3::new(
1006            self.col[0].x,
1007            self.col[0].y,
1008            self.col[0].z,
1009            self.col[1].x,
1010            self.col[1].y,
1011            self.col[1].z,
1012            self.col[2].x,
1013            self.col[2].y,
1014            self.col[2].z,
1015        )
1016    }
1017}
1018
1019#[cfg(test)]
1020mod tests {
1021    use super::*;
1022    use crate::vector::*;
1023
1024    #[test]
1025    fn test_matrix2_identity() {
1026        let m = Matrix2::<f32>::identity();
1027        assert_eq!(m.col[0].x, 1.0);
1028        assert_eq!(m.col[0].y, 0.0);
1029        assert_eq!(m.col[1].x, 0.0);
1030        assert_eq!(m.col[1].y, 1.0);
1031    }
1032
1033    #[test]
1034    fn test_matrix2_determinant() {
1035        let m = Matrix2::<f32>::new(1.0, 2.0, 3.0, 4.0);
1036        let det = m.determinant();
1037        assert_eq!(det, -2.0); // 1*4 - 3*2 = -2
1038        
1039        // Test singular matrix
1040        let m_singular = Matrix2::<f32>::new(1.0, 2.0, 2.0, 4.0);
1041        let det_singular = m_singular.determinant();
1042        assert_eq!(det_singular, 0.0);
1043    }
1044
1045    #[test]
1046    fn test_matrix2_inverse() {
1047        let m = Matrix2::<f32>::new(1.0, 2.0, 3.0, 4.0);
1048        let m_inv = m.inverse();
1049        let product = Matrix2::mul_matrix_matrix(&m, &m_inv);
1050        
1051        // Check if product is identity
1052        assert!((product.col[0].x - 1.0).abs() < 0.001);
1053        assert!((product.col[0].y).abs() < 0.001);
1054        assert!((product.col[1].x).abs() < 0.001);
1055        assert!((product.col[1].y - 1.0).abs() < 0.001);
1056    }
1057
1058    #[test]
1059    fn test_matrix2_transpose() {
1060        let m = Matrix2::<f32>::new(1.0, 2.0, 3.0, 4.0);
1061        let mt = m.transpose();
1062        assert_eq!(mt.col[0].x, 1.0);
1063        assert_eq!(mt.col[0].y, 3.0);
1064        assert_eq!(mt.col[1].x, 2.0);
1065        assert_eq!(mt.col[1].y, 4.0);
1066        
1067        // Transpose of transpose should be original
1068        let mtt = mt.transpose();
1069        assert_eq!(mtt.col[0].x, m.col[0].x);
1070        assert_eq!(mtt.col[0].y, m.col[0].y);
1071        assert_eq!(mtt.col[1].x, m.col[1].x);
1072        assert_eq!(mtt.col[1].y, m.col[1].y);
1073    }
1074
1075    #[test]
1076    fn test_matrix3_identity() {
1077        let m = Matrix3::<f32>::identity();
1078        assert_eq!(m.col[0].x, 1.0);
1079        assert_eq!(m.col[1].y, 1.0);
1080        assert_eq!(m.col[2].z, 1.0);
1081        assert_eq!(m.col[0].y, 0.0);
1082        assert_eq!(m.col[0].z, 0.0);
1083    }
1084
1085    #[test]
1086    fn test_matrix3_determinant() {
1087        let m = Matrix3::<f32>::new(
1088            1.0, 0.0, 0.0,
1089            0.0, 1.0, 0.0,
1090            0.0, 0.0, 1.0
1091        );
1092        assert_eq!(m.determinant(), 1.0);
1093        
1094        let m2 = Matrix3::<f32>::new(
1095            2.0, 3.0, 1.0,
1096            1.0, 0.0, 2.0,
1097            1.0, 2.0, 1.0
1098        );
1099        let det = m2.determinant();
1100        assert!((det - -3.0).abs() < 0.001);
1101    }
1102
1103    #[test]
1104    fn test_matrix3_inverse() {
1105        let m = Matrix3::<f32>::new(
1106            2.0, 3.0, 1.0,
1107            1.0, 0.0, 2.0,
1108            1.0, 2.0, 1.0
1109        );
1110        let m_inv = m.inverse();
1111        let product = Matrix3::mul_matrix_matrix(&m, &m_inv);
1112        
1113        // Check if product is close to identity
1114        for i in 0..3 {
1115            for j in 0..3 {
1116                let val = match (i, j) {
1117                    (0, 0) => product.col[0].x,
1118                    (1, 0) => product.col[0].y,
1119                    (2, 0) => product.col[0].z,
1120                    (0, 1) => product.col[1].x,
1121                    (1, 1) => product.col[1].y,
1122                    (2, 1) => product.col[1].z,
1123                    (0, 2) => product.col[2].x,
1124                    (1, 2) => product.col[2].y,
1125                    (2, 2) => product.col[2].z,
1126                    _ => 0.0,
1127                };
1128                let expected = if i == j { 1.0 } else { 0.0 };
1129                assert!((val - expected).abs() < 0.001);
1130            }
1131        }
1132    }
1133
1134    #[test]
1135    fn test_matrix4_identity() {
1136        let m = Matrix4::<f32>::identity();
1137        for i in 0..4 {
1138            for j in 0..4 {
1139                let val = match j {
1140                    0 => match i {
1141                        0 => m.col[0].x,
1142                        1 => m.col[0].y,
1143                        2 => m.col[0].z,
1144                        3 => m.col[0].w,
1145                        _ => 0.0,
1146                    },
1147                    1 => match i {
1148                        0 => m.col[1].x,
1149                        1 => m.col[1].y,
1150                        2 => m.col[1].z,
1151                        3 => m.col[1].w,
1152                        _ => 0.0,
1153                    },
1154                    2 => match i {
1155                        0 => m.col[2].x,
1156                        1 => m.col[2].y,
1157                        2 => m.col[2].z,
1158                        3 => m.col[2].w,
1159                        _ => 0.0,
1160                    },
1161                    3 => match i {
1162                        0 => m.col[3].x,
1163                        1 => m.col[3].y,
1164                        2 => m.col[3].z,
1165                        3 => m.col[3].w,
1166                        _ => 0.0,
1167                    },
1168                    _ => 0.0,
1169                };
1170                let expected = if i == j { 1.0 } else { 0.0 };
1171                assert_eq!(val, expected);
1172            }
1173        }
1174    }
1175
1176    #[test]
1177    fn test_matrix_vector_multiplication() {
1178        // Test Matrix2 * Vector2
1179        let m2 = Matrix2::<f32>::new(1.0, 2.0, 3.0, 4.0);
1180        let v2 = Vector2::<f32>::new(5.0, 6.0);
1181        let result2 = m2 * v2;
1182        assert_eq!(result2.x, 23.0); // 1*5 + 3*6 = 23
1183        assert_eq!(result2.y, 34.0); // 2*5 + 4*6 = 34
1184        
1185        // Test Matrix3 * Vector3
1186        let m3 = Matrix3::<f32>::identity();
1187        let v3 = Vector3::<f32>::new(1.0, 2.0, 3.0);
1188        let result3 = m3 * v3;
1189        assert_eq!(result3.x, 1.0);
1190        assert_eq!(result3.y, 2.0);
1191        assert_eq!(result3.z, 3.0);
1192        
1193        // Test Matrix4 * Vector4
1194        let m4 = Matrix4::<f32>::identity();
1195        let v4 = Vector4::<f32>::new(1.0, 2.0, 3.0, 4.0);
1196        let result4 = m4 * v4;
1197        assert_eq!(result4.x, 1.0);
1198        assert_eq!(result4.y, 2.0);
1199        assert_eq!(result4.z, 3.0);
1200        assert_eq!(result4.w, 4.0);
1201    }
1202
1203    #[test]
1204    fn test_matrix_multiplication() {
1205        // Test associativity: (A * B) * C == A * (B * C)
1206        let a = Matrix2::<f32>::new(1.0, 2.0, 3.0, 4.0);
1207        let b = Matrix2::<f32>::new(5.0, 6.0, 7.0, 8.0);
1208        let c = Matrix2::<f32>::new(9.0, 10.0, 11.0, 12.0);
1209        
1210        let left = (a * b) * c;
1211        let right = a * (b * c);
1212        
1213        assert!((left.col[0].x - right.col[0].x).abs() < 0.001);
1214        assert!((left.col[0].y - right.col[0].y).abs() < 0.001);
1215        assert!((left.col[1].x - right.col[1].x).abs() < 0.001);
1216        assert!((left.col[1].y - right.col[1].y).abs() < 0.001);
1217    }
1218
1219    #[test]
1220    fn test_matrix_addition_subtraction() {
1221        let m1 = Matrix2::<f32>::new(1.0, 2.0, 3.0, 4.0);
1222        let m2 = Matrix2::<f32>::new(5.0, 6.0, 7.0, 8.0);
1223        
1224        let sum = m1 + m2;
1225        assert_eq!(sum.col[0].x, 6.0);
1226        assert_eq!(sum.col[0].y, 8.0);
1227        assert_eq!(sum.col[1].x, 10.0);
1228        assert_eq!(sum.col[1].y, 12.0);
1229        
1230        let diff = m2 - m1;
1231        assert_eq!(diff.col[0].x, 4.0);
1232        assert_eq!(diff.col[0].y, 4.0);
1233        assert_eq!(diff.col[1].x, 4.0);
1234        assert_eq!(diff.col[1].y, 4.0);
1235    }
1236
1237    #[test]
1238    fn test_matrix4_inverse() {
1239        // Test with a known invertible matrix
1240        let m = Matrix4::<f32>::new(
1241            2.0, 0.0, 0.0, 0.0,
1242            0.0, 1.0, 0.0, 0.0,
1243            0.0, 0.0, 0.5, 0.0,
1244            1.0, 2.0, 3.0, 1.0
1245        );
1246        
1247        let m_inv = m.inverse();
1248        let product = m * m_inv;
1249        
1250        // Check if product is close to identity
1251        for i in 0..4 {
1252            for j in 0..4 {
1253                let expected = if i == j { 1.0 } else { 0.0 };
1254                let val = match j {
1255                    0 => match i {
1256                        0 => product.col[0].x,
1257                        1 => product.col[0].y,
1258                        2 => product.col[0].z,
1259                        3 => product.col[0].w,
1260                        _ => 0.0,
1261                    },
1262                    1 => match i {
1263                        0 => product.col[1].x,
1264                        1 => product.col[1].y,
1265                        2 => product.col[1].z,
1266                        3 => product.col[1].w,
1267                        _ => 0.0,
1268                    },
1269                    2 => match i {
1270                        0 => product.col[2].x,
1271                        1 => product.col[2].y,
1272                        2 => product.col[2].z,
1273                        3 => product.col[2].w,
1274                        _ => 0.0,
1275                    },
1276                    3 => match i {
1277                        0 => product.col[3].x,
1278                        1 => product.col[3].y,
1279                        2 => product.col[3].z,
1280                        3 => product.col[3].w,
1281                        _ => 0.0,
1282                    },
1283                    _ => 0.0,
1284                };
1285                assert!((val - expected).abs() < 0.001, 
1286                    "Matrix inverse failed at [{}, {}]: expected {}, got {}", i, j, expected, val);
1287            }
1288        }
1289    }
1290}