tract-linalg 0.3.2

Tiny, no-nonsense, self contained, TensorFlow and ONNX inference
Documentation
use std::{fmt, ops};

#[allow(non_camel_case_types)]
#[derive(Copy, Clone, Default, PartialEq, PartialOrd, Debug)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
pub struct f16(pub half::f16);

macro_rules! binary_f16 {
    ($f:ident) => {
        fn $f(self, other:f16) -> f16 {
            (self.0).to_f32().$f((other.0).to_f32()).into()
        }
    }
}

macro_rules! unary_as_f32 {
    ($f:ident) => {
        fn $f(self) -> f16 {
            (self.0).to_f32().$f().into()
        }
    };
}

macro_rules! unary_f16 {
    ($f:ident, $t:ty) => {
        fn $f(self) -> $t {
            (self.0).$f()
        }
    }
}

macro_rules! const_f16 {
    ($f:ident, $c:ident) => {
        fn $f() -> f16 {
            f16(half::consts::$c)
        }
    };
}

#[allow(deprecated)]
impl num_traits::Float for f16 {
    unary_as_f32!(floor);
    unary_as_f32!(ceil);
    unary_as_f32!(round);
    unary_as_f32!(trunc);
    unary_as_f32!(fract);
    unary_as_f32!(abs);
    unary_as_f32!(recip);
    unary_as_f32!(sqrt);
    unary_as_f32!(exp);
    unary_as_f32!(exp2);
    unary_as_f32!(ln);
    unary_as_f32!(log2);
    unary_as_f32!(log10);
    unary_as_f32!(cbrt);
    unary_as_f32!(sin);
    unary_as_f32!(cos);
    unary_as_f32!(tan);
    unary_as_f32!(sinh);
    unary_as_f32!(cosh);
    unary_as_f32!(tanh);
    unary_as_f32!(asin);
    unary_as_f32!(acos);
    unary_as_f32!(atan);
    unary_as_f32!(asinh);
    unary_as_f32!(acosh);
    unary_as_f32!(atanh);
    unary_as_f32!(exp_m1);
    unary_as_f32!(ln_1p);
    unary_f16!(classify, ::std::num::FpCategory);
    unary_f16!(is_nan, bool);
    unary_f16!(is_infinite, bool);
    unary_f16!(is_finite, bool);
    unary_f16!(is_normal, bool);
    unary_f16!(is_sign_positive, bool);
    unary_f16!(is_sign_negative, bool);
    binary_f16!(powf);
    binary_f16!(log);
    binary_f16!(max);
    binary_f16!(min);
    binary_f16!(abs_sub);
    binary_f16!(hypot);
    binary_f16!(atan2);
    const_f16!(nan, NAN);
    const_f16!(infinity, INFINITY);
    const_f16!(neg_infinity, NEG_INFINITY);
    const_f16!(neg_zero, NEG_ZERO);
    const_f16!(max_value, MAX);
    const_f16!(min_value, MIN);
    const_f16!(min_positive_value, MIN_POSITIVE);
    fn signum(self) -> f16 {
        f16(self.0.signum())
    }
    fn mul_add(self, a: f16, b: f16) -> f16 {
        (self.0).to_f32().mul_add((a.0).to_f32(), (b.0).to_f32()).into()
    }
    fn powi(self, i: i32) -> f16 {
        (self.0).to_f32().powi(i).into()
    }
    fn sin_cos(self) -> (f16, f16) {
        let (s, c) = (self.0).to_f32().sin_cos();
        (s.into(), c.into())
    }
    fn integer_decode(self) -> (u64, i16, i8) {
        (self.0).to_f32().integer_decode()
    }
}

impl num_traits::Num for f16 {
    type FromStrRadixErr = <f32 as num_traits::Num>::FromStrRadixErr;
    fn from_str_radix(str: &str, radix: u32) -> Result<Self, Self::FromStrRadixErr> {
        f32::from_str_radix(str, radix).map(|it| it.into())
    }
}

