#include <gmp.h>
#include "fmpz.h"
#include "ulong_extras.h"
#include "profiler.h"
void
_fmpz_crt_combine(fmpz_t r1r2, fmpz_t m1m2, const fmpz_t r1, const fmpz_t m1, const fmpz_t r2, const fmpz_t m2)
{
fmpz_invmod(m1m2, m1, m2);
fmpz_mul(m1m2, m1m2, m1);
fmpz_sub(r1r2, r2, r1);
fmpz_mul(r1r2, r1r2, m1m2);
fmpz_add(r1r2, r1r2, r1);
fmpz_mul(m1m2, m1, m2);
fmpz_mod(r1r2, r1r2, m1m2);
}
void
_fmpz_crt_combine_uiui(fmpz_t r1r2, fmpz_t m1m2, ulong r1, ulong m1, ulong r2, ulong m2)
{
ulong M[2];
umul_ppmm(M[1], M[0], m1, m2);
if (M[1] == 0)
{
ulong c, v;
c = n_invmod(m1, m2) * m1;
if (r2 >= r1)
v = n_mulmod2(r2 - r1, c, M[0]);
else
v = n_mulmod2(n_negmod(r1 - r2, M[0]), c, M[0]);
v = n_addmod(v, r1, M[0]);
fmpz_set_ui(r1r2, v);
fmpz_set_ui(m1m2, M[0]);
}
else
{
ulong c[2], t[4], q[3], r[3];
umul_ppmm(c[1], c[0], n_invmod(m1, m2), m1);
if (r2 >= r1)
{
t[2] = mpn_mul_1(t, c, 2, r2 - r1);
mpn_add_1(t, t, 3, r1);
mpn_tdiv_qr(q, r, 0, t, 3, M, 2);
}
else
{
sub_ddmmss(r[1], r[0], M[1], M[0], 0, r1 - r2);
mpn_mul_n(t, c, r, 2);
mpn_add_1(t, t, 4, r1);
mpn_tdiv_qr(q, r, 0, t, 4, M, 2);
}
fmpz_set_uiui(r1r2, r[1], r[0]);
fmpz_set_uiui(m1m2, M[1], M[0]);
}
}
void
tree_crt(fmpz_t r, fmpz_t m, nn_srcptr residues, nn_srcptr primes, slong len)
{
if (len == 0)
{
fmpz_zero(r);
fmpz_one(m);
}
else if (len == 1)
{
fmpz_set_ui(r, residues[0]);
fmpz_set_ui(m, primes[0]);
}
else if (len == 2)
{
_fmpz_crt_combine_uiui(r, m, residues[0], primes[0], residues[1], primes[1]);
}
else
{
fmpz_t r1, m1, r2, m2;
fmpz_init(r1);
fmpz_init(m1);
fmpz_init(r2);
fmpz_init(m2);
tree_crt(r1, m1, residues, primes, len / 2);
tree_crt(r2, m2, residues + len / 2, primes + len / 2, len - len / 2);
_fmpz_crt_combine(r, m, r1, m1, r2, m2);
fmpz_clear(r1);
fmpz_clear(m1);
fmpz_clear(r2);
fmpz_clear(m2);
}
}
void
fmpz_print1(const fmpz_t n)
{
}
void
benchmark(slong num_primes, slong prime_bits)
{
flint_rand_t state;
fmpz_comb_temp_t temp;
fmpz_comb_t comb;
nn_ptr primes, residues;
fmpz_t res;
slong k;
flint_rand_init(state);
primes = flint_malloc(num_primes * sizeof(ulong));
residues = flint_malloc(num_primes * sizeof(ulong));
fmpz_init(res);
primes[0] = n_nextprime(UWORD(1) << (prime_bits - 1), 0);
for (k = 1; k < num_primes; k++)
primes[k] = n_nextprime(primes[k-1], 0);
for (k = 0; k < num_primes; k++)
residues[k] = n_randint(state, primes[k]);
printf("simple tree: ");
TIMEIT_START;
fmpz_t tmp;
fmpz_init(tmp);
tree_crt(res, tmp, residues, primes, num_primes);
fmpz_clear(tmp);
TIMEIT_STOP;
fmpz_print1(res);
printf("multi CRT: ");
TIMEIT_START;
fmpz_comb_init(comb, primes, num_primes);
fmpz_comb_temp_init(temp, comb);
fmpz_multi_CRT_ui(res, residues, comb, temp, 0);
fmpz_comb_clear(comb);
fmpz_comb_temp_clear(temp);
TIMEIT_STOP;
fmpz_print1(res);
printf("multi precomp: ");
fmpz_comb_init(comb, primes, num_primes);
fmpz_comb_temp_init(temp, comb);
TIMEIT_START;
fmpz_multi_CRT_ui(res, residues, comb, temp, 0);
TIMEIT_STOP;
fmpz_comb_clear(comb);
fmpz_comb_temp_clear(temp);
fmpz_print1(res);
flint_free(primes);
flint_free(residues);
fmpz_clear(res);
}
int main(void)
{
slong len, bits;
bits = 5;
for (len = 1; len <= 4000; len = FLINT_MAX(len * 1.5, len + 1))
{
printf("bits = %ld, len = %ld\n", bits, len);
benchmark(len, bits);
}
bits = 64;
for (len = 1; len <= 4000; len = FLINT_MAX(len * 1.5, len + 1))
{
printf("bits = %ld, len = %ld\n", bits, len);
benchmark(len, bits);
}
return 0;
}