#include "radix.h"
#include "thread_pool.h"
#include "thread_support.h"
static void
radix_powers_init_ui_radix(radix_powers_t powers, ulong b, ulong max_exp, const radix_t radix)
{
slong k, len, n, bound_limbs;
nn_ptr bpow, bbuf;
FLINT_ASSERT(b != 0);
FLINT_ASSERT(max_exp >= 1);
FLINT_ASSERT(b < LIMB_RADIX(radix));
len = 1;
powers->exps[0] = max_exp;
bound_limbs = max_exp + 1;
while (powers->exps[len - 1] > 1)
{
powers->exps[len] = (powers->exps[len - 1] + 1) / 2;
bound_limbs += powers->exps[len] + 1;
len++;
}
powers->len = len;
powers->buf = bbuf = flint_malloc(bound_limbs * sizeof(ulong));
powers->pows[len - 1] = bbuf;
powers->pows[len - 1][0] = b;
powers->sizes[len - 1] = 1;
powers->val_limbs[len - 1] = 0;
bbuf += 1;
for (k = len - 2; k >= 0; k--)
{
n = powers->sizes[k + 1];
powers->pows[k] = bbuf;
bpow = powers->pows[k];
radix_sqr(bpow, powers->pows[k + 1], n, radix);
n = 2 * n;
n -= (bpow[n - 1] == 0);
if (powers->exps[k] != 2 * powers->exps[k + 1])
{
radix_divexact_1(bpow, bpow, n, b, radix);
n -= (bpow[n - 1] == 0);
}
bbuf += n;
powers->sizes[k] = n;
powers->val_limbs[k] = 0;
}
}
static void
flint_mpn_slice_bits(nn_ptr res, nn_srcptr x, mp_bitcnt_t start, mp_bitcnt_t stop)
{
FLINT_ASSERT(start >= 0);
FLINT_ASSERT(start < stop);
slong res_limbs = (stop - start + FLINT_BITS - 1) / FLINT_BITS;
slong start_limb = start / FLINT_BITS;
slong stop_limb = stop / FLINT_BITS;
slong rem;
rem = start % FLINT_BITS;
if (rem == 0)
{
mpn_copyi(res, x + start_limb, res_limbs);
}
else
{
mpn_rshift(res, x + start_limb, res_limbs, rem);
if (stop_limb == start_limb + res_limbs && stop_limb * FLINT_BITS < stop)
res[res_limbs - 1] |= (x[stop_limb] << (FLINT_BITS - rem));
}
rem = (stop - start) % FLINT_BITS;
if (rem != 0)
res[res_limbs - 1] &= ((UWORD(1) << rem) - 1);
}
slong
radix_set_mpn_basecase(nn_ptr res, nn_srcptr a, slong an, const radix_t radix)
{
ulong B = LIMB_RADIX(radix);
slong len;
nn_ptr tmp;
TMP_INIT;
while (an > 0 && a[an - 1] == 0)
an--;
if (an == 0)
return 0;
if (an == 1 && a[0] < B)
{
res[0] = a[0];
return 1;
}
TMP_START;
tmp = TMP_ALLOC(an * sizeof(ulong));
flint_mpn_copyi(tmp, a, an);
len = 0;
while (an > 12)
{
res[len] = mpn_divrem_1(tmp, 0, tmp, an, B);
len++;
an -= (tmp[an - 1] == 0);
}
while (an != 0)
{
res[len] = flint_mpn_divrem_1_preinv(tmp, tmp, an, radix->B.n, radix->B.ninv, radix->B.norm);
len++;
an -= (tmp[an - 1] == 0);
}
TMP_END;
return len;
}
static slong
radix_set_mpn_basecase1(nn_ptr res, nn_srcptr a, slong an_words, slong abits, slong astart, slong astop, const radix_t radix)
{
slong tbits, tlimbs, len;
nn_ptr tmp;
TMP_INIT;
astart *= abits;
astop *= abits;
astop = FLINT_MIN(astop, an_words * FLINT_BITS);
tbits = astop - astart;
tlimbs = (tbits + FLINT_BITS - 1) / FLINT_BITS;
TMP_START;
tmp = TMP_ALLOC(tlimbs * sizeof(ulong));
flint_mpn_slice_bits(tmp, a, astart, astop);
len = radix_set_mpn_basecase(res, tmp, tlimbs, radix);
TMP_END;
return len;
}
static slong
_radix_set_mpn_recursive(nn_ptr res, nn_srcptr a, slong an_words, slong abits, slong astart, slong astop, const radix_powers_t powers, slong depth, const radix_t radix);
typedef struct
{
nn_ptr res;
nn_srcptr a;
slong an_words;
slong abits;
slong astart;
slong astop;
const radix_powers_struct * powers;
slong depth;
const radix_struct * radix;
slong out_len;
}
worker_args_struct;
static void
worker(void * arg)
{
worker_args_struct * X = (worker_args_struct * ) arg;
X->out_len = _radix_set_mpn_recursive(X->res, X->a, X->an_words, X->abits, X->astart, X->astop, X->powers, X->depth, X->radix);
}
static slong
_radix_set_mpn_recursive(nn_ptr res, nn_srcptr a, slong an_words, slong abits, slong astart, slong astop, const radix_powers_t powers, slong depth, const radix_t radix)
{
slong an, an_lo;
nn_ptr lo, hi;
slong len_lo, len_hi, len_pow, len_res;
ulong cy;
TMP_INIT;
an = astop - astart;
if (an < 25)
{
len_lo = radix_set_mpn_basecase1(res, a, an_words, abits, astart, astop, radix);
return len_lo;
}
while (powers->exps[depth] > (an + 1) / 2)
depth++;
an_lo = powers->exps[depth];
TMP_START;
lo = TMP_ALLOC((an + 2) * sizeof(ulong));
hi = lo + an_lo + 1;
slong nworkers, nthreads, nworkers_save;
int want_workers;
thread_pool_handle * threads;
nworkers = 0;
if (an > 500)
{
nthreads = flint_get_num_threads();
want_workers = nthreads >= 2 && (an_lo <= 100000000 || depth >= 2);
nworkers = flint_request_threads(&threads, want_workers ? 2 : 1);
}
if (nworkers == 1)
{
worker_args_struct work_lo = { lo, a, an_words, abits, astart, astart + an_lo, powers, depth, radix, 0 };
worker_args_struct work_hi = { hi, a, an_words, abits, astart + an_lo, astop, powers, depth, radix, 0 };
nworkers_save = flint_set_num_workers(nthreads - nthreads / 2 - 1);
thread_pool_wake(global_thread_pool, threads[0], nthreads / 2 - 1, worker, &work_lo);
worker(&work_hi);
flint_reset_num_workers(nworkers_save);
thread_pool_wait(global_thread_pool, threads[0]);
flint_give_back_threads(threads, nworkers);
len_lo = work_lo.out_len;
len_hi = work_hi.out_len;
}
else
{
len_lo = _radix_set_mpn_recursive(lo, a, an_words, abits, astart, astart + an_lo, powers, depth, radix);
len_hi = _radix_set_mpn_recursive(hi, a, an_words, abits, astart + an_lo, astop, powers, depth, radix);
}
if (len_hi == 0)
{
flint_mpn_copyi(res, lo, len_lo);
len_res = len_lo;
}
else
{
slong val = powers->val_limbs[depth];
len_pow = powers->sizes[depth] - val;
radix_mul(res + val, hi, len_hi, powers->pows[depth] + val, len_pow, radix);
len_res = len_pow + len_hi + val;
len_res -= (res[len_res - 1] == 0);
FLINT_ASSERT(len_res >= len_lo);
if (len_lo <= val)
{
flint_mpn_copyi(res, lo, len_lo);
flint_mpn_zero(res + len_lo, val - len_lo);
}
else
{
flint_mpn_copyi(res, lo, val);
cy = radix_add(res + val, res + val, len_res - val, lo + val, len_lo - val, radix);
if (cy != 0)
{
res[len_res] = cy;
len_res++;
}
}
}
TMP_END;
return len_res;
}
slong
radix_set_mpn_divconquer(nn_ptr res, nn_srcptr a, slong an, const radix_t radix)
{
FLINT_ASSERT(res != a);
slong abits, abits_limbs;
slong rn;
radix_powers_t powers;
ulong B = LIMB_RADIX(radix);
if (B == 2)
{
slong i, j;
while (an > 0 && a[an - 1] == 0)
an--;
for (i = 0; i < an; i++)
{
ulong c = a[i];
for (j = 0; j < FLINT_BITS; j++)
res[i * FLINT_BITS + j] = ((c >> j) & UWORD(1));
}
rn = an * FLINT_BITS;
while (rn > 0 && res[rn - 1] == 0)
rn--;
return rn;
}
abits = FLINT_BIT_COUNT(B - 1) - 1;
abits_limbs = (an * FLINT_BITS + abits - 1) / abits;
slong e = FLINT_MAX((abits_limbs + 1) / 2, 1);
radix_powers_init_ui_radix(powers, UWORD(1) << abits, e, radix);
rn = _radix_set_mpn_recursive(res, a, an, abits, 0, abits_limbs, powers, 0, radix);
radix_powers_clear(powers);
return rn;
}
slong
radix_set_mpn(nn_ptr res, nn_srcptr a, slong an, const radix_t radix)
{
FLINT_ASSERT(res != a);
if (an < 40)
return radix_set_mpn_basecase(res, a, an, radix);
else
return radix_set_mpn_divconquer(res, a, an, radix);
}
slong
radix_set_mpn_need_alloc(slong n, const radix_t radix)
{
if (n == 0)
return 0;
slong Bbits = NMOD_BITS(radix->B) - 1;
slong nlimbs = (n * FLINT_BITS + Bbits - 1) / Bbits;
return nlimbs + 1;
}