lambda/math/
vector.rs

1//! Vector math types and functions.
2
3/// Generalized Vector operations that can be implemented by any vector like
4/// type.
5pub trait Vector {
6  type Scalar: Copy;
7  fn add(&self, other: &Self) -> Self;
8  fn subtract(&self, other: &Self) -> Self;
9  fn scale(&self, scalar: Self::Scalar) -> Self;
10  fn dot(&self, other: &Self) -> Self::Scalar;
11  fn cross(&self, other: &Self) -> Self;
12  fn length(&self) -> Self::Scalar;
13  fn normalize(&self) -> Self;
14  fn size(&self) -> usize;
15  fn at(&self, index: usize) -> Self::Scalar;
16  fn update(&mut self, index: usize, value: Self::Scalar);
17}
18
19impl<T> Vector for T
20where
21  T: AsMut<[f32]> + AsRef<[f32]> + Default,
22{
23  type Scalar = f32;
24
25  /// Add two vectors of any size together.
26  fn add(&self, other: &Self) -> Self {
27    let mut result = Self::default();
28
29    self
30      .as_ref()
31      .iter()
32      .zip(other.as_ref().iter())
33      .enumerate()
34      .for_each(|(i, (a, b))| result.as_mut()[i] = a + b);
35
36    return result;
37  }
38
39  /// Subtract two vectors of any size.
40  fn subtract(&self, other: &Self) -> Self {
41    let mut result = Self::default();
42
43    self
44      .as_ref()
45      .iter()
46      .zip(other.as_ref().iter())
47      .enumerate()
48      .for_each(|(i, (a, b))| result.as_mut()[i] = a - b);
49
50    return result;
51  }
52
53  fn dot(&self, other: &Self) -> Self::Scalar {
54    assert_eq!(
55      self.as_ref().len(),
56      other.as_ref().len(),
57      "Vectors must be the same length"
58    );
59
60    let mut result = 0.0;
61    for (a, b) in self.as_ref().iter().zip(other.as_ref().iter()) {
62      result += a * b;
63    }
64    return result;
65  }
66
67  /// Cross product of two 3D vectors. Panics if the vectors are not 3D.
68  fn cross(&self, other: &Self) -> Self {
69    assert_eq!(
70      self.as_ref().len(),
71      other.as_ref().len(),
72      "Vectors must be the same length"
73    );
74
75    let mut result = Self::default();
76    let a = self.as_ref();
77    let b = other.as_ref();
78
79    // TODO: This is only for 3D vectors
80    match a.len() {
81      3 => {
82        result.as_mut()[0] = a[1] * b[2] - a[2] * b[1];
83        result.as_mut()[1] = a[2] * b[0] - a[0] * b[2];
84        result.as_mut()[2] = a[0] * b[1] - a[1] * b[0];
85      }
86      _ => {
87        panic!("Cross product is only defined for 3 dimensional vectors.")
88      }
89    }
90    return result;
91  }
92
93  fn length(&self) -> Self::Scalar {
94    let mut result = 0.0;
95    for a in self.as_ref().iter() {
96      result += a * a;
97    }
98    result.sqrt()
99  }
100
101  fn normalize(&self) -> Self {
102    assert_ne!(self.length(), 0.0, "Cannot normalize a zero length vector");
103    let mut result = Self::default();
104    let length = self.length();
105
106    self.as_ref().iter().enumerate().for_each(|(i, a)| {
107      result.as_mut()[i] = a / length;
108    });
109
110    return result;
111  }
112
113  fn scale(&self, scalar: Self::Scalar) -> Self {
114    let mut result = Self::default();
115    self.as_ref().iter().enumerate().for_each(|(i, a)| {
116      result.as_mut()[i] = a * scalar;
117    });
118
119    return result;
120  }
121
122  fn size(&self) -> usize {
123    return self.as_ref().len();
124  }
125
126  fn at(&self, index: usize) -> Self::Scalar {
127    return self.as_ref()[index];
128  }
129
130  fn update(&mut self, index: usize, value: Self::Scalar) {
131    self.as_mut()[index] = value;
132  }
133}
134
135#[cfg(test)]
136mod tests {
137  use super::Vector;
138
139  #[test]
140  fn adding_vectors() {
141    let a = [1.0, 2.0, 3.0];
142    let b = [4.0, 5.0, 6.0];
143    let c = [5.0, 7.0, 9.0];
144
145    let result = a.add(&b);
146
147    assert_eq!(result, c);
148  }
149
150  #[test]
151  fn subtracting_vectors() {
152    let a = [1.0, 2.0, 3.0];
153    let b = [4.0, 5.0, 6.0];
154    let c = [-3.0, -3.0, -3.0];
155
156    let result = a.subtract(&b);
157
158    assert_eq!(result, c);
159  }
160
161  #[test]
162  fn scaling_vectors() {
163    let a = [1.0, 2.0, 3.0];
164    let b = [2.0, 4.0, 6.0];
165    let scalar = 2.0;
166
167    let result = a.scale(scalar);
168    assert_eq!(result, b);
169  }
170
171  #[test]
172  fn dot_product() {
173    let a = [1.0, 2.0, 3.0];
174    let b = [4.0, 5.0, 6.0];
175    let c = 32.0;
176
177    let result = a.dot(&b);
178    assert_eq!(result, c);
179  }
180
181  #[test]
182  fn cross_product() {
183    let a = [1.0, 2.0, 3.0];
184    let b = [4.0, 5.0, 6.0];
185    let c = [-3.0, 6.0, -3.0];
186
187    let result = a.cross(&b);
188    assert_eq!(result, c);
189  }
190
191  #[test]
192  fn cross_product_fails_for_non_3d_vectors() {
193    let a = [1.0, 2.0];
194    let b = [4.0, 5.0];
195
196    let result = std::panic::catch_unwind(|| a.cross(&b));
197    assert!(result.is_err());
198  }
199
200  #[test]
201  fn length() {
202    let a = [1.0, 2.0, 3.0];
203    let b = 3.7416573867739413;
204
205    let result = a.length();
206    assert_eq!(result, b);
207
208    let c = [1.0, 2.0, 3.0, 4.0];
209    let d = 5.477225575051661;
210    let result = c.length();
211    assert_eq!(result, d);
212  }
213
214  #[test]
215  fn normalize() {
216    let a = [4.0, 3.0, 2.0];
217    let b = [0.74278135, 0.55708605, 0.37139067];
218    let result = a.normalize();
219    assert_eq!(result, b);
220  }
221
222  #[test]
223  fn normalize_fails_for_zero_length_vector() {
224    let a = [0.0, 0.0, 0.0];
225
226    let result = std::panic::catch_unwind(|| a.normalize());
227    assert!(result.is_err());
228  }
229
230  #[test]
231  fn scale() {
232    let a = [1.0, 2.0, 3.0];
233    let b = [2.0, 4.0, 6.0];
234    let scalar = 2.0;
235
236    let result = a.scale(scalar);
237    assert_eq!(result, b);
238  }
239}