use core::fmt::{Debug, Display};
use core::iter::Sum;
use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
pub trait Scalar:
Copy
+ Clone
+ Debug
+ Display
+ PartialOrd
+ Add<Output = Self>
+ Sub<Output = Self>
+ Mul<Output = Self>
+ Div<Output = Self>
+ Neg<Output = Self>
+ AddAssign
+ SubAssign
+ MulAssign
+ DivAssign
+ Sum
+ Send
+ Sync
+ 'static
{
const ZERO: Self;
const ONE: Self;
const TWO: Self;
const HALF: Self;
const EPSILON: Self;
const INFINITY: Self;
const NEG_INFINITY: Self;
const NAN: Self;
const PI: Self;
const E: Self;
const SQRT_2: Self;
const LN_2: Self;
fn from_f64(x: f64) -> Self;
fn from_f32(x: f32) -> Self;
fn from_i32(x: i32) -> Self;
fn from_usize(n: usize) -> Self;
fn to_f64(self) -> f64;
fn to_f32(self) -> f32;
fn abs(self) -> Self;
fn sqrt(self) -> Self;
fn cbrt(self) -> Self;
#[inline]
fn sq(self) -> Self {
self * self
}
fn powi(self, n: i32) -> Self;
fn powf(self, n: Self) -> Self;
#[inline]
fn recip(self) -> Self {
Self::ONE / self
}
fn hypot(self, other: Self) -> Self;
fn sin(self) -> Self;
fn cos(self) -> Self;
fn tan(self) -> Self;
fn asin(self) -> Self;
fn acos(self) -> Self;
fn atan(self) -> Self;
fn atan2(self, other: Self) -> Self;
#[inline]
fn sincos(self) -> (Self, Self) {
(self.sin(), self.cos())
}
fn exp(self) -> Self;
fn exp2(self) -> Self;
fn exp_m1(self) -> Self;
fn ln(self) -> Self;
fn log2(self) -> Self;
fn log10(self) -> Self;
fn ln_1p(self) -> Self;
fn sinh(self) -> Self;
fn cosh(self) -> Self;
fn tanh(self) -> Self;
fn asinh(self) -> Self;
fn acosh(self) -> Self;
fn atanh(self) -> Self;
fn max(self, other: Self) -> Self;
fn min(self, other: Self) -> Self;
fn clamp(self, min: Self, max: Self) -> Self;
fn copysign(self, sign: Self) -> Self;
fn is_finite(self) -> bool;
fn is_nan(self) -> bool;
fn is_infinite(self) -> bool;
fn is_sign_positive(self) -> bool;
fn is_sign_negative(self) -> bool;
fn floor(self) -> Self;
fn ceil(self) -> Self;
fn round(self) -> Self;
fn trunc(self) -> Self;
fn fract(self) -> Self;
fn gamma_fn(self) -> Self;
fn ln_gamma(self) -> Self;
fn erf_fn(self) -> Self;
fn erfc_fn(self) -> Self;
fn mul_add(self, a: Self, b: Self) -> Self;
#[inline]
fn signum(self) -> Self {
if self > Self::ZERO {
Self::ONE
} else if self < Self::ZERO {
-Self::ONE
} else {
Self::ZERO
}
}
}
impl Scalar for f64 {
const ZERO: Self = 0.0;
const ONE: Self = 1.0;
const TWO: Self = 2.0;
const HALF: Self = 0.5;
const EPSILON: Self = f64::EPSILON;
const INFINITY: Self = f64::INFINITY;
const NEG_INFINITY: Self = f64::NEG_INFINITY;
const NAN: Self = f64::NAN;
const PI: Self = core::f64::consts::PI;
const E: Self = core::f64::consts::E;
const SQRT_2: Self = core::f64::consts::SQRT_2;
const LN_2: Self = core::f64::consts::LN_2;
#[inline]
fn from_f64(x: f64) -> Self {
x
}
#[inline]
fn from_f32(x: f32) -> Self {
x as f64
}
#[inline]
fn from_i32(x: i32) -> Self {
x as f64
}
#[inline]
fn from_usize(n: usize) -> Self {
n as f64
}
#[inline]
fn to_f64(self) -> f64 {
self
}
#[inline]
fn to_f32(self) -> f32 {
self as f32
}
#[inline]
fn abs(self) -> Self {
libm::fabs(self)
}
#[inline]
fn sqrt(self) -> Self {
libm::sqrt(self)
}
#[inline]
fn cbrt(self) -> Self {
libm::cbrt(self)
}
#[inline]
fn powi(self, n: i32) -> Self {
libm::pow(self, n as f64)
}
#[inline]
fn powf(self, n: Self) -> Self {
libm::pow(self, n)
}
#[inline]
fn hypot(self, other: Self) -> Self {
libm::hypot(self, other)
}
#[inline]
fn sin(self) -> Self {
libm::sin(self)
}
#[inline]
fn cos(self) -> Self {
libm::cos(self)
}
#[inline]
fn tan(self) -> Self {
libm::tan(self)
}
#[inline]
fn asin(self) -> Self {
libm::asin(self)
}
#[inline]
fn acos(self) -> Self {
libm::acos(self)
}
#[inline]
fn atan(self) -> Self {
libm::atan(self)
}
#[inline]
fn atan2(self, other: Self) -> Self {
libm::atan2(self, other)
}
#[inline]
fn sincos(self) -> (Self, Self) {
libm::sincos(self)
}
#[inline]
fn exp(self) -> Self {
libm::exp(self)
}
#[inline]
fn exp2(self) -> Self {
libm::exp2(self)
}
#[inline]
fn exp_m1(self) -> Self {
libm::expm1(self)
}
#[inline]
fn ln(self) -> Self {
libm::log(self)
}
#[inline]
fn log2(self) -> Self {
libm::log2(self)
}
#[inline]
fn log10(self) -> Self {
libm::log10(self)
}
#[inline]
fn ln_1p(self) -> Self {
libm::log1p(self)
}
#[inline]
fn sinh(self) -> Self {
libm::sinh(self)
}
#[inline]
fn cosh(self) -> Self {
libm::cosh(self)
}
#[inline]
fn tanh(self) -> Self {
libm::tanh(self)
}
#[inline]
fn asinh(self) -> Self {
libm::asinh(self)
}
#[inline]
fn acosh(self) -> Self {
libm::acosh(self)
}
#[inline]
fn atanh(self) -> Self {
libm::atanh(self)
}
#[inline]
fn max(self, other: Self) -> Self {
libm::fmax(self, other)
}
#[inline]
fn min(self, other: Self) -> Self {
libm::fmin(self, other)
}
#[inline]
fn clamp(self, min: Self, max: Self) -> Self {
libm::fmax(min, libm::fmin(self, max))
}
#[inline]
fn copysign(self, sign: Self) -> Self {
libm::copysign(self, sign)
}
#[inline]
fn is_finite(self) -> bool {
self.is_finite()
}
#[inline]
fn is_nan(self) -> bool {
self.is_nan()
}
#[inline]
fn is_infinite(self) -> bool {
self.is_infinite()
}
#[inline]
fn is_sign_positive(self) -> bool {
self.is_sign_positive()
}
#[inline]
fn is_sign_negative(self) -> bool {
self.is_sign_negative()
}
#[inline]
fn floor(self) -> Self {
libm::floor(self)
}
#[inline]
fn ceil(self) -> Self {
libm::ceil(self)
}
#[inline]
fn round(self) -> Self {
libm::round(self)
}
#[inline]
fn trunc(self) -> Self {
libm::trunc(self)
}
#[inline]
fn fract(self) -> Self {
self - libm::trunc(self)
}
#[inline]
fn gamma_fn(self) -> Self {
libm::tgamma(self)
}
#[inline]
fn ln_gamma(self) -> Self {
libm::lgamma(self)
}
#[inline]
fn erf_fn(self) -> Self {
libm::erf(self)
}
#[inline]
fn erfc_fn(self) -> Self {
libm::erfc(self)
}
#[inline]
fn mul_add(self, a: Self, b: Self) -> Self {
libm::fma(self, a, b)
}
}
impl Scalar for f32 {
const ZERO: Self = 0.0;
const ONE: Self = 1.0;
const TWO: Self = 2.0;
const HALF: Self = 0.5;
const EPSILON: Self = f32::EPSILON;
const INFINITY: Self = f32::INFINITY;
const NEG_INFINITY: Self = f32::NEG_INFINITY;
const NAN: Self = f32::NAN;
const PI: Self = core::f32::consts::PI;
const E: Self = core::f32::consts::E;
const SQRT_2: Self = core::f32::consts::SQRT_2;
const LN_2: Self = core::f32::consts::LN_2;
#[inline]
fn from_f64(x: f64) -> Self {
x as f32
}
#[inline]
fn from_f32(x: f32) -> Self {
x
}
#[inline]
fn from_i32(x: i32) -> Self {
x as f32
}
#[inline]
fn from_usize(n: usize) -> Self {
n as f32
}
#[inline]
fn to_f64(self) -> f64 {
self as f64
}
#[inline]
fn to_f32(self) -> f32 {
self
}
#[inline]
fn abs(self) -> Self {
libm::fabsf(self)
}
#[inline]
fn sqrt(self) -> Self {
libm::sqrtf(self)
}
#[inline]
fn cbrt(self) -> Self {
libm::cbrtf(self)
}
#[inline]
fn powi(self, n: i32) -> Self {
libm::powf(self, n as f32)
}
#[inline]
fn powf(self, n: Self) -> Self {
libm::powf(self, n)
}
#[inline]
fn hypot(self, other: Self) -> Self {
libm::hypotf(self, other)
}
#[inline]
fn sin(self) -> Self {
libm::sinf(self)
}
#[inline]
fn cos(self) -> Self {
libm::cosf(self)
}
#[inline]
fn tan(self) -> Self {
libm::tanf(self)
}
#[inline]
fn asin(self) -> Self {
libm::asinf(self)
}
#[inline]
fn acos(self) -> Self {
libm::acosf(self)
}
#[inline]
fn atan(self) -> Self {
libm::atanf(self)
}
#[inline]
fn atan2(self, other: Self) -> Self {
libm::atan2f(self, other)
}
#[inline]
fn sincos(self) -> (Self, Self) {
libm::sincosf(self)
}
#[inline]
fn exp(self) -> Self {
libm::expf(self)
}
#[inline]
fn exp2(self) -> Self {
libm::exp2f(self)
}
#[inline]
fn exp_m1(self) -> Self {
libm::expm1f(self)
}
#[inline]
fn ln(self) -> Self {
libm::logf(self)
}
#[inline]
fn log2(self) -> Self {
libm::log2f(self)
}
#[inline]
fn log10(self) -> Self {
libm::log10f(self)
}
#[inline]
fn ln_1p(self) -> Self {
libm::log1pf(self)
}
#[inline]
fn sinh(self) -> Self {
libm::sinhf(self)
}
#[inline]
fn cosh(self) -> Self {
libm::coshf(self)
}
#[inline]
fn tanh(self) -> Self {
libm::tanhf(self)
}
#[inline]
fn asinh(self) -> Self {
libm::asinhf(self)
}
#[inline]
fn acosh(self) -> Self {
libm::acoshf(self)
}
#[inline]
fn atanh(self) -> Self {
libm::atanhf(self)
}
#[inline]
fn max(self, other: Self) -> Self {
libm::fmaxf(self, other)
}
#[inline]
fn min(self, other: Self) -> Self {
libm::fminf(self, other)
}
#[inline]
fn clamp(self, min: Self, max: Self) -> Self {
libm::fmaxf(min, libm::fminf(self, max))
}
#[inline]
fn copysign(self, sign: Self) -> Self {
libm::copysignf(self, sign)
}
#[inline]
fn is_finite(self) -> bool {
self.is_finite()
}
#[inline]
fn is_nan(self) -> bool {
self.is_nan()
}
#[inline]
fn is_infinite(self) -> bool {
self.is_infinite()
}
#[inline]
fn is_sign_positive(self) -> bool {
self.is_sign_positive()
}
#[inline]
fn is_sign_negative(self) -> bool {
self.is_sign_negative()
}
#[inline]
fn floor(self) -> Self {
libm::floorf(self)
}
#[inline]
fn ceil(self) -> Self {
libm::ceilf(self)
}
#[inline]
fn round(self) -> Self {
libm::roundf(self)
}
#[inline]
fn trunc(self) -> Self {
libm::truncf(self)
}
#[inline]
fn fract(self) -> Self {
self - libm::truncf(self)
}
#[inline]
fn gamma_fn(self) -> Self {
libm::tgammaf(self)
}
#[inline]
fn ln_gamma(self) -> Self {
libm::lgammaf(self)
}
#[inline]
fn erf_fn(self) -> Self {
libm::erff(self)
}
#[inline]
fn erfc_fn(self) -> Self {
libm::erfcf(self)
}
#[inline]
fn mul_add(self, a: Self, b: Self) -> Self {
libm::fmaf(self, a, b)
}
}
pub fn to_f64_vec<S: Scalar>(v: &[S]) -> Vec<f64> {
v.iter().map(|x| x.to_f64()).collect()
}
pub fn from_f64_vec<S: Scalar>(v: &[f64]) -> Vec<S> {
v.iter().map(|&x| S::from_f64(x)).collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[allow(clippy::assertions_on_constants)]
fn test_constants_f64() {
assert_eq!(f64::ZERO, 0.0);
assert_eq!(f64::ONE, 1.0);
assert_eq!(f64::TWO, 2.0);
assert!(f64::EPSILON > 0.0);
assert!(f64::EPSILON < 1e-10);
}
#[test]
fn test_basic_ops_f64() {
let x: f64 = 4.0;
assert!((x.sqrt() - 2.0).abs() < 1e-10);
assert!((x.sq() - 16.0).abs() < 1e-10);
}
#[test]
fn test_trig_f64() {
let x: f64 = f64::PI / 4.0;
let (s, c) = x.sincos();
assert!((s - c).abs() < 1e-10); }
#[test]
fn test_special_functions_f64() {
let g: f64 = 5.0_f64.gamma_fn();
assert!((g - 24.0).abs() < 1e-10);
let e: f64 = 0.0_f64.erf_fn();
assert!(e.abs() < 1e-10);
}
#[test]
fn test_to_f64_vec_identity() {
let v = vec![1.0_f64, 2.5, 3.7];
let converted = to_f64_vec(&v);
assert_eq!(v, converted);
}
#[test]
fn test_from_f64_vec_identity() {
let v = vec![1.0, 2.5, 3.7];
let converted: Vec<f64> = from_f64_vec(&v);
assert_eq!(v, converted);
}
#[test]
fn test_f32_roundtrip() {
let orig = vec![1.0_f32, 2.5, 3.7];
let f64_vec = to_f64_vec(&orig);
let back: Vec<f32> = from_f64_vec(&f64_vec);
for (a, b) in orig.iter().zip(back.iter()) {
assert!((a - b).abs() < 1e-6);
}
}
}