#include "profiler.h"
#include "mpn_extras.h"
#if FLINT_HAVE_ASSEMBLY_x86_64_adx
#undef TIMEIT_END_REPEAT
#define TIMEIT_END_REPEAT(__timer, __reps) \
} \
timeit_stop(__timer); \
if (__timer->cpu >= 10) \
break; \
__reps *= 10; \
} \
} while (0)
short ktab[FLINT_MPN_SQRHIGH_K_TAB_SIZE] = { 0 };
void
_flint_mpn_mulhigh_n_mulders_recursive(mp_ptr rp, mp_srcptr np, mp_srcptr mp, mp_size_t n);
void
sqrhigh(mp_ptr rp, mp_srcptr np, mp_size_t n)
{
mp_limb_t cy;
mp_size_t l;
slong k;
if (FLINT_HAVE_SQRHIGH_FUNC(n))
{
rp[n - 1] = flint_mpn_sqrhigh_func_tab[n](rp + n, np);
return;
}
if (n < FLINT_MPN_SQRHIGH_K_TAB_SIZE)
k = ktab[n];
else
k = (n + 4) / 2;
if (k == 0)
{
if (n & 1)
rp[n - 1] = _flint_mpn_sqrhigh_basecase_odd(rp + n, np, n >> 1);
else
rp[n - 1] = _flint_mpn_sqrhigh_basecase_even(rp + n, np, n >> 1);
return;
}
FLINT_ASSERT(k >= (n + 4) / 2);
if (k == n)
{
flint_mpn_sqr(rp, np, n);
return;
}
l = n - k;
flint_mpn_sqr(rp + 2 * l, np + l, k);
_flint_mpn_mulhigh_n_mulders_recursive(rp, np, np + k, l);
cy = mpn_lshift(rp + l - 1, rp + l - 1, l + 1, 1);
cy += mpn_add_n(rp + n - 1, rp + n - 1, rp + l - 1, l + 1);
mpn_add_1(rp + n + l, rp + n + l, k, cy);
}
int
main()
{
mp_ptr X, Z;
slong n, k, i, kbest;
double tcpu, twall, tbase, tbest, tfull;
int j;
X = flint_malloc(sizeof(mp_limb_t) * FLINT_MPN_SQRHIGH_K_TAB_SIZE);
Z = flint_malloc(sizeof(mp_limb_t) * 2 * FLINT_MPN_SQRHIGH_K_TAB_SIZE);
for (i = 0; i < FLINT_MPN_SQRHIGH_K_TAB_SIZE; i++)
{
X[i] = -i;
}
for (n = 8; n < FLINT_MPN_SQRHIGH_K_TAB_SIZE; n++)
{
flint_printf("n = %wd ", n);
fflush(stdout);
tfull = 1e100;
tbase = 1e100;
tbest = 1e100;
ktab[n] = 0;
kbest = 0;
for (j = 0; j < 2; j++)
{
ktab[n] = 0;
TIMEIT_START;
sqrhigh(Z, X, n);
TIMEIT_STOP_VALUES(tcpu, twall);
tbase = FLINT_MIN(tbase, tcpu);
if (tcpu < tbest)
{
tbest = tcpu;
kbest = 0;
}
for (k = (n + 4) / 2; k <= n; k++)
{
if (n > 64 && !(k == n || k % 4 == 0))
continue;
if (n > 512 && !(k == n || k % 8 == 0))
continue;
if (n > 1024 && !(k == n || k % 16 == 0))
continue;
if (n > 768 && k < 0.7 * n)
continue;
ktab[n] = k;
TIMEIT_START;
sqrhigh(Z, X, n);
TIMEIT_STOP_VALUES(tcpu, twall);
if (tcpu < tbest)
{
tbest = tcpu;
kbest = k;
}
if (k == n)
tfull = FLINT_MIN(tfull, tcpu);
}
}
ktab[n] = kbest;
flint_printf("%wd %.3f %.3f %.3f\n", kbest, kbest / (double) n, tbase / tbest, tfull / tbest);
if (n % 8 == 7)
{
flint_printf("{");
for (i = 0; i <= n; i++)
{
flint_printf("%wd, ", (slong) ktab[i]);
if (i % 30 == 29)
flint_printf("\n");
}
flint_printf("}\n");
}
(void) twall;
}
flint_free(X);
flint_free(Z);
}
#else
int main(void) { return 0; }
#endif