use serde::{Deserialize, Serialize};
use std::ops::{Add, Index, IndexMut, Mul, Sub};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Vector<T> {
data: Vec<T>,
}
impl<T: Copy> Vector<T> {
#[must_use]
pub fn from_slice(data: &[T]) -> Self {
Self {
data: data.to_vec(),
}
}
#[must_use]
pub fn from_vec(data: Vec<T>) -> Self {
Self { data }
}
#[must_use]
pub fn len(&self) -> usize {
self.data.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
#[must_use]
pub fn as_slice(&self) -> &[T] {
&self.data
}
pub fn as_mut_slice(&mut self) -> &mut [T] {
&mut self.data
}
#[must_use]
pub fn slice(&self, start: usize, end: usize) -> Self {
Self::from_slice(&self.data[start..end])
}
}
impl<T> Index<usize> for Vector<T> {
type Output = T;
fn index(&self, index: usize) -> &Self::Output {
&self.data[index]
}
}
impl<T> IndexMut<usize> for Vector<T> {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
&mut self.data[index]
}
}
impl Vector<f32> {
#[must_use]
pub fn zeros(len: usize) -> Self {
Self {
data: vec![0.0; len],
}
}
#[must_use]
pub fn ones(len: usize) -> Self {
Self {
data: vec![1.0; len],
}
}
#[must_use]
pub fn sum(&self) -> f32 {
self.data.iter().sum()
}
#[must_use]
pub fn mean(&self) -> f32 {
if self.data.is_empty() {
return 0.0;
}
self.sum() / self.data.len() as f32
}
#[must_use]
pub fn dot(&self, other: &Self) -> f32 {
assert_eq!(
self.len(),
other.len(),
"Vector lengths must match for dot product"
);
self.data
.iter()
.zip(other.data.iter())
.map(|(a, b)| a * b)
.sum()
}
#[must_use]
pub fn add_scalar(&self, scalar: f32) -> Self {
Self {
data: self.data.iter().map(|x| x + scalar).collect(),
}
}
#[must_use]
pub fn mul_scalar(&self, scalar: f32) -> Self {
Self {
data: self.data.iter().map(|x| x * scalar).collect(),
}
}
#[must_use]
pub fn norm_squared(&self) -> f32 {
self.dot(self)
}
#[must_use]
pub fn norm(&self) -> f32 {
self.norm_squared().sqrt()
}
#[must_use]
pub fn argmin(&self) -> usize {
self.data
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map_or(0, |(i, _)| i)
}
#[must_use]
pub fn argmax(&self) -> usize {
self.data
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map_or(0, |(i, _)| i)
}
#[must_use]
pub fn variance(&self) -> f32 {
if self.data.is_empty() {
return 0.0;
}
let mean = self.mean();
self.data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / self.data.len() as f32
}
#[must_use]
pub fn std(&self) -> f32 {
self.variance().sqrt()
}
#[must_use]
pub fn gini_coefficient(&self) -> f32 {
if self.data.is_empty() {
return 0.0;
}
let mean = self.mean();
if mean == 0.0 {
return 0.0;
}
let n = self.data.len() as f32;
let mut sum_abs_diff = 0.0;
for i in 0..self.data.len() {
for j in 0..self.data.len() {
sum_abs_diff += (self.data[i] - self.data[j]).abs();
}
}
sum_abs_diff / (2.0 * n * n * mean)
}
}
impl Add for &Vector<f32> {
type Output = Vector<f32>;
fn add(self, other: Self) -> Self::Output {
assert_eq!(
self.len(),
other.len(),
"Vector lengths must match for addition"
);
Vector {
data: self
.data
.iter()
.zip(other.data.iter())
.map(|(a, b)| a + b)
.collect(),
}
}
}
impl Sub for &Vector<f32> {
type Output = Vector<f32>;
fn sub(self, other: Self) -> Self::Output {
assert_eq!(
self.len(),
other.len(),
"Vector lengths must match for subtraction"
);
Vector {
data: self
.data
.iter()
.zip(other.data.iter())
.map(|(a, b)| a - b)
.collect(),
}
}
}
impl Mul for &Vector<f32> {
type Output = Vector<f32>;
fn mul(self, other: Self) -> Self::Output {
assert_eq!(
self.len(),
other.len(),
"Vector lengths must match for multiplication"
);
Vector {
data: self
.data
.iter()
.zip(other.data.iter())
.map(|(a, b)| a * b)
.collect(),
}
}
}
#[cfg(test)]
#[path = "vector_tests.rs"]
mod tests;
#[cfg(test)]
#[path = "tests_vector_contract.rs"]
mod tests_vector_contract;