#include <string.h>
#include <gmp.h>
#include "ulong_extras.h"
#include "bernoulli.h"
#include "bernoulli/impl.h"
#define DEBUG 0
#define TIMING 1
ulong _bernoulli_n_muldivrem_precomp(ulong * q, ulong a, ulong b, ulong n, double bnpre)
{
ulong qq, r;
qq = (double) a * bnpre;
r = a * b - qq * n;
if ((slong) r < 0)
{
qq--;
r += n;
}
if (r >= n)
{
qq++;
r -= n;
}
*q = qq;
return r;
}
static ulong
bernoulli_sum_powg(ulong p, ulong pinv, ulong k, ulong g)
{
ulong half_gm1, sum, g_to_km1, g_to_jm1, g_to_km1_to_j, q, h;
slong j;
double g_pinv;
g_pinv = (double) g / (double) p;
half_gm1 = (g + ((g & 1) ? 0 : p) - 1) / 2;
g_to_km1 = n_powmod2_preinv(g, k-1, p, pinv);
g_to_jm1 = 1;
g_to_km1_to_j = g_to_km1;
sum = 0;
for (j = 1; j <= (p - 1) / 2; j++)
{
g_to_jm1 = _bernoulli_n_muldivrem_precomp(&q, g_to_jm1, g, p, g_pinv);
h = n_submod(q, half_gm1, p);
sum = n_submod(sum, n_mulmod2_preinv(h, g_to_km1_to_j, p, pinv), p);
g_to_km1_to_j = n_mulmod2_preinv(g_to_km1_to_j, g_to_km1, p, pinv);
}
return sum;
}
#define MAX_INV 256
typedef struct
{
ulong pinv[MAX_INV + 2];
ulong p;
int max_words;
}
expander_t;
static void
expander_init(expander_t * this, ulong p, int max_words)
{
ulong one;
FLINT_ASSERT(max_words >= 1);
FLINT_ASSERT(max_words <= MAX_INV);
this->max_words = max_words;
this->p = p;
one = 1;
mpn_divrem_1(this->pinv, max_words + 1, &one, 1, p);
}
static void
expander_expand(nn_ptr res, expander_t * this, ulong s, ulong n)
{
slong i;
FLINT_ASSERT(s > 0 && s < this->p);
FLINT_ASSERT(n >= 1);
FLINT_ASSERT(n <= this->max_words);
if (s == 1)
{
for (i = 1; i <= n; i++)
res[i] = this->pinv[this->max_words - n + i];
}
else
{
mpn_mul_1(res, this->pinv + this->max_words - n, n + 1, (ulong) s);
if (res[0] > -((ulong) s))
{
ulong ss = s;
mpn_divrem_1(res, n + 1, &ss, 1, this->p);
}
}
}
#define TABLE_LG_SIZE 8
#define TABLE_SIZE (WORD(1) << TABLE_LG_SIZE)
#define TABLE_MASK (TABLE_SIZE - 1)
#define NUM_TABLES (FLINT_BITS / TABLE_LG_SIZE)
#if FLINT_BITS % TABLE_LG_SIZE != 0
#error Number of bits in a ulong must be divisible by TABLE_LG_SIZE
#endif
static ulong bernsum_pow2(ulong p, ulong pinv, ulong k, ulong g, ulong n)
{
slong i, m;
ulong g_to_km1, two_to_km1, B_to_km1, s_jump;
ulong tables[NUM_TABLES][TABLE_SIZE];
ulong g_to_km1_to_i;
ulong g_to_i;
ulong sum;
expander_t expander;
slong h;
ulong x;
ulong weights[TABLE_SIZE];
ulong x_jump;
memset(tables, 0, sizeof(ulong) * NUM_TABLES * TABLE_SIZE);
m = (p-1) / n;
if (n & 1)
m >>= 1;
else
n >>= 1;
g_to_km1 = n_powmod2_preinv(g, k - 1, p, pinv);
two_to_km1 = n_powmod2_preinv(2, k - 1, p, pinv);
B_to_km1 = n_powmod2_preinv(two_to_km1, FLINT_BITS, p, pinv);
s_jump = n_powmod2_preinv(2, MAX_INV * FLINT_BITS, p, pinv);
g_to_km1_to_i = 1;
g_to_i = 1;
sum = 0;
expander_init(&expander, p, (n >= MAX_INV * FLINT_BITS)
? MAX_INV : ((n - 1) / FLINT_BITS + 1));
for (i = 0; i < m; i++)
{
ulong s, x, y;
slong nn;
s = g_to_i;
x = g_to_km1_to_i;
for (nn = n; nn > 0; nn -= MAX_INV * FLINT_BITS)
{
ulong s_over_p[MAX_INV + 2];
slong bits, words;
nn_ptr next;
if (nn >= MAX_INV * FLINT_BITS)
{
bits = MAX_INV * FLINT_BITS;
words = MAX_INV;
}
else
{
bits = nn;
words = (nn - 1) / FLINT_BITS + 1;
}
expander_expand(s_over_p, &expander, s, words);
next = s_over_p + words;
for (; bits >= FLINT_BITS; bits -= FLINT_BITS, next--)
{
ulong y;
#if NUM_TABLES != 8 && NUM_TABLES != 4
nn_ptr target;
#else
nn_ptr target0, target1, target2, target3, target4, target5, target6, target7;
#endif
y = *next;
#if NUM_TABLES != 8 && NUM_TABLES != 4
for (h = 0; h < NUM_TABLES; h++)
{
target = &(tables[h][y & TABLE_MASK]);
*target = n_submod(*target, x, p);
y >>= TABLE_LG_SIZE;
}
#else
target0 = &(tables[0][y & TABLE_MASK]);
*target0 = n_submod(*target0, x, p);
target1 = &(tables[1][(y >> TABLE_LG_SIZE) & TABLE_MASK]);
*target1 = n_submod(*target1, x, p);
target2 = &(tables[2][(y >> (2*TABLE_LG_SIZE)) & TABLE_MASK]);
*target2 = n_submod(*target2, x, p);
target3 = &(tables[3][(y >> (3*TABLE_LG_SIZE)) & TABLE_MASK]);
*target3 = n_submod(*target3, x, p);
#if NUM_TABLES == 8
target4 = &(tables[4][(y >> (4*TABLE_LG_SIZE)) & TABLE_MASK]);
*target4 = n_submod(*target4, x, p);
target5 = &(tables[5][(y >> (5*TABLE_LG_SIZE)) & TABLE_MASK]);
*target5 = n_submod(*target5, x, p);
target6 = &(tables[6][(y >> (6*TABLE_LG_SIZE)) & TABLE_MASK]);
*target6 = n_submod(*target6, x, p);
target7 = &(tables[7][(y >> (7*TABLE_LG_SIZE)) & TABLE_MASK]);
*target7 = n_submod(*target7, x, p);
#endif
#endif
x = n_mulmod2_preinv(x, B_to_km1, p, pinv);
}
y = *next;
for (; bits > 0; bits--)
{
if (y & (UWORD(1) << (FLINT_BITS - 1)))
sum = n_submod(sum, x, p);
else
sum = n_addmod(sum, x, p);
x = n_mulmod2_preinv(x, two_to_km1, p, pinv);
y <<= 1;
}
s = n_mulmod2_preinv(s, s_jump, p, pinv);
}
g_to_i = n_mulmod2_preinv(g_to_i, g, p, pinv);
g_to_km1_to_i = n_mulmod2_preinv(g_to_km1_to_i, g_to_km1, p, pinv);
}
#if DEBUG
{
slong i, j;
for (i = 0; i < NUM_TABLES; i++)
{
printf("tab[%lu] = ", i);
for (j = 0; j < TABLE_SIZE; j++)
printf("%lu ", tables[i][j]);
printf("\n");
}
}
#endif
weights[0] = 0;
for (h = 0, x = 1; h < TABLE_LG_SIZE;
h++, x = n_mulmod2_preinv(x, two_to_km1, p, pinv))
{
for (i = (WORD(1) << h) - 1; i >= 0; i--)
{
weights[2*i+1] = n_submod(weights[i], x, p);
weights[2*i] = n_addmod(weights[i], x, p);
}
}
x_jump = n_powmod2_preinv(two_to_km1, TABLE_LG_SIZE, p, pinv);
for (h = NUM_TABLES - 1, x = 1; h >= 0; h--)
{
for (i = 0; i < TABLE_SIZE; i++)
{
ulong y = n_mulmod2_preinv(tables[h][i], weights[i], p, pinv);
y = n_mulmod2_preinv(y, x, p, pinv);
sum = n_submod(sum, y, p);
}
x = n_mulmod2_preinv(x_jump, x, p, pinv);
}
return sum;
}
#define LOW_MASK ((UWORD(1) << (FLINT_BITS / 2)) - 1)
static inline ulong RedcFast(ulong x, ulong n, ulong ninv2)
{
ulong y = (x * ninv2) & LOW_MASK;
ulong z = x + (n * y);
return z >> (FLINT_BITS / 2);
}
static inline ulong Redc(ulong x, ulong n, ulong ninv2)
{
ulong y = RedcFast(x, n, ninv2);
if (y >= n)
y -= n;
return y;
}
static ulong PrepRedc(ulong n)
{
ulong bits;
ulong ninv2 = -n;
for (bits = 3; bits < FLINT_BITS/2; bits *= 2)
ninv2 = 2*ninv2 + n * ninv2 * ninv2;
return ninv2 & LOW_MASK;
}
static ulong bernsum_pow2_redc(ulong p, ulong pinv, ulong k, ulong g, ulong n)
{
ulong pinv2 = PrepRedc(p);
ulong F = (UWORD(1) << (FLINT_BITS/2)) % p;
ulong x;
slong h, i, m;
ulong weights[TABLE_SIZE];
ulong x_jump;
ulong x_jump_redc;
ulong g_to_km1;
ulong two_to_km1;
ulong B_to_km1;
ulong s_jump;
ulong g_redc;
ulong g_to_km1_redc;
ulong two_to_km1_redc;
ulong B_to_km1_redc;
ulong s_jump_redc;
ulong g_to_km1_to_i;
ulong g_to_i;
ulong sum;
ulong tables[NUM_TABLES][TABLE_SIZE];
expander_t expander;
memset(tables, 0, sizeof(ulong) * NUM_TABLES * TABLE_SIZE);
m = (p-1) / n;
if (n & 1)
m >>= 1;
else
n >>= 1;
g_to_km1 = n_powmod2_preinv(g, k-1, p, pinv);
two_to_km1 = n_powmod2_preinv(2, k-1, p, pinv);
B_to_km1 = n_powmod2_preinv(two_to_km1, FLINT_BITS, p, pinv);
s_jump = n_powmod2_preinv(2, MAX_INV * FLINT_BITS, p, pinv);
g_redc = n_mulmod2_preinv(g, F, p, pinv);
g_to_km1_redc = n_mulmod2_preinv(g_to_km1, F, p, pinv);
two_to_km1_redc = n_mulmod2_preinv(two_to_km1, F, p, pinv);
B_to_km1_redc = n_mulmod2_preinv(B_to_km1, F, p, pinv);
s_jump_redc = n_mulmod2_preinv(s_jump, F, p, pinv);
g_to_km1_to_i = 1;
g_to_i = 1;
sum = 0;
#if DEBUG
printf("%lu %lu %lu %lu %lu %lu %lu %lu %lu\n", F, g_to_km1, two_to_km1, B_to_km1, s_jump, g_redc, g_to_km1_redc, B_to_km1_redc, s_jump_redc);
#endif
expander_init(&expander, p, (n >= MAX_INV * FLINT_BITS)
? MAX_INV : ((n - 1) / FLINT_BITS + 1));
for (i = 0; i < m; i++)
{
ulong s, x, y;
slong nn, bits, words;
nn_ptr next;
s = g_to_i;
if (s >= p)
s -= p;
x = g_to_km1_to_i;
for (nn = n; nn > 0; nn -= MAX_INV * FLINT_BITS)
{
ulong s_over_p[MAX_INV + 2];
if (nn >= MAX_INV * FLINT_BITS)
{
bits = MAX_INV * FLINT_BITS;
words = MAX_INV;
}
else
{
bits = nn;
words = (nn - 1) / FLINT_BITS + 1;
}
expander_expand(s_over_p, &expander, s, words);
next = s_over_p + words;
for (; bits >= FLINT_BITS; bits -= FLINT_BITS, next--)
{
y = *next;
#if DEBUG
printf("i = %lu nn = %lu words = %lu bits = %lu y = %lu\n", i, nn, words, bits, y);
#endif
#if NUM_TABLES != 8 && NUM_TABLES != 4
for (h = 0; h < NUM_TABLES; h++)
{
tables[h][y & TABLE_MASK] += x;
y >>= TABLE_LG_SIZE;
}
#else
tables[0][ y & TABLE_MASK] += x;
tables[1][(y >> TABLE_LG_SIZE ) & TABLE_MASK] += x;
tables[2][(y >> (2*TABLE_LG_SIZE)) & TABLE_MASK] += x;
tables[3][(y >> (3*TABLE_LG_SIZE)) & TABLE_MASK] += x;
#if NUM_TABLES == 8
tables[4][(y >> (4*TABLE_LG_SIZE)) & TABLE_MASK] += x;
tables[5][(y >> (5*TABLE_LG_SIZE)) & TABLE_MASK] += x;
tables[6][(y >> (6*TABLE_LG_SIZE)) & TABLE_MASK] += x;
tables[7][(y >> (7*TABLE_LG_SIZE)) & TABLE_MASK] += x;
#endif
#endif
x = RedcFast(x * B_to_km1_redc, p, pinv2);
}
if (x >= p)
x -= p;
y = *next;
for (; bits > 0; bits--)
{
if (y & (UWORD(1) << (FLINT_BITS - 1)))
sum = n_submod(sum, x, p);
else
sum = n_addmod(sum, x, p);
x = Redc(x * two_to_km1_redc, p, pinv2);
y <<= 1;
}
s = Redc(s * s_jump_redc, p, pinv2);
}
g_to_i = RedcFast(g_to_i * g_redc, p, pinv2);
g_to_km1_to_i = RedcFast(g_to_km1_to_i * g_to_km1_redc, p, pinv2);
}
#if DEBUG
{
slong i, j;
for (i = 0; i < NUM_TABLES; i++)
{
printf("tab[%lu] = ", i);
for (j = 0; j < TABLE_SIZE; j++)
printf("%lu ", tables[i][j]);
printf("\n");
}
}
#endif
weights[0] = 0;
for (h = 0, x = n_powmod2_preinv(2, 3*FLINT_BITS/2, p, pinv);
h < TABLE_LG_SIZE; h++, x = Redc(x * two_to_km1_redc, p, pinv2))
{
for (i = (WORD(1) << h) - 1; i >= 0; i--)
{
weights[2*i+1] = n_submod(weights[i], x, p);
weights[2*i] = n_addmod(weights[i], x, p);
}
}
x_jump = n_powmod2_preinv(two_to_km1, TABLE_LG_SIZE, p, pinv);
x_jump_redc = n_mulmod2_preinv(x_jump, F, p, pinv);
for (h = NUM_TABLES - 1, x = 1; h >= 0; h--)
{
for (i = 0; i < TABLE_SIZE; i++)
{
ulong y;
y = RedcFast(tables[h][i], p, pinv2);
y = RedcFast(y * weights[i], p, pinv2);
y = RedcFast(y * x, p, pinv2);
sum += y;
}
x = Redc(x * x_jump_redc, p, pinv2);
}
return sum % p;
}
ulong _bernoulli_mod_p_harvey_powg(ulong p, ulong pinv, ulong k)
{
ulong x, g, t, g_to_k;
g = n_primitive_root_prime(p);
x = bernoulli_sum_powg(p, pinv, k, g);
g_to_k = n_powmod2_preinv(g, k, p, pinv);
t = n_invmod(p + 1 - g_to_k, p);
x = n_mulmod2_preinv(x, t, p, pinv);
x = n_addmod(x, x, p);
return x;
}
static ulong
n_multiplicative_order(ulong x, ulong p, ulong pinv, n_factor_t * F)
{
ulong m, q, mm;
slong i;
m = p - 1;
for (i = 0; i < F->num; i++)
{
q = F->p[i];
while (m % q == 0)
{
mm = m / q;
if (n_powmod2_preinv(x, mm, p, pinv) != 1)
break;
m = mm;
}
}
return m;
}
ulong _bernoulli_mod_p_harvey_pow2(ulong p, ulong pinv, ulong k)
{
n_factor_t F;
ulong g, n, x, t;
n_factor_init(&F);
n_factor(&F, p - 1, 1);
g = n_primitive_root_prime_prefactor(p, &F);
n = n_multiplicative_order(2, p, pinv, &F);
#if DEBUG
printf("g = %lu, n = %lu\n", g, n);
#endif
if (p < (UWORD(1) << (FLINT_BITS/2 - 1)))
x = bernsum_pow2_redc(p, pinv, k, g, n);
else
x = bernsum_pow2(p, pinv, k, g, n);
t = n_submod(n_invmod(n_powmod2_preinv(2, k, p, pinv), p), 1, p);
t = n_addmod(t, t, p);
t = n_invmod(t, p);
x = n_mulmod2_preinv(x, t, p, pinv);
return x;
}
static ulong _bernoulli_mod_p_harvey(ulong p, ulong pinv, ulong k)
{
if (n_powmod2_preinv(2, k, p, pinv) != 1)
{
return _bernoulli_mod_p_harvey_pow2(p, pinv, k);
}
else
{
return _bernoulli_mod_p_harvey_powg(p, pinv, k);
}
}
ulong bernoulli_mod_p_harvey(ulong k, ulong p)
{
ulong m, x, pinv;
FLINT_ASSERT(k >= 0);
FLINT_ASSERT(2 <= p && p < (UWORD(1) << FLINT_D_BITS));
if (k == 0)
return 1;
if (k == 1)
{
if (p == 2)
return -1;
return (p - 1) / 2;
}
if (k & 1)
return 0;
if (p <= 3)
return UWORD_MAX;
m = k % (p - 1);
if (m == 0)
return UWORD_MAX;
pinv = n_preinvert_limb(p);
x = _bernoulli_mod_p_harvey(p, pinv, m);
return n_mulmod2_preinv(x, k % p, p, pinv);
}