#include "mpn_mod.h"
#include "gr_mat.h"
FLINT_FORCE_INLINE void
addmul_addadd(nn_ptr val0, nn_ptr val1, nn_ptr val2, nn_ptr c, nn_srcptr a1, nn_srcptr b1, nn_srcptr a2, nn_srcptr b2, slong nlimbs, int add_can_overflow_nlimbs)
{
if (!add_can_overflow_nlimbs)
{
mpn_add_n(val1, a1, b1, nlimbs);
mpn_add_n(val2, a2, b2, nlimbs);
flint_mpn_mul_n(val0, val1, val2, nlimbs);
c[2 * nlimbs] += mpn_add_n(c, c, val0, 2 * nlimbs);
}
else
{
val1[nlimbs] = mpn_add_n(val1, a1, b1, nlimbs);
val2[nlimbs] = mpn_add_n(val2, a2, b2, nlimbs);
flint_mpn_mul_n(val0, val1, val2, nlimbs + 1);
FLINT_ASSERT(val0[2 * nlimbs + 1] == 0);
mpn_add_n(c, c, val0, 2 * nlimbs + 1);
}
}
FLINT_FORCE_INLINE void
addmul_subsub(nn_ptr val0, nn_ptr val1, nn_ptr val2, nn_ptr c, nn_srcptr a1, nn_srcptr b1, nn_srcptr a2, nn_srcptr b2, slong nlimbs)
{
int neg;
neg = flint_mpn_signed_sub_n(val1, a1, b1, nlimbs);
neg ^= flint_mpn_signed_sub_n(val2, a2, b2, nlimbs);
flint_mpn_mul_n(val0, val1, val2, nlimbs);
if (neg)
c[2 * nlimbs] -= mpn_sub_n(c, c, val0, 2 * nlimbs);
else
c[2 * nlimbs] += mpn_add_n(c, c, val0, 2 * nlimbs);
}
int mpn_mod_mat_mul_waksman(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx)
{
slong nlimbs = MPN_MOD_CTX_NLIMBS(ctx);
slong slimbs = 2 * nlimbs + 1;
slong m = A->r;
slong n = B->r;
slong p = B->c;
slong Astride = A->stride;
slong Bstride = B->stride;
slong Cstride = C->stride;
nn_srcptr Aentries = A->entries;
nn_srcptr Bentries = B->entries;
nn_ptr Centries = C->entries;
int add_can_overflow_nlimbs = (MPN_MOD_CTX_NORM(ctx) == 0);
if (m == 0 || n == 0 || p == 0)
return gr_mat_zero(C, ctx);
slong i, l, j, k;
nn_ptr Ctmp = flint_calloc(slimbs * ((m * p) + (p + m) + 5), sizeof(ulong));
nn_ptr Crow = Ctmp + slimbs * (m * p);
nn_ptr Ccol = Crow + slimbs * p;
nn_ptr val0 = Ccol + slimbs * m;
nn_ptr val1 = val0 + 2 * slimbs;
nn_ptr val2 = val1 + slimbs;
nn_ptr crow = val2 + slimbs;
#define A_ENTRY(ii, jj) (Aentries + ((ii) * Astride + (jj)) * nlimbs)
#define B_ENTRY(ii, jj) (Bentries + ((ii) * Bstride + (jj)) * nlimbs)
#define C_ENTRY(ii, jj) (Ctmp + ((ii) * p + (jj)) * slimbs)
#define COUT_ENTRY(ii, jj) (Centries + ((ii) * Cstride + (jj)) * nlimbs)
#define Crow_ENTRY(ii) (Crow + (ii) * slimbs)
#define Ccol_ENTRY(ii) (Ccol + (ii) * slimbs)
slong np = n >> 1;
for (j = 1; j <= np; j++)
{
slong j2 = (j << 1) - 1;
for (k = 0; k < p; k++)
{
addmul_addadd(val0, val1, val2, C_ENTRY(0, k), A_ENTRY(0, j2-1), B_ENTRY(j2, k), A_ENTRY(0, j2), B_ENTRY(j2-1, k), nlimbs, add_can_overflow_nlimbs);
addmul_subsub(val0, val1, val2, Crow_ENTRY(k), A_ENTRY(0, j2-1), B_ENTRY(j2, k), A_ENTRY(0, j2), B_ENTRY(j2-1, k), nlimbs);
}
for (l = 1; l < m; l++)
{
addmul_addadd(val0, val1, val2, C_ENTRY(l, 0), A_ENTRY(l, j2-1), B_ENTRY(j2, 0), A_ENTRY(l, j2), B_ENTRY(j2-1, 0), nlimbs, add_can_overflow_nlimbs);
addmul_subsub(val0, val1, val2, Ccol_ENTRY(l), A_ENTRY(l, j2-1), B_ENTRY(j2, 0), A_ENTRY(l, j2), B_ENTRY(j2-1, 0), nlimbs);
}
for (k = 1; k < p; k++)
{
for (l = 1; l < m; l++)
{
addmul_addadd(val0, val1, val2, C_ENTRY(l, k), A_ENTRY(l, j2-1), B_ENTRY(j2, k), A_ENTRY(l, j2), B_ENTRY(j2-1, k), nlimbs, add_can_overflow_nlimbs);
}
}
}
for (l = 1; l < m; l++)
{
mpn_add_n(val1, Ccol_ENTRY(l), C_ENTRY(l, 0), slimbs);
flint_mpn_signed_div2(Ccol_ENTRY(l), val1, slimbs);
mpn_sub_n(C_ENTRY(l, 0), C_ENTRY(l, 0), Ccol_ENTRY(l), slimbs);
}
mpn_add_n(val1, Crow, C_ENTRY(0, 0), slimbs);
flint_mpn_signed_div2(val0, val1, slimbs);
mpn_sub_n(C_ENTRY(0, 0), C_ENTRY(0, 0), val0, slimbs);
for (k = 1; k < p; k++)
{
mpn_add_n(crow, Crow_ENTRY(k), C_ENTRY(0, k), slimbs);
flint_mpn_signed_div2(val1, crow, slimbs);
mpn_sub_n(C_ENTRY(0, k), C_ENTRY(0, k), val1, slimbs);
mpn_sub_n(crow, val1, val0, slimbs);
for (l = 1; l < m; l++)
{
mpn_sub_n(val2, C_ENTRY(l, k), crow, slimbs);
mpn_sub_n(C_ENTRY(l, k), val2, Ccol_ENTRY(l), slimbs);
}
}
if ((n & 1) == 1)
{
for (l = 0; l < m; l++)
{
for (k = 0; k < p; k++)
{
flint_mpn_mul_n(val0, A_ENTRY(l, n-1), B_ENTRY(n-1, k), nlimbs);
C_ENTRY(l, k)[2 * nlimbs] += mpn_add_n(C_ENTRY(l, k), C_ENTRY(l, k), val0, 2 * nlimbs);
}
}
}
for (i = 0; i < m; i++)
{
for (k = 0; k < p; k++)
{
slong d;
FLINT_ASSERT((slong) C_ENTRY(i, k)[slimbs - 1] >= 0);
d = slimbs;
MPN_NORM(C_ENTRY(i, k), d);
mpn_mod_set_mpn(COUT_ENTRY(i, k), C_ENTRY(i, k), d, ctx);
}
}
flint_free(Ctmp);
return GR_SUCCESS;
}