use core::f32::consts::PI;
use crate::dtype::DType;
use crate::scalar::{Scalar, Float};
#[cfg(feature = "half")]
use half::{bf16, f16};
#[cfg(feature = "complex")]
use num_complex::Complex;
impl Scalar for f32 {
#[cfg(feature = "half")]
fn from_bf16(t: bf16) -> Self {
t.into()
}
#[cfg(feature = "half")]
fn from_f16(t: f16) -> Self {
t.into()
}
fn from_f32(t: f32) -> Self {
t
}
fn from_f64(t: f64) -> Self {
t as f32
}
#[cfg(feature = "complex")]
fn from_cf32(t: Complex<f32>) -> Self {
t.re
}
#[cfg(feature = "complex")]
fn from_cf64(t: Complex<f64>) -> Self {
t.re as f32
}
fn from_u8(t: u8) -> Self {
t as f32
}
fn from_i8(t: i8) -> Self {
t as f32
}
fn from_i16(t: i16) -> Self {
t as f32
}
fn from_i32(t: i32) -> Self {
t as f32
}
fn from_i64(t: i64) -> Self {
t as f32
}
fn from_bool(t: bool) -> Self {
t as i32 as f32
}
fn from_le_bytes(bytes: &[u8]) -> Self {
f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]])
}
fn dtype() -> DType {
DType::F32
}
fn zero() -> Self {
0.
}
fn one() -> Self {
1.
}
fn byte_size() -> usize {
4
}
fn abs(self) -> Self {
self.abs()
}
fn reciprocal(self) -> Self {
1.0 / self
}
fn neg(self) -> Self {
-self
}
fn relu(self) -> Self {
self.max(0.)
}
fn sin(self) -> Self {
let b = 4f32 / PI;
let c = -4f32 / (PI * PI);
return -(b * self + c * self * if self < 0. { -self } else { self });
}
fn floor(self) -> Self {
let i = self as i32 as f32;
return i - (i > self) as i32 as f32;
}
fn cos(self) -> Self {
let mut x = self;
x *= 1. / (2. * PI);
x -= 0.25 + (x + 0.25).floor();
x *= 16.0 * (x.abs() - 0.5);
return x;
}
fn sqrt(self) -> Self {
if self >= 0. {
Self::from_bits((self.to_bits() + 0x3f80_0000) >> 1)
} else {
Self::NAN
}
}
fn add(self, rhs: Self) -> Self {
self + rhs
}
fn sub(self, rhs: Self) -> Self {
self - rhs
}
fn mul(self, rhs: Self) -> Self {
self * rhs
}
fn div(self, rhs: Self) -> Self {
self / rhs
}
fn pow(self, rhs: Self) -> Self {
self.powf(rhs)
}
fn cmplt(self, rhs: Self) -> Self {
(self < rhs) as i32 as f32
}
fn max(self, rhs: Self) -> Self {
f32::max(self, rhs)
}
fn max_value() -> Self {
f32::MAX
}
fn min_value() -> Self {
f32::MIN
}
fn epsilon() -> Self {
0.00001
}
fn is_equal(self, rhs: Self) -> bool {
(self == -f32::INFINITY && rhs == -f32::INFINITY)
|| (self - rhs).abs() < Self::epsilon()
|| (self - rhs).abs() < self.abs() * 0.01
}
fn exp2(self) -> Self {
self.exp2()
}
fn log2(self) -> Self {
self.log2()
}
fn inv(self) -> Self {
1./self
}
fn not(self) -> Self {
if self != 0. { 0. } else { 1. }
}
fn nonzero(self) -> Self {
(self != 0.) as i32 as f32
}
fn cmpgt(self, rhs: Self) -> Self {
(self > rhs) as i32 as f32
}
fn or(self, rhs: Self) -> Self {
(self != 0. || rhs != 0.) as i32 as f32
}
}
impl Float for f32 {}