#include "mpn_extras.h"
#define TOOM22_MUL_N_REC(p, a, b, n, ws) \
do { \
if (FLINT_HAVE_MUL_N_FUNC(n)) \
FLINT_MPN_MUL_N_HARD(p, a, b, n); \
else \
flint_mpn_mul_toom22(p, a, n, b, n, ws); \
} while (0)
#define TOOM22_MUL_REC(p, a, an, b, bn, ws) \
do { \
if (FLINT_HAVE_MUL_FUNC(an, bn)) \
FLINT_MPN_MUL_HARD(p, a, an, b, bn); \
else if (bn <= 12 || 4 * an >= 5 * bn) \
flint_mpn_mul(p, a, an, b, bn); \
else \
flint_mpn_mul_toom22(p, a, an, b, bn, ws); \
} while (0)
void
flint_mpn_mul_toom22(mp_ptr pp,
mp_srcptr ap, mp_size_t an,
mp_srcptr bp, mp_size_t bn,
mp_ptr scratch)
{
mp_size_t n, s, t;
int vm1_neg;
mp_limb_t cy, cy2;
mp_ptr asm1;
mp_ptr bsm1;
TMP_INIT;
TMP_START;
if (scratch == NULL)
{
scratch = TMP_ALLOC(sizeof(mp_limb_t) * (2 * ((an) + FLINT_BITS)));
}
#define a0 ap
#define a1 (ap + n)
#define b0 bp
#define b1 (bp + n)
s = an >> 1;
n = an - s;
t = bn - n;
FLINT_ASSERT(an >= bn);
FLINT_ASSERT(0 < s && s <= n && (n - s) == (an & 1));
FLINT_ASSERT(0 < t && t <= s);
asm1 = pp;
bsm1 = pp + n;
vm1_neg = 0;
if ((an & 1) == 0)
{
if (mpn_cmp(a0, a1, n) < 0)
{
mpn_sub_n(asm1, a1, a0, n);
vm1_neg = 1;
}
else
{
mpn_sub_n(asm1, a0, a1, n);
}
}
else
{
if (a0[s] == 0 && mpn_cmp(a0, a1, s) < 0)
{
mpn_sub_n(asm1, a1, a0, s);
asm1[s] = 0;
vm1_neg = 1;
}
else
{
asm1[s] = a0[s] - mpn_sub_n(asm1, a0, a1, s);
}
}
if (t == n)
{
if (mpn_cmp(b0, b1, n) < 0)
{
mpn_sub_n(bsm1, b1, b0, n);
vm1_neg ^= 1;
}
else
{
mpn_sub_n(bsm1, b0, b1, n);
}
}
else
{
if (flint_mpn_zero_p(b0 + t, n - t) && mpn_cmp(b0, b1, t) < 0)
{
mpn_sub_n(bsm1, b1, b0, t);
flint_mpn_zero(bsm1 + t, n - t);
vm1_neg ^= 1;
}
else
{
mpn_sub(bsm1, b0, n, b1, t);
}
}
#define v0 pp
#define vinf (pp + 2 * n)
#define vm1 scratch
#define scratch_out scratch + 2 * n
TOOM22_MUL_N_REC (vm1, asm1, bsm1, n, scratch_out);
if (s > t)
TOOM22_MUL_REC(vinf, a1, s, b1, t, scratch_out);
else
TOOM22_MUL_N_REC(vinf, a1, b1, s, scratch_out);
TOOM22_MUL_N_REC(v0, ap, bp, n, scratch_out);
cy = mpn_add_n(pp + 2 * n, v0 + n, vinf, n);
cy2 = cy + mpn_add_n(pp + n, pp + 2 * n, v0, n);
cy += mpn_add(pp + 2 * n, pp + 2 * n, n, vinf + n, s + t - n);
if (vm1_neg)
{
cy += mpn_add_n(pp + n, pp + n, vm1, 2 * n);
}
else
{
cy -= mpn_sub_n(pp + n, pp + n, vm1, 2 * n);
if (FLINT_UNLIKELY(cy + 1 == 0))
{
#if FLINT_WANT_ASSERT
FLINT_ASSERT(cy2 == 1);
cy += mpn_add_1(pp + 2 * n, pp + 2 * n, n, cy2);
FLINT_ASSERT (cy == 0);
#else
flint_mpn_zero(pp + 2 * n, n);
FLINT_ASSERT(s + t == n || flint_mpn_zero_p(pp + 3 * n, s + t - n));
#endif
goto cleanup;
}
}
FLINT_ASSERT(cy <= 2);
FLINT_ASSERT(cy2 <= 2);
MPN_INCR_U(pp + 2 * n, s + t, cy2);
MPN_INCR_U(pp + 3 * n, s + t - n, cy);
cleanup:
TMP_END;
}