1#[cfg(not(feature = "std"))]
19use alloc::vec::Vec;
20
21use crate::Scalar;
22
23pub trait Vector<S: Scalar>: Clone + Sized {
41 fn zeros(len: usize) -> Self;
43
44 fn fill(len: usize, value: S) -> Self;
46
47 fn from_slice(data: &[S]) -> Self;
49
50 fn len(&self) -> usize;
52
53 #[inline]
55 fn is_empty(&self) -> bool {
56 self.len() == 0
57 }
58
59 fn get(&self, i: usize) -> S;
61
62 fn set(&mut self, i: usize, value: S);
64
65 fn get_mut(&mut self, i: usize) -> &mut S;
67
68 fn as_slice(&self) -> &[S];
70
71 fn as_mut_slice(&mut self) -> &mut [S];
73
74 fn copy_from(&mut self, other: &Self);
76
77 fn axpy(&mut self, a: S, x: &Self);
83
84 fn axpby(&mut self, a: S, x: &Self, b: S);
86
87 fn dot(&self, other: &Self) -> S;
89
90 fn scale(&mut self, a: S);
92
93 #[inline]
97 fn norm2(&self) -> S {
98 self.dot(self).sqrt()
99 }
100
101 fn norm_inf(&self) -> S;
103
104 fn norm1(&self) -> S;
106
107 fn weighted_rms_norm(&self, weights: &Self) -> S {
110 let n = S::from_usize(self.len());
111 let mut sum = S::ZERO;
112 for i in 0..self.len() {
113 let xi = self.get(i) / weights.get(i);
114 sum += xi * xi;
115 }
116 (sum / n).sqrt()
117 }
118
119 fn abs_inplace(&mut self);
123
124 fn max_elementwise(&mut self, other: &Self);
126
127 fn min_elementwise(&mut self, other: &Self);
129
130 fn sum(&self) -> S;
134
135 fn max_element(&self) -> S;
137
138 fn min_element(&self) -> S;
140
141 fn map_inplace<F: Fn(S) -> S>(&mut self, f: F);
145}
146
147impl<S: Scalar> Vector<S> for Vec<S> {
152 #[inline]
153 fn zeros(len: usize) -> Self {
154 vec![S::ZERO; len]
155 }
156
157 #[inline]
158 fn fill(len: usize, value: S) -> Self {
159 vec![value; len]
160 }
161
162 #[inline]
163 fn from_slice(data: &[S]) -> Self {
164 data.to_vec()
165 }
166
167 #[inline]
168 fn len(&self) -> usize {
169 Vec::len(self)
170 }
171
172 #[inline]
173 fn get(&self, i: usize) -> S {
174 self[i]
175 }
176
177 #[inline]
178 fn set(&mut self, i: usize, value: S) {
179 self[i] = value;
180 }
181
182 #[inline]
183 fn get_mut(&mut self, i: usize) -> &mut S {
184 &mut self[i]
185 }
186
187 #[inline]
188 fn as_slice(&self) -> &[S] {
189 self
190 }
191
192 #[inline]
193 fn as_mut_slice(&mut self) -> &mut [S] {
194 self
195 }
196
197 #[inline]
198 fn copy_from(&mut self, other: &Self) {
199 self.copy_from_slice(other);
200 }
201
202 fn axpy(&mut self, a: S, x: &Self) {
203 debug_assert_eq!(self.len(), x.len());
204 for (yi, xi) in self.iter_mut().zip(x.iter()) {
205 *yi += a * *xi;
206 }
207 }
208
209 fn axpby(&mut self, a: S, x: &Self, b: S) {
210 debug_assert_eq!(self.len(), x.len());
211 for (yi, xi) in self.iter_mut().zip(x.iter()) {
212 *yi = a * *xi + b * *yi;
213 }
214 }
215
216 fn dot(&self, other: &Self) -> S {
217 debug_assert_eq!(self.len(), other.len());
218 self.iter()
219 .zip(other.iter())
220 .fold(S::ZERO, |acc, (a, b)| acc + *a * *b)
221 }
222
223 fn scale(&mut self, a: S) {
224 for x in self.iter_mut() {
225 *x *= a;
226 }
227 }
228
229 fn norm_inf(&self) -> S {
230 self.iter().fold(S::ZERO, |acc, x| acc.max(x.abs()))
231 }
232
233 fn norm1(&self) -> S {
234 self.iter().fold(S::ZERO, |acc, x| acc + x.abs())
235 }
236
237 fn abs_inplace(&mut self) {
238 for x in self.iter_mut() {
239 *x = x.abs();
240 }
241 }
242
243 fn max_elementwise(&mut self, other: &Self) {
244 debug_assert_eq!(self.len(), other.len());
245 for (yi, xi) in self.iter_mut().zip(other.iter()) {
246 *yi = yi.max(*xi);
247 }
248 }
249
250 fn min_elementwise(&mut self, other: &Self) {
251 debug_assert_eq!(self.len(), other.len());
252 for (yi, xi) in self.iter_mut().zip(other.iter()) {
253 *yi = yi.min(*xi);
254 }
255 }
256
257 fn sum(&self) -> S {
258 self.iter().fold(S::ZERO, |acc, x| acc + *x)
259 }
260
261 fn max_element(&self) -> S {
262 self.iter().fold(S::NEG_INFINITY, |acc, x| acc.max(*x))
263 }
264
265 fn min_element(&self) -> S {
266 self.iter().fold(S::INFINITY, |acc, x| acc.min(*x))
267 }
268
269 fn map_inplace<F: Fn(S) -> S>(&mut self, f: F) {
270 for x in self.iter_mut() {
271 *x = f(*x);
272 }
273 }
274}
275
276#[cfg(test)]
281mod tests {
282 use super::*;
283
284 #[test]
285 fn test_zeros() {
286 let v: Vec<f64> = Vector::zeros(5);
287 assert_eq!(v.len(), 5);
288 for x in &v {
289 assert_eq!(*x, 0.0);
290 }
291 }
292
293 #[test]
294 fn test_fill() {
295 let v: Vec<f64> = Vector::fill(3, 2.5);
296 assert_eq!(v, vec![2.5, 2.5, 2.5]);
297 }
298
299 #[test]
300 fn test_axpy() {
301 let x: Vec<f64> = vec![1.0, 2.0, 3.0];
302 let mut y: Vec<f64> = vec![4.0, 5.0, 6.0];
303 y.axpy(2.0, &x);
304 assert_eq!(y, vec![6.0, 9.0, 12.0]);
305 }
306
307 #[test]
308 fn test_axpby() {
309 let x: Vec<f64> = vec![1.0, 2.0, 3.0];
310 let mut y: Vec<f64> = vec![4.0, 5.0, 6.0];
311 y.axpby(2.0, &x, 0.5);
313 assert!((y[0] - 4.0).abs() < 1e-10);
314 assert!((y[1] - 6.5).abs() < 1e-10);
315 assert!((y[2] - 9.0).abs() < 1e-10);
316 }
317
318 #[test]
319 fn test_dot() {
320 let x: Vec<f64> = vec![1.0, 2.0, 3.0];
321 let y: Vec<f64> = vec![4.0, 5.0, 6.0];
322 assert!((x.dot(&y) - 32.0).abs() < 1e-10);
324 }
325
326 #[test]
327 fn test_norm2() {
328 let v: Vec<f64> = vec![3.0, 4.0];
329 assert!((v.norm2() - 5.0).abs() < 1e-10);
330 }
331
332 #[test]
333 fn test_norm_inf() {
334 let v: Vec<f64> = vec![-5.0, 3.0, -1.0];
335 assert!((v.norm_inf() - 5.0).abs() < 1e-10);
336 }
337
338 #[test]
339 fn test_norm1() {
340 let v: Vec<f64> = vec![-1.0, 2.0, -3.0];
341 assert!((v.norm1() - 6.0).abs() < 1e-10);
342 }
343
344 #[test]
345 fn test_scale() {
346 let mut v: Vec<f64> = vec![1.0, 2.0, 3.0];
347 v.scale(2.0);
348 assert_eq!(v, vec![2.0, 4.0, 6.0]);
349 }
350
351 #[test]
352 fn test_sum() {
353 let v: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0];
354 assert!((v.sum() - 10.0).abs() < 1e-10);
355 }
356
357 #[test]
358 fn test_max_min_element() {
359 let v: Vec<f64> = vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0];
360 assert!((v.max_element() - 9.0).abs() < 1e-10);
361 assert!((v.min_element() - 1.0).abs() < 1e-10);
362 }
363
364 #[test]
365 fn test_weighted_rms_norm() {
366 let y: Vec<f64> = vec![2.0, 4.0];
369 let w: Vec<f64> = vec![1.0, 2.0];
370 assert!((y.weighted_rms_norm(&w) - 2.0).abs() < 1e-10);
371 }
372
373 #[test]
374 fn test_map_inplace() {
375 let mut v: Vec<f64> = vec![1.0, 4.0, 9.0];
376 v.map_inplace(|x| x.sqrt());
377 assert!((v[0] - 1.0).abs() < 1e-10);
378 assert!((v[1] - 2.0).abs() < 1e-10);
379 assert!((v[2] - 3.0).abs() < 1e-10);
380 }
381}