#include "mpn_extras.h"
#if FLINT_WANT_ASSERT
# define FLINT_ASSERT_NOCARRY(x) FLINT_ASSERT((x) == UWORD(0))
#else
# define FLINT_ASSERT_NOCARRY(x) x
#endif
#define TOOM32_MUL_N_REC(p, a, b, n, ws) \
do { \
flint_mpn_mul_n(p, a, b, n); \
} while (0)
void
flint_mpn_mul_toom32(mp_ptr pp,
mp_srcptr ap, mp_size_t an,
mp_srcptr bp, mp_size_t bn,
mp_ptr scratch)
{
mp_size_t n, s, t;
int vm1_neg;
mp_limb_t cy;
mp_limb_signed_t hi;
mp_limb_t ap1_hi, bp1_hi;
#define a0 ap
#define a1 (ap + n)
#define a2 (ap + 2 * n)
#define b0 bp
#define b1 (bp + n)
FLINT_ASSERT(bn + 2 <= an && an + 6 <= 3*bn);
n = 2 * an >= 3 * bn ? (an + 2) / (size_t) 3 : ((ulong) bn + 1) >> 1;
s = an - 2 * n;
t = bn - n;
FLINT_ASSERT(0 < s && s <= n);
FLINT_ASSERT(0 < t && t <= n);
FLINT_ASSERT(s + t >= n);
#define ap1 (pp)
#define bp1 (pp + n)
#define am1 (pp + 2*n)
#define bm1 (pp + 3*n)
#define v1 (scratch)
#define vm1 (pp)
#define scratch_out (scratch + 2*n + 1)
ap1_hi = mpn_add(ap1, a0, n, a2, s);
#if FLINT_HAVE_NATIVE_mpn_add_n_sub_n
if (ap1_hi == 0 && mpn_cmp(ap1, a1, n) < 0)
{
ap1_hi = mpn_add_n_sub_n(ap1, am1, a1, ap1, n) >> 1;
hi = 0;
vm1_neg = 1;
}
else
{
cy = mpn_add_n_sub_n(ap1, am1, ap1, a1, n);
hi = ap1_hi - (cy & 1);
ap1_hi += (cy >> 1);
vm1_neg = 0;
}
#else
if (ap1_hi == 0 && mpn_cmp(ap1, a1, n) < 0)
{
FLINT_ASSERT_NOCARRY(mpn_sub_n(am1, a1, ap1, n));
hi = 0;
vm1_neg = 1;
}
else
{
hi = ap1_hi - mpn_sub_n(am1, ap1, a1, n);
vm1_neg = 0;
}
ap1_hi += mpn_add_n(ap1, ap1, a1, n);
#endif
if (t == n)
{
#if FLINT_HAVE_NATIVE_mpn_add_n_sub_n
if (mpn_cmp(b0, b1, n) < 0)
{
cy = mpn_add_n_sub_n(bp1, bm1, b1, b0, n);
vm1_neg ^= 1;
}
else
{
cy = mpn_add_n_sub_n(bp1, bm1, b0, b1, n);
}
bp1_hi = cy >> 1;
#else
bp1_hi = mpn_add_n(bp1, b0, b1, n);
if (mpn_cmp(b0, b1, n) < 0)
{
FLINT_ASSERT_NOCARRY(mpn_sub_n(bm1, b1, b0, n));
vm1_neg ^= 1;
}
else
{
FLINT_ASSERT_NOCARRY(mpn_sub_n(bm1, b0, b1, n));
}
#endif
}
else
{
bp1_hi = mpn_add(bp1, b0, n, b1, t);
if (mpn_zero_p(b0 + t, n - t) && mpn_cmp(b0, b1, t) < 0)
{
FLINT_ASSERT_NOCARRY(mpn_sub_n(bm1, b1, b0, t));
flint_mpn_zero(bm1 + t, n - t);
vm1_neg ^= 1;
}
else
{
FLINT_ASSERT_NOCARRY(mpn_sub(bm1, b0, n, b1, t));
}
}
TOOM32_MUL_N_REC(v1, ap1, bp1, n, scratch_out);
if (ap1_hi == 1)
{
cy = mpn_add_n(v1 + n, v1 + n, bp1, n);
}
else if (ap1_hi > 1)
{
#if FLINT_HAVE_NATIVE_mpn_addlsh1_n_ip1
cy = mpn_addlsh1_n_ip1(v1 + n, bp1, n);
#else
cy = mpn_addmul_1(v1 + n, bp1, n, UWORD(2));
#endif
}
else
cy = 0;
if (bp1_hi != 0)
cy += ap1_hi + mpn_add_n(v1 + n, v1 + n, ap1, n);
v1[2 * n] = cy;
TOOM32_MUL_N_REC(vm1, am1, bm1, n, scratch_out);
if (hi)
hi = mpn_add_n(vm1+n, vm1+n, bm1, n);
vm1[2*n] = hi;
if (vm1_neg)
{
#if FLINT_HAVE_NATIVE_mpn_rsh1sub_n
mpn_rsh1sub_n(v1, v1, vm1, 2*n+1);
#else
mpn_sub_n(v1, v1, vm1, 2*n+1);
FLINT_ASSERT_NOCARRY(mpn_rshift(v1, v1, 2*n+1, 1));
#endif
}
else
{
#if FLINT_HAVE_NATIVE_mpn_rsh1add_n
mpn_rsh1add_n(v1, v1, vm1, 2*n+1);
#else
mpn_add_n(v1, v1, vm1, 2*n+1);
FLINT_ASSERT_NOCARRY(mpn_rshift(v1, v1, 2*n+1, 1));
#endif
}
hi = vm1[2*n];
cy = mpn_add_n(pp + 2*n, v1, v1 + n, n);
MPN_INCR_U(v1 + n, n + 1, cy + v1[2*n]);
if (vm1_neg)
{
cy = mpn_add_n(v1, v1, vm1, n);
hi += mpn_add_nc(pp + 2*n, pp + 2*n, vm1 + n, n, cy);
MPN_INCR_U(v1 + n, n+1, hi);
}
else
{
cy = mpn_sub_n(v1, v1, vm1, n);
hi += mpn_sub_nc(pp + 2*n, pp + 2*n, vm1 + n, n, cy);
MPN_DECR_U(v1 + n, n+1, hi);
}
TOOM32_MUL_N_REC(pp, a0, b0, n, scratch_out);
if (s > t) flint_mpn_mul(pp+3*n, a2, s, b1, t);
else flint_mpn_mul(pp+3*n, b1, t, a2, s);
cy = mpn_sub_n(pp + n, pp + n, pp+3*n, n);
hi = scratch[2*n] + cy;
cy = mpn_sub_nc(pp + 2*n, pp + 2*n, pp, n, cy);
hi -= mpn_sub_nc(pp + 3*n, scratch + n, pp + n, n, cy);
hi += mpn_add(pp + n, pp + n, 3*n, scratch, n);
if (FLINT_LIKELY(s + t > n))
{
hi -= mpn_sub(pp + 2*n, pp + 2*n, 2*n, pp + 4*n, s+t-n);
FLINT_ASSERT(hi >= 0);
MPN_INCR_U(pp + 4*n, s+t-n, hi);
}
else
{
FLINT_ASSERT(hi == 0);
}
}