#ifndef HVX_SQRT_H
#define HVX_SQRT_H
#include <stdbool.h>
#include <stdint.h>
#include "hex-utils.h"
#include "hvx-base.h"
#define RSQRT_CONST 0x5f3759df
#define RSQRT_ONE_HALF 0x3f000000
#define RSQRT_THREE_HALVES 0x3fc00000
#if __HVX_ARCH__ < 79
#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
#else
#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
#endif
static inline HVX_Vector hvx_vec_rsqrt_f32(HVX_Vector in_vec) {
HVX_Vector rsqrtconst = Q6_V_vsplat_R(RSQRT_CONST);
HVX_Vector onehalf = Q6_V_vsplat_R(RSQRT_ONE_HALF);
HVX_Vector threehalfs = Q6_V_vsplat_R(RSQRT_THREE_HALVES);
HVX_Vector x2, y, ypower2, temp;
x2 = Q6_Vqf32_vmpy_VsfVsf(in_vec, onehalf);
x2 = Q6_Vqf32_vadd_Vqf32Vsf(x2, Q6_V_vzero());
y = Q6_Vw_vasr_VwR(in_vec, 1);
y = Q6_Vw_vsub_VwVw(rsqrtconst, y);
ypower2 = Q6_Vqf32_vmpy_VsfVsf(y, y);
ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero());
temp = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2);
temp = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp));
temp = Q6_Vqf32_vmpy_VsfVsf(y, Q6_Vsf_equals_Vqf32(temp));
y = Q6_Vqf32_vadd_Vqf32Vsf(temp, Q6_V_vzero());
ypower2 = Q6_Vqf32_vmpy_Vqf32Vqf32(y, y);
ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero());
temp = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2);
temp = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp));
temp = Q6_Vqf32_vmpy_Vqf32Vqf32(y, temp);
y = Q6_Vqf32_vadd_Vqf32Vsf(temp, Q6_V_vzero());
ypower2 = Q6_Vqf32_vmpy_Vqf32Vqf32(y, y);
ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero());
temp = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2);
temp = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp));
temp = Q6_Vqf32_vmpy_Vqf32Vqf32(y, temp);
return Q6_Vsf_equals_Vqf32(temp);
}
#define hvx_sqrt_f32_loop_body(dst_type, src_type, vec_store) \
do { \
dst_type * restrict vdst = (dst_type *) dst; \
src_type * restrict vsrc = (src_type *) src; \
\
const uint32_t nvec = n / VLEN_FP32; \
const uint32_t nloe = n % VLEN_FP32; \
\
uint32_t i = 0; \
\
_Pragma("unroll(4)") \
for (; i < nvec; i++) { \
HVX_Vector inv_sqrt = hvx_vec_rsqrt_f32(vsrc[i]); \
HVX_Vector sqrt_res = HVX_OP_MUL(inv_sqrt, vsrc[i]); \
vdst[i] = sqrt_res; \
} \
if (nloe) { \
HVX_Vector inv_sqrt = hvx_vec_rsqrt_f32(vsrc[i]); \
HVX_Vector sqrt_res = HVX_OP_MUL(inv_sqrt, vsrc[i]); \
vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, sqrt_res); \
} \
} while(0)
static inline void hvx_sqrt_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src % 128 == 0);
hvx_sqrt_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
}
static inline void hvx_sqrt_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
hvx_sqrt_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
}
static inline void hvx_sqrt_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) src % 128 == 0);
hvx_sqrt_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
}
static inline void hvx_sqrt_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
hvx_sqrt_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
}
static inline void hvx_sqrt_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int num_elems) {
if ((unsigned long) dst % 128 == 0) {
if ((unsigned long) src % 128 == 0) {
hvx_sqrt_f32_aa(dst, src, num_elems);
} else {
hvx_sqrt_f32_au(dst, src, num_elems);
}
} else {
if ((unsigned long) src % 128 == 0) {
hvx_sqrt_f32_ua(dst, src, num_elems);
} else {
hvx_sqrt_f32_uu(dst, src, num_elems);
}
}
}
#endif