#ifndef VEC_AVX_H
#define VEC_AVX_H
#include <immintrin.h>
#include <math.h>
#include "celt/x86/x86cpu.h"
#define MAX_INPUTS (2048)
#define USE_SU_BIAS
#ifndef __SSE_4_1__
static inline __m128 mm_floor_ps(__m128 x) {
__m128 half = _mm_set1_ps(0.5);
return _mm_cvtepi32_ps(_mm_cvtps_epi32(_mm_sub_ps(x, half)));
}
#undef _mm_floor_ps
#define _mm_floor_ps(x) mm_floor_ps(x)
#endif
#ifndef __AVX__
typedef struct {
__m128 lo;
__m128 hi;
} mm256_emu;
#define __m256 mm256_emu
static inline mm256_emu mm256_loadu_ps(const float *src) {
mm256_emu ret;
ret.lo = _mm_loadu_ps(&src[0]);
ret.hi = _mm_loadu_ps(&src[4]);
return ret;
}
#define _mm256_loadu_ps(src) mm256_loadu_ps(src)
static inline void mm256_storeu_ps(float *dst, mm256_emu src) {
_mm_storeu_ps(dst, src.lo);
_mm_storeu_ps(&dst[4], src.hi);
}
#define _mm256_storeu_ps(dst, src) mm256_storeu_ps(dst, src)
static inline mm256_emu mm256_setzero_ps(void) {
mm256_emu ret;
ret.lo = _mm_setzero_ps();
ret.hi = ret.lo;
return ret;
}
#define _mm256_setzero_ps mm256_setzero_ps
static inline mm256_emu mm256_broadcast_ss(const float *x) {
mm256_emu ret;
ret.lo = _mm_set1_ps(*x);
ret.hi = ret.lo;
return ret;
}
#define _mm256_broadcast_ss(x) mm256_broadcast_ss(x)
static inline mm256_emu mm256_set1_ps(float x) {
mm256_emu ret;
ret.lo = _mm_set1_ps(x);
ret.hi = ret.lo;
return ret;
}
#define _mm256_set1_ps(x) mm256_set1_ps(x)
static inline mm256_emu mm256_mul_ps(mm256_emu a, mm256_emu b) {
mm256_emu ret;
ret.lo = _mm_mul_ps(a.lo, b.lo);
ret.hi = _mm_mul_ps(a.hi, b.hi);
return ret;
}
#define _mm256_mul_ps(a,b) mm256_mul_ps(a,b)
static inline mm256_emu mm256_add_ps(mm256_emu a, mm256_emu b) {
mm256_emu ret;
ret.lo = _mm_add_ps(a.lo, b.lo);
ret.hi = _mm_add_ps(a.hi, b.hi);
return ret;
}
#define _mm256_add_ps(a,b) mm256_add_ps(a,b)
static inline mm256_emu mm256_max_ps(mm256_emu a, mm256_emu b) {
mm256_emu ret;
ret.lo = _mm_max_ps(a.lo, b.lo);
ret.hi = _mm_max_ps(a.hi, b.hi);
return ret;
}
#define _mm256_max_ps(a,b) mm256_max_ps(a,b)
static inline mm256_emu mm256_min_ps(mm256_emu a, mm256_emu b) {
mm256_emu ret;
ret.lo = _mm_min_ps(a.lo, b.lo);
ret.hi = _mm_min_ps(a.hi, b.hi);
return ret;
}
#define _mm256_min_ps(a,b) mm256_min_ps(a,b)
static inline mm256_emu mm256_rcp_ps(mm256_emu a) {
mm256_emu ret;
ret.lo = _mm_rcp_ps(a.lo);
ret.hi = _mm_rcp_ps(a.hi);
return ret;
}
#define _mm256_rcp_ps(a) mm256_rcp_ps(a)
static inline __m128 mm256_extractf128_ps(mm256_emu x, int i) {
return (i==0) ? x.lo : x.hi;
}
#undef _mm256_extractf128_ps
#define _mm256_extractf128_ps(x,i) mm256_extractf128_ps(x,i)
static inline mm256_emu mm256_insertf128_ps(mm256_emu dst, __m128 src, int i) {
if (i==0) dst.lo = src;
else dst.hi = src;
return dst;
}
#undef _mm256_insertf128_ps
#define _mm256_insertf128_ps(dst,src,i) mm256_insertf128_ps(dst,src,i)
#endif
#ifndef __AVX2__
typedef struct {
__m128i lo;
__m128i hi;
} mm256i_emu;
typedef __m256i real_m256i;
#define __m256i mm256i_emu
static inline mm256i_emu mm256_setzero_si256(void) {
mm256i_emu ret;
ret.lo = _mm_setzero_si128();
ret.hi = ret.lo;
return ret;
}
#define _mm256_setzero_si256 mm256_setzero_si256
static inline mm256i_emu mm256_loadu_si256(const mm256i_emu *src) {
mm256i_emu ret;
ret.lo = _mm_loadu_si128((const __m128i*)src);
ret.hi = _mm_loadu_si128(&((const __m128i*)src)[1]);
return ret;
}
#define _mm256_loadu_si256(src) mm256_loadu_si256(src)
static inline void mm256_storeu_si256(mm256i_emu *dst, mm256i_emu src) {
_mm_storeu_si128((__m128i*)dst, src.lo);
_mm_storeu_si128(&((__m128i*)dst)[1], src.hi);
}
#define _mm256_storeu_si256(dst, src) mm256_storeu_si256(dst, src)
static inline mm256i_emu mm256_broadcastd_epi32(__m128i x) {
mm256i_emu ret;
ret.hi = ret.lo = _mm_shuffle_epi32(x, 0);
return ret;
}
#define _mm256_broadcastd_epi32(x) mm256_broadcastd_epi32(x)
static inline mm256i_emu mm256_set1_epi32(int x) {
mm256i_emu ret;
ret.lo = _mm_set1_epi32(x);
ret.hi = ret.lo;
return ret;
}
#define _mm256_set1_epi32(x) mm256_set1_epi32(x)
static inline mm256i_emu mm256_set1_epi16(int x) {
mm256i_emu ret;
ret.lo = _mm_set1_epi16(x);
ret.hi = ret.lo;
return ret;
}
#define _mm256_set1_epi16(x) mm256_set1_epi16(x)
static inline mm256i_emu mm256_add_epi32(mm256i_emu a, mm256i_emu b) {
mm256i_emu ret;
ret.lo = _mm_add_epi32(a.lo, b.lo);
ret.hi = _mm_add_epi32(a.hi, b.hi);
return ret;
}
#define _mm256_add_epi32(a,b) mm256_add_epi32(a,b)
static inline mm256i_emu mm256_madd_epi16(mm256i_emu a, mm256i_emu b) {
mm256i_emu ret;
ret.lo = _mm_madd_epi16(a.lo, b.lo);
ret.hi = _mm_madd_epi16(a.hi, b.hi);
return ret;
}
#define _mm256_madd_epi16(a,b) mm256_madd_epi16(a,b)
static inline mm256i_emu mm256_maddubs_epi16(mm256i_emu a, mm256i_emu b) {
mm256i_emu ret;
ret.lo = _mm_maddubs_epi16(a.lo, b.lo);
ret.hi = _mm_maddubs_epi16(a.hi, b.hi);
return ret;
}
#define _mm256_maddubs_epi16(a,b) mm256_maddubs_epi16(a,b)
#ifdef __AVX__
typedef union {
mm256i_emu fake;
real_m256i real;
} mm256_union;
static inline __m256 mm256_cvtepi32_ps(mm256i_emu a) {
mm256_union src;
src.fake = a;
return _mm256_cvtepi32_ps(src.real);
}
#define _mm256_cvtepi32_ps(a) mm256_cvtepi32_ps(a)
static inline mm256i_emu mm256_cvtps_epi32(__m256 a) {
mm256_union ret;
ret.real = _mm256_cvtps_epi32(a);
return ret.fake;
}
#define _mm256_cvtps_epi32(a) mm256_cvtps_epi32(a)
#else
static inline mm256_emu mm256_cvtepi32_ps(mm256i_emu a) {
mm256_emu ret;
ret.lo = _mm_cvtepi32_ps(a.lo);
ret.hi = _mm_cvtepi32_ps(a.hi);
return ret;
}
#define _mm256_cvtepi32_ps(a) mm256_cvtepi32_ps(a)
static inline mm256i_emu mm256_cvtps_epi32(mm256_emu a) {
mm256i_emu ret;
ret.lo = _mm_cvtps_epi32(a.lo);
ret.hi = _mm_cvtps_epi32(a.hi);
return ret;
}
#define _mm256_cvtps_epi32(a) mm256_cvtps_epi32(a)
#endif
#endif
#if !(defined(__FMA__) && defined(__AVX__))
#define _mm256_fmadd_ps(a,b,c) _mm256_add_ps(_mm256_mul_ps(a, b), c)
#define _mm_fmadd_ps(a,b,c) _mm_add_ps(_mm_mul_ps(a, b), c)
#endif
#ifdef __AVX2__
static inline __m256 exp8_approx(__m256 X)
{
const __m256 K0 = _mm256_set1_ps(0.99992522f);
const __m256 K1 = _mm256_set1_ps(0.69583354f);
const __m256 K2 = _mm256_set1_ps(0.22606716f);
const __m256 K3 = _mm256_set1_ps(0.078024523f);
const __m256 log2_E = _mm256_set1_ps(1.44269504f);
const __m256 max_in = _mm256_set1_ps(50.f);
const __m256 min_in = _mm256_set1_ps(-50.f);
__m256 XF, Y;
__m256i I;
X = _mm256_mul_ps(X, log2_E);
X = _mm256_max_ps(min_in, _mm256_min_ps(max_in, X));
XF = _mm256_floor_ps(X);
I = _mm256_cvtps_epi32(XF);
X = _mm256_sub_ps(X, XF);
Y = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(K3, X, K2), X, K1), X, K0);
I = _mm256_slli_epi32(I, 23);
Y = _mm256_castsi256_ps(_mm256_add_epi32(I, _mm256_castps_si256(Y)));
return Y;
}
static inline void vector_ps_to_epi8(unsigned char *x, const float *_x, int len) {
int i;
__m256 const127 = _mm256_set1_ps(127.f);
for (i=0;i<len;i+=8) {
__m256 xf;
__m256i xi;
xf = _mm256_loadu_ps(&_x[i]);
xf = _mm256_fmadd_ps(xf, const127, const127);
xi = _mm256_cvtps_epi32(xf);
xi = _mm256_packus_epi32(xi, _mm256_setzero_si256());
xi = _mm256_permute4x64_epi64(xi, 0xD8);
xi = _mm256_packus_epi16(xi, _mm256_setzero_si256());
xi = _mm256_permutevar8x32_epi32(xi, _mm256_setr_epi32(0,1, 0,0, 0,0, 0,0));
_mm256_storeu_si256 ((__m256i *)(void*)&x[i], xi);
}
}
#else
static inline __m128 exp4_approx(__m128 X)
{
const __m128 K0 = _mm_set1_ps(0.99992522f);
const __m128 K1 = _mm_set1_ps(0.69583354f);
const __m128 K2 = _mm_set1_ps(0.22606716f);
const __m128 K3 = _mm_set1_ps(0.078024523f);
const __m128 log2_E = _mm_set1_ps(1.44269504);
const __m128 max_in = _mm_set1_ps(50.f);
const __m128 min_in = _mm_set1_ps(-50.f);
const __m128i mask = _mm_set1_epi32(0x7fffffff);
__m128 XF, Y;
__m128i I;
X = _mm_mul_ps(X, log2_E);
X = _mm_max_ps(min_in, _mm_min_ps(max_in, X));
XF = _mm_floor_ps(X);
I = _mm_cvtps_epi32(XF);
X = _mm_sub_ps(X, XF);
Y = _mm_fmadd_ps(_mm_fmadd_ps(_mm_fmadd_ps(K3, X, K2), X, K1), X, K0);
I = _mm_slli_epi32(I, 23);
Y = _mm_castsi128_ps(_mm_and_si128(mask, _mm_add_epi32(I, _mm_castps_si128(Y))));
return Y;
}
static inline __m256 exp8_approx(__m256 X)
{
__m256 Y;
__m128 Xhi, Xlo, Yhi, Ylo;
Xhi = _mm256_extractf128_ps(X, 1);
Xlo = _mm256_extractf128_ps(X, 0);
Yhi = exp4_approx(Xhi);
Ylo = exp4_approx(Xlo);
Y = _mm256_insertf128_ps(_mm256_setzero_ps(), Yhi, 1);
Y = _mm256_insertf128_ps(Y, Ylo, 0);
return Y;
}
static inline void vector_ps_to_epi8(unsigned char *x, const float *_x, int len) {
int i;
for (i=0;i<len;i++) x[i] = 127+(int)floor(.5+127*_x[i]);
}
#endif
#ifdef __AVX__
static inline __m256 tanh8_approx(__m256 X)
{
const __m256 N0 = _mm256_set1_ps(952.52801514f);
const __m256 N1 = _mm256_set1_ps(96.39235687f);
const __m256 N2 = _mm256_set1_ps(0.60863042f);
const __m256 D0 = _mm256_set1_ps(952.72399902f);
const __m256 D1 = _mm256_set1_ps(413.36801147f);
const __m256 D2 = _mm256_set1_ps(11.88600922f);
const __m256 max_out = _mm256_set1_ps(1.f);
const __m256 min_out = _mm256_set1_ps(-1.f);
__m256 X2, num, den;
X2 = _mm256_mul_ps(X, X);
num = _mm256_fmadd_ps(_mm256_fmadd_ps(N2, X2, N1), X2, N0);
den = _mm256_fmadd_ps(_mm256_fmadd_ps(D2, X2, D1), X2, D0);
num = _mm256_mul_ps(num, X);
den = _mm256_rcp_ps(den);
num = _mm256_mul_ps(num, den);
return _mm256_max_ps(min_out, _mm256_min_ps(max_out, num));
}
static inline __m256 sigmoid8_approx(__m256 X)
{
const __m256 N0 = _mm256_set1_ps(238.13200378f);
const __m256 N1 = _mm256_set1_ps(6.02452230f);
const __m256 N2 = _mm256_set1_ps(0.00950985f);
const __m256 D0 = _mm256_set1_ps(952.72399902f);
const __m256 D1 = _mm256_set1_ps(103.34200287f);
const __m256 D2 = _mm256_set1_ps(0.74287558f);
const __m256 half = _mm256_set1_ps(0.5);
const __m256 max_out = _mm256_set1_ps(1.f);
const __m256 min_out = _mm256_set1_ps(0.f);
__m256 X2, num, den;
X2 = _mm256_mul_ps(X, X);
num = _mm256_fmadd_ps(_mm256_fmadd_ps(N2, X2, N1), X2, N0);
den = _mm256_fmadd_ps(_mm256_fmadd_ps(D2, X2, D1), X2, D0);
num = _mm256_mul_ps(num, X);
den = _mm256_rcp_ps(den);
num = _mm256_fmadd_ps(num, den, half);
return _mm256_max_ps(min_out, _mm256_min_ps(max_out, num));
}
static inline float tanh_approx(float x)
{
float out[8];
__m256 X, Y;
X = _mm256_set1_ps(x);
Y = tanh8_approx(X);
_mm256_storeu_ps(out, Y);
return out[0];
}
static inline float sigmoid_approx(float x)
{
float out[8];
__m256 X, Y;
X = _mm256_set1_ps(x);
Y = sigmoid8_approx(X);
_mm256_storeu_ps(out, Y);
return out[0];
}
#else
static inline __m128 tanh4_approx(__m128 X)
{
const __m128 N0 = _mm_set1_ps(952.52801514f);
const __m128 N1 = _mm_set1_ps(96.39235687f);
const __m128 N2 = _mm_set1_ps(0.60863042f);
const __m128 D0 = _mm_set1_ps(952.72399902f);
const __m128 D1 = _mm_set1_ps(413.36801147f);
const __m128 D2 = _mm_set1_ps(11.88600922f);
const __m128 max_out = _mm_set1_ps(1.f);
const __m128 min_out = _mm_set1_ps(-1.f);
__m128 X2, num, den;
X2 = _mm_mul_ps(X, X);
num = _mm_fmadd_ps(_mm_fmadd_ps(N2, X2, N1), X2, N0);
den = _mm_fmadd_ps(_mm_fmadd_ps(D2, X2, D1), X2, D0);
num = _mm_mul_ps(num, X);
den = _mm_rcp_ps(den);
num = _mm_mul_ps(num, den);
return _mm_max_ps(min_out, _mm_min_ps(max_out, num));
}
static inline __m128 sigmoid4_approx(__m128 X)
{
const __m128 N0 = _mm_set1_ps(238.13200378f);
const __m128 N1 = _mm_set1_ps(6.02452230f);
const __m128 N2 = _mm_set1_ps(0.00950985f);
const __m128 D0 = _mm_set1_ps(952.72399902f);
const __m128 D1 = _mm_set1_ps(103.34200287f);
const __m128 D2 = _mm_set1_ps(0.74287558f);
const __m128 half = _mm_set1_ps(0.5);
const __m128 max_out = _mm_set1_ps(1.f);
const __m128 min_out = _mm_set1_ps(0.f);
__m128 X2, num, den;
X2 = _mm_mul_ps(X, X);
num = _mm_fmadd_ps(_mm_fmadd_ps(N2, X2, N1), X2, N0);
den = _mm_fmadd_ps(_mm_fmadd_ps(D2, X2, D1), X2, D0);
num = _mm_mul_ps(num, X);
den = _mm_rcp_ps(den);
num = _mm_fmadd_ps(num, den, half);
return _mm_max_ps(min_out, _mm_min_ps(max_out, num));
}
static inline float tanh_approx(float x)
{
float out[4];
__m128 X, Y;
X = _mm_set1_ps(x);
Y = tanh4_approx(X);
_mm_storeu_ps(out, Y);
return out[0];
}
static inline float sigmoid_approx(float x)
{
float out[4];
__m128 X, Y;
X = _mm_set1_ps(x);
Y = sigmoid4_approx(X);
_mm_storeu_ps(out, Y);
return out[0];
}
#endif
static inline float lpcnet_exp(float x)
{
float out[8];
__m256 X, Y;
X = _mm256_set1_ps(x);
Y = exp8_approx(X);
_mm256_storeu_ps(out, Y);
return out[0];
}
static inline void softmax(float *y, const float *x, int N)
{
int i;
for (i=0;i<N-7;i+=8)
{
__m256 X, Y;
X = _mm256_loadu_ps(&x[i]);
Y = exp8_approx(X);
_mm256_storeu_ps(&y[i], Y);
}
for (;i<N;i++)
y[i] = lpcnet_exp(x[i]);
}
#ifdef __AVX__
static inline void vec_tanh(float *y, const float *x, int N)
{
int i;
for (i=0;i<N-7;i+=8)
{
__m256 X, Y;
X = _mm256_loadu_ps(&x[i]);
Y = tanh8_approx(X);
_mm256_storeu_ps(&y[i], Y);
}
for (;i<N;i++)
{
y[i] = tanh_approx(x[i]);
}
}
static inline void vec_sigmoid(float *y, const float *x, int N)
{
int i;
for (i=0;i<N-7;i+=8)
{
__m256 X, Y;
X = _mm256_loadu_ps(&x[i]);
Y = sigmoid8_approx(X);
_mm256_storeu_ps(&y[i], Y);
}
for (;i<N;i++)
{
y[i] = sigmoid_approx(x[i]);
}
}
#else
static inline void vec_tanh(float *y, const float *x, int N)
{
int i;
for (i=0;i<N-3;i+=4)
{
__m128 X, Y;
X = _mm_loadu_ps(&x[i]);
Y = tanh4_approx(X);
_mm_storeu_ps(&y[i], Y);
}
for (;i<N;i++)
{
y[i] = tanh_approx(x[i]);
}
}
static inline void vec_sigmoid(float *y, const float *x, int N)
{
int i;
for (i=0;i<N-3;i+=4)
{
__m128 X, Y;
X = _mm_loadu_ps(&x[i]);
Y = sigmoid4_approx(X);
_mm_storeu_ps(&y[i], Y);
}
for (;i<N;i++)
{
y[i] = sigmoid_approx(x[i]);
}
}
#endif
#if defined(__AVXVNNI__) || defined(__AVX512VNNI__)
#define opus_mm256_dpbusds_epi32(src, a, b) _mm256_dpbusds_epi32(src, a, b)
#elif defined(__AVX2__)
static inline __m256i opus_mm256_dpbusds_epi32(__m256i src, __m256i a, __m256i b) {
__m256i ones, tmp;
ones = _mm256_set1_epi16(1);
tmp = _mm256_maddubs_epi16(a, b);
tmp = _mm256_madd_epi16(tmp, ones);
return _mm256_add_epi32(src, tmp);
}
#elif defined(__SSSE3__)
static inline mm256i_emu opus_mm256_dpbusds_epi32(mm256i_emu src, mm256i_emu a, mm256i_emu b) {
mm256i_emu ones, tmp;
ones = _mm256_set1_epi16(1);
tmp = _mm256_maddubs_epi16(a, b);
tmp = _mm256_madd_epi16(tmp, ones);
return _mm256_add_epi32(src, tmp);
}
#elif defined(__SSE2__)
static inline __m128i mm_dpbusds_epi32(__m128i src, __m128i a, __m128i b) {
__m128i ah, al, bh, bl, tmp;
ah = _mm_srli_epi16(a, 8);
bh = _mm_srai_epi16(b, 8);
al = _mm_srli_epi16(_mm_slli_epi16(a, 8), 8);
bl = _mm_srai_epi16(_mm_slli_epi16(b, 8), 8);
tmp = _mm_add_epi32(_mm_madd_epi16(ah, bh), _mm_madd_epi16(al, bl));
return _mm_add_epi32(src, tmp);
}
static inline mm256i_emu opus_mm256_dpbusds_epi32(mm256i_emu src, mm256i_emu a, mm256i_emu b) {
mm256i_emu res;
res.hi = mm_dpbusds_epi32(src.hi, a.hi, b.hi);
res.lo = mm_dpbusds_epi32(src.lo, a.lo, b.lo);
return res;
}
#else
#error "No optimizations in vec_avx.h. This should never happen. "
#endif
static inline void sgemv(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
{
int i, j;
i=0;
for (;i<rows-15;i+=16)
{
float *y;
__m256 vy0, vy8;
y = &out[i];
vy0 = _mm256_setzero_ps();
vy8 = _mm256_setzero_ps();
for (j=0;j<cols;j++)
{
__m256 vxj;
__m256 vw;
vxj = _mm256_broadcast_ss(&x[j]);
vw = _mm256_loadu_ps(&weights[j*col_stride + i]);
vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
vw = _mm256_loadu_ps(&weights[j*col_stride + i + 8]);
vy8 = _mm256_fmadd_ps(vw, vxj, vy8);
}
_mm256_storeu_ps (&y[0], vy0);
_mm256_storeu_ps (&y[8], vy8);
}
for (;i<rows-7;i+=8)
{
float *y;
__m256 vy0;
y = &out[i];
vy0 = _mm256_setzero_ps();
for (j=0;j<cols;j++)
{
__m256 vxj;
__m256 vw;
vxj = _mm256_broadcast_ss(&x[j]);
vw = _mm256_loadu_ps(&weights[j*col_stride + i]);
vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
}
_mm256_storeu_ps (&y[0], vy0);
}
for (;i<rows-3;i+=4)
{
float *y;
__m128 vy0;
y = &out[i];
vy0 = _mm_setzero_ps();
for (j=0;j<cols;j++)
{
__m128 vxj;
__m128 vw;
vxj = _mm_set1_ps(x[j]);
vw = _mm_loadu_ps(&weights[j*col_stride + i]);
vy0 = _mm_fmadd_ps(vw, vxj, vy0);
}
_mm_storeu_ps (&y[0], vy0);
}
for (;i<rows;i++)
{
out[i] = 0;
for (j=0;j<cols;j++) out[i] += weights[j*col_stride + i]*x[j];
}
}
static inline void sparse_sgemv8x4(float *out, const float *weights, const int *idx, int rows, const float *x)
{
int i, j;
for (i=0;i<rows;i+=8)
{
float *y;
int cols;
__m256 vy0;
y = &out[i];
vy0 = _mm256_setzero_ps();
cols = *idx++;
for (j=0;j<cols;j++)
{
int id;
__m256 vxj;
__m256 vw;
id = *idx++;
vxj = _mm256_broadcast_ss(&x[id]);
vw = _mm256_loadu_ps(&weights[0]);
vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
vxj = _mm256_broadcast_ss(&x[id+1]);
vw = _mm256_loadu_ps(&weights[8]);
vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
vxj = _mm256_broadcast_ss(&x[id+2]);
vw = _mm256_loadu_ps(&weights[16]);
vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
vxj = _mm256_broadcast_ss(&x[id+3]);
vw = _mm256_loadu_ps(&weights[24]);
vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
weights += 32;
}
_mm256_storeu_ps (&y[0], vy0);
}
}
static inline void sparse_cgemv8x4(float *_out, const opus_int8 *w, const int *idx, const float *scale, int rows, int cols, const float *_x)
{
int i, j;
unsigned char x[MAX_INPUTS];
vector_ps_to_epi8(x, _x, cols);
for (i=0;i<rows;i+=8)
{
int colblocks;
__m256i vy0;
__m256 vout;
colblocks = *idx++;
vy0 = _mm256_setzero_si256();
j=0;
#if 1
for (;j<colblocks-3;j+=4)
{
__m256i vxj;
__m256i vw;
vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[*idx++]));
vw = _mm256_loadu_si256((const __m256i *)(void*)w);
vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
w += 32;
vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[*idx++]));
vw = _mm256_loadu_si256((const __m256i *)(void*)w);
vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
w += 32;
vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[*idx++]));
vw = _mm256_loadu_si256((const __m256i *)(void*)w);
vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
w += 32;
vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[*idx++]));
vw = _mm256_loadu_si256((const __m256i *)(void*)w);
vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
w += 32;
}
#endif
for (;j<colblocks;j++)
{
__m256i vxj;
__m256i vw;
vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[*idx++]));
vw = _mm256_loadu_si256((const __m256i *)(void*)w);
vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
w += 32;
}
vout = _mm256_cvtepi32_ps(vy0);
vout = _mm256_mul_ps(vout, _mm256_loadu_ps(&scale[i]));
_mm256_storeu_ps(&_out[i], vout);
}
}
static inline void cgemv8x4(float *_out, const opus_int8 *w, const float *scale, int rows, int cols, const float *_x)
{
int i, j;
unsigned char x[MAX_INPUTS];
vector_ps_to_epi8(x, _x, cols);
for (i=0;i<rows;i+=8)
{
__m256i vy0;
__m256 vout;
vy0 = _mm256_setzero_si256();
j=0;
#if 1
for (;j<cols-12;j+=16)
{
__m256i vxj;
__m256i vw;
vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[j]));
vw = _mm256_loadu_si256((const __m256i *)(void*)w);
vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
w += 32;
vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[j+4]));
vw = _mm256_loadu_si256((const __m256i *)(void*)w);
vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
w += 32;
vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[j+8]));
vw = _mm256_loadu_si256((const __m256i *)(void*)w);
vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
w += 32;
vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[j+12]));
vw = _mm256_loadu_si256((const __m256i *)(void*)w);
vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
w += 32;
}
#endif
for (;j<cols;j+=4)
{
__m256i vxj;
__m256i vw;
vxj = _mm256_broadcastd_epi32(_mm_loadu_si32(&x[j]));
vw = _mm256_loadu_si256((const __m256i *)(void*)w);
vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw);
w += 32;
}
vout = _mm256_cvtepi32_ps(vy0);
vout = _mm256_mul_ps(vout, _mm256_loadu_ps(&scale[i]));
_mm256_storeu_ps(&_out[i], vout);
}
}
#define SCALE (128.f*127.f)
#define SCALE_1 (1.f/128.f/127.f)
#define USE_SU_BIAS
#endif