1use serde::{Deserialize, Serialize};
4use std::ops::{Add, Index, IndexMut, Mul, Sub};
5
6#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
18pub struct Vector<T> {
19 data: Vec<T>,
20}
21
22impl<T: Copy> Vector<T> {
23 #[must_use]
25 pub fn from_slice(data: &[T]) -> Self {
26 Self {
27 data: data.to_vec(),
28 }
29 }
30
31 #[must_use]
33 pub fn from_vec(data: Vec<T>) -> Self {
34 Self { data }
35 }
36
37 #[must_use]
39 pub fn len(&self) -> usize {
40 self.data.len()
41 }
42
43 #[must_use]
45 pub fn is_empty(&self) -> bool {
46 self.data.is_empty()
47 }
48
49 #[must_use]
51 pub fn as_slice(&self) -> &[T] {
52 &self.data
53 }
54
55 pub fn as_mut_slice(&mut self) -> &mut [T] {
57 &mut self.data
58 }
59
60 #[must_use]
62 pub fn slice(&self, start: usize, end: usize) -> Self {
63 Self::from_slice(&self.data[start..end])
64 }
65}
66
67impl<T> Index<usize> for Vector<T> {
68 type Output = T;
69
70 fn index(&self, index: usize) -> &Self::Output {
71 &self.data[index]
72 }
73}
74
75impl<T> IndexMut<usize> for Vector<T> {
76 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
77 &mut self.data[index]
78 }
79}
80
81impl Vector<f32> {
82 #[must_use]
84 pub fn zeros(len: usize) -> Self {
85 Self {
86 data: vec![0.0; len],
87 }
88 }
89
90 #[must_use]
92 pub fn ones(len: usize) -> Self {
93 Self {
94 data: vec![1.0; len],
95 }
96 }
97
98 #[must_use]
100 pub fn sum(&self) -> f32 {
101 self.data.iter().sum()
102 }
103
104 #[must_use]
106 pub fn mean(&self) -> f32 {
107 if self.data.is_empty() {
108 return 0.0;
109 }
110 self.sum() / self.data.len() as f32
111 }
112
113 #[must_use]
119 pub fn dot(&self, other: &Self) -> f32 {
120 assert_eq!(
121 self.len(),
122 other.len(),
123 "Vector lengths must match for dot product"
124 );
125 self.data
126 .iter()
127 .zip(other.data.iter())
128 .map(|(a, b)| a * b)
129 .sum()
130 }
131
132 #[must_use]
134 pub fn add_scalar(&self, scalar: f32) -> Self {
135 Self {
136 data: self.data.iter().map(|x| x + scalar).collect(),
137 }
138 }
139
140 #[must_use]
142 pub fn mul_scalar(&self, scalar: f32) -> Self {
143 Self {
144 data: self.data.iter().map(|x| x * scalar).collect(),
145 }
146 }
147
148 #[must_use]
150 pub fn norm_squared(&self) -> f32 {
151 self.dot(self)
152 }
153
154 #[must_use]
156 pub fn norm(&self) -> f32 {
157 self.norm_squared().sqrt()
158 }
159
160 #[must_use]
162 pub fn argmin(&self) -> usize {
163 self.data
164 .iter()
165 .enumerate()
166 .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
167 .map_or(0, |(i, _)| i)
168 }
169
170 #[must_use]
172 pub fn argmax(&self) -> usize {
173 self.data
174 .iter()
175 .enumerate()
176 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
177 .map_or(0, |(i, _)| i)
178 }
179
180 #[must_use]
182 pub fn variance(&self) -> f32 {
183 if self.data.is_empty() {
184 return 0.0;
185 }
186 let mean = self.mean();
187 self.data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / self.data.len() as f32
188 }
189
190 #[must_use]
204 pub fn std(&self) -> f32 {
205 self.variance().sqrt()
206 }
207
208 #[must_use]
232 pub fn gini_coefficient(&self) -> f32 {
233 if self.data.is_empty() {
234 return 0.0;
235 }
236
237 let mean = self.mean();
238 if mean == 0.0 {
239 return 0.0;
240 }
241
242 let n = self.data.len() as f32;
243 let mut sum_abs_diff = 0.0;
244
245 for i in 0..self.data.len() {
246 for j in 0..self.data.len() {
247 sum_abs_diff += (self.data[i] - self.data[j]).abs();
248 }
249 }
250
251 sum_abs_diff / (2.0 * n * n * mean)
252 }
253}
254
255impl Add for &Vector<f32> {
256 type Output = Vector<f32>;
257
258 fn add(self, other: Self) -> Self::Output {
259 assert_eq!(
260 self.len(),
261 other.len(),
262 "Vector lengths must match for addition"
263 );
264 Vector {
265 data: self
266 .data
267 .iter()
268 .zip(other.data.iter())
269 .map(|(a, b)| a + b)
270 .collect(),
271 }
272 }
273}
274
275impl Sub for &Vector<f32> {
276 type Output = Vector<f32>;
277
278 fn sub(self, other: Self) -> Self::Output {
279 assert_eq!(
280 self.len(),
281 other.len(),
282 "Vector lengths must match for subtraction"
283 );
284 Vector {
285 data: self
286 .data
287 .iter()
288 .zip(other.data.iter())
289 .map(|(a, b)| a - b)
290 .collect(),
291 }
292 }
293}
294
295impl Mul for &Vector<f32> {
296 type Output = Vector<f32>;
297
298 fn mul(self, other: Self) -> Self::Output {
299 assert_eq!(
300 self.len(),
301 other.len(),
302 "Vector lengths must match for multiplication"
303 );
304 Vector {
305 data: self
306 .data
307 .iter()
308 .zip(other.data.iter())
309 .map(|(a, b)| a * b)
310 .collect(),
311 }
312 }
313}
314
315#[cfg(test)]
316mod tests {
317 use super::*;
318
319 #[test]
320 fn test_from_slice() {
321 let v = Vector::from_slice(&[1.0_f32, 2.0, 3.0]);
322 assert_eq!(v.len(), 3);
323 assert!((v[0] - 1.0).abs() < 1e-6);
324 }
325
326 #[test]
327 fn test_zeros() {
328 let v = Vector::<f32>::zeros(5);
329 assert_eq!(v.len(), 5);
330 assert!(v.as_slice().iter().all(|&x| x == 0.0));
331 }
332
333 #[test]
334 fn test_ones() {
335 let v = Vector::<f32>::ones(5);
336 assert_eq!(v.len(), 5);
337 assert!(v.as_slice().iter().all(|&x| (x - 1.0).abs() < 1e-6));
338 }
339
340 #[test]
341 fn test_sum() {
342 let v = Vector::from_slice(&[1.0_f32, 2.0, 3.0]);
343 assert!((v.sum() - 6.0).abs() < 1e-6);
344 }
345
346 #[test]
347 fn test_mean() {
348 let v = Vector::from_slice(&[1.0_f32, 2.0, 3.0]);
349 assert!((v.mean() - 2.0).abs() < 1e-6);
350 }
351
352 #[test]
353 fn test_dot() {
354 let a = Vector::from_slice(&[1.0_f32, 2.0, 3.0]);
355 let b = Vector::from_slice(&[4.0_f32, 5.0, 6.0]);
356 assert!((a.dot(&b) - 32.0).abs() < 1e-6);
357 }
358
359 #[test]
360 fn test_add_scalar() {
361 let v = Vector::from_slice(&[1.0_f32, 2.0, 3.0]);
362 let result = v.add_scalar(10.0);
363 assert!((result[0] - 11.0).abs() < 1e-6);
364 assert!((result[1] - 12.0).abs() < 1e-6);
365 assert!((result[2] - 13.0).abs() < 1e-6);
366 }
367
368 #[test]
369 fn test_mul_scalar() {
370 let v = Vector::from_slice(&[1.0_f32, 2.0, 3.0]);
371 let result = v.mul_scalar(2.0);
372 assert!((result[0] - 2.0).abs() < 1e-6);
373 assert!((result[1] - 4.0).abs() < 1e-6);
374 assert!((result[2] - 6.0).abs() < 1e-6);
375 }
376
377 #[test]
378 fn test_norm() {
379 let v = Vector::from_slice(&[3.0_f32, 4.0]);
380 assert!((v.norm() - 5.0).abs() < 1e-6);
381 }
382
383 #[test]
384 fn test_argmin() {
385 let v = Vector::from_slice(&[3.0_f32, 1.0, 2.0]);
386 assert_eq!(v.argmin(), 1);
387 }
388
389 #[test]
390 fn test_argmax() {
391 let v = Vector::from_slice(&[3.0_f32, 1.0, 2.0]);
392 assert_eq!(v.argmax(), 0);
393 }
394
395 #[test]
396 fn test_variance() {
397 let v = Vector::from_slice(&[2.0_f32, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0]);
398 assert!((v.variance() - 4.0).abs() < 1e-6);
399 }
400
401 #[test]
402 fn test_add_vectors() {
403 let a = Vector::from_slice(&[1.0_f32, 2.0, 3.0]);
404 let b = Vector::from_slice(&[4.0_f32, 5.0, 6.0]);
405 let result = &a + &b;
406 assert!((result[0] - 5.0).abs() < 1e-6);
407 assert!((result[1] - 7.0).abs() < 1e-6);
408 assert!((result[2] - 9.0).abs() < 1e-6);
409 }
410
411 #[test]
412 fn test_sub_vectors() {
413 let a = Vector::from_slice(&[4.0_f32, 5.0, 6.0]);
414 let b = Vector::from_slice(&[1.0_f32, 2.0, 3.0]);
415 let result = &a - &b;
416 assert!((result[0] - 3.0).abs() < 1e-6);
417 assert!((result[1] - 3.0).abs() < 1e-6);
418 assert!((result[2] - 3.0).abs() < 1e-6);
419 }
420
421 #[test]
422 fn test_slice() {
423 let v = Vector::from_slice(&[1.0_f32, 2.0, 3.0, 4.0, 5.0]);
424 let sliced = v.slice(1, 4);
425 assert_eq!(sliced.len(), 3);
426 assert!((sliced[0] - 2.0).abs() < 1e-6);
427 assert!((sliced[2] - 4.0).abs() < 1e-6);
428 }
429
430 #[test]
431 fn test_empty_mean() {
432 let v = Vector::<f32>::from_vec(vec![]);
433 assert!((v.mean() - 0.0).abs() < 1e-6);
434 }
435
436 #[test]
437 fn test_is_empty() {
438 let empty = Vector::<f32>::from_vec(vec![]);
439 assert!(empty.is_empty());
440
441 let non_empty = Vector::from_slice(&[1.0_f32]);
442 assert!(!non_empty.is_empty());
443 }
444
445 #[test]
446 fn test_argmax_single_element() {
447 let v = Vector::from_slice(&[42.0_f32]);
448 assert_eq!(v.argmax(), 0);
449 }
450
451 #[test]
452 fn test_argmax_all_equal() {
453 let v = Vector::from_slice(&[5.0_f32, 5.0, 5.0]);
454 let idx = v.argmax();
455 assert!(idx < v.len());
457 assert!((v[idx] - 5.0).abs() < 1e-6);
458 }
459
460 #[test]
461 fn test_argmin_single_element() {
462 let v = Vector::from_slice(&[42.0_f32]);
463 assert_eq!(v.argmin(), 0);
464 }
465
466 #[test]
467 fn test_argmin_all_equal() {
468 let v = Vector::from_slice(&[5.0_f32, 5.0, 5.0]);
469 let idx = v.argmin();
470 assert!(idx < v.len());
472 assert!((v[idx] - 5.0).abs() < 1e-6);
473 }
474
475 #[test]
476 fn test_argmax_not_at_zero() {
477 let v = Vector::from_slice(&[1.0_f32, 2.0, 10.0]);
479 assert_eq!(v.argmax(), 2);
480 }
481
482 #[test]
483 fn test_mul_vectors() {
484 let a = Vector::from_slice(&[2.0_f32, 3.0, 4.0]);
486 let b = Vector::from_slice(&[5.0_f32, 6.0, 7.0]);
487 let result = &a * &b;
488 assert!((result[0] - 10.0).abs() < 1e-6);
490 assert!((result[1] - 18.0).abs() < 1e-6);
491 assert!((result[2] - 28.0).abs() < 1e-6);
492
493 assert!((result[0] - 7.0).abs() > 0.1);
495 assert!((result[1] - 0.5).abs() > 1.0);
497 }
498
499 #[test]
500 fn test_is_empty_true() {
501 let v: Vector<f32> = Vector::from_slice(&[]);
504 assert!(v.is_empty(), "Empty vector should return true for is_empty");
505 assert_eq!(v.len(), 0, "Empty vector should have len 0");
506 }
507
508 #[test]
509 fn test_is_empty_false() {
510 let v = Vector::from_slice(&[1.0_f32]);
512 assert!(
513 !v.is_empty(),
514 "Non-empty vector should return false for is_empty"
515 );
516 }
517
518 #[test]
519 fn test_argmin_not_at_one() {
520 let v = Vector::from_slice(&[1.0_f32, 5.0, 3.0]);
522 assert_eq!(v.argmin(), 0, "Minimum should be at index 0, not 1");
523 }
524
525 #[test]
526 fn test_argmin_at_end() {
527 let v = Vector::from_slice(&[5.0_f32, 3.0, 1.0]);
529 assert_eq!(v.argmin(), 2, "Minimum should be at index 2");
530 }
531
532 #[test]
533 fn test_as_mut_slice_modifies() {
534 let mut v = Vector::from_slice(&[1.0_f32, 2.0, 3.0]);
537 {
538 let slice = v.as_mut_slice();
539 slice[0] = 10.0;
540 slice[1] = 20.0;
541 }
542 assert!(
543 (v[0] - 10.0).abs() < 1e-6,
544 "First element should be modified to 10.0"
545 );
546 assert!(
547 (v[1] - 20.0).abs() < 1e-6,
548 "Second element should be modified to 20.0"
549 );
550 }
551
552 #[test]
553 fn test_as_mut_slice_length() {
554 let mut v = Vector::from_slice(&[1.0_f32, 2.0, 3.0, 4.0]);
556 let slice = v.as_mut_slice();
557 assert_eq!(slice.len(), 4, "Mutable slice should have correct length");
558 }
559
560 #[test]
564 fn test_argmax_f32_returns_nonzero() {
565 let v: Vector<f32> = Vector::from_slice(&[1.0, 2.0, 999.0, 3.0]);
568 assert_eq!(
569 v.argmax(),
570 2,
571 "argmax must return 2 (position of max 999.0), not 0"
572 );
573 assert_ne!(v.argmax(), 0, "argmax must not always return 0");
575 }
576
577 #[test]
578 fn test_as_mut_slice_f32_not_empty() {
579 let mut v: Vector<f32> = Vector::from_slice(&[10.0, 20.0, 30.0]);
582 let slice = v.as_mut_slice();
583
584 assert_eq!(
586 slice.len(),
587 3,
588 "as_mut_slice must return slice with 3 elements, not empty"
589 );
590
591 slice[0] = 100.0;
593 assert!(
594 (v[0] - 100.0).abs() < 1e-6,
595 "as_mut_slice must allow mutation of original data"
596 );
597 }
598
599 #[test]
600 fn test_mul_f32_not_addition() {
601 let a: Vector<f32> = Vector::from_slice(&[3.0, 4.0]);
604 let b: Vector<f32> = Vector::from_slice(&[5.0, 6.0]);
605 let result = &a * &b;
606
607 assert!(
609 (result[0] - 15.0).abs() < 1e-6,
610 "3*5 must equal 15, not 3+5=8"
611 );
612 assert!(
613 (result[1] - 24.0).abs() < 1e-6,
614 "4*6 must equal 24, not 4+6=10"
615 );
616
617 assert!((result[0] - 8.0).abs() > 1.0, "Must not be addition");
619 assert!((result[1] - 10.0).abs() > 1.0, "Must not be addition");
620 }
621
622 #[test]
623 fn test_mul_f32_not_division() {
624 let a: Vector<f32> = Vector::from_slice(&[12.0, 20.0]);
627 let b: Vector<f32> = Vector::from_slice(&[3.0, 4.0]);
628 let result = &a * &b;
629
630 assert!(
632 (result[0] - 36.0).abs() < 1e-6,
633 "12*3 must equal 36, not 12/3=4"
634 );
635 assert!(
636 (result[1] - 80.0).abs() < 1e-6,
637 "20*4 must equal 80, not 20/4=5"
638 );
639
640 assert!((result[0] - 4.0).abs() > 1.0, "Must not be division");
642 assert!((result[1] - 5.0).abs() > 1.0, "Must not be division");
643 }
644
645 #[test]
646 fn test_std() {
647 let v = Vector::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0]);
649 let std = v.std();
650
651 assert!((std - 1.414).abs() < 0.01, "std = {std}");
653 }
654
655 #[test]
656 fn test_std_uniform() {
657 let v = Vector::from_slice(&[5.0, 5.0, 5.0, 5.0]);
659 assert!((v.std() - 0.0).abs() < 1e-6);
660 }
661
662 #[test]
663 fn test_gini_coefficient_perfect_equality() {
664 let v = Vector::from_slice(&[5.0, 5.0, 5.0, 5.0]);
666 assert!((v.gini_coefficient() - 0.0).abs() < 0.01);
667 }
668
669 #[test]
670 fn test_gini_coefficient_inequality() {
671 let v = Vector::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0]);
673 let gini = v.gini_coefficient();
674
675 assert!(gini > 0.0 && gini < 1.0, "Gini = {gini}");
677
678 assert!((gini - 0.267).abs() < 0.01, "Gini = {gini}");
680 }
681
682 #[test]
683 fn test_gini_coefficient_maximum_inequality() {
684 let v = Vector::from_slice(&[0.0, 0.0, 0.0, 100.0]);
686 let gini = v.gini_coefficient();
687
688 assert!(gini > 0.7 && gini < 0.8, "Gini = {gini}");
691 }
692
693 #[test]
694 fn test_gini_coefficient_empty() {
695 let v: Vector<f32> = Vector::from_slice(&[]);
696 assert_eq!(v.gini_coefficient(), 0.0);
697 }
698}