#include "mpn_mod.h"
#include "mpn_mod/impl.h"
void _mpn_dot_rev_2x2_3(nn_ptr s, nn_srcptr a, nn_srcptr b, slong len)
{
ulong A0, A1, B0, B1;
ulong p2, p1, p0;
ulong s2, s1, s0;
ulong u2, u1;
ulong v2;
slong k;
s2 = s1 = s0 = 0;
u2 = u1 = 0;
v2 = 0;
for (k = 0; k < len; k++)
{
A0 = a[2 * k + 0];
A1 = a[2 * k + 1];
B0 = b[2 * (len - 1 - k) + 0];
B1 = b[2 * (len - 1 - k) + 1];
umul_ppmm(p2, p1, A1, B0);
add_ssaaaa(u2, u1, u2, u1, p2, p1);
p2 = A1 * B1;
umul_ppmm(p1, p0, A0, B0);
add_sssaaaaaa(s2, s1, s0, s2, s1, s0, p2, p1, p0);
umul_ppmm(p2, p1, A0, B1);
add_ssaaaa(v2, u1, v2, u1, p2, p1);
}
u2 = u2 + v2;
add_ssaaaa(s2, s1, s2, s1, u2, u1);
s[0] = s0;
s[1] = s1;
s[2] = s2;
}
void _mpn_dot_rev_2x2_4(nn_ptr s, nn_srcptr a, nn_srcptr b, slong len)
{
ulong A0, A1, B0, B1;
ulong p3, p2, p1, p0;
ulong s3, s2, s1, s0;
ulong u3, u2, u1;
ulong v3, v2;
slong k;
s3 = s2 = s1 = s0 = 0;
u3 = u2 = u1 = 0;
v3 = v2 = 0;
for (k = 0; k < len; k++)
{
A0 = a[2 * k + 0];
A1 = a[2 * k + 1];
B0 = b[2 * (len - 1 - k) + 0];
B1 = b[2 * (len - 1 - k) + 1];
umul_ppmm(p2, p1, A1, B0);
add_sssaaaaaa(u3, u2, u1, u3, u2, u1, UWORD(0), p2, p1);
umul_ppmm(p3, p2, A1, B1);
umul_ppmm(p1, p0, A0, B0);
add_ssssaaaaaaaa(s3, s2, s1, s0, s3, s2, s1, s0, p3, p2, p1, p0);
umul_ppmm(p2, p1, A0, B1);
add_sssaaaaaa(v3, v2, u1, v3, v2, u1, UWORD(0), p2, p1);
}
add_ssaaaa(u3, u2, u3, u2, v3, v2);
add_sssaaaaaa(s3, s2, s1, s3, s2, s1, u3, u2, u1);
s[0] = s0;
s[1] = s1;
s[2] = s2;
s[3] = s3;
}
void _mpn_dot_rev_2x2_5(nn_ptr s, nn_srcptr a, nn_srcptr b, slong len)
{
ulong A0, A1, B0, B1;
ulong p3, p2, p1, p0;
ulong s4, s3, s2, s1, s0;
ulong u2, u1;
ulong v3, v2;
slong k;
s4 = s3 = s2 = s1 = s0 = 0;
u2 = u1 = 0;
v3 = v2 = 0;
for (k = 0; k < len; k++)
{
A0 = a[2 * k + 0];
A1 = a[2 * k + 1];
B0 = b[2 * (len - 1 - k) + 0];
B1 = b[2 * (len - 1 - k) + 1];
umul_ppmm(p2, p1, A1, B0);
add_sssaaaaaa(s3, s2, s1, s3, s2, s1, UWORD(0), p2, p1);
umul_ppmm(p1, p0, B0, A0);
add_sssaaaaaa(u2, u1, s0, u2, u1, s0, UWORD(0), p1, p0);
umul_ppmm(p2, p1, B1, A0);
add_sssaaaaaa(s3, s2, s1, s3, s2, s1, UWORD(0), p2, p1);
umul_ppmm(p3, p2, B1, A1);
add_sssaaaaaa(s4, v3, v2, s4, v3, v2, UWORD(0), p3, p2);
}
add_sssaaaaaa(s3, s2, s1, s3, s2, s1, UWORD(0), u2, u1);
add_sssaaaaaa(s4, s3, s2, s4, s3, s2, UWORD(0), v3, v2);
s[0] = s0;
s[1] = s1;
s[2] = s2;
s[3] = s3;
s[4] = s4;
}
void _mpn_dot_rev_3x3_5(nn_ptr s, nn_srcptr a, nn_srcptr b, slong len)
{
ulong A0, A1, A2, B0, B1, B2;
ulong p4, p3, p2, p1, p0;
ulong s4, s3, s2, s1, s0;
ulong u2, u1;
ulong v3, v2;
slong k;
s4 = s3 = s2 = s1 = s0 = 0;
u2 = u1 = 0;
v3 = v2 = 0;
for (k = 0; k < len; k++)
{
A0 = a[3 * k + 0];
A1 = a[3 * k + 1];
A2 = a[3 * k + 2];
B0 = b[3 * (len - 1 - k) + 0];
B1 = b[3 * (len - 1 - k) + 1];
B2 = b[3 * (len - 1 - k) + 2];
umul_ppmm(p2, p1, A1, B0);
add_sssaaaaaa(s3, s2, s1, s3, s2, s1, UWORD(0), p2, p1);
umul_ppmm(p1, p0, B0, A0);
add_sssaaaaaa(u2, u1, s0, u2, u1, s0, UWORD(0), p1, p0);
umul_ppmm(p2, p1, B1, A0);
add_sssaaaaaa(s3, s2, s1, s3, s2, s1, UWORD(0), p2, p1);
umul_ppmm(p3, p2, B1, A1);
add_sssaaaaaa(s4, v3, v2, s4, v3, v2, UWORD(0), p3, p2);
umul_ppmm(p3, p2, A2, B0);
add_sssaaaaaa(s4, v3, v2, s4, v3, v2, UWORD(0), p3, p2);
umul_ppmm(p3, p2, B2, A0);
add_sssaaaaaa(s4, v3, v2, s4, v3, v2, UWORD(0), p3, p2);
umul_ppmm(p4, p3, A2, B1);
add_ssaaaa(s4, v3, s4, v3, p4, p3);
umul_ppmm(p4, p3, B2, A1);
add_ssaaaa(s4, v3, s4, v3, p4, p3);
s4 += B2 * A2;
}
add_sssaaaaaa(s3, s2, s1, s3, s2, s1, UWORD(0), u2, u1);
add_sssaaaaaa(s4, s3, s2, s4, s3, s2, UWORD(0), v3, v2);
s[0] = s0;
s[1] = s1;
s[2] = s2;
s[3] = s3;
s[4] = s4;
}
void
_mpn_dot_rev_nxn_2n(nn_ptr res, nn_srcptr a, nn_srcptr b, slong len, slong nlimbs)
{
ulong t[2 * MPN_MOD_MAX_LIMBS + 3];
slong slimbs = 2 * nlimbs;
slong j;
flint_mpn_mul_n(res, a, b + (len - 1) * nlimbs, nlimbs);
for (j = 1; j < len; j++)
{
flint_mpn_mul_n(t, a + j * nlimbs, b + (len - 1 - j) * nlimbs, nlimbs);
mpn_add_n(res, res, t, slimbs);
}
}
void
_mpn_dot_rev_nxn_2nm1(nn_ptr res, nn_srcptr a, nn_srcptr b, slong len, slong nlimbs)
{
ulong t[2 * MPN_MOD_MAX_LIMBS + 3];
slong slimbs = 2 * nlimbs - 1;
slong j;
flint_mpn_mul_n(t, a, b + (len - 1) * nlimbs, nlimbs);
flint_mpn_copyi(res, t, slimbs);
for (j = 1; j < len; j++)
{
flint_mpn_mul_n(t, a + j * nlimbs, b + (len - 1 - j) * nlimbs, nlimbs);
mpn_add_n(res, res, t, slimbs);
}
}
void
_mpn_dot_rev_nxn_2np1(nn_ptr res, nn_srcptr a, nn_srcptr b, slong len, slong nlimbs)
{
ulong t[2 * MPN_MOD_MAX_LIMBS + 3];
slong slimbs = 2 * nlimbs + 1;
slong j;
flint_mpn_mul_n(res, a, b + (len - 1) * nlimbs, nlimbs);
res[slimbs - 1] = 0;
for (j = 1; j < len; j++)
{
flint_mpn_mul_n(t, a + j * nlimbs, b + (len - 1 - j) * nlimbs, nlimbs);
res[slimbs - 1] += mpn_add_n(res, res, t, 2 * nlimbs);
}
}
static void
_mpn_poly_mullow_classical(nn_ptr res, nn_srcptr poly1, slong len1, nn_srcptr poly2, slong len2, slong len, slong nlimbs, slong slimbs)
{
slong i, top1, top2;
FLINT_ASSERT((slimbs == 2 * nlimbs) || (slimbs == 2 * nlimbs + 1) || (slimbs == 2 * nlimbs - 1));
if (nlimbs == 2)
{
if (slimbs == 3)
{
for (i = 0; i < len; i++)
{
top1 = FLINT_MIN(len1 - 1, i);
top2 = FLINT_MIN(len2 - 1, i);
_mpn_dot_rev_2x2_3(res + i * slimbs, poly1 + (i - top2) * nlimbs, poly2 + (i - top1) * nlimbs, top1 + top2 - i + 1);
}
}
else if (slimbs == 4)
{
for (i = 0; i < len; i++)
{
top1 = FLINT_MIN(len1 - 1, i);
top2 = FLINT_MIN(len2 - 1, i);
_mpn_dot_rev_2x2_4(res + i * slimbs, poly1 + (i - top2) * nlimbs, poly2 + (i - top1) * nlimbs, top1 + top2 - i + 1);
}
}
else
{
for (i = 0; i < len; i++)
{
top1 = FLINT_MIN(len1 - 1, i);
top2 = FLINT_MIN(len2 - 1, i);
_mpn_dot_rev_2x2_5(res + i * slimbs, poly1 + (i - top2) * nlimbs, poly2 + (i - top1) * nlimbs, top1 + top2 - i + 1);
}
}
}
else if (nlimbs == 3 && slimbs == 5)
{
for (i = 0; i < len; i++)
{
top1 = FLINT_MIN(len1 - 1, i);
top2 = FLINT_MIN(len2 - 1, i);
_mpn_dot_rev_3x3_5(res + i * slimbs, poly1 + (i - top2) * nlimbs, poly2 + (i - top1) * nlimbs, top1 + top2 - i + 1);
}
}
else if (slimbs == 2 * nlimbs - 1)
{
for (i = 0; i < len; i++)
{
top1 = FLINT_MIN(len1 - 1, i);
top2 = FLINT_MIN(len2 - 1, i);
_mpn_dot_rev_nxn_2nm1(res + i * slimbs, poly1 + (i - top2) * nlimbs, poly2 + (i - top1) * nlimbs, top1 + top2 - i + 1, nlimbs);
}
}
else if (slimbs == 2 * nlimbs)
{
for (i = 0; i < len; i++)
{
top1 = FLINT_MIN(len1 - 1, i);
top2 = FLINT_MIN(len2 - 1, i);
_mpn_dot_rev_nxn_2n(res + i * slimbs, poly1 + (i - top2) * nlimbs, poly2 + (i - top1) * nlimbs, top1 + top2 - i + 1, nlimbs);
}
}
else
{
for (i = 0; i < len; i++)
{
top1 = FLINT_MIN(len1 - 1, i);
top2 = FLINT_MIN(len2 - 1, i);
_mpn_dot_rev_nxn_2np1(res + i * slimbs, poly1 + (i - top2) * nlimbs, poly2 + (i - top1) * nlimbs, top1 + top2 - i + 1, nlimbs);
}
}
}
static void
_mpn_poly_sqrlow_classical(nn_ptr res, nn_srcptr poly1, slong len1, slong len, slong nlimbs, slong slimbs)
{
slong i, start, stop;
ulong t[2 * MPN_MOD_MAX_LIMBS + 3];
nn_ptr rp;
FLINT_ASSERT((slimbs == 2 * nlimbs) || (slimbs == 2 * nlimbs + 1) || (slimbs == 2 * nlimbs - 1));
if (slimbs == 2 * nlimbs - 1)
{
flint_mpn_sqr(t, poly1, nlimbs);
flint_mpn_copyi(res, t, slimbs);
for (i = 1; i < FLINT_MIN(len, 2 * len1 - 2); i++)
{
rp = res + i * slimbs;
start = FLINT_MAX(0, i - len1 + 1);
stop = FLINT_MIN(len1 - 1, (i + 1) / 2 - 1);
if (nlimbs == 2)
_mpn_dot_rev_2x2_3(rp, poly1 + start * nlimbs, poly1 + (i - stop) * nlimbs, stop - start + 1);
else if (nlimbs == 3)
_mpn_dot_rev_3x3_5(rp, poly1 + start * nlimbs, poly1 + (i - stop) * nlimbs, stop - start + 1);
else
_mpn_dot_rev_nxn_2nm1(rp, poly1 + start * nlimbs, poly1 + (i - stop) * nlimbs, stop - start + 1, nlimbs);
mpn_lshift(rp, rp, slimbs, 1);
if (i % 2 == 0 && i / 2 < len1)
{
flint_mpn_sqr(t, poly1 + (i / 2) * nlimbs, nlimbs);
mpn_add_n(rp, rp, t, slimbs);
}
}
if (len1 >= 2 && len >= 2 * len1 - 1)
{
flint_mpn_sqr(t, poly1 + (len1 - 1) * nlimbs, nlimbs);
flint_mpn_copyi(res + (2 * len1 - 2) * slimbs, t, slimbs);
}
}
else if (slimbs == 2 * nlimbs)
{
flint_mpn_sqr(res, poly1, nlimbs);
for (i = 1; i < FLINT_MIN(len, 2 * len1 - 2); i++)
{
rp = res + i * slimbs;
start = FLINT_MAX(0, i - len1 + 1);
stop = FLINT_MIN(len1 - 1, (i + 1) / 2 - 1);
if (nlimbs == 2)
_mpn_dot_rev_2x2_4(rp, poly1 + start * nlimbs, poly1 + (i - stop) * nlimbs, stop - start + 1);
else
_mpn_dot_rev_nxn_2n(rp, poly1 + start * nlimbs, poly1 + (i - stop) * nlimbs, stop - start + 1, nlimbs);
mpn_lshift(rp, rp, slimbs, 1);
if (i % 2 == 0 && i / 2 < len1)
{
flint_mpn_sqr(t, poly1 + (i / 2) * nlimbs, nlimbs);
mpn_add_n(rp, rp, t, slimbs);
}
}
if (len1 >= 2 && len >= 2 * len1 - 1)
flint_mpn_sqr(res + (2 * len1 - 2) * slimbs, poly1 + (len1 - 1) * nlimbs, nlimbs);
}
else
{
flint_mpn_sqr(res, poly1, nlimbs);
res[slimbs - 1] = 0;
for (i = 1; i < FLINT_MIN(len, 2 * len1 - 2); i++)
{
rp = res + i * slimbs;
start = FLINT_MAX(0, i - len1 + 1);
stop = FLINT_MIN(len1 - 1, (i + 1) / 2 - 1);
if (nlimbs == 2)
_mpn_dot_rev_2x2_5(rp, poly1 + start * nlimbs, poly1 + (i - stop) * nlimbs, stop - start + 1);
else
_mpn_dot_rev_nxn_2np1(rp, poly1 + start * nlimbs, poly1 + (i - stop) * nlimbs, stop - start + 1, nlimbs);
mpn_lshift(rp, rp, slimbs, 1);
if (i % 2 == 0 && i / 2 < len1)
{
flint_mpn_sqr(t, poly1 + (i / 2) * nlimbs, nlimbs);
rp[slimbs - 1] += mpn_add_n(rp, rp, t, slimbs - 1);
}
}
if (len1 >= 2 && len >= 2 * len1 - 1)
{
flint_mpn_sqr(res + (2 * len1 - 2) * slimbs, poly1 + (len1 - 1) * nlimbs, nlimbs);
res[(2 * len1 - 1) * slimbs - 1] = 0;
}
}
}
FLINT_FORCE_INLINE void
_mpn_poly_add_n(nn_ptr res, nn_srcptr f, slong flen, nn_srcptr g, slong glen, slong nlimbs)
{
slong m = FLINT_MIN(flen, glen);
mpn_add_n(res, f, g, nlimbs * m);
if (flen > glen)
flint_mpn_copyi(res + nlimbs * m, f + nlimbs * m, nlimbs * (flen - m));
if (glen > flen)
flint_mpn_copyi(res + nlimbs * m, g + nlimbs * m, nlimbs * (glen - m));
}
FLINT_FORCE_INLINE void
_mpn_poly_add_n_carry(nn_ptr res, nn_srcptr f, slong flen, nn_srcptr g, slong glen, slong nlimbs)
{
slong m = FLINT_MIN(flen, glen);
slong i;
slong nlimbs2 = nlimbs + 1;
if (nlimbs == 2)
{
for (i = 0; i < m; i++)
add_sssaaaaaa(res[i * nlimbs2 + 2], res[i * nlimbs2 + 1], res[i * nlimbs2],
0, f[i * nlimbs + 1], f[i * nlimbs],
0, g[i * nlimbs + 1], g[i * nlimbs]);
for (i = m; i < flen; i++)
{
flint_mpn_copyi(res + i * nlimbs2, f + i * nlimbs, nlimbs);
res[(i + 1) * nlimbs2 - 1] = 0;
}
for (i = m; i < glen; i++)
{
flint_mpn_copyi(res + i * nlimbs2, g + i * nlimbs, nlimbs);
res[(i + 1) * nlimbs2 - 1] = 0;
}
}
else
{
for (i = 0; i < m; i++)
res[(i + 1) * nlimbs2 - 1] = mpn_add_n(res + i * nlimbs2, f + i * nlimbs, g + i * nlimbs, nlimbs);
for (i = m; i < flen; i++)
{
flint_mpn_copyi(res + i * nlimbs2, f + i * nlimbs, nlimbs);
res[(i + 1) * nlimbs2 - 1] = 0;
}
for (i = m; i < glen; i++)
{
flint_mpn_copyi(res + i * nlimbs2, g + i * nlimbs, nlimbs);
res[(i + 1) * nlimbs2 - 1] = 0;
}
}
}
static void
_mpn_poly_mul_karatsuba(nn_ptr res, nn_srcptr f, slong flen, nn_srcptr g, slong glen, slong nlimbs, slong slimbs, slong cutoff, int norm)
{
slong m, f1len, g1len, tlen, ulen, vlen, alloc;
nn_ptr t, u, v;
nn_srcptr f0, f1, g0, g1;
int squaring = (f == g) && (flen == glen);
TMP_INIT;
if (FLINT_MIN(flen, glen) < FLINT_MAX(cutoff, 2))
{
if (squaring)
_mpn_poly_sqrlow_classical(res, f, flen, flen + glen - 1, nlimbs, slimbs);
else
_mpn_poly_mullow_classical(res, f, flen, g, glen, flen + glen - 1, nlimbs, slimbs);
return;
}
m = (FLINT_MIN(flen, glen) + 1) / 2;
f0 = f;
g0 = g;
f1 = f + m * nlimbs;
g1 = g + m * nlimbs;
f1len = flen - m;
g1len = glen - m;
_mpn_poly_mul_karatsuba(res, f, m, g, m, nlimbs, slimbs, cutoff, norm);
flint_mpn_zero(res + (2 * m - 1) * slimbs, slimbs);
_mpn_poly_mul_karatsuba(res + (2 * m) * slimbs, f1, f1len, g1, g1len, nlimbs, slimbs, cutoff, norm);
tlen = FLINT_MAX(m, f1len);
ulen = FLINT_MAX(m, g1len);
vlen = tlen + ulen - 1;
alloc = tlen * (nlimbs + 1) + ulen * (nlimbs + 1) + vlen * slimbs;
TMP_START;
t = TMP_ALLOC(alloc * sizeof(ulong));
u = t + tlen * (nlimbs + 1);
v = u + ulen * (nlimbs + 1);
if (norm == 0)
{
_mpn_poly_add_n_carry(t, f0, m, f1, f1len, nlimbs);
if (!squaring)
{
_mpn_poly_add_n_carry(u, g0, m, g1, g1len, nlimbs);
_mpn_poly_mul_karatsuba(v, t, tlen, u, ulen, nlimbs + 1, slimbs, cutoff, -1);
}
else
{
_mpn_poly_mul_karatsuba(v, t, tlen, t, tlen, nlimbs + 1, slimbs, cutoff, -1);
}
}
else
{
_mpn_poly_add_n(t, f0, m, f1, f1len, nlimbs);
if (!squaring)
{
_mpn_poly_add_n(u, g0, m, g1, g1len, nlimbs);
_mpn_poly_mul_karatsuba(v, t, tlen, u, ulen, nlimbs, slimbs, cutoff, norm - 1);
}
else
{
_mpn_poly_mul_karatsuba(v, t, tlen, t, tlen, nlimbs, slimbs, cutoff, norm - 1);
}
}
mpn_sub_n(v, v, res, (2 * m - 1) * slimbs);
mpn_sub_n(v, v, res + 2 * m * slimbs, (f1len + g1len - 1) * slimbs);
mpn_add_n(res + m * slimbs, res + m * slimbs, v, vlen * slimbs);
TMP_END;
}
int
_mpn_mod_poly_mulmid_karatsuba(nn_ptr res, nn_srcptr poly1, slong len1, nn_srcptr poly2, slong len2, slong nlo, slong nhi, slong cutoff, gr_ctx_t ctx)
{
nn_ptr t;
slong i, l;
slong nlimbs, slimbs;
flint_bitcnt_t sbits;
int norm;
TMP_INIT;
TMP_START;
norm = MPN_MOD_CTX_NORM(ctx);
len1 = FLINT_MIN(len1, nhi);
len2 = FLINT_MIN(len2, nhi);
nlimbs = MPN_MOD_CTX_NLIMBS(ctx);
sbits = MPN_MOD_CTX_MODULUS_BITS(ctx);
if (nlo != 0)
{
slong nlo2 = (len1 + len2 - 1) - nlo;
if (len1 > nlo2)
{
slong trunc = len1 - nlo2;
poly1 += trunc * nlimbs;
len1 -= trunc;
nlo -= trunc;
nhi -= trunc;
}
if (len2 > nlo2)
{
slong trunc = len2 - nlo2;
poly2 += trunc * nlimbs;
len2 -= trunc;
nlo -= trunc;
nhi -= trunc;
}
}
if (cutoff == -1)
{
if (poly1 == poly2 && len1 == len2)
{
if (sbits <= 128)
cutoff = 32;
else if (sbits <= 154)
cutoff = 24;
else if (sbits <= 448)
cutoff = 12;
else if (sbits <= 600)
cutoff = 6;
else
cutoff = 4;
}
else
{
if (sbits <= 155)
cutoff = 12;
else if (sbits <= 320)
cutoff = 8;
else if (sbits <= 560)
cutoff = 6;
else
cutoff = 4;
}
cutoff = FLINT_MIN(cutoff, len1);
cutoff = FLINT_MIN(cutoff, len2);
}
sbits = 2 * sbits + 2 * FLINT_BIT_COUNT(FLINT_MIN(len1, len2));
slimbs = (sbits + FLINT_BITS - 1) / FLINT_BITS;
t = TMP_ALLOC(sizeof(ulong) * slimbs * (len1 + len2 - 1));
_mpn_poly_mul_karatsuba(t, poly1, len1, poly2, len2, nlimbs, slimbs, cutoff, norm);
for (i = nlo; i < nhi; i++)
{
l = slimbs;
MPN_NORM(t + i * slimbs, l);
mpn_mod_set_mpn(res + (i - nlo) * nlimbs, t + i * slimbs, l, ctx);
}
TMP_END;
return GR_SUCCESS;
}
int
_mpn_mod_poly_mullow_karatsuba(nn_ptr res, nn_srcptr poly1, slong len1, nn_srcptr poly2, slong len2, slong len, slong cutoff, gr_ctx_t ctx)
{
return _mpn_mod_poly_mulmid_karatsuba(res, poly1, len1, poly2, len2, 0, len, cutoff, ctx);
}