#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
use crate::Scalar;
pub trait Vector<S: Scalar>: Clone + Sized {
fn zeros(len: usize) -> Self;
fn fill(len: usize, value: S) -> Self;
fn from_slice(data: &[S]) -> Self;
fn len(&self) -> usize;
#[inline]
fn is_empty(&self) -> bool {
self.len() == 0
}
fn get(&self, i: usize) -> S;
fn set(&mut self, i: usize, value: S);
fn get_mut(&mut self, i: usize) -> &mut S;
fn as_slice(&self) -> &[S];
fn as_mut_slice(&mut self) -> &mut [S];
fn copy_from(&mut self, other: &Self);
fn axpy(&mut self, a: S, x: &Self);
fn axpby(&mut self, a: S, x: &Self, b: S);
fn dot(&self, other: &Self) -> S;
fn scale(&mut self, a: S);
#[inline]
fn norm2(&self) -> S {
self.dot(self).sqrt()
}
fn norm_inf(&self) -> S;
fn norm1(&self) -> S;
fn weighted_rms_norm(&self, weights: &Self) -> S {
let n = S::from_usize(self.len());
let mut sum = S::ZERO;
for i in 0..self.len() {
let xi = self.get(i) / weights.get(i);
sum += xi * xi;
}
(sum / n).sqrt()
}
fn abs_inplace(&mut self);
fn max_elementwise(&mut self, other: &Self);
fn min_elementwise(&mut self, other: &Self);
fn sum(&self) -> S;
fn max_element(&self) -> S;
fn min_element(&self) -> S;
fn map_inplace<F: Fn(S) -> S>(&mut self, f: F);
}
impl<S: Scalar> Vector<S> for Vec<S> {
#[inline]
fn zeros(len: usize) -> Self {
vec![S::ZERO; len]
}
#[inline]
fn fill(len: usize, value: S) -> Self {
vec![value; len]
}
#[inline]
fn from_slice(data: &[S]) -> Self {
data.to_vec()
}
#[inline]
fn len(&self) -> usize {
Vec::len(self)
}
#[inline]
fn get(&self, i: usize) -> S {
self[i]
}
#[inline]
fn set(&mut self, i: usize, value: S) {
self[i] = value;
}
#[inline]
fn get_mut(&mut self, i: usize) -> &mut S {
&mut self[i]
}
#[inline]
fn as_slice(&self) -> &[S] {
self
}
#[inline]
fn as_mut_slice(&mut self) -> &mut [S] {
self
}
#[inline]
fn copy_from(&mut self, other: &Self) {
self.copy_from_slice(other);
}
fn axpy(&mut self, a: S, x: &Self) {
debug_assert_eq!(self.len(), x.len());
for (yi, xi) in self.iter_mut().zip(x.iter()) {
*yi += a * *xi;
}
}
fn axpby(&mut self, a: S, x: &Self, b: S) {
debug_assert_eq!(self.len(), x.len());
for (yi, xi) in self.iter_mut().zip(x.iter()) {
*yi = a * *xi + b * *yi;
}
}
fn dot(&self, other: &Self) -> S {
debug_assert_eq!(self.len(), other.len());
self.iter()
.zip(other.iter())
.fold(S::ZERO, |acc, (a, b)| acc + *a * *b)
}
fn scale(&mut self, a: S) {
for x in self.iter_mut() {
*x *= a;
}
}
fn norm_inf(&self) -> S {
self.iter().fold(S::ZERO, |acc, x| acc.max(x.abs()))
}
fn norm1(&self) -> S {
self.iter().fold(S::ZERO, |acc, x| acc + x.abs())
}
fn abs_inplace(&mut self) {
for x in self.iter_mut() {
*x = x.abs();
}
}
fn max_elementwise(&mut self, other: &Self) {
debug_assert_eq!(self.len(), other.len());
for (yi, xi) in self.iter_mut().zip(other.iter()) {
*yi = yi.max(*xi);
}
}
fn min_elementwise(&mut self, other: &Self) {
debug_assert_eq!(self.len(), other.len());
for (yi, xi) in self.iter_mut().zip(other.iter()) {
*yi = yi.min(*xi);
}
}
fn sum(&self) -> S {
self.iter().fold(S::ZERO, |acc, x| acc + *x)
}
fn max_element(&self) -> S {
self.iter().fold(S::NEG_INFINITY, |acc, x| acc.max(*x))
}
fn min_element(&self) -> S {
self.iter().fold(S::INFINITY, |acc, x| acc.min(*x))
}
fn map_inplace<F: Fn(S) -> S>(&mut self, f: F) {
for x in self.iter_mut() {
*x = f(*x);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_zeros() {
let v: Vec<f64> = Vector::zeros(5);
assert_eq!(v.len(), 5);
for x in &v {
assert_eq!(*x, 0.0);
}
}
#[test]
fn test_fill() {
let v: Vec<f64> = Vector::fill(3, 2.5);
assert_eq!(v, vec![2.5, 2.5, 2.5]);
}
#[test]
fn test_axpy() {
let x: Vec<f64> = vec![1.0, 2.0, 3.0];
let mut y: Vec<f64> = vec![4.0, 5.0, 6.0];
y.axpy(2.0, &x);
assert_eq!(y, vec![6.0, 9.0, 12.0]);
}
#[test]
fn test_axpby() {
let x: Vec<f64> = vec![1.0, 2.0, 3.0];
let mut y: Vec<f64> = vec![4.0, 5.0, 6.0];
y.axpby(2.0, &x, 0.5);
assert!((y[0] - 4.0).abs() < 1e-10);
assert!((y[1] - 6.5).abs() < 1e-10);
assert!((y[2] - 9.0).abs() < 1e-10);
}
#[test]
fn test_dot() {
let x: Vec<f64> = vec![1.0, 2.0, 3.0];
let y: Vec<f64> = vec![4.0, 5.0, 6.0];
assert!((x.dot(&y) - 32.0).abs() < 1e-10);
}
#[test]
fn test_norm2() {
let v: Vec<f64> = vec![3.0, 4.0];
assert!((v.norm2() - 5.0).abs() < 1e-10);
}
#[test]
fn test_norm_inf() {
let v: Vec<f64> = vec![-5.0, 3.0, -1.0];
assert!((v.norm_inf() - 5.0).abs() < 1e-10);
}
#[test]
fn test_norm1() {
let v: Vec<f64> = vec![-1.0, 2.0, -3.0];
assert!((v.norm1() - 6.0).abs() < 1e-10);
}
#[test]
fn test_scale() {
let mut v: Vec<f64> = vec![1.0, 2.0, 3.0];
v.scale(2.0);
assert_eq!(v, vec![2.0, 4.0, 6.0]);
}
#[test]
fn test_sum() {
let v: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0];
assert!((v.sum() - 10.0).abs() < 1e-10);
}
#[test]
fn test_max_min_element() {
let v: Vec<f64> = vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0];
assert!((v.max_element() - 9.0).abs() < 1e-10);
assert!((v.min_element() - 1.0).abs() < 1e-10);
}
#[test]
fn test_weighted_rms_norm() {
let y: Vec<f64> = vec![2.0, 4.0];
let w: Vec<f64> = vec![1.0, 2.0];
assert!((y.weighted_rms_norm(&w) - 2.0).abs() < 1e-10);
}
#[test]
fn test_map_inplace() {
let mut v: Vec<f64> = vec![1.0, 4.0, 9.0];
v.map_inplace(|x| x.sqrt());
assert!((v[0] - 1.0).abs() < 1e-10);
assert!((v[1] - 2.0).abs() < 1e-10);
assert!((v[2] - 3.0).abs() < 1e-10);
}
}