#define EIGEN_STRONG_INLINE static inline
#define EIGEN_DEVICE_FUNC
#ifndef EIGEN_HALF_CUDA_H
#define EIGEN_HALF_CUDA_H
namespace Eigen {
namespace half_impl {
struct __half {
EIGEN_DEVICE_FUNC __half() : x(0) {}
explicit EIGEN_DEVICE_FUNC __half(unsigned short raw) : x(raw) {}
unsigned short x;
};
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __half raw_uint16_to_half(unsigned short x);
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __half float_to_half_rtne(float ff);
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float half_to_float(__half h);
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __half raw_uint16_to_half(unsigned short x) {
__half h;
h.x = x;
return h;
}
union FP32 {
unsigned int u;
float f;
};
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __half float_to_half_rtne(float ff) {
#if defined(EIGEN_HAS_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
return __float2half(ff);
#elif defined(EIGEN_HAS_FP16_C)
__half h;
h.x = _cvtss_sh(ff, 0);
return h;
#else
FP32 f; f.f = ff;
const FP32 f32infty = { 255 << 23 };
const FP32 f16max = { (127 + 16) << 23 };
const FP32 denorm_magic = { ((127 - 15) + (23 - 10) + 1) << 23 };
unsigned int sign_mask = 0x80000000u;
__half o;
o.x = static_cast<unsigned short>(0x0u);
unsigned int sign = f.u & sign_mask;
f.u ^= sign;
if (f.u >= f16max.u) { o.x = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; } else { if (f.u < (113 << 23)) { f.f += denorm_magic.f;
o.x = static_cast<unsigned short>(f.u - denorm_magic.u);
} else {
unsigned int mant_odd = (f.u >> 13) & 1;
f.u += ((unsigned int)(15 - 127) << 23) + 0xfff;
f.u += mant_odd;
o.x = static_cast<unsigned short>(f.u >> 13);
}
}
o.x |= static_cast<unsigned short>(sign >> 16);
return o;
#endif
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float half_to_float(__half h) {
#if defined(EIGEN_HAS_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
return __half2float(h);
#elif defined(EIGEN_HAS_FP16_C)
return _cvtsh_ss(h.x);
#else
const FP32 magic = { 113 << 23 };
const unsigned int shifted_exp = 0x7c00 << 13; FP32 o;
o.u = (h.x & 0x7fff) << 13; unsigned int exp = shifted_exp & o.u; o.u += (127 - 15) << 23;
if (exp == shifted_exp) { o.u += (128 - 16) << 23; } else if (exp == 0) { o.u += 1 << 23; o.f -= magic.f; }
o.u |= (h.x & 0x8000) << 16; return o.f;
#endif
}
}
}
#endif