#include <stdlib.h>
#include "fmpz.h"
#include "gr.h"
#include "gr_poly.h"
#include "mpn_mod.h"
#include "profiler.h"
#include "double_extras.h"
static void
mpn_mod_set_mpn2(nn_ptr res, nn_srcptr s, slong l, gr_ctx_t ctx)
{
MPN_NORM(s, l);
mpn_mod_set_mpn(res, s, l, ctx);
}
#define FLINT_MPN_MUL_3_2X2(R2, R1, R0, a1, a0, b1, b0) \
do \
{ \
ulong __tmp2, __tmp1; \
umul_ppmm(R1, R0, a0, b0); \
(R2) = (a1) * (b1); \
umul_ppmm(__tmp2, __tmp1, a0, b1); \
add_ssaaaa(R2, R1, R2, R1, __tmp2, __tmp1); \
umul_ppmm(__tmp2, __tmp1, a1, b0); \
add_ssaaaa(R2, R1, R2, R1, __tmp2, __tmp1); \
} \
while (0) \
int _mpn_mod_poly_divrem_q1_preinv1_old(nn_ptr Q, nn_ptr R,
nn_srcptr A, slong lenA, nn_srcptr B, slong lenB,
nn_srcptr invL, gr_ctx_t ctx)
{
ulong q0[MPN_MOD_MAX_LIMBS];
ulong q1[MPN_MOD_MAX_LIMBS];
ulong t[2 * MPN_MOD_MAX_LIMBS + 1];
ulong u[2 * MPN_MOD_MAX_LIMBS];
slong i;
slong nlimbs = MPN_MOD_CTX_NLIMBS(ctx);
int monic = mpn_mod_is_one(invL, ctx) == T_TRUE;
if (monic)
mpn_mod_set(q1, A + (lenA - 1) * nlimbs, ctx);
else
mpn_mod_mul(q1, A + (lenA - 1) * nlimbs, invL, ctx);
mpn_mod_mul(t, q1, B + (lenB - 2) * nlimbs, ctx);
mpn_mod_sub(t, t, A + (lenA - 2) * nlimbs, ctx);
if (monic)
mpn_mod_set(q0, t, ctx);
else
mpn_mod_mul(q0, t, invL, ctx);
mpn_mod_mul(t, q0, B, ctx);
mpn_mod_add(R, A, t, ctx);
mpn_mod_neg(Q, q0, ctx);
mpn_mod_set(Q + nlimbs, q1, ctx);
mpn_mod_neg(q1, q1, ctx);
if (nlimbs == 2)
{
slong bits = 2 * MPN_MOD_CTX_MODULUS_BITS(ctx) + 1;
slong slimbs = (bits + FLINT_BITS - 1) / FLINT_BITS;
if (slimbs == 3)
{
for (i = 1; i < lenB - 1; i++)
{
nn_srcptr B1ptr = B + (i - 1) * nlimbs;
nn_srcptr Bptr = B + i * nlimbs;
nn_srcptr Aptr = A + i * nlimbs;
FLINT_MPN_MUL_3_2X2(t[2], t[1], t[0], q1[1], q1[0], B1ptr[1], B1ptr[0]);
add_sssaaaaaa(t[2], t[1], t[0], t[2], t[1], t[0], 0, Aptr[1], Aptr[0]);
FLINT_MPN_MUL_3_2X2(u[2], u[1], u[0], q0[1], q0[0], Bptr[1], Bptr[0]);
add_sssaaaaaa(t[2], t[1], t[0], t[2], t[1], t[0], u[2], u[1], u[0]);
mpn_mod_set_mpn2(R + i * nlimbs, t, slimbs, ctx);
}
}
else
{
for (i = 1; i < lenB - 1; i++)
{
nn_srcptr B1ptr = B + (i - 1) * nlimbs;
nn_srcptr Bptr = B + i * nlimbs;
nn_srcptr Aptr = A + i * nlimbs;
FLINT_MPN_MUL_2X2(t[3], t[2], t[1], t[0], q1[1], q1[0], B1ptr[1], B1ptr[0]);
add_ssssaaaaaaaa(t[3], t[2], t[1], t[0], t[3], t[2], t[1], t[0], 0, 0, Aptr[1], Aptr[0]);
FLINT_MPN_MUL_2X2(u[3], u[2], u[1], u[0], q0[1], q0[0], Bptr[1], Bptr[0]);
add_sssssaaaaaaaaaa(t[4], t[3], t[2], t[1], t[0], 0, t[3], t[2], t[1], t[0], 0, u[3], u[2], u[1], u[0]);
mpn_mod_set_mpn2(R + i * nlimbs, t, slimbs, ctx);
}
}
}
else
{
for (i = 1; i < lenB - 1; i++)
{
flint_mpn_mul_n(t, q1, B + (i - 1) * nlimbs, nlimbs);
flint_mpn_mul_n(u, q0, B + i * nlimbs, nlimbs);
t[2 * nlimbs] = mpn_add_n(t, t, u, 2 * nlimbs);
ulong cy = mpn_add_n(t, t, A + i * nlimbs, nlimbs);
mpn_add_1(t + nlimbs, t + nlimbs, nlimbs + 1, cy);
mpn_mod_set_mpn2(R + i * nlimbs, t, 2 * nlimbs + 1, ctx);
}
}
return GR_SUCCESS;
}
int _mpn_mod_poly_divrem_q1_preinv1_karatsuba(nn_ptr Q, nn_ptr R,
nn_srcptr A, slong lenA, nn_srcptr B, slong lenB,
nn_srcptr invL, gr_ctx_t ctx)
{
ulong q0[MPN_MOD_MAX_LIMBS];
ulong q1[MPN_MOD_MAX_LIMBS];
ulong q0q1[MPN_MOD_MAX_LIMBS];
ulong t[2 * MPN_MOD_MAX_LIMBS + 1];
ulong u[2 * MPN_MOD_MAX_LIMBS];
slong i;
slong nlimbs = MPN_MOD_CTX_NLIMBS(ctx);
int monic = mpn_mod_is_one(invL, ctx) == T_TRUE;
if (monic)
mpn_mod_set(q1, A + (lenA - 1) * nlimbs, ctx);
else
mpn_mod_mul(q1, A + (lenA - 1) * nlimbs, invL, ctx);
mpn_mod_mul(t, q1, B + (lenB - 2) * nlimbs, ctx);
mpn_mod_sub(t, t, A + (lenA - 2) * nlimbs, ctx);
if (monic)
mpn_mod_set(q0, t, ctx);
else
mpn_mod_mul(q0, t, invL, ctx);
mpn_mod_mul(t, q0, B, ctx);
mpn_mod_add(R, A, t, ctx);
mpn_mod_neg(Q, q0, ctx);
mpn_mod_set(Q + nlimbs, q1, ctx);
mpn_mod_neg(q1, q1, ctx);
mpn_mod_add(q0q1, q0, q1, ctx);
nn_srcptr d = MPN_MOD_CTX_MODULUS(ctx);
for (i = 1; i < lenB - 1; i++)
{
if (i % 2 == 1)
{
flint_mpn_submod_n(R + i * nlimbs, A + i * nlimbs, t, d, nlimbs);
flint_mpn_addmod_n(u, B + (i - 1) * nlimbs, B + i * nlimbs, d, nlimbs);
mpn_mod_mul(t, q0q1, u, ctx);
flint_mpn_addmod_n(R + i * nlimbs, R + i * nlimbs, t, d, nlimbs);
mpn_mod_mul(t, q1, B + i * nlimbs, ctx);
flint_mpn_submod_n(R + i * nlimbs, R + i * nlimbs, t, d, nlimbs);
}
else
{
flint_mpn_addmod_n(R + i * nlimbs, A + i * nlimbs, t, d, nlimbs);
mpn_mod_mul(t, q0, B + i * nlimbs, ctx);
flint_mpn_addmod_n(R + i * nlimbs, R + i * nlimbs, t, d, nlimbs);
}
}
return GR_SUCCESS;
}
#if 1
#undef TIMEIT_END_REPEAT
#define TIMEIT_END_REPEAT(__timer, __reps) \
} \
timeit_stop(__timer); \
if (__timer->cpu >= 20) \
break; \
__reps *= 10; \
} \
} while (0)
#endif
slong parameters[] = { 2, 3, 4, 6, 8, 10, 12, 14, 16, 20, 24, 32, 48, 64, 96, 128, 0 };
void
randvec(gr_ptr vec, flint_rand_t state, slong len, gr_ctx_t ctx)
{
slong i;
fmpz_t t;
fmpz_init(t);
for (i = 0; i < len; i++)
{
fmpz_randbits(t, state, MPN_MOD_CTX_MODULUS_BITS(ctx) + 10);
GR_IGNORE(gr_set_fmpz(GR_ENTRY(vec, i, ctx->sizeof_elem), t, ctx));
}
fmpz_clear(t);
}
#define OLD 0
#define FMMA 1
#define FMMA_PRECOND 2
#define KARATSUBA 3
#define KARATSUBA_PRECOND 4
#define FIND_BEST 0
#define BEST_VS_DEFAULT 1
#define OLD_VS_DEFAULT 2
void
best_table(flint_rand_t state, int comparison, gr_ctx_t ctx)
{
gr_ptr A, B, Q, R, invL;
double times[6], __;
slong lenA, lenB, len, leni;
int best, i;
for (leni = 0; (len = parameters[leni]) != 0; leni++)
{
lenA = len + 1;
lenB = len;
A = gr_heap_init_vec(lenA, ctx);
B = gr_heap_init_vec(lenB, ctx);
Q = gr_heap_init_vec(2, ctx);
R = gr_heap_init_vec(lenB - 1, ctx);
invL = gr_heap_init_vec(1, ctx);
randvec(A, state, lenA, ctx);
randvec(B, state, lenB, ctx);
GR_MUST_SUCCEED(mpn_mod_inv(invL, GR_ENTRY(B, lenB - 1, ctx->sizeof_elem), ctx));
TIMEIT_START;
GR_MUST_SUCCEED(_mpn_mod_poly_divrem_q1_preinv1_old(Q, R, A, lenA, B, lenB, invL, ctx));
TIMEIT_STOP_VALUES(__, times[0]);
TIMEIT_START;
GR_MUST_SUCCEED(_mpn_mod_poly_divrem_q1_preinv1_fmma(Q, R, A, lenA, B, lenB, invL, ctx));
TIMEIT_STOP_VALUES(__, times[1]);
TIMEIT_START;
GR_MUST_SUCCEED(_mpn_mod_poly_divrem_q1_preinv1_fmma_precond(Q, R, A, lenA, B, lenB, invL, ctx));
TIMEIT_STOP_VALUES(__, times[2]);
TIMEIT_START;
GR_MUST_SUCCEED(_mpn_mod_poly_divrem_q1_preinv1_karatsuba(Q, R, A, lenA, B, lenB, invL, ctx));
TIMEIT_STOP_VALUES(__, times[3]);
TIMEIT_START;
GR_MUST_SUCCEED(_mpn_mod_poly_divrem_q1_preinv1_karatsuba_precond(Q, R, A, lenA, B, lenB, invL, ctx));
TIMEIT_STOP_VALUES(__, times[4]);
TIMEIT_START;
GR_MUST_SUCCEED(_mpn_mod_poly_divrem_q1_preinv1(Q, R, A, lenA, B, lenB, invL, ctx));
TIMEIT_STOP_VALUES(__, times[5]);
best = 0;
for (i = 1; i < 5; i++)
{
if (times[i] < times[best])
best = i;
}
if (comparison == FIND_BEST)
{
flint_printf("%6wd", best);
}
else if (comparison == BEST_VS_DEFAULT)
{
flint_printf(" %.3f", times[best] / times[5]);
}
else if (comparison == OLD_VS_DEFAULT)
{
flint_printf(" %.3f", times[0] / times[5]);
}
fflush(stdout);
(void) __;
gr_heap_clear_vec(A, lenA, ctx);
gr_heap_clear_vec(B, lenB, ctx);
gr_heap_clear_vec(Q, 2, ctx);
gr_heap_clear_vec(R, lenB - 1, ctx);
}
}
int main(void)
{
fmpz_t p;
gr_ctx_t ctx;
flint_rand_t state;
slong bits;
slong len, leni;
flint_rand_init(state);
flint_printf(" ");
for (leni = 0; (len = parameters[leni]) != 0; leni++)
flint_printf("%5wd ", len);
flint_printf("\n");
for (bits = 96; bits <= 1024; bits += 32)
{
flint_printf("%5wd", bits);
fflush(stdout);
fmpz_init(p);
fmpz_randprime(p, state, bits, 0);
GR_MUST_SUCCEED(gr_ctx_init_mpn_mod(ctx, p));
best_table(state, OLD_VS_DEFAULT, ctx);
gr_ctx_clear(ctx);
flint_printf("\n");
}
fmpz_clear(p);
flint_rand_clear(state);
}