1pub trait Vector {
6 type Scalar: Copy;
7 fn add(&self, other: &Self) -> Self;
8 fn subtract(&self, other: &Self) -> Self;
9 fn scale(&self, scalar: Self::Scalar) -> Self;
10 fn dot(&self, other: &Self) -> Self::Scalar;
11 fn cross(&self, other: &Self) -> Self;
12 fn length(&self) -> Self::Scalar;
13 fn normalize(&self) -> Self;
14 fn size(&self) -> usize;
15 fn at(&self, index: usize) -> Self::Scalar;
16 fn update(&mut self, index: usize, value: Self::Scalar);
17}
18
19impl<T> Vector for T
20where
21 T: AsMut<[f32]> + AsRef<[f32]> + Default,
22{
23 type Scalar = f32;
24
25 fn add(&self, other: &Self) -> Self {
27 let mut result = Self::default();
28
29 self
30 .as_ref()
31 .iter()
32 .zip(other.as_ref().iter())
33 .enumerate()
34 .for_each(|(i, (a, b))| result.as_mut()[i] = a + b);
35
36 return result;
37 }
38
39 fn subtract(&self, other: &Self) -> Self {
41 let mut result = Self::default();
42
43 self
44 .as_ref()
45 .iter()
46 .zip(other.as_ref().iter())
47 .enumerate()
48 .for_each(|(i, (a, b))| result.as_mut()[i] = a - b);
49
50 return result;
51 }
52
53 fn dot(&self, other: &Self) -> Self::Scalar {
54 assert_eq!(
55 self.as_ref().len(),
56 other.as_ref().len(),
57 "Vectors must be the same length"
58 );
59
60 let mut result = 0.0;
61 for (a, b) in self.as_ref().iter().zip(other.as_ref().iter()) {
62 result += a * b;
63 }
64 return result;
65 }
66
67 fn cross(&self, other: &Self) -> Self {
69 assert_eq!(
70 self.as_ref().len(),
71 other.as_ref().len(),
72 "Vectors must be the same length"
73 );
74
75 let mut result = Self::default();
76 let a = self.as_ref();
77 let b = other.as_ref();
78
79 match a.len() {
81 3 => {
82 result.as_mut()[0] = a[1] * b[2] - a[2] * b[1];
83 result.as_mut()[1] = a[2] * b[0] - a[0] * b[2];
84 result.as_mut()[2] = a[0] * b[1] - a[1] * b[0];
85 }
86 _ => {
87 panic!("Cross product is only defined for 3 dimensional vectors.")
88 }
89 }
90 return result;
91 }
92
93 fn length(&self) -> Self::Scalar {
94 let mut result = 0.0;
95 for a in self.as_ref().iter() {
96 result += a * a;
97 }
98 result.sqrt()
99 }
100
101 fn normalize(&self) -> Self {
102 assert_ne!(self.length(), 0.0, "Cannot normalize a zero length vector");
103 let mut result = Self::default();
104 let length = self.length();
105
106 self.as_ref().iter().enumerate().for_each(|(i, a)| {
107 result.as_mut()[i] = a / length;
108 });
109
110 return result;
111 }
112
113 fn scale(&self, scalar: Self::Scalar) -> Self {
114 let mut result = Self::default();
115 self.as_ref().iter().enumerate().for_each(|(i, a)| {
116 result.as_mut()[i] = a * scalar;
117 });
118
119 return result;
120 }
121
122 fn size(&self) -> usize {
123 return self.as_ref().len();
124 }
125
126 fn at(&self, index: usize) -> Self::Scalar {
127 return self.as_ref()[index];
128 }
129
130 fn update(&mut self, index: usize, value: Self::Scalar) {
131 self.as_mut()[index] = value;
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use super::Vector;
138
139 #[test]
140 fn adding_vectors() {
141 let a = [1.0, 2.0, 3.0];
142 let b = [4.0, 5.0, 6.0];
143 let c = [5.0, 7.0, 9.0];
144
145 let result = a.add(&b);
146
147 assert_eq!(result, c);
148 }
149
150 #[test]
151 fn subtracting_vectors() {
152 let a = [1.0, 2.0, 3.0];
153 let b = [4.0, 5.0, 6.0];
154 let c = [-3.0, -3.0, -3.0];
155
156 let result = a.subtract(&b);
157
158 assert_eq!(result, c);
159 }
160
161 #[test]
162 fn scaling_vectors() {
163 let a = [1.0, 2.0, 3.0];
164 let b = [2.0, 4.0, 6.0];
165 let scalar = 2.0;
166
167 let result = a.scale(scalar);
168 assert_eq!(result, b);
169 }
170
171 #[test]
172 fn dot_product() {
173 let a = [1.0, 2.0, 3.0];
174 let b = [4.0, 5.0, 6.0];
175 let c = 32.0;
176
177 let result = a.dot(&b);
178 assert_eq!(result, c);
179 }
180
181 #[test]
182 fn cross_product() {
183 let a = [1.0, 2.0, 3.0];
184 let b = [4.0, 5.0, 6.0];
185 let c = [-3.0, 6.0, -3.0];
186
187 let result = a.cross(&b);
188 assert_eq!(result, c);
189 }
190
191 #[test]
192 fn cross_product_fails_for_non_3d_vectors() {
193 let a = [1.0, 2.0];
194 let b = [4.0, 5.0];
195
196 let result = std::panic::catch_unwind(|| a.cross(&b));
197 assert!(result.is_err());
198 }
199
200 #[test]
201 fn length() {
202 let a = [1.0, 2.0, 3.0];
203 let b = 3.7416573867739413;
204
205 let result = a.length();
206 assert_eq!(result, b);
207
208 let c = [1.0, 2.0, 3.0, 4.0];
209 let d = 5.477225575051661;
210 let result = c.length();
211 assert_eq!(result, d);
212 }
213
214 #[test]
215 fn normalize() {
216 let a = [4.0, 3.0, 2.0];
217 let b = [0.74278135, 0.55708605, 0.37139067];
218 let result = a.normalize();
219 assert_eq!(result, b);
220 }
221
222 #[test]
223 fn normalize_fails_for_zero_length_vector() {
224 let a = [0.0, 0.0, 0.0];
225
226 let result = std::panic::catch_unwind(|| a.normalize());
227 assert!(result.is_err());
228 }
229
230 #[test]
231 fn scale() {
232 let a = [1.0, 2.0, 3.0];
233 let b = [2.0, 4.0, 6.0];
234 let scalar = 2.0;
235
236 let result = a.scale(scalar);
237 assert_eq!(result, b);
238 }
239}