#include "mpn_mod.h"
static void
_mpn_mod_poly_bit_pack(nn_ptr res, nn_srcptr x, slong len, flint_bitcnt_t bits, slong nlimbs)
{
slong i, l, shift;
for (i = 0; i < len; i++)
{
l = (bits * i) / FLINT_BITS;
shift = (bits * i) % FLINT_BITS;
if (shift == 0)
flint_mpn_copyi(res + l, x + i * nlimbs, nlimbs);
else
res[l + nlimbs] = mpn_lshift(res + l, x + i * nlimbs, nlimbs, shift);
}
}
static void
_mpn_mod_poly_bit_unpack(nn_ptr res, nn_srcptr x, slong nlo, slong nhi, flint_bitcnt_t bits, slong nlimbs, gr_ctx_t ctx)
{
slong i, i1, i2, l1, shift, l2;
ulong t[2 * MPN_MOD_MAX_LIMBS + 3];
ulong mask;
slong blimbs, tn;
blimbs = (bits + FLINT_BITS - 1) / FLINT_BITS;
if (bits % FLINT_BITS == 0)
mask = ~UWORD(0);
else
mask = ((~UWORD(0)) >> (FLINT_BITS - (bits % FLINT_BITS)));
for (i = nlo; i < nhi; i++)
{
i1 = bits * i;
i2 = bits * (i + 1);
l1 = i1 / FLINT_BITS;
l2 = (i2 + FLINT_BITS - 1) / FLINT_BITS;
shift = i1 % FLINT_BITS;
if (shift == 0)
flint_mpn_copyi(t, x + l1, l2 - l1);
else
mpn_rshift(t, x + l1, l2 - l1, shift);
tn = blimbs;
t[tn - 1] &= mask;
MPN_NORM(t, tn);
mpn_mod_set_mpn(res + (i - nlo) * nlimbs, t, tn, ctx);
}
}
int
_mpn_mod_poly_mulmid_KS(nn_ptr res, nn_srcptr poly1, slong len1, nn_srcptr poly2, slong len2, slong nlo, slong nhi, gr_ctx_t ctx)
{
slong bits, nbits, nlimbs, limbs1, limbs2;
nn_ptr arr1, arr2, arr;
int squaring;
len1 = FLINT_MIN(len1, nhi);
len2 = FLINT_MIN(len2, nhi);
nlimbs = MPN_MOD_CTX_NLIMBS(ctx);
nbits = MPN_MOD_CTX_MODULUS_BITS(ctx);
if (nlo != 0)
{
slong nlo2 = (len1 + len2 - 1) - nlo;
if (len1 > nlo2)
{
slong trunc = len1 - nlo2;
poly1 += trunc * nlimbs;
len1 -= trunc;
nlo -= trunc;
nhi -= trunc;
}
if (len2 > nlo2)
{
slong trunc = len2 - nlo2;
poly2 += trunc * nlimbs;
len2 -= trunc;
nlo -= trunc;
nhi -= trunc;
}
}
squaring = (poly1 == poly2 && len1 == len2);
bits = 2 * nbits + FLINT_BIT_COUNT(FLINT_MIN(len1, len2));
limbs1 = (bits * len1 - 1) / FLINT_BITS + 1;
limbs2 = (bits * len2 - 1) / FLINT_BITS + 1;
FLINT_ASSERT(limbs1 >= (bits * (len1 - 1) / FLINT_BITS + nlimbs + 1));
FLINT_ASSERT(limbs2 >= (bits * (len2 - 1) / FLINT_BITS + nlimbs + 1));
arr1 = flint_calloc(squaring ? limbs1 : limbs1 + limbs2, sizeof(ulong));
arr2 = squaring ? arr1 : arr1 + limbs1;
_mpn_mod_poly_bit_pack(arr1, poly1, len1, bits, nlimbs);
if (!squaring)
_mpn_mod_poly_bit_pack(arr2, poly2, len2, bits, nlimbs);
arr = flint_malloc((limbs1 + limbs2) * sizeof(ulong));
if (squaring)
flint_mpn_sqr(arr, arr1, limbs1);
else if (limbs1 >= limbs2)
flint_mpn_mul(arr, arr1, limbs1, arr2, limbs2);
else
flint_mpn_mul(arr, arr2, limbs2, arr1, limbs1);
_mpn_mod_poly_bit_unpack(res, arr, nlo, nhi, bits, nlimbs, ctx);
flint_free(arr1);
flint_free(arr);
return GR_SUCCESS;
}
int
_mpn_mod_poly_mullow_KS(nn_ptr res, nn_srcptr poly1, slong len1, nn_srcptr poly2, slong len2, slong len, gr_ctx_t ctx)
{
return _mpn_mod_poly_mulmid_KS(res, poly1, len1, poly2, len2, 0, len, ctx);
}