#include "fmpz.h"
#include "fmpz/impl.h"
#include "fmpz_vec.h"
#include "fmpz_poly.h"
#include "fmpz_poly/impl.h"
#include "nmod_vec.h"
#include "gr_poly.h"
static int
_fmpz_poly_check_interpolant(const fmpz * poly, const fmpz * xs, const fmpz * ys, slong n)
{
fmpz_t y;
slong i;
int ok = 1;
fmpz_init(y);
for (i = 0; i < n && ok; i++)
{
_fmpz_poly_evaluate_fmpz(y, poly, n, xs + i);
ok = fmpz_equal(y, ys + i);
}
fmpz_clear(y);
return ok;
}
static int
_fmpz_vec_has_unique_entries(const fmpz * x, slong n)
{
fmpz * t;
slong i;
int ok = 1;
t = _fmpz_vec_init(n);
_fmpz_vec_set(t, x, n);
_fmpz_vec_sort(t, n);
for (i = 1; i < n; i++)
if (fmpz_equal(t + i - 1, t + i))
ok = 0;
_fmpz_vec_clear(t, n);
return ok;
}
int
_checked_nmod_poly_interpolate(nn_ptr r, nn_srcptr x, nn_srcptr y, slong n, nmod_t mod)
{
gr_ctx_t ctx;
gr_ctx_init_nmod(ctx, mod.n);
return (_gr_poly_interpolate_fast(r, x, y, n, ctx) == GR_SUCCESS);
}
int
_fmpz_poly_interpolate_multi_mod(fmpz * poly,
const fmpz * xs, const fmpz * ys, slong n)
{
ulong p;
slong j, k;
nmod_t mod;
nn_ptr xm, ym;
fmpz_t M, t, u, c, M2, M1M2;
slong total_primes, num_primes, count_good;
slong xbits, ybits;
int ok = 1;
int checked_unique = 0;
slong bound_bits;
nn_ptr primes = NULL, xmod = NULL, ymod = NULL, residues = NULL;
int * good = NULL;
if (n == 0)
return 1;
fmpz_init(M);
fmpz_init(t);
fmpz_init(u);
fmpz_init(c);
fmpz_init(M2);
fmpz_init(M1M2);
xbits = _fmpz_vec_max_bits(xs, n);
xbits = FLINT_ABS(xbits);
ybits = _fmpz_vec_max_bits(ys, n);
ybits = FLINT_ABS(ybits);
bound_bits = (xbits + 1) * (n - 1) + FLINT_BIT_COUNT(n) * ybits;
xm = _nmod_vec_init(n);
ym = _nmod_vec_init(n);
total_primes = 0;
p = UWORD(1) << (FLINT_BITS - 1);
for (;;)
{
if (total_primes < 16)
{
p = n_nextprime(p, 1);
nmod_init(&mod, p);
_fmpz_vec_get_nmod_vec(xm, xs, n, mod);
_fmpz_vec_get_nmod_vec(ym, ys, n, mod);
if (_checked_nmod_poly_interpolate(ym, xm, ym, n, mod))
{
num_primes = 1;
if (total_primes == 0)
{
_fmpz_vec_set_nmod_vec(poly, ym, n, mod);
fmpz_set_ui(M, p);
}
else
{
_fmpz_poly_CRT_ui(poly, poly, n, M, ym, n, mod.n, mod.ninv, 1);
fmpz_mul_ui(M, M, p);
}
}
else
{
num_primes = 0;
}
}
else
{
fmpz_comb_t comb;
fmpz_comb_temp_t temp;
num_primes = FLINT_MAX(1, total_primes / 2);
primes = flint_realloc(primes, sizeof(ulong) * num_primes);
xmod = flint_realloc(xmod, sizeof(ulong) * n * num_primes);
ymod = flint_realloc(ymod, sizeof(ulong) * n * num_primes);
residues = flint_realloc(residues, sizeof(ulong) * num_primes);
good = flint_realloc(good, sizeof(int) * num_primes);
for (k = 0; k < num_primes; k++)
{
p = n_nextprime(p, 1);
primes[k] = p;
}
_fmpz_ui_vec_prod(M2, primes, num_primes);
fmpz_comb_init(comb, primes, num_primes);
fmpz_comb_temp_init(temp, comb);
for (j = 0; j < n; j++)
{
fmpz_multi_mod_ui(residues, xs + j, comb, temp);
for (k = 0; k < num_primes; k++)
xmod[k * n + j] = residues[k];
fmpz_multi_mod_ui(residues, ys + j, comb, temp);
for (k = 0; k < num_primes; k++)
ymod[k * n + j] = residues[k];
}
count_good = 0;
for (k = 0; k < num_primes; k++)
{
nmod_init(&mod, primes[k]);
good[k] = _checked_nmod_poly_interpolate(ymod + k * n, xmod + k * n, ymod + k * n, n, mod);
count_good += (good[k] != 0);
}
if (count_good < num_primes)
{
count_good = 0;
for (k = 0; k < num_primes; k++)
{
if (good[k])
{
primes[count_good] = primes[k];
if (count_good != k)
_nmod_vec_set(ymod + count_good * n, ymod + k * n, n);
count_good++;
}
}
num_primes = count_good;
if (num_primes != 0)
{
_fmpz_ui_vec_prod(M2, primes, num_primes);
fmpz_comb_temp_clear(temp);
fmpz_comb_clear(comb);
fmpz_comb_init(comb, primes, num_primes);
fmpz_comb_temp_init(temp, comb);
}
}
if (num_primes != 0)
{
fmpz_mul(M1M2, M, M2);
fmpz_mod(c, M, M2);
fmpz_invmod(c, c, M2);
for (j = 0; j < n; j++)
{
for (k = 0; k < num_primes; k++)
residues[k] = ymod[k * n + j];
fmpz_multi_CRT_ui(t, residues, comb, temp, 0);
_fmpz_CRT(u, poly + j, M, t, M2, M1M2, c, 1);
fmpz_set(poly + j, u);
}
fmpz_swap(M, M1M2);
}
fmpz_comb_temp_clear(temp);
fmpz_comb_clear(comb);
}
total_primes += num_primes;
if (num_primes == 0 && !checked_unique)
{
if (!_fmpz_vec_has_unique_entries(xs, n))
{
ok = 0;
break;
}
checked_unique = 1;
}
else
{
if (_fmpz_poly_check_interpolant(poly, xs, ys, n))
break;
}
if ((slong) fmpz_bits(M) > bound_bits + 1)
{
ok = 0;
break;
}
}
fmpz_clear(M);
fmpz_clear(t);
fmpz_clear(u);
fmpz_clear(c);
fmpz_clear(M2);
fmpz_clear(M1M2);
_nmod_vec_clear(xm);
_nmod_vec_clear(ym);
flint_free(primes);
flint_free(xmod);
flint_free(ymod);
flint_free(residues);
flint_free(good);
return ok;
}
int
fmpz_poly_interpolate_multi_mod(fmpz_poly_t poly,
const fmpz * xs, const fmpz * ys, slong n)
{
int ok;
fmpz_poly_fit_length(poly, n);
ok = _fmpz_poly_interpolate_multi_mod(poly->coeffs, xs, ys, n);
_fmpz_poly_set_length(poly, n);
_fmpz_poly_normalise(poly);
return ok;
}