#include "thread_support.h"
#include "mpn_mod.h"
#include "gr_mat.h"
static const short mat_mul_cutoff_tab[][2] = {
{88, 88},
{230, 230},
{220, 220},
{160, 160},
{10, 39},
{9, 57},
{17, 57},
{74, 74},
{8, 64},
{8, 64},
{8, 92},
{17, 63},
{8, 69},
{8, 65},
{8, 91},
{14, 67},
{8, 75},
{7, 77},
{7, 93},
{10, 79},
{6, 77},
{6, 91},
{6, 93},
{10, 73},
{6, 75},
{6, 87},
{6, 87},
{9, 71},
{6, 70},
{5, 77},
{5, 84},
{6, 67},
{5, 75},
{5, 74},
{4, 78},
{6, 63},
{4, 68},
{4, 68},
{4, 75},
{6, 59},
{4, 66},
{4, 66},
{4, 71},
{5, 59},
{4, 59},
{4, 66},
{4, 63},
{4, 58},
{4, 57},
{4, 61},
{4, 61},
{4, 55},
{4, 57},
{4, 64},
{4, 65},
{4, 57},
{4, 61},
{4, 62},
{4, 65},
{4, 53},
};
int
mpn_mod_mat_mul(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx)
{
slong ar = A->r;
if (ar <= 3)
return gr_mat_mul_classical(C, A, B, ctx);
slong ac = A->c;
slong bc = B->c;
slong tab_i, cutoff_waksman, cutoff_multi_mod;
slong bits = MPN_MOD_CTX_MODULUS_BITS(ctx);
tab_i = (bits - FLINT_BITS - 1) / 16;
cutoff_waksman = mat_mul_cutoff_tab[tab_i][0];
cutoff_multi_mod = mat_mul_cutoff_tab[tab_i][1];
if (ar < cutoff_waksman || ac < cutoff_waksman || bc < cutoff_waksman)
return gr_mat_mul_classical(C, A, B, ctx);
if (ar < cutoff_multi_mod || ac < cutoff_multi_mod || bc < cutoff_multi_mod)
return mpn_mod_mat_mul_waksman(C, A, B, ctx);
#ifndef FLINT_USES_BLAS
if (bits >= 113 && bits <= 128 && flint_get_num_available_threads() == 1)
return gr_mat_mul_strassen(C, A, B, ctx);
#endif
return mpn_mod_mat_mul_multi_mod(C, A, B, ctx);
}