lambda/math/
matrix.rs

1//! Matrix math types and functions.
2
3use lambda_platform::rand::get_uniformly_random_floats_between;
4
5use super::{
6  turns_to_radians,
7  vector::Vector,
8};
9
10// -------------------------------- MATRIX -------------------------------------
11
12/// Matrix trait which defines the basic operations that can be performed on a
13/// matrix. Lambda currently implements this trait for f32 arrays of arrays
14/// for any size.
15pub trait Matrix<V: Vector> {
16  fn add(&self, other: &Self) -> Self;
17  fn subtract(&self, other: &Self) -> Self;
18  fn multiply(&self, other: &Self) -> Self;
19  fn transpose(&self) -> Self;
20  fn inverse(&self) -> Self;
21  fn transform(&self, other: &V) -> V;
22  fn determinant(&self) -> f32;
23  fn size(&self) -> (usize, usize);
24  fn row(&self, row: usize) -> &V;
25  fn at(&self, row: usize, column: usize) -> V::Scalar;
26  fn update(&mut self, row: usize, column: usize, value: V::Scalar);
27}
28
29// -------------------------------- FUNCTIONS ----------------------------------
30
31/// Obtain the submatrix of the input matrix starting from the given row &
32/// column.
33pub fn submatrix<V: Vector<Scalar = f32>, MatrixLike: Matrix<V>>(
34  matrix: MatrixLike,
35  row: usize,
36  column: usize,
37) -> Vec<Vec<V::Scalar>> {
38  let mut submatrix = Vec::new();
39  let (rows, columns) = matrix.size();
40
41  for k in 0..rows {
42    if k != row {
43      let mut row = Vec::new();
44      for l in 0..columns {
45        if l != column {
46          row.push(matrix.at(k, l));
47        }
48      }
49      submatrix.push(row);
50    }
51  }
52  return submatrix;
53}
54
55/// Creates a translation matrix with the given translation vector. The output vector
56pub fn translation_matrix<
57  InputVector: Vector<Scalar = f32>,
58  ResultingVector: Vector<Scalar = f32>,
59  OutputMatrix: Matrix<ResultingVector> + Default,
60>(
61  vector: InputVector,
62) -> OutputMatrix {
63  let mut result = OutputMatrix::default();
64  let (rows, columns) = result.size();
65  assert_eq!(
66    rows - 1,
67    vector.size(),
68    "Vector must contain one less element than the vectors of the input matrix"
69  );
70
71  for i in 0..rows {
72    for j in 0..columns {
73      if i == j {
74        result.update(i, j, 1.0);
75      } else if j == columns - 1 {
76        result.update(i, j, vector.at(i));
77      } else {
78        result.update(i, j, 0.0);
79      }
80    }
81  }
82
83  return result;
84}
85
86/// Rotates the input matrix by the given number of turns around the given axis.
87/// The axis must be a unit vector and the turns must be in the range [0, 1).
88/// The rotation is counter-clockwise when looking down the axis.
89pub fn rotate_matrix<
90  InputVector: Vector<Scalar = f32>,
91  ResultingVector: Vector<Scalar = f32>,
92  OutputMatrix: Matrix<ResultingVector> + Default + Clone,
93>(
94  matrix_to_rotate: OutputMatrix,
95  axis_to_rotate: InputVector,
96  angle_in_turns: f32,
97) -> OutputMatrix {
98  let (rows, columns) = matrix_to_rotate.size();
99  assert_eq!(rows, columns, "Matrix must be square");
100  assert_eq!(rows, 4, "Matrix must be 4x4");
101  assert_eq!(
102    axis_to_rotate.size(),
103    3,
104    "Axis vector must have 3 elements (x, y, z)"
105  );
106
107  let angle_in_radians = turns_to_radians(angle_in_turns);
108  let cosine_of_angle = angle_in_radians.cos();
109  let sin_of_angle = angle_in_radians.sin();
110
111  let t = 1.0 - cosine_of_angle;
112  let x = axis_to_rotate.at(0);
113  let y = axis_to_rotate.at(1);
114  let z = axis_to_rotate.at(2);
115
116  let mut rotation_matrix = OutputMatrix::default();
117
118  let rotation = match (x as u8, y as u8, z as u8) {
119    (0, 0, 0) => {
120      // No rotation
121      return matrix_to_rotate;
122    }
123    (0, 0, 1) => {
124      // Rotate around z-axis
125      [
126        [cosine_of_angle, sin_of_angle, 0.0, 0.0],
127        [-sin_of_angle, cosine_of_angle, 0.0, 0.0],
128        [0.0, 0.0, 1.0, 0.0],
129        [0.0, 0.0, 0.0, 1.0],
130      ]
131    }
132    (0, 1, 0) => {
133      // Rotate around y-axis
134      [
135        [cosine_of_angle, 0.0, -sin_of_angle, 0.0],
136        [0.0, 1.0, 0.0, 0.0],
137        [sin_of_angle, 0.0, cosine_of_angle, 0.0],
138        [0.0, 0.0, 0.0, 1.0],
139      ]
140    }
141    (1, 0, 0) => {
142      // Rotate around x-axis
143      [
144        [1.0, 0.0, 0.0, 0.0],
145        [0.0, cosine_of_angle, sin_of_angle, 0.0],
146        [0.0, -sin_of_angle, cosine_of_angle, 0.0],
147        [0.0, 0.0, 0.0, 1.0],
148      ]
149    }
150    _ => {
151      panic!("Axis must be a unit vector")
152    }
153  };
154
155  for i in 0..rows {
156    for j in 0..columns {
157      rotation_matrix.update(i, j, rotation[i][j]);
158    }
159  }
160
161  return matrix_to_rotate.multiply(&rotation_matrix);
162}
163
164/// Creates a 4x4 perspective matrix given the fov in turns (unit between
165/// 0..2pi radians), aspect ratio, near clipping plane (also known as z_near),
166/// and far clipping plane (also known as z_far). Enforces that the matrix being
167/// created is square in both debug and release builds, but only enforces that
168/// the output matrix is 4x4 in debug builds.
169pub fn perspective_matrix<
170  V: Vector<Scalar = f32>,
171  MatrixLike: Matrix<V> + Default,
172>(
173  fov: V::Scalar,
174  aspect_ratio: V::Scalar,
175  near_clipping_plane: V::Scalar,
176  far_clipping_plane: V::Scalar,
177) -> MatrixLike {
178  let mut result = MatrixLike::default();
179  let (rows, columns) = result.size();
180  assert_eq!(
181    rows, columns,
182    "Matrix must be square to be a perspective matrix"
183  );
184  debug_assert_eq!(rows, 4, "Matrix must be 4x4 to be a perspective matrix");
185  let fov_in_radians = turns_to_radians(fov);
186  let f = 1.0 / (fov_in_radians / 2.0).tan();
187  let range = near_clipping_plane - far_clipping_plane;
188
189  result.update(0, 0, f / aspect_ratio);
190  result.update(1, 1, f);
191  result.update(2, 2, (near_clipping_plane + far_clipping_plane) / range);
192  result.update(2, 3, -1.0);
193  result.update(
194    3,
195    2,
196    (2.0 * near_clipping_plane * far_clipping_plane) / range,
197  );
198
199  return result;
200}
201
202/// Create a matrix of any size that is filled with zeros.
203pub fn zeroed_matrix<
204  V: Vector<Scalar = f32>,
205  MatrixLike: Matrix<V> + Default,
206>(
207  rows: usize,
208  columns: usize,
209) -> MatrixLike {
210  let mut result = MatrixLike::default();
211  for i in 0..rows {
212    for j in 0..columns {
213      result.update(i, j, 0.0);
214    }
215  }
216  return result;
217}
218
219/// Creates a new matrix with the given number of rows and columns, and fills it
220/// with the given value.
221pub fn filled_matrix<
222  V: Vector<Scalar = f32>,
223  MatrixLike: Matrix<V> + Default,
224>(
225  rows: usize,
226  columns: usize,
227  value: V::Scalar,
228) -> MatrixLike {
229  let mut result = MatrixLike::default();
230  for i in 0..rows {
231    for j in 0..columns {
232      result.update(i, j, value);
233    }
234  }
235  return result;
236}
237
238/// Creates an identity matrix of the given size.
239pub fn identity_matrix<
240  V: Vector<Scalar = f32>,
241  MatrixLike: Matrix<V> + Default,
242>(
243  rows: usize,
244  columns: usize,
245) -> MatrixLike {
246  assert_eq!(
247    rows, columns,
248    "Matrix must be square to be an identity matrix"
249  );
250  let mut result = MatrixLike::default();
251  for i in 0..rows {
252    for j in 0..columns {
253      if i == j {
254        result.update(i, j, 1.0);
255      } else {
256        result.update(i, j, 0.0);
257      }
258    }
259  }
260  return result;
261}
262
263// -------------------------- ARRAY IMPLEMENTATION -----------------------------
264
265/// Matrix implementations for arrays of f32 arrays. Including the trait Matrix into
266/// your code will allow you to use these function implementation for any array
267/// of f32 arrays.
268impl<Array, V> Matrix<V> for Array
269where
270  Array: AsMut<[V]> + AsRef<[V]> + Default,
271  V: AsMut<[f32]> + AsRef<[f32]> + Vector<Scalar = f32> + Sized,
272{
273  fn add(&self, other: &Self) -> Self {
274    let mut result = Self::default();
275    for (i, (a, b)) in
276      self.as_ref().iter().zip(other.as_ref().iter()).enumerate()
277    {
278      result.as_mut()[i] = a.add(b);
279    }
280    return result;
281  }
282
283  fn subtract(&self, other: &Self) -> Self {
284    let mut result = Self::default();
285
286    for (i, (a, b)) in
287      self.as_ref().iter().zip(other.as_ref().iter()).enumerate()
288    {
289      result.as_mut()[i] = a.subtract(b);
290    }
291    return result;
292  }
293
294  fn multiply(&self, other: &Self) -> Self {
295    let mut result = Self::default();
296
297    // We transpose the other matrix to convert the columns into rows, allowing
298    // us to compute the new values of each index using the dot product
299    // function.
300    let transposed = other.transpose();
301
302    for (i, a) in self.as_ref().iter().enumerate() {
303      for (j, b) in transposed.as_ref().iter().enumerate() {
304        result.update(i, j, a.dot(b));
305      }
306    }
307    return result;
308  }
309
310  /// Transposes the matrix, swapping the rows and columns.
311  fn transpose(&self) -> Self {
312    let mut result = Self::default();
313    for (i, a) in self.as_ref().iter().enumerate() {
314      for j in 0..a.as_ref().len() {
315        result.update(i, j, self.at(j, i));
316      }
317    }
318    return result;
319  }
320
321  fn inverse(&self) -> Self {
322    todo!()
323  }
324
325  fn transform(&self, other: &V) -> V {
326    todo!()
327  }
328
329  /// Computes the determinant of any square matrix using Laplace expansion.
330  fn determinant(&self) -> f32 {
331    let (width, height) =
332      (self.as_ref()[0].as_ref().len(), self.as_ref().len());
333
334    if width != height {
335      panic!("Cannot compute determinant of non-square matrix");
336    }
337
338    return match height {
339      1 => self.as_ref()[0].as_ref()[0],
340      2 => {
341        let a = self.at(0, 0);
342        let b = self.at(0, 1);
343        let c = self.at(1, 0);
344        let d = self.at(1, 1);
345        a * d - b * c
346      }
347      _ => {
348        let mut result = 0.0;
349        for i in 0..height {
350          let mut submatrix: Vec<Vec<f32>> = Vec::with_capacity(height - 1);
351          for j in 1..height {
352            let mut row = Vec::new();
353            for k in 0..height {
354              if k != i {
355                row.push(self.at(j, k));
356              }
357            }
358            submatrix.push(row);
359          }
360          result += self.at(0, i)
361            * submatrix.determinant()
362            * (-1.0 as f32).powi(i as i32);
363        }
364        result
365      }
366    };
367  }
368
369  /// Return the size as a (rows, columns).
370  fn size(&self) -> (usize, usize) {
371    return (self.as_ref().len(), self.row(0).size());
372  }
373
374  /// Return a reference to the row.
375  fn row(&self, row: usize) -> &V {
376    return &self.as_ref()[row];
377  }
378
379  ///
380  fn at(&self, row: usize, column: usize) -> <V as Vector>::Scalar {
381    return self.as_ref()[row].as_ref()[column];
382  }
383
384  fn update(&mut self, row: usize, column: usize, new_value: V::Scalar) {
385    self.as_mut()[row].as_mut()[column] = new_value;
386  }
387}
388
389// ---------------------------------- TESTS ------------------------------------
390
391#[cfg(test)]
392mod tests {
393
394  use super::{
395    filled_matrix,
396    perspective_matrix,
397    rotate_matrix,
398    submatrix,
399    Matrix,
400  };
401  use crate::math::{
402    matrix::translation_matrix,
403    turns_to_radians,
404  };
405
406  #[test]
407  fn square_matrix_add() {
408    let a = [[1.0, 2.0], [3.0, 4.0]];
409    let b = [[5.0, 6.0], [7.0, 8.0]];
410    let c = a.add(&b);
411    assert_eq!(c, [[6.0, 8.0], [10.0, 12.0]]);
412  }
413
414  #[test]
415  fn square_matrix_subtract() {
416    let a = [[1.0, 2.0], [3.0, 4.0]];
417    let b = [[5.0, 6.0], [7.0, 8.0]];
418    let c = a.subtract(&b);
419    assert_eq!(c, [[-4.0, -4.0], [-4.0, -4.0]]);
420  }
421
422  #[test]
423  // Test square matrix multiplication.
424  fn square_matrix_multiply() {
425    let m1 = [[1.0, 2.0], [3.0, 4.0]];
426    let m2 = [[2.0, 0.0], [1.0, 2.0]];
427
428    let mut result = m1.multiply(&m2);
429    assert_eq!(result, [[4.0, 4.0], [10.0, 8.0]]);
430
431    result = m2.multiply(&m1);
432    assert_eq!(result, [[2.0, 4.0], [7.0, 10.0]])
433  }
434
435  #[test]
436  fn transpose_square_matrix() {
437    let m = [[1.0, 2.0], [5.0, 6.0]];
438    let t = m.transpose();
439    assert_eq!(t, [[1.0, 5.0], [2.0, 6.0]]);
440  }
441
442  #[test]
443  fn square_matrix_determinant() {
444    let m = [[3.0, 8.0], [4.0, 6.0]];
445    assert_eq!(m.determinant(), -14.0);
446
447    let m2 = [[6.0, 1.0, 1.0], [4.0, -2.0, 5.0], [2.0, 8.0, 7.0]];
448    assert_eq!(m2.determinant(), -306.0);
449  }
450
451  #[test]
452  fn non_square_matrix_determinant() {
453    let m = [[3.0, 8.0], [4.0, 6.0], [0.0, 1.0]];
454    let result = std::panic::catch_unwind(|| m.determinant());
455    assert_eq!(false, result.is_ok());
456  }
457
458  #[test]
459  fn submatrix_for_matrix_array() {
460    let matrix = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
461
462    let expected_submatrix = vec![vec![2.0, 3.0], vec![8.0, 9.0]];
463    let actual_submatrix = submatrix(matrix, 1, 0);
464
465    assert_eq!(expected_submatrix, actual_submatrix);
466  }
467
468  #[test]
469  fn translate_matrix() {
470    let translation: [[f32; 3]; 3] = translation_matrix([56.0, 5.0]);
471    assert_eq!(
472      translation,
473      [[1.0, 0.0, 56.0], [0.0, 1.0, 5.0], [0.0, 0.0, 1.0]]
474    );
475
476    let translation: [[f32; 4]; 4] = translation_matrix([10.0, 2.0, 3.0]);
477    let expected = [
478      [1.0, 0.0, 0.0, 10.0],
479      [0.0, 1.0, 0.0, 2.0],
480      [0.0, 0.0, 1.0, 3.0],
481      [0.0, 0.0, 0.0, 1.0],
482    ];
483    assert_eq!(translation, expected);
484  }
485
486  #[test]
487  fn perspective_matrix_test() {
488    let perspective: [[f32; 4]; 4] =
489      perspective_matrix(1.0 / 4.0, 1.0, 1.0, 0.0);
490
491    // Compute the field of view values used by the perspective matrix by hand.
492    let fov_radians = turns_to_radians(1.0 / 4.0);
493    let f = 1.0 / (fov_radians / 2.0).tan();
494
495    let expected: [[f32; 4]; 4] = [
496      [f, 0.0, 0.0, 0.0],
497      [0.0, f, 0.0, 0.0],
498      [0.0, 0.0, 1.0, -1.0],
499      [0.0, 0.0, 0.0, 0.0],
500    ];
501
502    assert_eq!(perspective, expected);
503  }
504
505  /// Test the rotation matrix for a 3D rotation.
506  #[test]
507  fn rotate_matrices() {
508    // Test a zero turn rotation.
509    let matrix: [[f32; 4]; 4] = filled_matrix(4, 4, 1.0);
510    let rotated_matrix = rotate_matrix(matrix, [0.0, 0.0, 1.0], 0.0);
511    assert_eq!(rotated_matrix, matrix);
512
513    // Test a 90 degree rotation.
514    let matrix = [
515      [1.0, 2.0, 3.0, 4.0],
516      [5.0, 6.0, 7.0, 8.0],
517      [9.0, 10.0, 11.0, 12.0],
518      [13.0, 14.0, 15.0, 16.0],
519    ];
520    let rotated = rotate_matrix(matrix, [0.0, 1.0, 0.0], 0.25);
521    let expected = [
522      [3.0, 1.9999999, -1.0000001, 4.0],
523      [7.0, 5.9999995, -5.0000005, 8.0],
524      [11.0, 9.999999, -9.000001, 12.0],
525      [14.999999, 13.999999, -13.000001, 16.0],
526    ];
527
528    for i in 0..4 {
529      for j in 0..4 {
530        crate::assert_approximately_equal!(
531          rotated.at(i, j),
532          expected.at(i, j),
533          0.1
534        );
535      }
536    }
537  }
538}