#include "gmpcompat.h"
#include "mpn_extras.h"
#include "ulong_extras.h"
#include "fmpz.h"
#include "fmpz_mat.h"
#include "fft.h"
static ulong fft_combine_bits_signed(
ulong * z,
ulong ** a, slong alen,
flint_bitcnt_t bits,
slong limbs,
slong zn)
{
slong i, zout;
ulong * t;
ulong f;
TMP_INIT;
FLINT_ASSERT(bits > 1);
TMP_START;
t = TMP_ARRAY_ALLOC((limbs + 1), ulong);
f = 0;
zout = 0;
for (i = 0; i < alen; i++)
{
slong q = (bits*i)/FLINT_BITS;
slong r = (bits*i)%FLINT_BITS;
ulong s;
ulong halflimb = UWORD(1) << (FLINT_BITS - 1);
if (a[i][limbs] | (a[i][limbs - 1] > halflimb))
{
mpn_sub_1(t, a[i], limbs, UWORD(1));
s = 1;
}
else
{
mpn_copyi(t, a[i], limbs);
s = 0;
}
t[limbs] = -s;
if (r != 0)
mpn_lshift(t, t, limbs + 1, r);
if (q < zn)
{
slong new_zout = FLINT_MIN(zn, q + limbs + 1);
FLINT_ASSERT(new_zout >= zout);
while (zout < new_zout)
z[zout++] = -f;
FLINT_ASSERT(new_zout > q);
f ^= s;
f ^= mpn_add_n(z + q, z + q, t, new_zout - q);
}
else
{
if (q == zn)
f ^= t[0]&1;
break;
}
}
while (zout < zn)
z[zout++] = -f;
TMP_END;
FLINT_ASSERT(f == 0 || f == 1);
return f;
}
static slong fft_split_bits_fmpz(
ulong ** poly,
const fmpz_t x,
flint_bitcnt_t bits,
slong limbs)
{
slong len;
int x_is_neg = 0;
if (COEFF_IS_MPZ(*x))
{
slong s = COEFF_TO_PTR(*x)->_mp_size;
x_is_neg = s < 0;
len = fft_split_bits(poly, COEFF_TO_PTR(*x)->_mp_d,
x_is_neg ? -s : s, bits, limbs);
}
else if (!fmpz_is_zero(x))
{
ulong ux;
x_is_neg = *x < 0;
ux = x_is_neg ? -*x : *x;
len = fft_split_bits(poly, &ux, 1, bits, limbs);
}
else
{
len = 0;
}
if (x_is_neg)
{
slong i;
for (i = 0; i < len; i++)
mpn_negmod_2expp1(poly[i], poly[i], limbs);
}
return len;
}
static void fft_combine_bits_fmpz(
fmpz_t x,
ulong ** poly, slong length,
flint_bitcnt_t bits,
slong limbs,
slong total_limbs,
int sign)
{
mpz_ptr mx = _fmpz_promote(x);
ulong * d = FLINT_MPZ_REALLOC(mx, total_limbs);
if (sign)
{
if (fft_combine_bits_signed(d, poly, length, bits, limbs, total_limbs))
{
mpn_neg(d, d, total_limbs);
MPN_NORM(d, total_limbs);
FLINT_ASSERT(total_limbs > 0);
mx->_mp_size = -total_limbs;
}
else
{
MPN_NORM(d, total_limbs);
mx->_mp_size = total_limbs;
}
}
else
{
flint_mpn_zero(d, total_limbs);
fft_combine_bits(d, poly, length, bits, limbs, total_limbs);
MPN_NORM(d, total_limbs);
mx->_mp_size = total_limbs;
}
_fmpz_demote_val(x);
}
static void _either_fft_or_mfa(
ulong ** coeffs,
slong n, flint_bitcnt_t w,
ulong ** t1, ulong ** t2, ulong ** t3,
slong n1,
flint_bitcnt_t depth,
slong trunc,
slong limbs,
int use_mfa)
{
ulong trunc2, rs, s;
slong l;
if (use_mfa)
{
fft_mfa_truncate_sqrt2(coeffs, n, w, t1, t2, t3, n1, trunc);
for (l = 0; l < 2*n; l++)
mpn_normmod_2expp1(coeffs[l], limbs);
trunc2 = (trunc - 2*n)/n1;
for (s = 0; s < trunc2; s++)
{
slong u;
rs = n_revbin(s, depth - depth/2 + 1);
for (u = 0; u < n1; u++)
{
l = 2*n + rs*n1 + u;
mpn_normmod_2expp1(coeffs[l], limbs);
}
}
}
else
{
fft_truncate_sqrt2(coeffs, n, w, t1, t2, t3, trunc);
for (l = 0; l < trunc; l++)
mpn_normmod_2expp1(coeffs[l], limbs);
}
}
static void _dot(
ulong* c,
ulong** A, slong Astride,
ulong** B, slong Bstride,
slong len,
slong limbs,
ulong* t,
ulong* t2)
{
slong i;
flint_bitcnt_t nw = limbs*FLINT_BITS;
FLINT_ASSERT(len > 0);
i = 0;
do {
const ulong* a = A[i*Astride];
const ulong* b = B[i*Bstride];
if (i == 0)
{
c[limbs] = flint_mpn_mulmod_2expp1_basecase(c, a, b,
2*a[limbs] + b[limbs], nw, t2);
}
else
{
t[limbs] = flint_mpn_mulmod_2expp1_basecase(t, a, b,
2*a[limbs] + b[limbs], nw, t2);
c[limbs] += t[limbs];
c[limbs] += mpn_add_n(c, c, t, limbs);
mpn_normmod_2expp1(c, limbs);
}
} while (++i < len);
}
static void _fmpz_mat_mul_truncate_sqrt2(
fmpz_mat_t C,
const fmpz_mat_t A, slong Abits,
const fmpz_mat_t B, slong Bbits,
flint_bitcnt_t depth,
flint_bitcnt_t w,
slong j1, slong j2,
const int use_mfa,
const int sign)
{
slong M = fmpz_mat_nrows(A);
slong K = fmpz_mat_ncols(A);
slong N = fmpz_mat_ncols(B);
slong clgK = FLINT_CLOG2(K) + sign;
slong n = WORD(1) << depth;
ulong trunc, sqrt;
ulong bits1 = (n*w - (depth + 1 + clgK))/2;
slong Climbs = (Abits + Bbits + clgK + FLINT_BITS - 1)/FLINT_BITS;
slong limbs = (n*w)/FLINT_BITS;
slong size = limbs + 1;
slong i, j, l, h;
ulong * temp, *t, * t1, * t2, * t3, * Adata, * Bdata, * Cdata;
ulong ** coeffs, ** Acoeffs, ** Bcoeffs, ** Ccoeffs;
FLINT_ASSERT(limbs > 0);
FLINT_ASSERT(limbs*FLINT_BITS == n*w);
FLINT_ASSERT(j1 <= 2*n || j2 <= 2*n);
temp = FLINT_ARRAY_ALLOC((6 + 4*n*(M*K + K*N + 1))*size, ulong);
t = temp + 2*size;
t1 = t + size;
t2 = t1 + size;
t3 = t2 + size;
Adata = t3 + size;
Bdata = Adata + 4*n*M*K*size;
Cdata = Bdata + 4*n*K*N*size;
coeffs = FLINT_ARRAY_ALLOC(4*n*(M*K + K*N + 1), ulong*);
Acoeffs = coeffs;
Bcoeffs = Acoeffs + 4*n*M*K;
Ccoeffs = Bcoeffs + 4*n*K*N;
for (i = 0; i < M; i++)
for (j = 0; j < K; j++)
for (l = 0; l < 4*n; l++)
Acoeffs[(i*K + j)*4*n + l] = Adata + ((i*K + j)*4*n + l)*size;
for (i = 0; i < K; i++)
for (j = 0; j < N; j++)
for (l = 0; l < 4*n; l++)
Bcoeffs[(i*N + j)*4*n + l] = Bdata + ((i*N + j)*4*n + l)*size;
for (l = 0; l < 4*n; l++)
Ccoeffs[l] = Cdata + l*size;
for (i = 0; i < M; i++)
for (j = 0; j < K; j++)
{
h = fft_split_bits_fmpz(Acoeffs + (i*K + j)*4*n,
fmpz_mat_entry(A, i, j), bits1, limbs);
for (l = h; l < 4*n; l++)
flint_mpn_zero(Acoeffs[(i*K + j)*4*n + l], size);
}
for (i = 0; i < K; i++)
for (j = 0; j < N; j++)
{
h = fft_split_bits_fmpz(Bcoeffs + (i*N + j)*4*n,
fmpz_mat_entry(B, i, j), bits1, limbs);
for (l = h; l < 4*n; l++)
flint_mpn_zero(Bcoeffs[(i*N + j)*4*n + l], size);
}
FLINT_ASSERT(j1 > 0);
FLINT_ASSERT(j2 > 0);
FLINT_ASSERT(j1 + j2 - 1 <= 4*n);
trunc = j1 + j2 - 1;
trunc = FLINT_MAX(trunc, (ulong) 2*n + 1);
if (use_mfa)
{
sqrt = UWORD(1) << (depth/2);
trunc = (trunc + 2*sqrt - 1) & (-2*sqrt);
}
else
{
sqrt = 1;
trunc = (trunc + 1) & -UWORD(2);
}
FLINT_ASSERT(trunc > 2*n);
FLINT_ASSERT(trunc % (2*sqrt) == 0);
for (i = 0; i < M; i++)
for (j = 0; j < K; j++)
{
_either_fft_or_mfa(Acoeffs + (i*K + j)*4*n, n, w,
&t1, &t2, &t3, sqrt, depth, trunc, limbs, use_mfa);
}
for (i = 0; i < K; i++)
for (j = 0; j < N; j++)
{
_either_fft_or_mfa(Bcoeffs + (i*N + j)*4*n, n, w,
&t1, &t2, &t3, sqrt, depth, trunc, limbs, use_mfa);
}
for (i = 0; i < M; i++)
for (j = 0; j < N; j++)
{
ulong ux;
if (use_mfa)
{
ulong trunc2, rs, s, u;
for (l = 0; l < 2 * n; l++)
{
_dot(Ccoeffs[l], Acoeffs + (i*K + 0)*4*n + l, 4*n,
Bcoeffs + (0*N + j)*4*n + l, N*4*n,
K, limbs, t, temp);
}
trunc2 = (trunc - 2*n)/sqrt;
for (s = 0; s < trunc2; s++)
{
rs = n_revbin(s, depth - depth/2 + 1);
for (u = 0; u < sqrt; u++)
{
l = 2*n + rs*sqrt + u;
_dot(Ccoeffs[l], Acoeffs + (i*K + 0)*4*n + l, 4*n,
Bcoeffs + (0*N + j)*4*n + l, N*4*n,
K, limbs, t, temp);
}
}
ifft_mfa_truncate_sqrt2(Ccoeffs, n, w, &t1, &t2, &t3, sqrt, trunc);
}
else
{
for (ux = 0; ux < trunc; ux++)
{
_dot(Ccoeffs[ux], Acoeffs + (i*K + 0)*4*n + ux, 4*n,
Bcoeffs + (0*N + j)*4*n + ux, N*4*n,
K, limbs, t, temp);
}
ifft_truncate_sqrt2(Ccoeffs, n, w, &t1, &t2, &t3, trunc);
}
for (ux = 0; ux < trunc; ux++)
{
mpn_div_2expmod_2expp1(Ccoeffs[ux], Ccoeffs[ux], limbs, depth + 2);
mpn_normmod_2expp1(Ccoeffs[ux], limbs);
}
fft_combine_bits_fmpz(fmpz_mat_entry(C, i, j), Ccoeffs,
j1 + j2 - 1, bits1, limbs, Climbs, sign);
}
flint_free(temp);
flint_free(coeffs);
}
void _fmpz_mat_mul_fft(
fmpz_mat_t C,
const fmpz_mat_t A, slong abits,
const fmpz_mat_t B, slong bbits,
int sign)
{
slong K = fmpz_mat_ncols(A);
slong clgK = FLINT_CLOG2(K) + sign;
slong depth = 6;
slong w = 1;
slong n = WORD(1) << depth;
ulong bits = (n*w - (depth + 1 + clgK))/2;
ulong bits1 = FLINT_MAX(abits, WORD(2000));
ulong bits2 = FLINT_MAX(bbits, WORD(2000));
slong j1 = (bits1 + bits - 1)/bits;
slong j2 = (bits2 + bits - 1)/bits;
int use_mfa;
FLINT_ASSERT(sign == 0 || sign == 1);
FLINT_ASSERT(abits > 0);
FLINT_ASSERT(bbits > 0);
FLINT_ASSERT(j1 + j2 - 1 > 2*n);
while (j1 + j2 - 1 > 4*n)
{
if (w == 1)
{
w = 2;
}
else
{
depth++;
w = 1;
n *= 2;
}
bits = (n*w - (depth + 1 + clgK))/2;
j1 = (bits1 + bits - 1)/bits;
j2 = (bits2 + bits - 1)/bits;
}
if (depth < 11)
{
slong wadj = 1;
if (depth < 6)
wadj = WORD(1) << (6 - depth);
if (w > wadj)
{
do {
w -= wadj;
bits = (n*w - (depth + 1 + clgK))/2;
j1 = (bits1 + bits - 1)/bits;
j2 = (bits2 + bits - 1)/bits;
} while (j1 + j2 - 1 <= 4*n && w > wadj);
w += wadj;
}
use_mfa = 0;
}
else
{
use_mfa = 1;
}
bits = (n*w - (depth + 1 + clgK))/2;
j1 = (bits1 + bits - 1)/bits;
j2 = (bits2 + bits - 1)/bits;
_fmpz_mat_mul_truncate_sqrt2(C, A, abits, B, bbits, depth, w, j1, j2,
use_mfa, sign);
}
void fmpz_mat_mul_fft(fmpz_mat_t C, const fmpz_mat_t A, const fmpz_mat_t B)
{
slong ar, br, bc;
slong abits, bbits;
int sign;
ar = fmpz_mat_nrows(A);
br = fmpz_mat_nrows(B);
bc = fmpz_mat_ncols(B);
if (ar == 0 || br == 0 || bc == 0)
{
fmpz_mat_zero(C);
return;
}
abits = fmpz_mat_max_bits(A);
bbits = fmpz_mat_max_bits(B);
sign = 0;
if (abits < 0)
{
sign = 1;
abits = -abits;
}
if (bbits < 0)
{
sign = 1;
bbits = -bbits;
}
if (abits == 0 || bbits == 0)
{
fmpz_mat_zero(C);
return;
}
_fmpz_mat_mul_fft(C, A, abits, B, bbits, sign);
}