#include "flint-mparam.h"
#include "nmod_vec.h"
void
_nmod_vec_scalar_addmul_nmod_generic(nn_ptr res, nn_srcptr vec, slong len, ulong c, nmod_t mod)
{
slong i;
for (i = 0; i < len; i++)
res[i] = nmod_addmul(res[i], vec[i], c, mod);
}
void _nmod_vec_scalar_addmul_nmod_shoup(nn_ptr res, nn_srcptr vec,
slong len, ulong c, nmod_t mod)
{
const ulong c_pr = n_mulmod_precomp_shoup(c, mod.n);
for (slong i = 0; i < len; i++)
{
ulong t = n_mulmod_shoup(c, vec[i], c_pr, mod.n);
res[i] = _nmod_add(res[i], t, mod);
}
}
void _nmod_vec_scalar_addmul_nmod(nn_ptr res, nn_srcptr vec,
slong len, ulong c, nmod_t mod)
{
if (c == UWORD(0))
return;
else if (c == UWORD(1))
_nmod_vec_add(res, res, vec, len, mod);
else if (c == mod.n - UWORD(1))
_nmod_vec_sub(res, res, vec, len, mod);
else if (len >= FLINT_MULMOD_SHOUP_THRESHOLD && NMOD_BITS(mod) != FLINT_BITS)
_nmod_vec_scalar_addmul_nmod_shoup(res, vec, len, c, mod);
else
_nmod_vec_scalar_addmul_nmod_generic(res, vec, len, c, mod);
}
void
_nmod_vec_scalar_mul_nmod_redc(nn_ptr res, nn_srcptr vec, slong len, ulong c, nmod_t mod)
{
nmod_redc_ctx_t ctx;
slong i;
ulong c_redc;
nmod_redc_ctx_init_nmod(ctx, mod);
c_redc = nmod_redc_set_nmod(c, ctx);
for (i = 0; i < len; i++)
res[i] = nmod_redc_mul(vec[i], c_redc, ctx);
}
static void _nmod_vec_scalar_mul_nmod_fullword(nn_ptr res, nn_srcptr vec,
slong len, ulong c, nmod_t mod)
{
slong i;
for (i = 0; i < len; i++)
NMOD_MUL_FULLWORD(res[i], vec[i], c, mod);
}
void _nmod_vec_scalar_mul_nmod_generic(nn_ptr res, nn_srcptr vec,
slong len, ulong c, nmod_t mod)
{
slong i;
for (i = 0; i < len; i++)
NMOD_MUL_PRENORM(res[i], vec[i], c << mod.norm, mod);
}
void _nmod_vec_scalar_mul_nmod_shoup(nn_ptr res, nn_srcptr vec,
slong len, ulong c, nmod_t mod)
{
const ulong c_pr = n_mulmod_precomp_shoup(c, mod.n);
for (slong i = 0; i < len; i++)
res[i] = n_mulmod_shoup(c, vec[i], c_pr, mod.n);
}
void _nmod_vec_scalar_mul_nmod(nn_ptr res, nn_srcptr vec,
slong len, ulong c, nmod_t mod)
{
if (c == UWORD(0))
_nmod_vec_zero(res, len);
else if (c == UWORD(1))
_nmod_vec_set(res, vec, len);
else if (c == mod.n - UWORD(1))
_nmod_vec_neg(res, vec, len, mod);
else if (NMOD_BITS(mod) == FLINT_BITS)
if (len >= 8 && mod.n % 2 != 0)
_nmod_vec_scalar_mul_nmod_redc(res, vec, len, c, mod);
else
_nmod_vec_scalar_mul_nmod_fullword(res, vec, len, c, mod);
else if (len >= FLINT_MULMOD_SHOUP_THRESHOLD)
_nmod_vec_scalar_mul_nmod_shoup(res, vec, len, c, mod);
else
_nmod_vec_scalar_mul_nmod_generic(res, vec, len, c, mod);
}