#include <math.h>
#include "nmod.h"
#include "fq_zech.h"
#include "mpoly.h"
#include "nmod_mpoly.h"
#include "fq_zech_mpoly.h"
static int _is_proved_not_square_medprime(
int count,
flint_rand_t state,
const ulong * Acoeffs,
const ulong * Aexps,
slong Alen,
flint_bitcnt_t Abits,
const mpoly_ctx_t mctx,
nmod_t mod)
{
int success = 0;
slong i;
fq_zech_struct eval[1], * t, * alphas, ** alpha_ptrs;
fq_zech_ctx_t fqctx;
slong edeg, max_degree = n_flog(1000000, mod.n);
TMP_INIT;
FLINT_ASSERT(Alen > 0);
edeg = (max_degree + count - 2)/2;
edeg = FLINT_MAX(2, edeg);
if (edeg > max_degree)
return 0;
fq_zech_ctx_init_ui(fqctx, mod.n, edeg, "#");
fq_zech_init(eval, fqctx);
TMP_START;
alphas = (fq_zech_struct *) TMP_ALLOC(mctx->nvars*sizeof(fq_zech_struct));
alpha_ptrs = (fq_zech_struct **) TMP_ALLOC(mctx->nvars*sizeof(fq_zech_struct *));
for (i = 0; i < mctx->nvars; i++)
{
alpha_ptrs[i] = alphas + i;
fq_zech_init(alphas + i, fqctx);
}
t = (fq_zech_struct *) TMP_ALLOC(Alen*sizeof(fq_zech_struct));
for (i = 0; i < Alen; i++)
{
fq_zech_init(t + i, fqctx);
fq_zech_set_ui(t + i, Acoeffs[i], fqctx);
}
count *= 3;
next_p:
for (i = 0; i < mctx->nvars; i++)
fq_zech_rand(alphas + i, state, fqctx);
_fq_zech_mpoly_eval_all_fq_zech(eval, t, Aexps, Alen, Abits,
alpha_ptrs, mctx, fqctx);
success = !fq_zech_is_square(eval, fqctx);
if (!success && --count >= 0)
goto next_p;
fq_zech_clear(eval, fqctx);
fq_zech_ctx_clear(fqctx);
TMP_END;
return success;
}
static int _is_proved_not_square(
int count,
flint_rand_t state,
const ulong * Acoeffs,
const ulong * Aexps,
slong Alen,
flint_bitcnt_t Abits,
const mpoly_ctx_t mctx,
nmod_t mod)
{
int tries_left, success = 0;
slong i, N = mpoly_words_per_exp(Abits, mctx);
ulong eval, * alphas;
ulong * t;
TMP_INIT;
FLINT_ASSERT(Alen > 0);
TMP_START;
t = (ulong *) TMP_ALLOC(N*sizeof(ulong));
if (count == 1)
{
success = mpoly_is_proved_not_square(Aexps, Alen, Abits, N, t);
if (success)
goto cleanup;
}
tries_left = 3*count;
alphas = (ulong *) TMP_ALLOC(mctx->nvars*sizeof(ulong));
next_p:
for (i = 0; i < mctx->nvars; i++)
alphas[i] = n_urandint(state, mod.n);
eval = _nmod_mpoly_eval_all_ui(Acoeffs, Aexps, Alen, Abits, alphas, mctx, mod);
success = n_jacobi_unsigned(eval, mod.n) < 0;
if (!success && --tries_left >= 0)
goto next_p;
cleanup:
TMP_END;
if (!success)
success = _is_proved_not_square_medprime(count, state,
Acoeffs, Aexps, Alen, Abits, mctx, mod);
return success;
}
static int _nmod_mpoly_sqrt_heap1(
nmod_mpoly_t Q,
const ulong * Acoeffs,
const ulong * Aexps,
slong Alen,
flint_bitcnt_t bits,
const mpoly_ctx_t mctx,
nmod_t mod)
{
slong i, j, Qlen, Ai;
slong next_loc, heap_len = 1, heap_alloc;
mpoly_heap1_s * heap;
mpoly_heap_t * chain_nodes[64];
mpoly_heap_t ** chain;
slong exp_alloc;
slong * store, * store_base;
mpoly_heap_t * x;
ulong * Qcoeffs = Q->coeffs;
ulong * Qexps = Q->exps;
ulong mask, exp, exp3 = 0;
ulong maskhi;
ulong pp1, pp0, acc2, acc1, acc0, tmp2, tmp1, tmp0, lc_minus_inv;
int lt_divides;
flint_rand_t heuristic_state;
int heuristic_count = 0;
FLINT_ASSERT(mpoly_words_per_exp(bits, mctx) == 1);
mpoly_get_cmpmask(&maskhi, 1, bits, mctx);
flint_rand_init(heuristic_state);
next_loc = 2*n_sqrt(Alen) + 4;
heap_alloc = next_loc - 3;
heap = (mpoly_heap1_s *) flint_malloc((heap_alloc + 1)*sizeof(mpoly_heap1_s));
chain_nodes[0] = (mpoly_heap_t *) flint_malloc(heap_alloc*sizeof(mpoly_heap_t));
chain = (mpoly_heap_t **) flint_malloc(heap_alloc*sizeof(mpoly_heap_t*));
store = store_base = (slong *) flint_malloc(2*heap_alloc*sizeof(mpoly_heap_t *));
for (i = 0; i < heap_alloc; i++)
chain[i] = chain_nodes[0] + i;
exp_alloc = 1;
mask = mpoly_overflow_mask_sp(bits);
Ai = 1;
Qlen = 0;
_nmod_mpoly_fit_length(&Qcoeffs, &Q->coeffs_alloc,
&Qexps, &Q->exps_alloc, 1, Qlen + 1);
Qcoeffs[0] = n_sqrtmod(Acoeffs[0], mod.n);
if (Qcoeffs[0] == 0)
goto not_sqrt;
Qlen = 1;
lc_minus_inv = mod.n - nmod_inv(nmod_add(Qcoeffs[0], Qcoeffs[0], mod), mod);
if (!mpoly_monomial_halves1(Qexps + 0, Aexps[0], mask))
goto not_sqrt;
{
if (0 == n_sqrtmod(Acoeffs[Alen - 1], mod.n))
goto not_sqrt;
if (!mpoly_monomial_halves1(&exp3, Aexps[Alen - 1], mask))
goto not_sqrt;
exp3 += Qexps[0];
}
while (heap_len > 1 || Ai < Alen)
{
_nmod_mpoly_fit_length(&Qcoeffs, &Q->coeffs_alloc,
&Qexps, &Q->exps_alloc, 1, Qlen + 1);
acc2 = acc1 = acc0 = 0;
if (heap_len > 1 && Ai < Alen && Aexps[Ai] == heap[1].exp)
{
exp = Aexps[Ai];
acc0 = mod.n - Acoeffs[Ai];
Ai++;
}
else if (heap_len > 1 && (Ai >= Alen ||
mpoly_monomial_gt1(heap[1].exp, Aexps[Ai], maskhi)))
{
exp = heap[1].exp;
if (mpoly_monomial_overflows1(exp, mask))
goto not_sqrt;
}
else
{
FLINT_ASSERT(Ai < Alen);
exp = Aexps[Ai];
acc0 = mod.n - Acoeffs[Ai];
Ai++;
goto skip_heap;
}
tmp2 = tmp1 = tmp0 = 0;
do {
x = _mpoly_heap_pop1(heap, &heap_len, maskhi);
do {
*store++ = x->i;
*store++ = x->j;
umul_ppmm(pp1, pp0, Qcoeffs[x->i], Qcoeffs[x->j]);
if (x->i != x->j)
add_sssaaaaaa(tmp2, tmp1, tmp0, tmp2, tmp1, tmp0,
UWORD(0), pp1, pp0);
else
add_sssaaaaaa(acc2, acc1, acc0, acc2, acc1, acc0,
UWORD(0), pp1, pp0);
} while ((x = x->next) != NULL);
} while (heap_len > 1 && heap[1].exp == exp);
add_sssaaaaaa(acc2, acc1, acc0, acc2, acc1, acc0, tmp2, tmp1, tmp0);
add_sssaaaaaa(acc2, acc1, acc0, acc2, acc1, acc0, tmp2, tmp1, tmp0);
NMOD_RED3(acc0, acc2, acc1, acc0, mod);
while (store > store_base)
{
j = *--store;
i = *--store;
if (j < i)
{
x = chain[i];
x->i = i;
x->j = j + 1;
x->next = NULL;
_mpoly_heap_insert1(heap, Qexps[x->i] + Qexps[x->j], x,
&next_loc, &heap_len, maskhi);
}
}
if (acc0 == 0)
continue;
skip_heap:
lt_divides = mpoly_monomial_divides1(Qexps + Qlen, exp, Qexps[0], mask);
if (!lt_divides)
goto not_sqrt;
Qcoeffs[Qlen] = nmod_mul(acc0, lc_minus_inv, mod);
if (Qlen >= heap_alloc)
{
if (Qlen > Alen && _is_proved_not_square(
++heuristic_count, heuristic_state,
Acoeffs, Aexps, Alen, bits, mctx, mod))
{
goto not_sqrt;
}
heap_alloc *= 2;
heap = (mpoly_heap1_s *) flint_realloc(heap, (heap_alloc + 1)*sizeof(mpoly_heap1_s));
chain_nodes[exp_alloc] = (mpoly_heap_t *) flint_malloc((heap_alloc/2)*sizeof(mpoly_heap_t));
chain = (mpoly_heap_t **) flint_realloc(chain, heap_alloc*sizeof(mpoly_heap_t*));
store = store_base = (slong *) flint_realloc(store_base, 2*heap_alloc*sizeof(mpoly_heap_t *));
for (i = 0; i < heap_alloc/2; i++)
chain[i + heap_alloc/2] = chain_nodes[exp_alloc] + i;
exp_alloc++;
}
i = Qlen;
x = chain[i];
x->i = i;
x->j = 1;
x->next = NULL;
_mpoly_heap_insert1(heap, Qexps[i] + Qexps[1], x,
&next_loc, &heap_len, maskhi);
Qlen++;
}
cleanup:
flint_rand_clear(heuristic_state);
Q->coeffs = Qcoeffs;
Q->exps = Qexps;
Q->length = Qlen;
flint_free(heap);
flint_free(chain);
flint_free(store_base);
for (i = 0; i < exp_alloc; i++)
flint_free(chain_nodes[i]);
return Qlen > 0;
not_sqrt:
Qlen = 0;
goto cleanup;
}
static int _nmod_mpoly_sqrt_heap(
nmod_mpoly_t Q,
const ulong * Acoeffs,
const ulong * Aexps,
slong Alen,
flint_bitcnt_t bits,
const mpoly_ctx_t mctx,
nmod_t mod)
{
slong N = mpoly_words_per_exp(bits, mctx);
ulong * cmpmask;
slong i, j, Qlen, Ai;
slong next_loc;
slong heap_len = 1, heap_alloc;
int exp_alloc;
mpoly_heap_s * heap;
mpoly_heap_t * chain_nodes[64];
mpoly_heap_t ** chain;
slong * store, * store_base;
mpoly_heap_t * x;
ulong * Qcoeffs = Q->coeffs;
ulong * Qexps = Q->exps;
ulong * exp, * exp3;
ulong * exps[64];
ulong ** exp_list;
slong exp_next;
ulong mask;
ulong pp1, pp0, acc2, acc1, acc0, tmp2, tmp1, tmp0, lc_minus_inv;
int lt_divides, halves;
flint_rand_t heuristic_state;
int heuristic_count = 0;
TMP_INIT;
if (N == 1)
return _nmod_mpoly_sqrt_heap1(Q, Acoeffs, Aexps, Alen, bits, mctx, mod);
TMP_START;
cmpmask = (ulong *) TMP_ALLOC(N*sizeof(ulong));
mpoly_get_cmpmask(cmpmask, N, bits, mctx);
flint_rand_init(heuristic_state);
next_loc = 2*sqrt(Alen) + 4;
heap_alloc = next_loc - 3;
heap = (mpoly_heap_s *) flint_malloc((heap_alloc + 1)*sizeof(mpoly_heap_s));
chain_nodes[0] = (mpoly_heap_t *) flint_malloc(heap_alloc*sizeof(mpoly_heap_t));
chain = (mpoly_heap_t **) flint_malloc(heap_alloc*sizeof(mpoly_heap_t*));
store = store_base = (slong *) flint_malloc(2*heap_alloc*sizeof(mpoly_heap_t *));
for (i = 0; i < heap_alloc; i++)
chain[i] = chain_nodes[0] + i;
exps[0] = (ulong *) flint_malloc(heap_alloc*N*sizeof(ulong));
exp_alloc = 1;
exp_list = (ulong **) flint_malloc(heap_alloc*sizeof(ulong *));
exp = (ulong *) TMP_ALLOC(N*sizeof(ulong));
exp3 = (ulong *) TMP_ALLOC(N*sizeof(ulong));
exp_next = 0;
for (i = 0; i < heap_alloc; i++)
exp_list[i] = exps[0] + i*N;
mask = (bits <= FLINT_BITS) ? mpoly_overflow_mask_sp(bits) : 0;
Ai = 1;
Qlen = 0;
_nmod_mpoly_fit_length(&Qcoeffs, &Q->coeffs_alloc,
&Qexps, &Q->exps_alloc, N, Qlen + 1);
Qcoeffs[0] = n_sqrtmod(Acoeffs[0], mod.n);
if (Qcoeffs[0] == 0)
goto not_sqrt;
Qlen = 1;
lc_minus_inv = mod.n - nmod_inv(nmod_add(Qcoeffs[0], Qcoeffs[0], mod), mod);
if (bits <= FLINT_BITS)
halves = mpoly_monomial_halves(Qexps + 0, Aexps + 0, N, mask);
else
halves = mpoly_monomial_halves_mp(Qexps + 0, Aexps + 0, N, bits);
if (!halves)
goto not_sqrt;
{
if (0 == n_sqrtmod(Acoeffs[Alen - 1], mod.n))
goto not_sqrt;
if (bits <= FLINT_BITS)
halves = mpoly_monomial_halves(exp3, Aexps + (Alen - 1)*N, N, mask);
else
halves = mpoly_monomial_halves_mp(exp3, Aexps + (Alen - 1)*N, N, bits);
if (!halves)
goto not_sqrt;
if (bits <= FLINT_BITS)
mpoly_monomial_add(exp3, exp3, Qexps + 0, N);
else
mpoly_monomial_add_mp(exp3, exp3, Qexps + 0, N);
}
while (heap_len > 1 || Ai < Alen)
{
_nmod_mpoly_fit_length(&Qcoeffs, &Q->coeffs_alloc,
&Qexps, &Q->exps_alloc, N, Qlen + 1);
acc2 = acc1 = acc0 = 0;
if (heap_len > 1 && Ai < Alen &&
mpoly_monomial_equal(Aexps + N*Ai, heap[1].exp, N))
{
mpoly_monomial_set(exp, Aexps + N*Ai, N);
acc0 = mod.n - Acoeffs[Ai];
Ai++;
}
else if (heap_len > 1 && (Ai >= Alen || mpoly_monomial_lt(
Aexps + N*Ai, heap[1].exp, N, cmpmask)))
{
mpoly_monomial_set(exp, heap[1].exp, N);
if (bits <= FLINT_BITS ? mpoly_monomial_overflows(exp, N, mask)
: mpoly_monomial_overflows_mp(exp, N, bits))
goto not_sqrt;
}
else
{
FLINT_ASSERT(Ai < Alen);
mpoly_monomial_set(exp, Aexps + N*Ai, N);
acc0 = mod.n - Acoeffs[Ai];
Ai++;
goto skip_heap;
}
tmp2 = tmp1 = tmp0 = 0;
do {
exp_list[--exp_next] = heap[1].exp;
x = _mpoly_heap_pop(heap, &heap_len, N, cmpmask);
do {
*store++ = x->i;
*store++ = x->j;
FLINT_ASSERT(x->i > 0 && x->j > 0);
umul_ppmm(pp1, pp0, Qcoeffs[x->i], Qcoeffs[x->j]);
if (x->i != x->j)
add_sssaaaaaa(tmp2, tmp1, tmp0, tmp2, tmp1, tmp0,
UWORD(0), pp1, pp0);
else
add_sssaaaaaa(acc2, acc1, acc0, acc2, acc1, acc0,
UWORD(0), pp1, pp0);
} while ((x = x->next) != NULL);
} while (heap_len > 1 && mpoly_monomial_equal(heap[1].exp, exp, N));
add_sssaaaaaa(acc2, acc1, acc0, acc2, acc1, acc0, tmp2, tmp1, tmp0);
add_sssaaaaaa(acc2, acc1, acc0, acc2, acc1, acc0, tmp2, tmp1, tmp0);
NMOD_RED3(acc0, acc2, acc1, acc0, mod);
while (store > store_base)
{
j = *--store;
i = *--store;
if (j < i)
{
x = chain[i];
x->i = i;
x->j = j + 1;
x->next = NULL;
if (bits <= FLINT_BITS)
mpoly_monomial_add(exp_list[exp_next], Qexps + N*x->i,
Qexps + N*x->j, N);
else
mpoly_monomial_add_mp(exp_list[exp_next], Qexps + N*x->i,
Qexps + N*x->j, N);
exp_next += _mpoly_heap_insert(heap, exp_list[exp_next], x,
&next_loc, &heap_len, N, cmpmask);
}
}
if (acc0 == 0)
continue;
skip_heap:
if (bits <= FLINT_BITS)
lt_divides = mpoly_monomial_divides(Qexps + N*Qlen,
exp, Qexps + N*0, N, mask);
else
lt_divides = mpoly_monomial_divides_mp(Qexps + N*Qlen,
exp, Qexps + N*0, N, bits);
if (!lt_divides)
goto not_sqrt;
Qcoeffs[Qlen] = nmod_mul(acc0, lc_minus_inv, mod);
if (Qlen >= heap_alloc)
{
if (Qlen > Alen && _is_proved_not_square(
++heuristic_count, heuristic_state,
Acoeffs, Aexps, Alen, bits, mctx, mod))
{
goto not_sqrt;
}
heap_alloc *= 2;
heap = (mpoly_heap_s *) flint_realloc(heap, (heap_alloc + 1)*sizeof(mpoly_heap_s));
chain_nodes[exp_alloc] = (mpoly_heap_t *) flint_malloc((heap_alloc/2)*sizeof(mpoly_heap_t));
chain = (mpoly_heap_t **) flint_realloc(chain, heap_alloc*sizeof(mpoly_heap_t*));
store = store_base = (slong *) flint_realloc(store_base, 2*heap_alloc*sizeof(mpoly_heap_t *));
exps[exp_alloc] = (ulong *) flint_malloc((heap_alloc/2)*N*sizeof(ulong));
exp_list = (ulong **) flint_realloc(exp_list, heap_alloc*sizeof(ulong *));
for (i = 0; i < heap_alloc/2; i++)
{
chain[i + heap_alloc/2] = chain_nodes[exp_alloc] + i;
exp_list[i + heap_alloc/2] = exps[exp_alloc] + i*N;
}
exp_alloc++;
}
i = Qlen;
x = chain[i];
x->i = i;
x->j = 1;
x->next = NULL;
if (bits <= FLINT_BITS)
mpoly_monomial_add(exp_list[exp_next], Qexps + x->i*N,
Qexps + x->j*N, N);
else
mpoly_monomial_add_mp(exp_list[exp_next], Qexps + x->i*N,
Qexps + x->j*N, N);
exp_next += _mpoly_heap_insert(heap, exp_list[exp_next], x,
&next_loc, &heap_len, N, cmpmask);
Qlen++;
}
cleanup:
flint_rand_clear(heuristic_state);
Q->coeffs = Qcoeffs;
Q->exps = Qexps;
Q->length = Qlen;
flint_free(heap);
flint_free(chain);
flint_free(store_base);
flint_free(exp_list);
for (i = 0; i < exp_alloc; i++)
{
flint_free(exps[i]);
flint_free(chain_nodes[i]);
}
TMP_END;
return Qlen > 0;
not_sqrt:
Qlen = 0;
goto cleanup;
}
int nmod_mpoly_sqrt_heap(nmod_mpoly_t Q, const nmod_mpoly_t A,
const nmod_mpoly_ctx_t ctx)
{
int success;
slong lenq_est;
if ((ctx->mod.n % 2) == 0)
{
flint_bitcnt_t bits = A->bits;
ulong * Aexps = A->exps;
slong Alen = A->length;
slong i, N = mpoly_words_per_exp(bits, ctx->minfo);
ulong mask = (bits <= FLINT_BITS) ? mpoly_overflow_mask_sp(bits) : 0;
if (ctx->mod.n != 2)
flint_throw(FLINT_IMPINV, "nmod_mpoly_sqrt_heap: "
"cannot compute sqrt modulo %wd*%wd", 2, ctx->mod.n/2);
if (Q != A)
{
nmod_mpoly_fit_length_reset_bits(Q, Alen, bits, ctx);
for (i = 0; i < Alen; i++)
Q->coeffs[i] = 1;
}
for (i = 0; i < Alen; i++)
{
if (bits <= FLINT_BITS ?
!mpoly_monomial_halves(Q->exps + N*i, Aexps + N*i, N, mask) :
!mpoly_monomial_halves_mp(Q->exps + N*i, Aexps + N*i, N, bits))
{
Q->length = 0;
return 0;
}
}
Q->length = Alen;
return 1;
}
if (nmod_mpoly_is_zero(A, ctx))
{
nmod_mpoly_zero(Q, ctx);
return 1;
}
lenq_est = n_sqrt(A->length);
if (Q == A)
{
nmod_mpoly_t T;
nmod_mpoly_init3(T, lenq_est, A->bits, ctx);
success = _nmod_mpoly_sqrt_heap(T, A->coeffs, A->exps, A->length,
A->bits, ctx->minfo, ctx->mod);
nmod_mpoly_swap(Q, T, ctx);
nmod_mpoly_clear(T, ctx);
}
else
{
nmod_mpoly_fit_length_reset_bits(Q, lenq_est, A->bits, ctx);
success = _nmod_mpoly_sqrt_heap(Q, A->coeffs, A->exps, A->length,
A->bits, ctx->minfo, ctx->mod);
}
return success;
}