Skip to main content

naga_rust_rt/
matrix.rs

1use core::ops;
2
3use crate::{Scalar, Vec2, Vec3, Vec4};
4
5// -------------------------------------------------------------------------------------------------
6
7/// Generate a row vector from a matrix.
8/// Has to be a separate macro due to the inexpressiveness of macro repetitions.
9macro_rules! generate_row_expr {
10    ($self:ident, $vec_type:ident, $row_field:ident, [$($column_field:ident),*]) => {
11        $vec_type::new($( $self.$column_field.$row_field ),*)
12    }
13}
14
15/// Generate the body of a transpose() method.
16/// Has to be a separate macro due to the inexpressiveness of macro repetitions.
17macro_rules! transpose_body {
18    ($self:ident, $mat_type:ident, $vec_type:ident, [$($row_field:ident),*], $column_fields:tt) => {
19        $mat_type {
20            $(
21                $row_field: generate_row_expr!($self, $vec_type, $row_field, $column_fields),
22            )*
23        }
24    }
25}
26
27/// Implement `Index` and `IndexMut`.
28macro_rules! impl_index {
29    ($columns:literal, $mat_type:ident, $column_type:ident, $index_type:ty) => {
30        impl<T> ops::Index<$index_type> for $mat_type<T> {
31            type Output = $column_type<T>;
32
33            #[inline]
34            fn index(&self, index: $index_type) -> &Self::Output {
35                // manual bounds check because we need to convert to usize and we’d like to have
36                // only one panic branch rather than two
37                if (0..$columns).contains(&index) {
38                    &self.as_array_of_columns_ref()[index as usize]
39                } else {
40                    panic!("matrix indexing out of bounds")
41                }
42            }
43        }
44
45        impl<T> ops::IndexMut<$index_type> for $mat_type<T> {
46            #[inline]
47            fn index_mut(&mut self, index: $index_type) -> &mut Self::Output {
48                if (0..$columns).contains(&index) {
49                    &mut self.as_array_of_columns_mut()[index as usize]
50                } else {
51                    panic!("matrix indexing out of bounds")
52                }
53            }
54        }
55    };
56}
57
58macro_rules! matrix_struct {
59    ($columns:literal, $rows:literal, $column_type:ident, [$($column_field:ident),*], $row_type:ident, [$($row_field:ident),*]) => {
60        paste::paste! {
61            #[doc = concat!("Matrix with ", $columns, " columns and ", $rows, " rows.")]
62            ///
63            /// The matrix is stored column-major; that is, each field is a whole column of the matrix.
64            #[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
65            #[repr(C)]
66            pub struct [< Mat $columns x $rows >] <T> {
67                $( pub $column_field: $column_type<T>, )*
68            }
69
70            impl<T> [< Mat $columns x $rows >] <T> {
71                pub fn new($([< $column_field _column >]: $column_type<T>,)*) -> Self {
72                    Self { $($column_field: [< $column_field _column >],)* }
73                }
74
75                pub fn transpose(self) -> [< Mat $rows x $columns >] <T> {
76                    transpose_body!(
77                        self,
78                        [< Mat $rows x $columns >],
79                        $row_type,
80                        [$($row_field),*],
81                        [$($column_field),*]
82                    )
83                }
84
85                #[inline]
86                fn as_array_of_columns_ref(&self) -> &[$column_type<T>; $columns] {
87                    // Reinterpret the reference to self as a reference to an array.
88                    // SAFETY: Matrices are `repr(C)` and have the same elements as the array.
89                    unsafe { &*(&raw const *self).cast::<[$column_type<T>; $columns]>() }
90                }
91                #[inline]
92                fn as_array_of_columns_mut(&mut self) -> &mut [$column_type<T>; $columns] {
93                    // Reinterpret the reference to self as a reference to an array.
94                    // SAFETY: Matrices are `repr(C)` and have the same elements as the array.
95                    unsafe { &mut *(&raw mut *self).cast::<[$column_type<T>; $columns]>() }
96                }
97            }
98
99            // Indexing, by usize, i32, or u32, yields a column vector
100            impl_index!($columns, [< Mat $columns x $rows >], $column_type, usize);
101            impl_index!($columns, [< Mat $columns x $rows >], $column_type, i32);
102            impl_index!($columns, [< Mat $columns x $rows >], $column_type, u32);
103
104            impl<T> ops::Add for [< Mat $columns x $rows >]<T>
105            where
106                $column_type<T>: ops::Add<Output = $column_type<T>>,
107            {
108                type Output = Self;
109
110                /// Performs component-wise addition.
111                #[inline]
112                fn add(self, rhs: Self) -> Self::Output {
113                    Self::new(
114                        $( self.$column_field + rhs.$column_field ),*
115                    )
116                }
117            }
118
119            impl<T> ops::Sub for [< Mat $columns x $rows >]<T>
120            where
121                $column_type<T>: ops::Sub<Output = $column_type<T>>,
122            {
123                type Output = Self;
124
125                /// Performs component-wise subtraction.
126                #[inline]
127                fn sub(self, rhs: Self) -> Self::Output {
128                    Self::new(
129                        $( self.$column_field - rhs.$column_field ),*
130                    )
131                }
132            }
133
134            impl<T> ops::Mul<Scalar<T>> for [< Mat $columns x $rows >]<T>
135            where
136                $column_type<T>: ops::Mul<Scalar<T>, Output = $column_type<T>>,
137                T: Copy,
138            {
139                type Output = Self;
140
141                /// Performs component-wise multiplication by a scalar.
142                #[inline]
143                fn mul(self, rhs: Scalar<T>) -> Self::Output {
144                    Self::new(
145                        $( self.$column_field * rhs ),*
146                    )
147                }
148            }
149
150            impl<T> ops::Mul<[< Mat $columns x $rows >]<T>> for Scalar<T>
151            where
152                Scalar<T>: ops::Mul<$column_type<T>, Output = $column_type<T>>,
153                T: Copy,
154            {
155                type Output = [< Mat $columns x $rows >]<T>;
156
157                /// Performs component-wise multiplication by a scalar.
158                #[inline]
159                fn mul(self, rhs: [< Mat $columns x $rows >]<T>) -> Self::Output {
160                    Self::Output::new(
161                        $( self * rhs.$column_field ),*
162                    )
163                }
164            }
165
166            impl<T> ops::Mul<$row_type<T>> for [< Mat $columns x $rows >]<T>
167            where
168                // bounds copied from dot()
169                Scalar<T>: ops::Mul<Output = Scalar<T>> + num_traits::ConstZero,
170                T: Copy,
171            {
172                type Output = $column_type<T>;
173
174                /// Multiplication with matrix on the left and vector on the right.
175                #[inline]
176                fn mul(self, rhs: $row_type<T>) -> Self::Output {
177                    let t = self.transpose();
178                    $column_type::from_scalars(
179                        // dot product of LHS rows with RHS column
180                        $( t.$row_field.dot(rhs) ),*
181                    )
182                }
183            }
184
185            impl<T> ops::Mul<[< Mat $columns x $rows >]<T>> for $column_type<T>
186            where
187                Scalar<T>: ops::Mul<Output = Scalar<T>> + num_traits::ConstZero,
188                T: Copy,
189            {
190                type Output = $row_type<T>;
191
192                /// Multiplication with vector on the left and matrix on the right.
193                #[inline]
194                fn mul(self, rhs: [< Mat $columns x $rows >]<T>) -> Self::Output {
195                    $row_type::from_scalars(
196                        // dot product of LHS row with RHS columns
197                        $( self.dot(rhs.$column_field) ),*
198                    )
199                }
200            }
201        }
202    }
203}
204
205matrix_struct!(2, 2, Vec2, [x, y], Vec2, [x, y]);
206matrix_struct!(2, 3, Vec3, [x, y], Vec2, [x, y, z]);
207matrix_struct!(2, 4, Vec4, [x, y], Vec2, [x, y, z, w]);
208matrix_struct!(3, 2, Vec2, [x, y, z], Vec3, [x, y]);
209matrix_struct!(3, 3, Vec3, [x, y, z], Vec3, [x, y, z]);
210matrix_struct!(3, 4, Vec4, [x, y, z], Vec3, [x, y, z, w]);
211matrix_struct!(4, 2, Vec2, [x, y, z, w], Vec4, [x, y]);
212matrix_struct!(4, 3, Vec3, [x, y, z, w], Vec4, [x, y, z]);
213matrix_struct!(4, 4, Vec4, [x, y, z, w], Vec4, [x, y, z, w]);
214
215// -------------------------------------------------------------------------------------------------
216
217macro_rules! matrix_multiply {
218    (
219        $rows:literal,
220        $columns:literal,
221        $common:literal,
222        [$($column_field:ident),*]
223    ) => {
224        paste::paste! {
225            impl ops::Mul<[< Mat $columns x $common >]<f32>> for [< Mat $common x $rows >]<f32> {
226                type Output = [< Mat $columns x $rows >]<f32>;
227
228                /// Performs matrix multiplication.
229                #[inline]
230                fn mul(self, rhs: [< Mat $columns x $common >]<f32>) -> Self::Output {
231                    [< Mat $columns x $rows >]::new(
232                        $( self * rhs.$column_field ),*
233                    )
234                }
235            }
236
237        }
238    }
239}
240
241matrix_multiply!(2, 2, 2, [x, y]);
242matrix_multiply!(2, 2, 3, [x, y]);
243matrix_multiply!(2, 2, 4, [x, y]);
244matrix_multiply!(2, 3, 2, [x, y, z]);
245matrix_multiply!(2, 3, 3, [x, y, z]);
246matrix_multiply!(2, 3, 4, [x, y, z]);
247matrix_multiply!(2, 4, 2, [x, y, z, w]);
248matrix_multiply!(2, 4, 3, [x, y, z, w]);
249matrix_multiply!(2, 4, 4, [x, y, z, w]);
250matrix_multiply!(3, 2, 2, [x, y]);
251matrix_multiply!(3, 2, 3, [x, y]);
252matrix_multiply!(3, 2, 4, [x, y]);
253matrix_multiply!(3, 3, 2, [x, y, z]);
254matrix_multiply!(3, 3, 3, [x, y, z]);
255matrix_multiply!(3, 3, 4, [x, y, z]);
256matrix_multiply!(3, 4, 2, [x, y, z, w]);
257matrix_multiply!(3, 4, 3, [x, y, z, w]);
258matrix_multiply!(3, 4, 4, [x, y, z, w]);
259matrix_multiply!(4, 2, 2, [x, y]);
260matrix_multiply!(4, 2, 3, [x, y]);
261matrix_multiply!(4, 2, 4, [x, y]);
262matrix_multiply!(4, 3, 2, [x, y, z]);
263matrix_multiply!(4, 3, 3, [x, y, z]);
264matrix_multiply!(4, 3, 4, [x, y, z]);
265matrix_multiply!(4, 4, 2, [x, y, z, w]);
266matrix_multiply!(4, 4, 3, [x, y, z, w]);
267matrix_multiply!(4, 4, 4, [x, y, z, w]);