#include <stdint.h>
#include <math.h>
#include "nmod.h"
#include "nmod_vec.h"
#include "nmod_poly.h"
#if FLINT_HAVE_FFT_SMALL
#include "fft_small.h"
#define FFT_PRIME UWORD(1108307720798209)
#define M16 (UWORD(1) << 16)
#define M17 (UWORD(1) << 17)
#define FFT_SMALL_MAX_REPACK 23
static const int repack_limit_tab[] = {
0,
264497,
256720,
77049,
36925,
21385,
14099,
9872,
7337,
5565,
4481,
3609,
3084,
2562,
2129,
1859,
1687,
1446,
1264,
1160,
1035,
942,
874,
725,
662,
626,
587,
562,
517,
490,
426,
410,
363,
342,
336,
321,
};
#if 0#endif
#if 1
FLINT_FORCE_INLINE
uint32_t u32_mod_preinv(uint32_t x, uint32_t d, uint64_t dinv)
{
uint64_t l = dinv * x;
#if !defined(__GNUC__)
uint64_t hi, lo;
umul_ppmm(hi, lo, l, d);
return hi;
#else
return ((__uint128_t) l * d ) >> 64;
#endif
}
FLINT_FORCE_INLINE
uint64_t u32_preinv(uint32_t d)
{
return UINT64_C(0xFFFFFFFFFFFFFFFF) / d + 1;
}
#else#endif
#define UNPACK(m) \
cy = 0; \
uint32_t dd = mod.n; \
uint64_t ddinv = u32_preinv(dd); \
if (znlo != 0) \
{ \
if (znlo % 2) \
{ \
c = z2[znlo / 2 - znlo2]; \
c1 = (c / m) % m; \
c2 = c / (m * m); \
d = u32_mod_preinv(c1, dd, ddinv); \
z[0] = d; \
cy = c2; \
} \
else \
cy = z2[znlo / 2 - 1 - znlo2] / (m * m); \
} \
for (i = (znlo + 1) / 2; 2 * i + 1 < zn; i++) \
{ \
c = z2[i - znlo2]; \
c0 = c % m; \
c1 = (c / m) % m; \
c2 = c / (m * m); \
d = u32_mod_preinv(c0 + cy, dd, ddinv); \
z[2 * i - znlo] = d; \
d = u32_mod_preinv(c1, dd, ddinv); \
z[2 * i + 1 - znlo] = d; \
cy = c2; \
} \
if (zn % 2) \
{ \
if (zn > 2 * zn2) \
c0 = 0; \
else \
c0 = z2[zn2 - 1 - znlo2] % M; \
d = u32_mod_preinv(c0 + cy, dd, ddinv); \
z[zn - 1 - znlo] = d; \
}
static int
_nmod_poly_mulmid_fft_small_repack_m(nn_ptr z, nn_srcptr a, slong an, nn_srcptr b, slong bn, slong znlo, slong zn, ulong M, nmod_t mod)
{
TMP_INIT;
ulong *a2, *b2, *z2;
slong an2, bn2, zn2, znlo2;
ulong d, cy, c, c0, c1, c2;
nmod_t mod2;
slong i;
int squaring;
FLINT_ASSERT(M == M16 || M == M17);
FLINT_ASSERT(mod.n <= FFT_SMALL_MAX_REPACK);
squaring = (a == b) && (an == bn);
TMP_START;
an2 = (an + 1) / 2;
bn2 = (bn + 1) / 2;
zn2 = FLINT_MIN((zn + 1) / 2, an2 + bn2 - 1);
znlo2 = FLINT_MAX(0, (znlo - 1) / 2);
if (squaring)
{
a2 = TMP_ALLOC((an2 + (zn2 - znlo2)) * sizeof(ulong));
b2 = a2;
z2 = a2 + an2;
}
else
{
a2 = TMP_ALLOC((an2 + bn2 + (zn2 - znlo2)) * sizeof(ulong));
b2 = a2 + an2;
z2 = b2 + bn2;
}
if (M == M16)
{
for (i = 0; i + 1 < an; i += 2)
a2[i / 2] = a[i] | (a[i + 1] << 16);
if (!squaring)
for (i = 0; i + 1 < bn; i += 2)
b2[i / 2] = b[i] | (b[i + 1] << 16);
}
else
{
for (i = 0; i + 1 < an; i += 2)
a2[i / 2] = a[i] | (a[i + 1] << 17);
if (!squaring)
for (i = 0; i + 1 < bn; i += 2)
b2[i / 2] = b[i] | (b[i + 1] << 17);
}
if (an % 2) a2[an / 2] = a[an - 1];
if (bn % 2) b2[bn / 2] = b[bn - 1];
nmod_init(&mod2, FFT_PRIME);
_nmod_poly_mul_mid_default_mpn_ctx(z2, znlo2, zn2, a2, an2, b2, bn2, mod2);
if (M == M16)
{
UNPACK(M16);
}
else
{
UNPACK(M17);
}
TMP_END;
return 1;
}
int
_nmod_poly_mulmid_fft_small_repack(nn_ptr z, nn_srcptr a, slong an, nn_srcptr b, slong bn, slong znlo, slong zn, nmod_t mod)
{
uint32_t aeven;
uint32_t aodd;
uint32_t beven;
uint32_t bodd;
ulong c0_bound;
ulong c1_bound;
ulong c2_bound;
ulong n = mod.n;
slong i;
an = FLINT_MIN(an, zn);
bn = FLINT_MIN(bn, zn);
if (n > FFT_SMALL_MAX_REPACK || FLINT_MAX(an, bn) > (1 << 19))
return 0;
if (an + bn > 2.3 * repack_limit_tab[n])
return 0;
c0_bound = FLINT_MIN(an, bn) * (n - 1) * (n - 1);
if (c0_bound < M16)
return _nmod_poly_mulmid_fft_small_repack_m(z, a, an, b, bn, znlo, zn, M16, mod);
aeven = (an % 2) ? a[an - 1] : 0;
aodd = 0;
for (i = 0; i + 1 < an; i += 2)
{
aeven += (uint32_t) a[i] * (uint32_t) a[i];
aodd += (uint32_t) a[i + 1] * (uint32_t) a[i + 1];
}
aeven = sqrt(aeven) + 1;
aodd = sqrt(aodd) + 1;
if (a == b && an == bn)
{
beven = aeven;
bodd = aodd;
}
else
{
beven = (bn % 2) ? b[bn - 1] : 0;
bodd = 0;
for (i = 0; i + 1 < bn; i += 2)
{
beven += (uint32_t) b[i] * (uint32_t) b[i];
bodd += (uint32_t) b[i + 1] * (uint32_t) b[i + 1];
}
beven = sqrt(beven) + 1;
bodd = sqrt(bodd) + 1;
}
c0_bound = aeven * beven;
c1_bound = aeven * bodd + aodd * beven;
c2_bound = aodd * bodd;
if (c0_bound < M16 && c1_bound < M16 && c2_bound * M16 * M16 < FFT_PRIME)
return _nmod_poly_mulmid_fft_small_repack_m(z, a, an, b, bn, znlo, zn, M16, mod);
if (c0_bound < M17 && c1_bound < M17 && c2_bound * M17 * M17 < FFT_PRIME)
return _nmod_poly_mulmid_fft_small_repack_m(z, a, an, b, bn, znlo, zn, M17, mod);
return 0;
}
void
_nmod_poly_mulmid_fft_small(nn_ptr z, nn_srcptr a, slong an, nn_srcptr b, slong bn, slong znlo, slong zn, nmod_t mod)
{
if (!_nmod_poly_mulmid_fft_small_repack(z, a, an, b, bn, znlo, zn, mod))
_nmod_poly_mul_mid_default_mpn_ctx(z, znlo, zn, a, an, b, bn, mod);
}
void
_nmod_poly_mullow_fft_small(nn_ptr z, nn_srcptr a, slong an, nn_srcptr b, slong bn, slong zn, nmod_t mod)
{
_nmod_poly_mulmid_fft_small(z, a, an, b, bn, 0, zn, mod);
}
static const short fft_mul_tab[] = {1326, 1326, 1095, 802, 674, 537, 330, 306, 290,
274, 200, 192, 182, 173, 163, 99, 97, 93, 90, 82, 80, 438, 414, 324, 393,
298, 298, 268, 187, 185, 176, 176, 168, 167, 158, 158, 97, 96, 93, 92, 89,
89, 85, 85, 80, 81, 177, 172, 163, 162, 164, 176, 171, 167, 167, 164, 163,
163, 160, 165, 95, 96, 90, 94, };
static const short fft_sqr_tab[] = {1420, 1420, 1353, 964, 689, 569, 407, 353, 321,
321, 292, 279, 200, 182, 182, 159, 159, 152, 145, 139, 723, 626, 626, 569,
597, 448, 542, 292, 292, 200, 191, 191, 182, 182, 166, 166, 166, 159, 159,
159, 152, 152, 145, 145, 93, 200, 191, 182, 182, 182, 182, 191, 191, 191,
182, 182, 174, 182, 182, 182, 152, 152, 152, 145, };
static const short fft_mullow_tab[] = {1115, 1115, 597, 569, 407, 321, 306, 279, 191,
182, 166, 159, 152, 145, 139, 89, 85, 78, 75, 75, 69, 174, 174, 166, 159,
152, 152, 152, 97, 101, 106, 111, 101, 101, 101, 139, 145, 145, 139, 145,
145, 139, 145, 145, 145, 182, 182, 182, 182, 182, 182, 191, 200, 220, 210,
200, 210, 210, 210, 210, 191, 182, 182, 174, };
int
_nmod_poly_mullow_want_fft_small(slong len1, slong len2, slong n, int squaring, nmod_t mod)
{
slong len, bits, cutoff_len;
if (len2 > len1)
FLINT_SWAP(slong, len1, len2);
if (n == len1 + len2 - 1)
{
if (mod.n <= FFT_SMALL_MAX_REPACK)
{
if (len2 < 64)
return 0;
len = len1 + len2 - 1;
return len > 370 && !(len > 512 && len < 660)
&& !(mod.n <= 3 && len > 1024 && len < 1100);
}
bits = NMOD_BITS(mod);
cutoff_len = FLINT_MIN(len1, 2 * len2);
if (squaring)
return cutoff_len >= fft_sqr_tab[bits - 1];
else
return cutoff_len >= fft_mul_tab[bits - 1];
}
else
{
if (mod.n <= FFT_SMALL_MAX_REPACK)
{
if (len2 < 64)
return 0;
len = len1 + len2 - 1;
return len > 450 && !(len > 512 && len < 700);
}
bits = NMOD_BITS(mod);
cutoff_len = FLINT_MIN(len1, len2);
return cutoff_len >= fft_mullow_tab[bits - 1];
}
}
#else
int
_nmod_poly_mulmid_fft_small_repack(nn_ptr z, nn_srcptr a, slong an, nn_srcptr b, slong bn, slong znlo, slong zn, nmod_t mod)
{
return 0;
}
int
_nmod_poly_mullow_want_fft_small(slong len1, slong len2, slong n, int squaring, nmod_t mod)
{
return 0;
}
void
_nmod_poly_mulmid_fft_small(nn_ptr z, nn_srcptr a, slong an, nn_srcptr b, slong bn, slong znlo, slong zn, nmod_t mod)
{
flint_throw(FLINT_ERROR, "fft_small is not available");
}
void
_nmod_poly_mullow_fft_small(nn_ptr z, nn_srcptr a, slong an, nn_srcptr b, slong bn, slong zn, nmod_t mod)
{
_nmod_poly_mulmid_fft_small(z, a, an, b, bn, 0, zn, mod);
}
#endif