#include "fmpz.h"
#include "fmpz_mod_poly.h"
#include "gr.h"
#include "gr_poly.h"
void _fmpz_mod_poly_sqrt_series(fmpz * g, const fmpz * h, slong hlen, slong n, const fmpz_mod_ctx_t mod)
{
gr_ctx_t gr_ctx;
_gr_ctx_init_fmpz_mod_from_ref(gr_ctx, mod);
GR_MUST_SUCCEED(_gr_poly_sqrt_series(g, h, hlen, n, gr_ctx));
}
void fmpz_mod_poly_sqrt_series(fmpz_mod_poly_t g, const fmpz_mod_poly_t h, slong n, const fmpz_mod_ctx_t ctx)
{
const slong hlen = h->length;
if (n == 0 || h->length == 0)
{
fmpz_mod_poly_zero(g, ctx);
return;
}
if (!fmpz_is_one(h->coeffs + 0))
{
flint_throw(FLINT_ERROR, "Exception (fmpz_mod_poly_sqrt_series). Constant term != 1.\n");
}
if (hlen == 1)
n = 1;
if (g == h)
{
fmpz_mod_poly_t t;
fmpz_mod_poly_init2(t, n, ctx);
_fmpz_mod_poly_sqrt_series(t->coeffs, h->coeffs, hlen, n, ctx);
_fmpz_mod_poly_set_length(t, n);
_fmpz_mod_poly_normalise(t);
fmpz_mod_poly_swap(g, t, ctx);
fmpz_mod_poly_clear(t, ctx);
}
else
{
_fmpz_mod_poly_fit_length(g, n);
_fmpz_mod_poly_sqrt_series(g->coeffs, h->coeffs, hlen, n, ctx);
_fmpz_mod_poly_set_length(g, n);
_fmpz_mod_poly_normalise(g);
}
}