impl num_traits::Zero for f16 {
    fn is_zero(&self) -> bool {
        f32::from(self.0).is_zero()
    }
    fn zero() -> f16 {
        0.0f32.into()
    }
}

impl num_traits::One for f16 {
    fn one() -> f16 {
        1.0f32.into()
    }
}

impl num_traits::ToPrimitive for f16 {
    fn to_i64(&self) -> Option<i64> {
        f32::from(self.0).to_i64()
    }
    fn to_u64(&self) -> Option<u64> {
        f32::from(self.0).to_u64()
    }
}

impl num_traits::AsPrimitive<usize> for f16 {
    fn as_(self) -> usize {
        self.0.to_f32() as usize
    }
}

impl num_traits::AsPrimitive<f32> for f16 {
    fn as_(self) -> f32 {
        self.0.to_f32()
    }
}

impl num_traits::AsPrimitive<f16> for f32 {
    fn as_(self) -> f16 {
        f16(half::f16::from_f32(self))
    }
}

impl num_traits::AsPrimitive<f64> for f16 {
    fn as_(self) -> f64 {
        self.0.to_f64()
    }
}

impl num_traits::AsPrimitive<f16> for f64 {
    fn as_(self) -> f16 {
        f16(half::f16::from_f64(self))
    }
}

impl num_traits::NumCast for f16 {
    fn from<T: num_traits::ToPrimitive>(n: T) -> Option<Self> {
        n.to_f32().map(|f| f16(half::f16::from_f32(f)))
    }
}

impl num_traits::Bounded for f16 {
    fn min_value() -> f16 {
        f16(half::consts::MIN)
    }
    fn max_value() -> f16 {
        f16(half::consts::MAX)
    }
}

impl ops::Neg for f16 {
    type Output = f16;
    fn neg(self) -> f16 {
        self.0.to_f32().neg().into()
    }
}

impl From<f32> for f16 {
    fn from(f: f32) -> f16 {
        f16(half::f16::from_f32(f))
    }
}

impl fmt::Display for f16 {
    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
        self.0.fmt(fmt)
    }
}

impl num_traits::AsPrimitive<f16> for usize {
    fn as_(self) -> f16 {
        f16(half::f16::from_f64(self as f64))
    }
}

impl ops::Add<f16> for f16 {
    type Output = f16;
    fn add(self, other: f16) -> f16 {
        (self.0.to_f32() + other.0.to_f32()).into()
    }
}

impl ops::Sub<f16> for f16 {
    type Output = f16;
    fn sub(self, other: f16) -> f16 {
        (self.0.to_f32() - other.0.to_f32()).into()
    }
}

impl ops::Mul<f16> for f16 {
    type Output = f16;
    fn mul(self, other: f16) -> f16 {
        (self.0.to_f32() * other.0.to_f32()).into()
    }
}

impl ops::Div<f16> for f16 {
    type Output = f16;
    fn div(self, other: f16) -> f16 {
        (self.0.to_f32() / other.0.to_f32()).into()
    }
}

impl ops::Rem<f16> for f16 {
    type Output = f16;
    fn rem(self, other: f16) -> f16 {
        (self.0.to_f32() % other.0.to_f32()).into()
    }
}

impl std::iter::Sum for f16 {
    fn sum<I>(iter: I) -> Self
    where
        I: Iterator<Item = f16>,
    {
        iter.fold(0.0f32, |acc, i| acc + i.0.to_f32()).into()
    }
}

impl<'a> std::iter::Sum<&'a f16> for f16 {
    fn sum<I>(iter: I) -> Self
    where
        I: Iterator<Item = &'a f16>,
    {
        iter.fold(0.0f32, |acc, i| acc + i.0.to_f32()).into()
    }
}

impl std::str::FromStr for f16 {
    type Err = std::num::ParseFloatError;
    fn from_str(s: &str) -> Result<f16, Self::Err> {
        s.parse::<f32>().map(|f| f.into())
    }
}