use crate::dtype::{DType, FloatElement};
use crate::error::{Result, TorshError};
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum SpecialValue {
PositiveInfinity,
NegativeInfinity,
QuietNaN,
SignalingNaN,
PositiveZero,
NegativeZero,
Subnormal,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RoundingMode {
ToNearestTiesToEven,
ToNearestTiesAway,
TowardZero,
TowardPositiveInfinity,
TowardNegativeInfinity,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Exception {
InvalidOperation,
DivisionByZero,
Overflow,
Underflow,
Inexact,
}
pub trait IEEE754Float: FloatElement {
fn is_positive_infinity(self) -> bool;
fn is_negative_infinity(self) -> bool;
fn is_infinity(self) -> bool {
self.is_positive_infinity() || self.is_negative_infinity()
}
fn is_nan_value(self) -> bool;
fn is_positive_zero(self) -> bool;
fn is_negative_zero(self) -> bool;
fn is_subnormal(self) -> bool;
fn sign_bit(self) -> bool;
fn positive_infinity() -> Self;
fn negative_infinity() -> Self;
fn quiet_nan() -> Self;
fn positive_zero() -> Self;
fn negative_zero() -> Self;
fn copysign(self, sign: Self) -> Self;
}
impl IEEE754Float for f32 {
fn is_positive_infinity(self) -> bool {
self == f32::INFINITY
}
fn is_negative_infinity(self) -> bool {
self == f32::NEG_INFINITY
}
fn is_nan_value(self) -> bool {
self.is_nan()
}
fn is_positive_zero(self) -> bool {
self == 0.0 && self.is_sign_positive()
}
fn is_negative_zero(self) -> bool {
self == 0.0 && self.is_sign_negative()
}
fn is_subnormal(self) -> bool {
self.is_finite() && self != 0.0 && self.abs() < f32::MIN_POSITIVE
}
fn sign_bit(self) -> bool {
self.is_sign_negative()
}
fn positive_infinity() -> Self {
f32::INFINITY
}
fn negative_infinity() -> Self {
f32::NEG_INFINITY
}
fn quiet_nan() -> Self {
f32::NAN
}
fn positive_zero() -> Self {
0.0_f32
}
fn negative_zero() -> Self {
-0.0_f32
}
fn copysign(self, sign: Self) -> Self {
self.copysign(sign)
}
}
impl IEEE754Float for f64 {
fn is_positive_infinity(self) -> bool {
self == f64::INFINITY
}
fn is_negative_infinity(self) -> bool {
self == f64::NEG_INFINITY
}
fn is_nan_value(self) -> bool {
self.is_nan()
}
fn is_positive_zero(self) -> bool {
self == 0.0 && self.is_sign_positive()
}
fn is_negative_zero(self) -> bool {
self == 0.0 && self.is_sign_negative()
}
fn is_subnormal(self) -> bool {
self.is_finite() && self != 0.0 && self.abs() < f64::MIN_POSITIVE
}
fn sign_bit(self) -> bool {
self.is_sign_negative()
}
fn positive_infinity() -> Self {
f64::INFINITY
}
fn negative_infinity() -> Self {
f64::NEG_INFINITY
}
fn quiet_nan() -> Self {
f64::NAN
}
fn positive_zero() -> Self {
0.0_f64
}
fn negative_zero() -> Self {
-0.0_f64
}
fn copysign(self, sign: Self) -> Self {
self.copysign(sign)
}
}
pub struct ComplianceChecker;
impl ComplianceChecker {
pub fn check_special_values<T: IEEE754Float>() -> Result<()> {
let pos_inf = T::positive_infinity();
let neg_inf = T::negative_infinity();
assert!(pos_inf.is_positive_infinity());
assert!(neg_inf.is_negative_infinity());
assert!(pos_inf.is_infinity());
assert!(neg_inf.is_infinity());
let nan = T::quiet_nan();
assert!(nan.is_nan_value());
assert!(nan != nan);
let pos_zero = T::positive_zero();
let neg_zero = T::negative_zero();
assert!(pos_zero.is_positive_zero());
assert!(neg_zero.is_negative_zero());
assert!(pos_zero == neg_zero);
Ok(())
}
pub fn check_special_arithmetic<T: IEEE754Float>() -> Result<()> {
let pos_inf = T::positive_infinity();
let neg_inf = T::negative_infinity();
let nan = T::quiet_nan();
let one = T::from(1.0).expect("numeric conversion should succeed");
let zero = T::from(0.0).expect("numeric conversion should succeed");
assert!((pos_inf + pos_inf).is_positive_infinity());
assert!((neg_inf + neg_inf).is_negative_infinity());
assert!((pos_inf - pos_inf).is_nan_value());
assert!((pos_inf * one).is_positive_infinity());
assert!((neg_inf * one).is_negative_infinity());
assert!((pos_inf * zero).is_nan_value());
let pos_zero = T::positive_zero();
let neg_zero = T::negative_zero();
assert!((one / pos_zero).is_positive_infinity());
assert!((one / neg_zero).is_negative_infinity());
assert!((zero / zero).is_nan_value());
assert!((nan + one).is_nan_value());
assert!((nan * one).is_nan_value());
assert!((nan / one).is_nan_value());
Ok(())
}
pub fn check_comparisons<T: IEEE754Float>() -> Result<()> {
let pos_inf = T::positive_infinity();
let neg_inf = T::negative_infinity();
let nan = T::quiet_nan();
let one = T::from(1.0).expect("numeric conversion should succeed");
let pos_zero = T::positive_zero();
let neg_zero = T::negative_zero();
assert!(!(nan == nan));
assert!(!(nan < one));
assert!(!(nan > one));
assert!(!(nan <= one));
assert!(!(nan >= one));
assert!(nan != nan);
assert!(pos_inf > one);
assert!(neg_inf < one);
assert!(pos_inf > neg_inf);
assert!(pos_zero == neg_zero);
assert!(!(pos_zero < neg_zero));
assert!(!(pos_zero > neg_zero));
Ok(())
}
pub fn check_sign_operations<T: IEEE754Float>() -> Result<()> {
let one = T::from(1.0).expect("numeric conversion should succeed");
let neg_one = T::from(-1.0).expect("numeric conversion should succeed");
let pos_zero = T::positive_zero();
let neg_zero = T::negative_zero();
assert!(!one.sign_bit());
assert!(neg_one.sign_bit());
assert!(!pos_zero.sign_bit());
assert!(neg_zero.sign_bit());
assert_eq!(IEEE754Float::copysign(one, neg_one), neg_one);
assert_eq!(IEEE754Float::copysign(neg_one, one), one);
assert!(IEEE754Float::copysign(pos_zero, neg_one).is_negative_zero());
assert!(IEEE754Float::copysign(neg_zero, one).is_positive_zero());
Ok(())
}
pub fn check_subnormal_handling<T: IEEE754Float>() -> Result<()> {
let min_positive = if std::mem::size_of::<T>() == 4 {
T::from(f32::MIN_POSITIVE).expect("numeric conversion should succeed")
} else {
T::from(f64::MIN_POSITIVE).expect("numeric conversion should succeed")
};
let two = T::from(2.0).expect("numeric conversion should succeed");
let half_min = min_positive / two;
if half_min != T::from(0.0).expect("numeric conversion should succeed") {
assert!(IEEE754Float::is_subnormal(half_min));
}
Ok(())
}
pub fn run_all_checks<T: IEEE754Float>() -> Result<()> {
Self::check_special_values::<T>()?;
Self::check_special_arithmetic::<T>()?;
Self::check_comparisons::<T>()?;
Self::check_sign_operations::<T>()?;
Self::check_subnormal_handling::<T>()?;
Ok(())
}
}
pub fn is_ieee754_compliant(dtype: DType) -> bool {
matches!(dtype, DType::F16 | DType::F32 | DType::F64)
}
pub fn validate_ieee754_operation(dtype: DType, operation: &str) -> Result<()> {
if !is_ieee754_compliant(dtype) {
return Err(TorshError::InvalidArgument(format!(
"Operation '{}' requires IEEE 754 compliant floating-point type, got {:?}",
operation, dtype
)));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_f32_special_values() {
assert!(ComplianceChecker::check_special_values::<f32>().is_ok());
}
#[test]
fn test_f64_special_values() {
assert!(ComplianceChecker::check_special_values::<f64>().is_ok());
}
#[test]
fn test_f32_special_arithmetic() {
assert!(ComplianceChecker::check_special_arithmetic::<f32>().is_ok());
}
#[test]
fn test_f64_special_arithmetic() {
assert!(ComplianceChecker::check_special_arithmetic::<f64>().is_ok());
}
#[test]
fn test_f32_comparisons() {
assert!(ComplianceChecker::check_comparisons::<f32>().is_ok());
}
#[test]
fn test_f64_comparisons() {
assert!(ComplianceChecker::check_comparisons::<f64>().is_ok());
}
#[test]
fn test_f32_sign_operations() {
assert!(ComplianceChecker::check_sign_operations::<f32>().is_ok());
}
#[test]
fn test_f64_sign_operations() {
assert!(ComplianceChecker::check_sign_operations::<f64>().is_ok());
}
#[test]
fn test_f32_subnormal_handling() {
assert!(ComplianceChecker::check_subnormal_handling::<f32>().is_ok());
}
#[test]
fn test_f64_subnormal_handling() {
assert!(ComplianceChecker::check_subnormal_handling::<f64>().is_ok());
}
#[test]
fn test_f32_full_compliance() {
assert!(ComplianceChecker::run_all_checks::<f32>().is_ok());
}
#[test]
fn test_f64_full_compliance() {
assert!(ComplianceChecker::run_all_checks::<f64>().is_ok());
}
#[test]
fn test_is_ieee754_compliant() {
assert!(is_ieee754_compliant(DType::F32));
assert!(is_ieee754_compliant(DType::F64));
assert!(is_ieee754_compliant(DType::F16));
assert!(!is_ieee754_compliant(DType::I32));
assert!(!is_ieee754_compliant(DType::I64));
assert!(!is_ieee754_compliant(DType::Bool));
}
#[test]
fn test_validate_ieee754_operation() {
assert!(validate_ieee754_operation(DType::F32, "sin").is_ok());
assert!(validate_ieee754_operation(DType::F64, "cos").is_ok());
let result = validate_ieee754_operation(DType::I32, "sin");
assert!(result.is_err());
}
#[test]
fn test_nan_properties() {
let nan = f32::NAN;
assert!(nan != nan);
assert!(!(nan == nan));
assert!(!(nan < 0.0));
assert!(!(nan > 0.0));
assert!(!(nan <= 0.0));
assert!(!(nan >= 0.0));
assert!(nan.is_nan_value());
}
#[test]
fn test_infinity_arithmetic() {
let inf = f32::INFINITY;
let neg_inf = f32::NEG_INFINITY;
assert!((inf + 1.0) == inf);
assert!((inf * 2.0) == inf);
assert!((neg_inf * 2.0) == neg_inf);
assert!((inf * neg_inf).is_negative_infinity());
assert!((inf / inf).is_nan());
}
#[test]
fn test_zero_sign() {
let pos_zero = 0.0_f32;
let neg_zero = -0.0_f32;
assert!(pos_zero == neg_zero);
assert!(pos_zero.is_positive_zero());
assert!(neg_zero.is_negative_zero());
assert!((1.0 / pos_zero).is_positive_infinity());
assert!((1.0 / neg_zero).is_negative_infinity());
}
#[test]
fn test_subnormal_numbers() {
let min_positive = f32::MIN_POSITIVE; let subnormal = min_positive / 2.0;
if subnormal != 0.0 {
assert!(subnormal.is_subnormal());
assert!(subnormal.is_finite());
assert!(subnormal > 0.0);
assert!(subnormal < min_positive);
}
}
#[test]
fn test_copysign() {
let x = 3.0_f32;
let y = -5.0_f32;
assert_eq!(x.copysign(y), -3.0);
assert_eq!((-x).copysign(x), 3.0);
let pos_zero = 0.0_f32;
let neg_zero = -0.0_f32;
assert!(pos_zero.copysign(neg_zero).is_negative_zero());
assert!(neg_zero.copysign(pos_zero).is_positive_zero());
}
#[test]
fn test_rounding_mode_enum() {
let mode = RoundingMode::ToNearestTiesToEven;
assert_eq!(mode, RoundingMode::ToNearestTiesToEven);
assert_ne!(mode, RoundingMode::TowardZero);
}
#[test]
fn test_exception_enum() {
let exc = Exception::InvalidOperation;
assert_eq!(exc, Exception::InvalidOperation);
assert_ne!(exc, Exception::DivisionByZero);
}
#[test]
fn test_special_value_enum() {
let val = SpecialValue::PositiveInfinity;
assert_eq!(val, SpecialValue::PositiveInfinity);
assert_ne!(val, SpecialValue::NegativeInfinity);
}
}