#include "fmpz.h"
#include "fmpz_vec.h"
#include "fmpz_poly.h"
#include "fmpz_mod_vec.h"
#include "fmpz_mod_poly.h"
static inline
int _fmpz_mod_poly_sqrt_2(fmpz * s, const fmpz * p, slong len)
{
slong i;
for (i = 1; i < len; i += 2)
if (!fmpz_is_zero(p + i))
return 0;
for (i = 0; i < len; i += 2)
fmpz_set(s + i / 2, p + i);
return 1;
}
int
_fmpz_mod_poly_sqrt(fmpz * s, const fmpz * p, slong len, const fmpz_mod_ctx_t mod)
{
slong slen, i;
int result;
fmpz * t;
fmpz_t c, d;
if (len % 2 == 0)
return len == 0;
if (fmpz_cmp_ui(mod->n, 2) == 0)
return _fmpz_mod_poly_sqrt_2(s, p, len);
while (fmpz_is_zero(p + 0))
{
if (!fmpz_is_zero(p + 1))
return 0;
fmpz_zero(s + 0);
p += 2;
len -= 2;
s++;
}
fmpz_init(c);
fmpz_init(d);
fmpz_set(d, p + 0);
fmpz_set(c, d);
if (!fmpz_is_one(c))
{
if (!fmpz_sqrtmod(c, c, mod->n))
{
result = 0;
goto cleanup;
}
}
if (len == 1)
{
fmpz_set(s + 0, c);
result = 1;
goto cleanup;
}
slen = len / 2 + 1;
t = _fmpz_vec_init(len);
if (fmpz_is_one(c))
_fmpz_mod_poly_sqrt_series(s, p, slen, slen, mod);
else
{
fmpz_invmod(d, d, mod->n);
_fmpz_mod_vec_scalar_mul_fmpz_mod(t, p, slen, d, mod);
_fmpz_mod_poly_sqrt_series(s, t, slen, slen, mod);
}
if (!fmpz_is_one(c))
_fmpz_mod_vec_scalar_mul_fmpz_mod(s, s, slen, c, mod);
_fmpz_poly_mulhigh(t, s, slen, s, slen, slen);
for (i = 0; i < slen; i++)
fmpz_zero(t + i);
_fmpz_mod_vec_set_fmpz_vec(t + slen, t + slen, slen - 1, mod);
result = _fmpz_vec_equal(t + slen, p + slen, len - slen);
_fmpz_vec_clear(t, len);
cleanup:
fmpz_clear(c);
fmpz_clear(d);
return result;
}
int
fmpz_mod_poly_sqrt(fmpz_mod_poly_t b, const fmpz_mod_poly_t a, const fmpz_mod_ctx_t ctx)
{
slong blen, len = a->length;
int result;
if (len % 2 == 0)
{
fmpz_mod_poly_zero(b, ctx);
return len == 0;
}
if (b == a)
{
fmpz_mod_poly_t tmp;
fmpz_mod_poly_init(tmp, ctx);
result = fmpz_mod_poly_sqrt(tmp, a, ctx);
fmpz_mod_poly_swap(b, tmp, ctx);
fmpz_mod_poly_clear(tmp, ctx);
return result;
}
blen = len / 2 + 1;
fmpz_mod_poly_fit_length(b, blen, ctx);
result = _fmpz_mod_poly_sqrt(b->coeffs, a->coeffs, len, ctx);
if (!result)
blen = 0;
_fmpz_mod_poly_set_length(b, blen);
_fmpz_mod_poly_normalise(b);
return result;
}