#![forbid(unsafe_code)]
use core::hint::cold_path;
use crate::LaError;
#[must_use]
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct Vector<const D: usize> {
data: [f64; D],
}
impl<const D: usize> Vector<D> {
#[cfg(test)]
#[inline]
pub(crate) const fn new(data: [f64; D]) -> Self {
match Self::try_new(data) {
Ok(vector) => vector,
Err(_) => panic!("Vector::new requires finite entries"),
}
}
#[inline]
pub const fn try_new(data: [f64; D]) -> Result<Self, LaError> {
if let Some(index) = Self::first_non_finite_entry_in(&data) {
Err(LaError::non_finite_at(index))
} else {
Ok(Self::new_unchecked(data))
}
}
#[inline]
pub(crate) const fn new_unchecked(data: [f64; D]) -> Self {
Self { data }
}
const fn first_non_finite_entry_in(data: &[f64; D]) -> Option<usize> {
let mut i = 0;
while i < D {
if !data[i].is_finite() {
return Some(i);
}
i += 1;
}
None
}
#[inline]
pub const fn zero() -> Self {
Self::new_unchecked([0.0; D])
}
#[inline]
#[must_use]
pub const fn as_array(&self) -> &[f64; D] {
&self.data
}
#[inline]
#[must_use]
pub const fn into_array(self) -> [f64; D] {
self.data
}
#[inline]
pub const fn dot(self, other: Self) -> Result<f64, LaError> {
let lhs = self.as_array();
let rhs = other.as_array();
let mut acc = 0.0;
let mut i = 0;
while i < D {
acc = lhs[i].mul_add(rhs[i], acc);
if !acc.is_finite() {
cold_path();
return Err(LaError::non_finite_at(i));
}
i += 1;
}
Ok(acc)
}
#[inline]
pub const fn norm2_sq(self) -> Result<f64, LaError> {
self.dot(self)
}
}
impl<const D: usize> Default for Vector<D> {
#[inline]
fn default() -> Self {
Self::zero()
}
}
#[cfg(test)]
mod tests {
use super::*;
use core::hint::black_box;
use approx::assert_abs_diff_eq;
use pastey::paste;
macro_rules! gen_public_api_vector_tests {
($d:literal) => {
paste! {
#[test]
fn [<public_api_vector_new_as_array_into_array_ $d d>]() {
let arr = {
let mut arr = [0.0f64; $d];
let values = [1.0f64, 2.0, 3.0, 4.0, 5.0];
for (dst, src) in arr.iter_mut().zip(values.iter()) {
*dst = *src;
}
arr
};
let v = Vector::<$d>::new(arr);
for i in 0..$d {
assert_abs_diff_eq!(v.as_array()[i], arr[i], epsilon = 0.0);
}
let out = v.into_array();
for i in 0..$d {
assert_abs_diff_eq!(out[i], arr[i], epsilon = 0.0);
}
}
#[test]
fn [<public_api_vector_zero_as_array_into_array_default_ $d d>]() {
let z = Vector::<$d>::zero();
for &x in z.as_array() {
assert_abs_diff_eq!(x, 0.0, epsilon = 0.0);
}
for x in z.into_array() {
assert_abs_diff_eq!(x, 0.0, epsilon = 0.0);
}
let d = Vector::<$d>::default();
for x in d.into_array() {
assert_abs_diff_eq!(x, 0.0, epsilon = 0.0);
}
}
#[test]
fn [<public_api_vector_dot_and_norm2_sq_ $d d>]() {
let a_arr = {
let mut arr = [0.0f64; $d];
let values = [1.0f64, 2.0, 3.0, 4.0, 5.0];
for (dst, src) in arr.iter_mut().zip(values.iter()) {
*dst = black_box(*src);
}
arr
};
let b_arr = {
let mut arr = [0.0f64; $d];
let values = [-2.0f64, 0.5, 4.0, -1.0, 2.0];
for (dst, src) in arr.iter_mut().zip(values.iter()) {
*dst = black_box(*src);
}
arr
};
let expected_dot = {
let mut acc = 0.0;
let mut i = 0;
while i < $d {
acc = a_arr[i].mul_add(b_arr[i], acc);
i += 1;
}
acc
};
let expected_norm2_sq = {
let mut acc = 0.0;
let mut i = 0;
while i < $d {
acc = a_arr[i].mul_add(a_arr[i], acc);
i += 1;
}
acc
};
let a = Vector::<$d>::new(black_box(a_arr));
let b = Vector::<$d>::new(black_box(b_arr));
let dot_fn: fn(Vector<$d>, Vector<$d>) -> Result<f64, LaError> =
black_box(Vector::<$d>::dot);
let norm2_sq_fn: fn(Vector<$d>) -> Result<f64, LaError> =
black_box(Vector::<$d>::norm2_sq);
assert_abs_diff_eq!(
dot_fn(black_box(a), black_box(b)).unwrap(),
expected_dot,
epsilon = 1e-14
);
assert_abs_diff_eq!(
norm2_sq_fn(black_box(a)).unwrap(),
expected_norm2_sq,
epsilon = 1e-14
);
}
#[test]
fn [<public_api_vector_try_new_rejects_nonfinite_ $d d>]() {
let mut a_arr = [1.0f64; $d];
a_arr[$d - 1] = f64::NAN;
assert_eq!(
Vector::<$d>::try_new(a_arr),
Err(LaError::NonFinite {
row: None,
col: $d - 1,
})
);
}
#[test]
fn [<public_api_vector_try_new_rejects_nonfinite_rhs_fixture_ $d d>]() {
let mut b_arr = [1.0f64; $d];
b_arr[0] = f64::INFINITY;
assert_eq!(
Vector::<$d>::try_new(b_arr),
Err(LaError::NonFinite { row: None, col: 0 })
);
}
#[test]
fn [<public_api_vector_dot_and_norm2_sq_reject_overflow_ $d d>]() {
let mut a_arr = [1.0f64; $d];
a_arr[0] = f64::MAX;
let a = Vector::<$d>::new(a_arr);
let mut b_arr = [1.0f64; $d];
b_arr[0] = 2.0;
let b = Vector::<$d>::new(b_arr);
assert_eq!(a.dot(b), Err(LaError::NonFinite { row: None, col: 0 }));
assert_eq!(a.norm2_sq(), Err(LaError::NonFinite { row: None, col: 0 }));
}
}
};
}
gen_public_api_vector_tests!(2);
gen_public_api_vector_tests!(3);
gen_public_api_vector_tests!(4);
gen_public_api_vector_tests!(5);
}