#include "thread_pool.h"
#include "thread_support.h"
#include "nmod.h"
#include "nmod_mat.h"
#include "mpn_mod.h"
#include "gr_mat.h"
typedef struct {
slong m;
slong k;
slong n;
slong Astartrow;
slong Astoprow;
slong Bstartrow;
slong Bstoprow;
slong Cstartrow;
slong Cstoprow;
nn_ptr Aentries;
slong Astride;
nn_ptr Bentries;
slong Bstride;
nn_ptr Centries;
slong Cstride;
nmod_mat_t * mod_A;
nmod_mat_t * mod_B;
nmod_mat_t * mod_C;
slong num_primes;
nn_ptr primes;
gr_ctx_struct * ctx;
} _worker_arg;
FLINT_FORCE_INLINE ulong
nmod_set_mpn_2(nn_srcptr ad, nmod_t mod)
{
ulong r = 0;
NMOD_RED2(r, r, ad[1], mod);
NMOD_RED2(r, r, ad[0], mod);
return r;
}
#if 0#endif
FLINT_FORCE_INLINE ulong
nmod_set_mpn(nn_srcptr ad, slong an, nmod_t mod)
{
return mpn_mod_1(ad, an, mod.n);
}
static void _mod_worker(void * varg)
{
_worker_arg * arg = (_worker_arg *) varg;
slong i, j, l;
slong k = arg->k;
slong n = arg->n;
slong Astartrow = arg->Astartrow;
slong Astoprow = arg->Astoprow;
slong Bstartrow = arg->Bstartrow;
slong Bstoprow = arg->Bstoprow;
nn_ptr Aentries = arg->Aentries;
slong Astride = arg->Astride;
nn_ptr Bentries = arg->Bentries;
slong Bstride = arg->Bstride;
nmod_mat_t * mod_A = arg->mod_A;
nmod_mat_t * mod_B = arg->mod_B;
slong num_primes = arg->num_primes;
slong nlimbs = MPN_MOD_CTX_NLIMBS(arg->ctx);
ulong first_prime = UWORD(1) << (FLINT_BITS - 1);
if (nlimbs == 2 && arg->primes[0] == first_prime)
{
for (i = Astartrow; i < Astoprow; i++)
{
for (j = 0; j < k; j++)
{
nmod_mat_entry(mod_A[0], i, j) = (Aentries + (i * Astride + j) * nlimbs)[0] & (first_prime - 1);
for (l = 1; l < num_primes; l++)
nmod_mat_entry(mod_A[l], i, j) = nmod_set_mpn_2(Aentries + (i * Astride + j) * nlimbs, mod_A[l]->mod);
}
}
if (mod_B != NULL)
{
for (i = Bstartrow; i < Bstoprow; i++)
for (j = 0; j < n; j++)
{
nmod_mat_entry(mod_B[0], i, j) = (Bentries + (i * Bstride + j) * nlimbs)[0] & (first_prime - 1);
for (l = 1; l < num_primes; l++)
nmod_mat_entry(mod_B[l], i, j) = nmod_set_mpn_2(Bentries + (i * Bstride + j) * nlimbs, mod_A[l]->mod);
}
}
}
else
{
for (i = Astartrow; i < Astoprow; i++)
for (j = 0; j < k; j++)
for (l = 0; l < num_primes; l++)
nmod_mat_entry(mod_A[l], i, j) = nmod_set_mpn(Aentries + (i * Astride + j) * nlimbs, nlimbs, mod_A[l]->mod);
if (mod_B != NULL)
{
for (i = Bstartrow; i < Bstoprow; i++)
for (j = 0; j < n; j++)
for (l = 0; l < num_primes; l++)
nmod_mat_entry(mod_B[l], i, j) = nmod_set_mpn(Bentries + (i * Bstride + j) * nlimbs, nlimbs, mod_A[l]->mod);
}
}
}
static void _crt_worker(void * varg)
{
_worker_arg * arg = (_worker_arg *) varg;
slong i, j, l;
slong n = arg->n;
slong Cstartrow = arg->Cstartrow;
slong Cstoprow = arg->Cstoprow;
nn_ptr Centries = arg->Centries;
slong Cstride = arg->Cstride;
nmod_mat_t * mod_C = arg->mod_C;
ulong * primes = arg->primes;
slong num_primes = arg->num_primes;
gr_ctx_struct * ctx = arg->ctx;
slong nlimbs = MPN_MOD_CTX_NLIMBS(ctx);
{
nn_ptr M, Ns, T, U;
slong Msize, Nsize;
ulong cy, ri;
M = FLINT_ARRAY_ALLOC(num_primes + 1, ulong);
M[0] = primes[0];
Msize = 1;
for (i = 1; i < num_primes; i++)
{
FLINT_ASSERT(Msize > 0);
M[Msize] = cy = mpn_mul_1(M, M, Msize, primes[i]);
Msize += (cy != 0);
}
Nsize = Msize + 2;
Ns = FLINT_ARRAY_ALLOC(Nsize*num_primes, ulong);
T = FLINT_ARRAY_ALLOC(Nsize, ulong);
U = FLINT_ARRAY_ALLOC(Nsize, ulong);
for (i = 0; i < num_primes; i++)
{
Ns[i*Nsize + (Nsize - 1)] = 0;
Ns[i*Nsize + (Nsize - 2)] = 0;
mpn_divrem_1(Ns + i * Nsize, 0, M, Msize, primes[i]);
ri = mpn_mod_1(Ns + i * Nsize, Msize, primes[i]);
ri = n_invmod(ri, primes[i]);
FLINT_ASSERT(Msize > 0);
Ns[i*Nsize + Msize] = mpn_mul_1(Ns + i*Nsize, Ns + i*Nsize, Msize, ri);
}
for (i = Cstartrow; i < Cstoprow; i++)
for (j = 0; j < n; j++)
{
ri = nmod_mat_entry(mod_C[0], i, j);
FLINT_ASSERT(Nsize > 1);
T[Nsize - 1] = mpn_mul_1(T, Ns, Nsize - 1, ri);
for (l = 1; l < num_primes; l++)
{
ri = nmod_mat_entry(mod_C[l], i, j);
T[Nsize - 1] += mpn_addmul_1(T, Ns + l*Nsize, Nsize - 1, ri);
}
mpn_tdiv_qr(U, T, 0, T, Nsize, M, Msize);
mpn_mod_set_mpn(Centries + (i * Cstride + j) * nlimbs, T, Msize, ctx);
}
flint_free(M);
flint_free(Ns);
flint_free(T);
flint_free(U);
}
}
int mpn_mod_mat_mul_multi_mod(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx)
{
slong i, start, stop;
slong m, k, n;
flint_bitcnt_t primes_bits;
_worker_arg mainarg;
_worker_arg * args;
slong num_workers;
thread_pool_handle * handles;
slong limit;
ulong first_prime;
int squaring = (A == B);
flint_bitcnt_t Abits, Bbits, Cbits, bits, mod_bits;
slong nlimbs;
nlimbs = MPN_MOD_CTX_NLIMBS(ctx);
mod_bits = FLINT_BITS * (nlimbs - 1) + (FLINT_BITS - MPN_MOD_CTX_NORM(ctx));
Abits = mod_bits;
Bbits = mod_bits;
Cbits = Abits + Bbits + FLINT_BIT_COUNT(A->c);
bits = Cbits;
mainarg.m = m = A->r;
mainarg.k = k = A->c;
mainarg.n = n = B->c;
if (m < 1 || n < 1 || k < 1)
return gr_mat_zero(C, ctx);
mainarg.ctx = ctx;
mainarg.Aentries = (nn_ptr) A->entries;
mainarg.Astride = A->stride;
mainarg.Bentries = (nn_ptr) B->entries;
mainarg.Bstride = B->stride;
mainarg.Centries = (nn_ptr) C->entries;
mainarg.Cstride = C->stride;
primes_bits = NMOD_MAT_OPTIMAL_MODULUS_BITS;
if (bits < primes_bits || bits <= FLINT_BITS - 1)
{
mainarg.num_primes = 1;
first_prime = UWORD(1) << bits;
}
else
{
mainarg.num_primes = 1 + (bits - (FLINT_BITS - 1) + primes_bits - 1)/primes_bits;
first_prime = UWORD(1) << (FLINT_BITS - 1);
}
mainarg.primes = FLINT_ARRAY_ALLOC(mainarg.num_primes, ulong);
mainarg.primes[0] = first_prime;
if (mainarg.num_primes > 1)
{
mainarg.primes[1] = n_nextprime(UWORD(1) << primes_bits, 0);
for (i = 2; i < mainarg.num_primes; i++)
mainarg.primes[i] = n_nextprime(mainarg.primes[i-1], 0);
}
mainarg.mod_A = FLINT_ARRAY_ALLOC(mainarg.num_primes, nmod_mat_t);
if (squaring)
mainarg.mod_B = NULL;
else
mainarg.mod_B = FLINT_ARRAY_ALLOC(mainarg.num_primes, nmod_mat_t);
mainarg.mod_C = FLINT_ARRAY_ALLOC(mainarg.num_primes, nmod_mat_t);
for (i = 0; i < mainarg.num_primes; i++)
{
nmod_mat_init(mainarg.mod_A[i], A->r, A->c, mainarg.primes[i]);
if (!squaring)
nmod_mat_init(mainarg.mod_B[i], B->r, B->c, mainarg.primes[i]);
nmod_mat_init(mainarg.mod_C[i], C->r, C->c, mainarg.primes[i]);
}
limit = ((m + k + n)/128)*(1 + bits/1024);
limit = FLINT_MIN(limit, (m + k)/4);
if (limit < 2)
{
mod_single:
mainarg.Astartrow = 0;
mainarg.Astoprow = m;
mainarg.Bstartrow = 0;
mainarg.Bstoprow = k;
_mod_worker(&mainarg);
}
else
{
num_workers = flint_request_threads(&handles, limit);
if (num_workers < 1)
{
flint_give_back_threads(handles, num_workers);
goto mod_single;
}
args = FLINT_ARRAY_ALLOC(num_workers, _worker_arg);
for (start = 0, i = 0; i < num_workers; start = stop, i++)
{
args[i] = mainarg;
stop = _thread_pool_find_work_2(m, k, k, n, i + 1, num_workers + 1);
_thread_pool_distribute_work_2(start, stop,
&args[i].Astartrow, &args[i].Astoprow, m,
&args[i].Bstartrow, &args[i].Bstoprow, k);
}
_thread_pool_distribute_work_2(start, m + k,
&mainarg.Astartrow, &mainarg.Astoprow, m,
&mainarg.Bstartrow, &mainarg.Bstoprow, k);
for (i = 0; i < num_workers; i++)
thread_pool_wake(global_thread_pool, handles[i], 0, _mod_worker, &args[i]);
_mod_worker(&mainarg);
for (i = 0; i < num_workers; i++)
thread_pool_wait(global_thread_pool, handles[i]);
flint_give_back_threads(handles, num_workers);
flint_free(args);
}
for (i = 0; i < mainarg.num_primes; i++)
nmod_mat_mul(mainarg.mod_C[i], mainarg.mod_A[i], squaring ? mainarg.mod_A[i] : mainarg.mod_B[i]);
limit = ((m + n)/64)*(1 + bits/1024);
limit = FLINT_MIN(limit, m/2);
if (limit < 2)
{
crt_single:
mainarg.Cstartrow = 0;
mainarg.Cstoprow = m;
_crt_worker(&mainarg);
}
else
{
num_workers = flint_request_threads(&handles, limit);
if (num_workers < 1)
{
flint_give_back_threads(handles, num_workers);
goto crt_single;
}
args = FLINT_ARRAY_ALLOC(num_workers, _worker_arg);
for (start = 0, i = 0; i < num_workers; start = stop, i++)
{
args[i] = mainarg;
stop = (i + 1)*m/(num_workers + 1);
args[i].Cstartrow = start;
args[i].Cstoprow = stop;
}
mainarg.Cstartrow = start;
mainarg.Cstoprow = m;
for (i = 0; i < num_workers; i++)
thread_pool_wake(global_thread_pool, handles[i], 0, _crt_worker, &args[i]);
_crt_worker(&mainarg);
for (i = 0; i < num_workers; i++)
thread_pool_wait(global_thread_pool, handles[i]);
flint_give_back_threads(handles, num_workers);
flint_free(args);
}
for (i = 0; i < mainarg.num_primes; i++)
{
nmod_mat_clear(mainarg.mod_A[i]);
if (!squaring)
nmod_mat_clear(mainarg.mod_B[i]);
nmod_mat_clear(mainarg.mod_C[i]);
}
flint_free(mainarg.mod_A);
if (!squaring)
flint_free(mainarg.mod_B);
flint_free(mainarg.mod_C);
flint_free(mainarg.primes);
return GR_SUCCESS;
}