aprender/primitives/
vector.rs

1//! Vector type for 1D numeric data.
2
3use serde::{Deserialize, Serialize};
4use std::ops::{Add, Index, IndexMut, Mul, Sub};
5
6/// A 1D vector of floating-point values.
7///
8/// # Examples
9///
10/// ```
11/// use aprender::primitives::Vector;
12///
13/// let v = Vector::from_slice(&[1.0, 2.0, 3.0]);
14/// assert_eq!(v.len(), 3);
15/// assert!((v.sum() - 6.0).abs() < 1e-6);
16/// ```
17#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
18pub struct Vector<T> {
19    data: Vec<T>,
20}
21
22impl<T: Copy> Vector<T> {
23    /// Creates a new vector from a slice.
24    #[must_use]
25    pub fn from_slice(data: &[T]) -> Self {
26        Self {
27            data: data.to_vec(),
28        }
29    }
30
31    /// Creates a new vector from a Vec.
32    #[must_use]
33    pub fn from_vec(data: Vec<T>) -> Self {
34        Self { data }
35    }
36
37    /// Returns the number of elements.
38    #[must_use]
39    pub fn len(&self) -> usize {
40        self.data.len()
41    }
42
43    /// Returns true if the vector is empty.
44    #[must_use]
45    pub fn is_empty(&self) -> bool {
46        self.data.is_empty()
47    }
48
49    /// Returns a slice of the underlying data.
50    #[must_use]
51    pub fn as_slice(&self) -> &[T] {
52        &self.data
53    }
54
55    /// Returns a mutable slice of the underlying data.
56    pub fn as_mut_slice(&mut self) -> &mut [T] {
57        &mut self.data
58    }
59
60    /// Returns a slice from start to end indices.
61    #[must_use]
62    pub fn slice(&self, start: usize, end: usize) -> Self {
63        Self::from_slice(&self.data[start..end])
64    }
65}
66
67impl<T> Index<usize> for Vector<T> {
68    type Output = T;
69
70    fn index(&self, index: usize) -> &Self::Output {
71        &self.data[index]
72    }
73}
74
75impl<T> IndexMut<usize> for Vector<T> {
76    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
77        &mut self.data[index]
78    }
79}
80
81impl Vector<f32> {
82    /// Creates a vector of zeros.
83    #[must_use]
84    pub fn zeros(len: usize) -> Self {
85        Self {
86            data: vec![0.0; len],
87        }
88    }
89
90    /// Creates a vector of ones.
91    #[must_use]
92    pub fn ones(len: usize) -> Self {
93        Self {
94            data: vec![1.0; len],
95        }
96    }
97
98    /// Computes the sum of all elements.
99    #[must_use]
100    pub fn sum(&self) -> f32 {
101        self.data.iter().sum()
102    }
103
104    /// Computes the mean of all elements.
105    #[must_use]
106    pub fn mean(&self) -> f32 {
107        if self.data.is_empty() {
108            return 0.0;
109        }
110        self.sum() / self.data.len() as f32
111    }
112
113    /// Computes the dot product with another vector.
114    ///
115    /// # Panics
116    ///
117    /// Panics if vectors have different lengths.
118    #[must_use]
119    pub fn dot(&self, other: &Self) -> f32 {
120        assert_eq!(
121            self.len(),
122            other.len(),
123            "Vector lengths must match for dot product"
124        );
125        self.data
126            .iter()
127            .zip(other.data.iter())
128            .map(|(a, b)| a * b)
129            .sum()
130    }
131
132    /// Adds a scalar to each element.
133    #[must_use]
134    pub fn add_scalar(&self, scalar: f32) -> Self {
135        Self {
136            data: self.data.iter().map(|x| x + scalar).collect(),
137        }
138    }
139
140    /// Multiplies each element by a scalar.
141    #[must_use]
142    pub fn mul_scalar(&self, scalar: f32) -> Self {
143        Self {
144            data: self.data.iter().map(|x| x * scalar).collect(),
145        }
146    }
147
148    /// Computes the squared L2 norm.
149    #[must_use]
150    pub fn norm_squared(&self) -> f32 {
151        self.dot(self)
152    }
153
154    /// Computes the L2 norm.
155    #[must_use]
156    pub fn norm(&self) -> f32 {
157        self.norm_squared().sqrt()
158    }
159
160    /// Returns the index of the minimum element.
161    #[must_use]
162    pub fn argmin(&self) -> usize {
163        self.data
164            .iter()
165            .enumerate()
166            .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
167            .map_or(0, |(i, _)| i)
168    }
169
170    /// Returns the index of the maximum element.
171    #[must_use]
172    pub fn argmax(&self) -> usize {
173        self.data
174            .iter()
175            .enumerate()
176            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
177            .map_or(0, |(i, _)| i)
178    }
179
180    /// Computes variance of all elements.
181    #[must_use]
182    pub fn variance(&self) -> f32 {
183        if self.data.is_empty() {
184            return 0.0;
185        }
186        let mean = self.mean();
187        self.data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / self.data.len() as f32
188    }
189
190    /// Computes standard deviation of all elements.
191    ///
192    /// Standard deviation is the square root of variance.
193    ///
194    /// # Examples
195    ///
196    /// ```
197    /// use aprender::primitives::Vector;
198    ///
199    /// let v = Vector::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0]);
200    /// let std = v.std();
201    /// assert!((std - 1.414).abs() < 0.01);
202    /// ```
203    #[must_use]
204    pub fn std(&self) -> f32 {
205        self.variance().sqrt()
206    }
207
208    /// Computes Gini coefficient (inequality measure).
209    ///
210    /// The Gini coefficient measures inequality in a distribution.
211    /// Formula: G = Σ Σ |x_i - x_j| / (2n² * mean)
212    ///
213    /// # Returns
214    /// - 0.0: Perfect equality (all values are the same)
215    /// - 1.0: Maximum inequality (one value has everything)
216    ///
217    /// # Examples
218    ///
219    /// ```
220    /// use aprender::primitives::Vector;
221    ///
222    /// // Perfect equality
223    /// let v = Vector::from_slice(&[5.0, 5.0, 5.0]);
224    /// assert!((v.gini_coefficient() - 0.0).abs() < 0.01);
225    ///
226    /// // Some inequality
227    /// let v = Vector::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0]);
228    /// let gini = v.gini_coefficient();
229    /// assert!(gini > 0.0 && gini < 1.0);
230    /// ```
231    #[must_use]
232    pub fn gini_coefficient(&self) -> f32 {
233        if self.data.is_empty() {
234            return 0.0;
235        }
236
237        let mean = self.mean();
238        if mean == 0.0 {
239            return 0.0;
240        }
241
242        let n = self.data.len() as f32;
243        let mut sum_abs_diff = 0.0;
244
245        for i in 0..self.data.len() {
246            for j in 0..self.data.len() {
247                sum_abs_diff += (self.data[i] - self.data[j]).abs();
248            }
249        }
250
251        sum_abs_diff / (2.0 * n * n * mean)
252    }
253}
254
255impl Add for &Vector<f32> {
256    type Output = Vector<f32>;
257
258    fn add(self, other: Self) -> Self::Output {
259        assert_eq!(
260            self.len(),
261            other.len(),
262            "Vector lengths must match for addition"
263        );
264        Vector {
265            data: self
266                .data
267                .iter()
268                .zip(other.data.iter())
269                .map(|(a, b)| a + b)
270                .collect(),
271        }
272    }
273}
274
275impl Sub for &Vector<f32> {
276    type Output = Vector<f32>;
277
278    fn sub(self, other: Self) -> Self::Output {
279        assert_eq!(
280            self.len(),
281            other.len(),
282            "Vector lengths must match for subtraction"
283        );
284        Vector {
285            data: self
286                .data
287                .iter()
288                .zip(other.data.iter())
289                .map(|(a, b)| a - b)
290                .collect(),
291        }
292    }
293}
294
295impl Mul for &Vector<f32> {
296    type Output = Vector<f32>;
297
298    fn mul(self, other: Self) -> Self::Output {
299        assert_eq!(
300            self.len(),
301            other.len(),
302            "Vector lengths must match for multiplication"
303        );
304        Vector {
305            data: self
306                .data
307                .iter()
308                .zip(other.data.iter())
309                .map(|(a, b)| a * b)
310                .collect(),
311        }
312    }
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318
319    #[test]
320    fn test_from_slice() {
321        let v = Vector::from_slice(&[1.0_f32, 2.0, 3.0]);
322        assert_eq!(v.len(), 3);
323        assert!((v[0] - 1.0).abs() < 1e-6);
324    }
325
326    #[test]
327    fn test_zeros() {
328        let v = Vector::<f32>::zeros(5);
329        assert_eq!(v.len(), 5);
330        assert!(v.as_slice().iter().all(|&x| x == 0.0));
331    }
332
333    #[test]
334    fn test_ones() {
335        let v = Vector::<f32>::ones(5);
336        assert_eq!(v.len(), 5);
337        assert!(v.as_slice().iter().all(|&x| (x - 1.0).abs() < 1e-6));
338    }
339
340    #[test]
341    fn test_sum() {
342        let v = Vector::from_slice(&[1.0_f32, 2.0, 3.0]);
343        assert!((v.sum() - 6.0).abs() < 1e-6);
344    }
345
346    #[test]
347    fn test_mean() {
348        let v = Vector::from_slice(&[1.0_f32, 2.0, 3.0]);
349        assert!((v.mean() - 2.0).abs() < 1e-6);
350    }
351
352    #[test]
353    fn test_dot() {
354        let a = Vector::from_slice(&[1.0_f32, 2.0, 3.0]);
355        let b = Vector::from_slice(&[4.0_f32, 5.0, 6.0]);
356        assert!((a.dot(&b) - 32.0).abs() < 1e-6);
357    }
358
359    #[test]
360    fn test_add_scalar() {
361        let v = Vector::from_slice(&[1.0_f32, 2.0, 3.0]);
362        let result = v.add_scalar(10.0);
363        assert!((result[0] - 11.0).abs() < 1e-6);
364        assert!((result[1] - 12.0).abs() < 1e-6);
365        assert!((result[2] - 13.0).abs() < 1e-6);
366    }
367
368    #[test]
369    fn test_mul_scalar() {
370        let v = Vector::from_slice(&[1.0_f32, 2.0, 3.0]);
371        let result = v.mul_scalar(2.0);
372        assert!((result[0] - 2.0).abs() < 1e-6);
373        assert!((result[1] - 4.0).abs() < 1e-6);
374        assert!((result[2] - 6.0).abs() < 1e-6);
375    }
376
377    #[test]
378    fn test_norm() {
379        let v = Vector::from_slice(&[3.0_f32, 4.0]);
380        assert!((v.norm() - 5.0).abs() < 1e-6);
381    }
382
383    #[test]
384    fn test_argmin() {
385        let v = Vector::from_slice(&[3.0_f32, 1.0, 2.0]);
386        assert_eq!(v.argmin(), 1);
387    }
388
389    #[test]
390    fn test_argmax() {
391        let v = Vector::from_slice(&[3.0_f32, 1.0, 2.0]);
392        assert_eq!(v.argmax(), 0);
393    }
394
395    #[test]
396    fn test_variance() {
397        let v = Vector::from_slice(&[2.0_f32, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0]);
398        assert!((v.variance() - 4.0).abs() < 1e-6);
399    }
400
401    #[test]
402    fn test_add_vectors() {
403        let a = Vector::from_slice(&[1.0_f32, 2.0, 3.0]);
404        let b = Vector::from_slice(&[4.0_f32, 5.0, 6.0]);
405        let result = &a + &b;
406        assert!((result[0] - 5.0).abs() < 1e-6);
407        assert!((result[1] - 7.0).abs() < 1e-6);
408        assert!((result[2] - 9.0).abs() < 1e-6);
409    }
410
411    #[test]
412    fn test_sub_vectors() {
413        let a = Vector::from_slice(&[4.0_f32, 5.0, 6.0]);
414        let b = Vector::from_slice(&[1.0_f32, 2.0, 3.0]);
415        let result = &a - &b;
416        assert!((result[0] - 3.0).abs() < 1e-6);
417        assert!((result[1] - 3.0).abs() < 1e-6);
418        assert!((result[2] - 3.0).abs() < 1e-6);
419    }
420
421    #[test]
422    fn test_slice() {
423        let v = Vector::from_slice(&[1.0_f32, 2.0, 3.0, 4.0, 5.0]);
424        let sliced = v.slice(1, 4);
425        assert_eq!(sliced.len(), 3);
426        assert!((sliced[0] - 2.0).abs() < 1e-6);
427        assert!((sliced[2] - 4.0).abs() < 1e-6);
428    }
429
430    #[test]
431    fn test_empty_mean() {
432        let v = Vector::<f32>::from_vec(vec![]);
433        assert!((v.mean() - 0.0).abs() < 1e-6);
434    }
435
436    #[test]
437    fn test_is_empty() {
438        let empty = Vector::<f32>::from_vec(vec![]);
439        assert!(empty.is_empty());
440
441        let non_empty = Vector::from_slice(&[1.0_f32]);
442        assert!(!non_empty.is_empty());
443    }
444
445    #[test]
446    fn test_argmax_single_element() {
447        let v = Vector::from_slice(&[42.0_f32]);
448        assert_eq!(v.argmax(), 0);
449    }
450
451    #[test]
452    fn test_argmax_all_equal() {
453        let v = Vector::from_slice(&[5.0_f32, 5.0, 5.0]);
454        let idx = v.argmax();
455        // When all equal, any valid index is acceptable
456        assert!(idx < v.len());
457        assert!((v[idx] - 5.0).abs() < 1e-6);
458    }
459
460    #[test]
461    fn test_argmin_single_element() {
462        let v = Vector::from_slice(&[42.0_f32]);
463        assert_eq!(v.argmin(), 0);
464    }
465
466    #[test]
467    fn test_argmin_all_equal() {
468        let v = Vector::from_slice(&[5.0_f32, 5.0, 5.0]);
469        let idx = v.argmin();
470        // When all equal, any valid index is acceptable
471        assert!(idx < v.len());
472        assert!((v[idx] - 5.0).abs() < 1e-6);
473    }
474
475    #[test]
476    fn test_argmax_not_at_zero() {
477        // Max at index 2, not 0 - catches "argmax -> 0" mutation
478        let v = Vector::from_slice(&[1.0_f32, 2.0, 10.0]);
479        assert_eq!(v.argmax(), 2);
480    }
481
482    #[test]
483    fn test_mul_vectors() {
484        // Element-wise multiplication - catches operator mutations
485        let a = Vector::from_slice(&[2.0_f32, 3.0, 4.0]);
486        let b = Vector::from_slice(&[5.0_f32, 6.0, 7.0]);
487        let result = &a * &b;
488        // 2*5=10, 3*6=18, 4*7=28
489        assert!((result[0] - 10.0).abs() < 1e-6);
490        assert!((result[1] - 18.0).abs() < 1e-6);
491        assert!((result[2] - 28.0).abs() < 1e-6);
492
493        // Verify it's not addition: would be 7, 9, 11
494        assert!((result[0] - 7.0).abs() > 0.1);
495        // Verify it's not division: would be 0.4, 0.5, 0.571...
496        assert!((result[1] - 0.5).abs() > 1.0);
497    }
498
499    #[test]
500    fn test_is_empty_true() {
501        // Test that empty vector returns true for is_empty
502        // Catches is_empty -> false mutation
503        let v: Vector<f32> = Vector::from_slice(&[]);
504        assert!(v.is_empty(), "Empty vector should return true for is_empty");
505        assert_eq!(v.len(), 0, "Empty vector should have len 0");
506    }
507
508    #[test]
509    fn test_is_empty_false() {
510        // Test that non-empty vector returns false for is_empty
511        let v = Vector::from_slice(&[1.0_f32]);
512        assert!(
513            !v.is_empty(),
514            "Non-empty vector should return false for is_empty"
515        );
516    }
517
518    #[test]
519    fn test_argmin_not_at_one() {
520        // Min at index 0, not 1 - catches "argmin -> 1" mutation
521        let v = Vector::from_slice(&[1.0_f32, 5.0, 3.0]);
522        assert_eq!(v.argmin(), 0, "Minimum should be at index 0, not 1");
523    }
524
525    #[test]
526    fn test_argmin_at_end() {
527        // Min at last index - catches various index mutations
528        let v = Vector::from_slice(&[5.0_f32, 3.0, 1.0]);
529        assert_eq!(v.argmin(), 2, "Minimum should be at index 2");
530    }
531
532    #[test]
533    fn test_as_mut_slice_modifies() {
534        // Test that as_mut_slice actually allows modification
535        // Catches as_mut_slice -> empty slice mutation
536        let mut v = Vector::from_slice(&[1.0_f32, 2.0, 3.0]);
537        {
538            let slice = v.as_mut_slice();
539            slice[0] = 10.0;
540            slice[1] = 20.0;
541        }
542        assert!(
543            (v[0] - 10.0).abs() < 1e-6,
544            "First element should be modified to 10.0"
545        );
546        assert!(
547            (v[1] - 20.0).abs() < 1e-6,
548            "Second element should be modified to 20.0"
549        );
550    }
551
552    #[test]
553    fn test_as_mut_slice_length() {
554        // Verify as_mut_slice returns correct length
555        let mut v = Vector::from_slice(&[1.0_f32, 2.0, 3.0, 4.0]);
556        let slice = v.as_mut_slice();
557        assert_eq!(slice.len(), 4, "Mutable slice should have correct length");
558    }
559
560    // EXTREME TDD: Additional mutation-killing tests
561    // These tests explicitly target MISSED mutants reported by cargo-mutants
562
563    #[test]
564    fn test_argmax_f32_returns_nonzero() {
565        // MUTATION TARGET: "replace Vector<f32>::argmax -> usize with 0"
566        // This test ensures argmax() does NOT always return 0
567        let v: Vector<f32> = Vector::from_slice(&[1.0, 2.0, 999.0, 3.0]);
568        assert_eq!(
569            v.argmax(),
570            2,
571            "argmax must return 2 (position of max 999.0), not 0"
572        );
573        // Double check: if mutation makes argmax() return 0, this fails
574        assert_ne!(v.argmax(), 0, "argmax must not always return 0");
575    }
576
577    #[test]
578    fn test_as_mut_slice_f32_not_empty() {
579        // MUTATION TARGET: "replace as_mut_slice -> &mut[T] with Vec::leak(Vec::new())"
580        // This test ensures as_mut_slice() does NOT return empty slice
581        let mut v: Vector<f32> = Vector::from_slice(&[10.0, 20.0, 30.0]);
582        let slice = v.as_mut_slice();
583
584        // If mutation returns empty slice, len check fails
585        assert_eq!(
586            slice.len(),
587            3,
588            "as_mut_slice must return slice with 3 elements, not empty"
589        );
590
591        // If mutation returns empty slice, modification has no effect
592        slice[0] = 100.0;
593        assert!(
594            (v[0] - 100.0).abs() < 1e-6,
595            "as_mut_slice must allow mutation of original data"
596        );
597    }
598
599    #[test]
600    fn test_mul_f32_not_addition() {
601        // MUTATION TARGET: "replace * with + in <impl Mul for &Vector<f32>>::mul"
602        // This test ensures Mul uses *, not +
603        let a: Vector<f32> = Vector::from_slice(&[3.0, 4.0]);
604        let b: Vector<f32> = Vector::from_slice(&[5.0, 6.0]);
605        let result = &a * &b;
606
607        // Multiplication: [3*5=15, 4*6=24]
608        assert!(
609            (result[0] - 15.0).abs() < 1e-6,
610            "3*5 must equal 15, not 3+5=8"
611        );
612        assert!(
613            (result[1] - 24.0).abs() < 1e-6,
614            "4*6 must equal 24, not 4+6=10"
615        );
616
617        // If mutation uses +, we get [8, 10] instead
618        assert!((result[0] - 8.0).abs() > 1.0, "Must not be addition");
619        assert!((result[1] - 10.0).abs() > 1.0, "Must not be addition");
620    }
621
622    #[test]
623    fn test_mul_f32_not_division() {
624        // MUTATION TARGET: "replace * with / in <impl Mul for &Vector<f32>>::mul"
625        // This test ensures Mul uses *, not /
626        let a: Vector<f32> = Vector::from_slice(&[12.0, 20.0]);
627        let b: Vector<f32> = Vector::from_slice(&[3.0, 4.0]);
628        let result = &a * &b;
629
630        // Multiplication: [12*3=36, 20*4=80]
631        assert!(
632            (result[0] - 36.0).abs() < 1e-6,
633            "12*3 must equal 36, not 12/3=4"
634        );
635        assert!(
636            (result[1] - 80.0).abs() < 1e-6,
637            "20*4 must equal 80, not 20/4=5"
638        );
639
640        // If mutation uses /, we get [4, 5] instead
641        assert!((result[0] - 4.0).abs() > 1.0, "Must not be division");
642        assert!((result[1] - 5.0).abs() > 1.0, "Must not be division");
643    }
644
645    #[test]
646    fn test_std() {
647        // Test standard deviation calculation
648        let v = Vector::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0]);
649        let std = v.std();
650
651        // Expected: sqrt(variance) = sqrt(2.0) ≈ 1.414
652        assert!((std - 1.414).abs() < 0.01, "std = {std}");
653    }
654
655    #[test]
656    fn test_std_uniform() {
657        // Uniform values should have std = 0
658        let v = Vector::from_slice(&[5.0, 5.0, 5.0, 5.0]);
659        assert!((v.std() - 0.0).abs() < 1e-6);
660    }
661
662    #[test]
663    fn test_gini_coefficient_perfect_equality() {
664        // All values equal -> Gini = 0
665        let v = Vector::from_slice(&[5.0, 5.0, 5.0, 5.0]);
666        assert!((v.gini_coefficient() - 0.0).abs() < 0.01);
667    }
668
669    #[test]
670    fn test_gini_coefficient_inequality() {
671        // Some inequality
672        let v = Vector::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0]);
673        let gini = v.gini_coefficient();
674
675        // Should be between 0 and 1
676        assert!(gini > 0.0 && gini < 1.0, "Gini = {gini}");
677
678        // For this specific distribution: Gini ≈ 0.267
679        assert!((gini - 0.267).abs() < 0.01, "Gini = {gini}");
680    }
681
682    #[test]
683    fn test_gini_coefficient_maximum_inequality() {
684        // Maximum inequality: one person has everything
685        let v = Vector::from_slice(&[0.0, 0.0, 0.0, 100.0]);
686        let gini = v.gini_coefficient();
687
688        // Should approach 1.0 (but exact value depends on n)
689        // For n=4: Gini = 0.75
690        assert!(gini > 0.7 && gini < 0.8, "Gini = {gini}");
691    }
692
693    #[test]
694    fn test_gini_coefficient_empty() {
695        let v: Vector<f32> = Vector::from_slice(&[]);
696        assert_eq!(v.gini_coefficient(), 0.0);
697    }
698}