#pragma once
#include "src/arm_common/simd_macro/marm_neon.h"
namespace megdnn {
namespace arm_common {
typedef float32x4_t v4sf; typedef uint32x4_t v4su; typedef int32x4_t v4si;
v4sf log_ps_f32(v4sf x);
v4sf exp_ps_f32(v4sf x);
void sincos_ps_f32(v4sf x, v4sf* ysin, v4sf* ycos);
v4sf sin_ps_f32(v4sf x);
v4sf cos_ps_f32(v4sf x);
v4sf tan_ps_f32(v4sf x);
static inline v4sf div_ps_f32(v4sf& x, v4sf& y) {
#if MEGDNN_AARCH64
return vdivq_f32(x, y);
#else
float32x4_t recp = vrecpeq_f32(y);
recp = vmulq_f32(vrecpsq_f32(y, recp), recp);
return vmulq_f32(x, recp);
#endif
}
#if defined(__ARM_FEATURE_FMA)
#define fma_ps_f32(c, b, a) vfmaq_f32((c), (a), (b))
#else
#define fma_ps_f32(c, b, a) vmlaq_f32((c), (a), (b))
#endif
v4sf sigmoid_ps_f32(v4sf x);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
float16x8_t exp_ps_f16(float16x8_t x);
static inline float16x8_t div_ps_f16(float16x8_t& x, float16x8_t& y) {
#if MEGDNN_AARCH64
return vdivq_f16(x, y);
#else
float16x8_t recp = vrecpeq_f16(y);
recp = vmulq_f16(vrecpsq_f16(y, recp), recp);
return vmulq_f16(x, recp);
#endif
}
float16x8_t sigmoid_ps_f16(float16x8_t x);
#endif
} }