#include "mpn_extras.h"
#include "fmpz.h"
#include "fmpz_vec.h"
#include "fmpz_poly.h"
static slong flint_mpn_tdiv_q_fmpz_inplace(nn_ptr arrayg, slong limbsg, fmpz_t gc)
{
if (fmpz_size(gc) == 1)
{
mpn_divmod_1(arrayg, arrayg, limbsg, fmpz_get_ui(gc));
return limbsg - (arrayg[limbsg - 1] == 0);
}
else
{
slong tlimbs;
mpz_ptr mgc = COEFF_TO_PTR(*gc);
nn_ptr temp = flint_malloc(limbsg*sizeof(ulong));
flint_mpn_copyi(temp, arrayg, limbsg);
mpn_tdiv_q(arrayg, temp, limbsg, mgc->_mp_d, mgc->_mp_size);
tlimbs = limbsg - mgc->_mp_size + 1;
tlimbs -= (arrayg[tlimbs - 1] == 0);
flint_free(temp);
return tlimbs;
}
}
static int multiplies_out(fmpz * P, slong len, const fmpz * Q, slong qlen,
const fmpz * G, slong glen, slong sign, fmpz * temp)
{
int divides = 0;
if (qlen >= glen)
_fmpz_poly_mul(temp, Q, qlen, G, glen);
else
_fmpz_poly_mul(temp, G, glen, Q, qlen);
if (sign < WORD(0)) _fmpz_vec_neg(temp, temp, glen + qlen - 1);
divides = (glen + qlen - 1 == len && _fmpz_vec_equal(temp, P, len));
return divides;
}
int
_fmpz_poly_gcd_heuristic(fmpz * res, const fmpz * poly1, slong len1,
const fmpz * poly2, slong len2)
{
slong bits1, bits2, max_bits, pack_bits, bound_bits, bits_G, bits_Q;
ulong limbs1, limbs2, limbsg, pack_limbs, qlimbs, qlimbs2;
ulong log_glen, log_length;
slong sign1, sign2, glen, qlen, qlen2;
fmpz_t ac, bc, d, gc;
fmpz * A, * B, * G, * Q, * t;
nn_ptr array1, array2, arrayg, q, temp;
int divides;
fmpz_init(ac);
fmpz_init(bc);
fmpz_init(d);
_fmpz_poly_content(ac, poly1, len1);
_fmpz_poly_content(bc, poly2, len2);
fmpz_gcd(d, ac, bc);
if (len2 == 1)
{
fmpz_set(res, d);
fmpz_clear(ac);
fmpz_clear(bc);
fmpz_clear(d);
return 1;
}
A = _fmpz_vec_init(len1);
B = _fmpz_vec_init(len2);
_fmpz_vec_scalar_divexact_fmpz(A, poly1, len1, ac);
_fmpz_vec_scalar_divexact_fmpz(B, poly2, len2, bc);
fmpz_clear(ac);
fmpz_clear(bc);
if (len2 == 2)
{
Q = _fmpz_vec_init(len1 - len2 + 1);
if (_fmpz_poly_divides(Q, A, len1, B, 2))
{
_fmpz_vec_scalar_mul_fmpz(res, B, 2, d);
if (fmpz_sgn(res + 1) < 0)
_fmpz_vec_neg(res, res, 2);
}
else
{
fmpz_set(res, d);
fmpz_zero(res + 1);
}
fmpz_clear(d);
_fmpz_vec_clear(A, len1);
_fmpz_vec_clear(B, len2);
_fmpz_vec_clear(Q, len1 - len2 + 1);
return 1;
}
bits1 = FLINT_ABS(_fmpz_vec_max_bits(A, len1));
bits2 = FLINT_ABS(_fmpz_vec_max_bits(B, len2));
max_bits = FLINT_MAX(bits1, bits2) + 1;
bound_bits = FLINT_MIN(bits1, bits2) + 6;
pack_bits = FLINT_MAX(bound_bits, max_bits);
pack_limbs = (pack_bits - 1)/FLINT_BITS + 1;
if (pack_bits >= 32)
pack_bits = FLINT_BITS*pack_limbs;
limbs1 = (pack_bits*len1 - 1)/FLINT_BITS + 1;
limbs2 = (pack_bits*len2 - 1)/FLINT_BITS + 1;
array1 = flint_calloc(limbs1, sizeof(ulong));
array2 = flint_calloc(limbs2, sizeof(ulong));
arrayg = flint_calloc(limbs2, sizeof(ulong));
sign1 = (slong) fmpz_sgn(A + len1 - 1);
_fmpz_poly_bit_pack(array1, A, len1, pack_bits, sign1);
while (array1[limbs1 - 1] == 0) limbs1--;
sign2 = (slong) fmpz_sgn(B + len2 - 1);
_fmpz_poly_bit_pack(array2, B, len2, pack_bits, sign2);
while (array2[limbs2 - 1] == 0) limbs2--;
limbsg = flint_mpn_gcd_full(arrayg, array1, limbs1, array2, limbs2);
glen = FLINT_MIN((slong) ((limbsg * FLINT_BITS) / pack_bits) + 1, len2);
G = _fmpz_vec_init(glen);
flint_mpn_zero(arrayg + limbsg, limbs2-limbsg);
_fmpz_poly_bit_unpack(G, 0, glen, arrayg, pack_bits, 0);
while (G[glen - 1] == 0) glen--;
fmpz_init(gc);
_fmpz_poly_content(gc, G, glen);
if (!fmpz_is_one(gc))
limbsg = flint_mpn_tdiv_q_fmpz_inplace(arrayg, limbsg, gc);
qlimbs = limbs1 - limbsg + 1;
qlen = FLINT_MIN(len1, (slong) ((qlimbs * FLINT_BITS) / pack_bits) + 1);
qlimbs2 = limbs2 - limbsg + 1;
qlen2 = FLINT_MIN(len2, (slong) ((qlimbs2 * FLINT_BITS) / pack_bits) + 1);
qlimbs = (FLINT_MAX(qlen, qlen2)*pack_bits - 1)/FLINT_BITS + 1;
q = flint_calloc(qlimbs, sizeof(ulong));
temp = flint_malloc(limbsg*sizeof(ulong));
divides = 0;
if (flint_mpn_divides(q, array1, limbs1, arrayg, limbsg, temp))
{
Q = _fmpz_vec_init(len1);
t = _fmpz_vec_init(len1 + glen);
_fmpz_poly_bit_unpack(Q, 0, qlen, q, pack_bits, 0);
while (Q[qlen - 1] == 0) qlen--;
_fmpz_vec_scalar_divexact_fmpz(G, G, glen, gc);
bits_G = FLINT_ABS(_fmpz_vec_max_bits(G, glen));
bits_Q = FLINT_ABS(_fmpz_vec_max_bits(Q, qlen));
log_glen = FLINT_BIT_COUNT(glen);
log_length = FLINT_MIN(log_glen, FLINT_BIT_COUNT(qlen));
divides = (bits_G + bits_Q + log_length < (ulong) pack_bits);
if (!divides)
divides = multiplies_out(A, len1, Q, qlen, G, glen, sign1, t);
if (divides)
{
divides = 0;
flint_mpn_zero(q, qlimbs);
if (flint_mpn_divides(q, array2, limbs2, arrayg, limbsg, temp))
{
_fmpz_poly_bit_unpack(Q, 0, qlen2, q, pack_bits, 0);
while (Q[qlen2 - 1] == 0) qlen2--;
bits_Q = FLINT_ABS(_fmpz_vec_max_bits(Q, qlen2));
log_length = FLINT_MIN(log_glen, FLINT_BIT_COUNT(qlen2));
divides = (bits_G + bits_Q + log_length < (ulong) pack_bits);
if (!divides)
divides = multiplies_out(B, len2, Q, qlen2, G, glen, sign1, t);
}
}
_fmpz_vec_clear(t, len1 + glen);
_fmpz_vec_clear(Q, len1);
}
flint_free(q);
flint_free(temp);
flint_free(arrayg);
flint_free(array1);
flint_free(array2);
fmpz_clear(gc);
_fmpz_vec_clear(A, len1);
_fmpz_vec_clear(B, len2);
if (divides)
{
_fmpz_vec_zero(res + glen, len2 - glen);
_fmpz_vec_scalar_mul_fmpz(res, G, glen, d);
}
fmpz_clear(d);
_fmpz_vec_clear(G, glen);
return divides;
}
int
fmpz_poly_gcd_heuristic(fmpz_poly_t res, const fmpz_poly_t poly1,
const fmpz_poly_t poly2)
{
if (poly1->length < poly2->length)
{
return fmpz_poly_gcd_heuristic(res, poly2, poly1);
}
else
{
const slong len1 = poly1->length;
const slong len2 = poly2->length;
int done = 1;
if (len1 == 0)
{
fmpz_poly_zero(res);
}
else if (len2 == 0)
{
if (fmpz_sgn(poly1->coeffs + (len1 - 1)) > 0)
fmpz_poly_set(res, poly1);
else
fmpz_poly_neg(res, poly1);
}
else
{
fmpz_poly_fit_length(res, len2);
done = _fmpz_poly_gcd_heuristic(res->coeffs, poly1->coeffs, len1,
poly2->coeffs, len2);
if (done)
{
_fmpz_poly_set_length(res, len2);
_fmpz_poly_normalise(res);
}
}
return done;
}
}