#include "mpn_extras.h"
void flint_mpn_mulmod_preinvn(mp_ptr r,
mp_srcptr a, mp_srcptr b, mp_size_t n,
mp_srcptr d, mp_srcptr dinv, ulong norm)
{
mp_limb_t cy, b0, b1;
if (n == 2)
{
mp_limb_t t[10];
if (norm)
{
b0 = (b[0] >> norm) | (b[1] << (FLINT_BITS - norm));
b1 = b[1] >> norm;
} else
{
b0 = b[0];
b1 = b[1];
}
FLINT_MPN_MUL_2X2(t[3], t[2], t[1], t[0], a[1], a[0], b1, b0);
FLINT_MPN_MUL_2X2(t[9], t[8], t[7], t[6], t[3], t[2], dinv[1], dinv[0]);
add_ssaaaa(t[9], t[8], t[9], t[8], t[3], t[2]);
FLINT_MPN_MUL_3P2X2(t[6], t[5], t[4], t[9], t[8], d[1], d[0]);
sub_dddmmmsss(cy, r[1], r[0], t[2], t[1], t[0], t[6], t[5], t[4]);
while (cy > 0)
{
sub_dddmmmsss(cy, r[1], r[0], cy, r[1], r[0], 0, d[1], d[0]);
}
if ((r[1] > d[1]) || (r[1] == d[1] && r[0] >= d[0]))
{
sub_ddmmss(r[1], r[0], r[1], r[0], d[1], d[0]);
}
} else
{
mp_ptr t;
TMP_INIT;
TMP_START;
t = TMP_ALLOC(5*n*sizeof(mp_limb_t));
if (a == b)
flint_mpn_sqr(t, a, n);
else
flint_mpn_mul_n(t, a, b, n);
if (norm)
mpn_rshift(t, t, 2*n, norm);
flint_mpn_mul_or_mulhigh_n(t + 3*n, t + n, dinv, n);
mpn_add_n(t + 4*n, t + 4*n, t + n, n);
flint_mpn_mul_or_mullow_n(t + 2*n, t + 4*n, d, n);
cy = t[n] - t[3*n] - mpn_sub_n(r, t, t + 2*n, n);
while (cy > 0)
cy -= mpn_sub_n(r, r, d, n);
if (mpn_cmp(r, d, n) >= 0)
mpn_sub_n(r, r, d, n);
FLINT_ASSERT(mpn_cmp(r, d, n) < 0);
TMP_END;
}
}
void flint_mpn_mulmod_preinvn_2(mp_ptr r,
mp_srcptr a, mp_srcptr b,
mp_srcptr d, mp_srcptr dinv, ulong norm)
{
mp_limb_t cy, b0, b1, r0, r1;
mp_limb_t t[10];
if (norm)
{
b0 = (b[0] << norm);
b1 = (b[1] << norm) | (b[0] >> (FLINT_BITS - norm));
}
else
{
b0 = b[0];
b1 = b[1];
}
FLINT_MPN_MUL_2X2(t[3], t[2], t[1], t[0], a[1], a[0], b1, b0);
FLINT_MPN_MUL_2X2(t[9], t[8], t[7], t[6], t[3], t[2], dinv[1], dinv[0]);
add_ssaaaa(t[9], t[8], t[9], t[8], t[3], t[2]);
FLINT_MPN_MUL_3P2X2(t[6], t[5], t[4], t[9], t[8], d[1], d[0]);
sub_dddmmmsss(cy, r1, r0, t[2], t[1], t[0], t[6], t[5], t[4]);
while (cy > 0)
{
sub_dddmmmsss(cy, r1, r0, cy, r1, r0, 0, d[1], d[0]);
}
if ((r1 > d[1]) || (r1 == d[1] && r0 >= d[0]))
{
sub_ddmmss(r1, r0, r1, r0, d[1], d[0]);
}
if (norm)
{
r[0] = (r0 >> norm) | (r1 << (FLINT_BITS - norm));
r[1] = (r1 >> norm);
}
else
{
r[0] = r0;
r[1] = r1;
}
}
void flint_mpn_fmmamod_preinvn(mp_ptr r,
mp_srcptr a, mp_srcptr b,
mp_srcptr e, mp_srcptr f,
mp_size_t n,
mp_srcptr d, mp_srcptr dinv, ulong norm)
{
mp_ptr t, u;
ulong cy;
TMP_INIT;
TMP_START;
t = TMP_ALLOC((7 * n) * sizeof(mp_limb_t));
u = t + (5 * n);
if (a == b)
flint_mpn_sqr(t, a, n);
else
flint_mpn_mul_n(t, a, b, n);
if (e == f)
flint_mpn_sqr(u, e, n);
else
flint_mpn_mul_n(u, e, f, n);
if (norm)
{
mpn_add_n(t, t, u, 2 * n);
cy = mpn_lshift(t, t, 2 * n, norm);
}
else
{
cy = mpn_add_n(t, t, u, 2 * n);
}
if (cy != 0 || mpn_cmp(t + n, d, n) >= 0)
{
mpn_sub_n(t + n, t + n, d, n);
}
flint_mpn_mul_or_mulhigh_n(t + 3 * n, t + n, dinv, n);
mpn_add_n(t + 4 * n, t + 4 * n, t + n, n);
flint_mpn_mul_or_mullow_n(t + 2 * n, t + 4 * n, d, n);
cy = t[n] - t[3 * n] - mpn_sub_n(r, t, t + 2 * n, n);
while (cy > 0)
cy -= mpn_sub_n(r, r, d, n);
if (mpn_cmp(r, d, n) >= 0)
mpn_sub_n(r, r, d, n);
FLINT_ASSERT(mpn_cmp(r, d, n) < 0);
if (norm)
mpn_rshift(r, r, n, norm);
TMP_END;
}
void flint_mpn_fmmamod_preinvn_2(mp_ptr r,
mp_srcptr a, mp_srcptr b,
mp_srcptr e, mp_srcptr f,
mp_srcptr d, mp_srcptr dinv, ulong norm)
{
mp_limb_t cy, b0, b1, r0, r1;
mp_limb_t f0, f1;
mp_limb_t t[10], u[4];
if (norm)
{
b0 = (b[0] << norm);
b1 = (b[1] << norm) | (b[0] >> (FLINT_BITS - norm));
f0 = (f[0] << norm);
f1 = (f[1] << norm) | (f[0] >> (FLINT_BITS - norm));
}
else
{
b0 = b[0];
b1 = b[1];
f0 = f[0];
f1 = f[1];
}
FLINT_MPN_MUL_2X2(t[3], t[2], t[1], t[0], a[1], a[0], b1, b0);
FLINT_MPN_MUL_2X2(u[3], u[2], u[1], u[0], e[1], e[0], f1, f0);
add_sssssaaaaaaaaaa(cy, t[3], t[2], t[1], t[0],
0, t[3], t[2], t[1], t[0],
0, u[3], u[2], u[1], u[0]);
if (cy || mpn_cmp(t + 2, d, 2) >= 0)
sub_ddmmss(t[3], t[2], t[3], t[2], d[1], d[0]);
FLINT_MPN_MUL_2X2(t[9], t[8], t[7], t[6], t[3], t[2], dinv[1], dinv[0]);
add_ssaaaa(t[9], t[8], t[9], t[8], t[3], t[2]);
FLINT_MPN_MUL_3P2X2(t[6], t[5], t[4], t[9], t[8], d[1], d[0]);
sub_dddmmmsss(cy, r1, r0, t[2], t[1], t[0], t[6], t[5], t[4]);
while (cy > 0)
{
sub_dddmmmsss(cy, r1, r0, cy, r1, r0, 0, d[1], d[0]);
}
if ((r1 > d[1]) || (r1 == d[1] && r0 >= d[0]))
{
sub_ddmmss(r1, r0, r1, r0, d[1], d[0]);
}
if (norm)
{
r[0] = (r0 >> norm) | (r1 << (FLINT_BITS - norm));
r[1] = (r1 >> norm);
}
else
{
r[0] = r0;
r[1] = r1;
}
}