#include "fmpz.h"
#include "fmpz_vec.h"
#include "fmpz_mat.h"
#include "fq_nmod.h"
#include "n_poly.h"
#include "mpoly.h"
#include "fq_nmod_mpoly.h"
void _fq_nmod_mpoly_compose_mat(
fq_nmod_mpoly_t A,
const fq_nmod_mpoly_t B,
const fmpz_mat_t M,
const fq_nmod_mpoly_ctx_t ctxB,
const fq_nmod_mpoly_ctx_t ctxAC)
{
slong d = fq_nmod_ctx_degree(ctxAC->fqctx);
slong i;
fmpz * u, * v;
flint_bitcnt_t vbits;
slong Blen = B->length;
flint_bitcnt_t Bbits = B->bits;
slong BN = mpoly_words_per_exp(Bbits, ctxB->minfo);
const ulong * Bexp = B->exps;
const ulong * Bcoeffs = B->coeffs;
slong AN;
FLINT_ASSERT(A != B);
FLINT_ASSERT(fmpz_mat_nrows(M) == ctxAC->minfo->nfields + 1);
FLINT_ASSERT(fmpz_mat_ncols(M) == ctxB->minfo->nfields);
u = _fmpz_vec_init(ctxB->minfo->nfields);
v = _fmpz_vec_init(ctxAC->minfo->nfields + 1);
fq_nmod_mpoly_fit_length_reset_bits(A, Blen, MPOLY_MIN_BITS, ctxAC);
A->length = 0;
for (i = 0; i < Blen; i++)
{
mpoly_unpack_vec_fmpz(u, Bexp + BN*i, Bbits, ctxB->minfo->nfields, 1);
fmpz_mat_mul_fmpz_vec(v, M, u, fmpz_mat_ncols(M));
if (!fmpz_is_zero(v + ctxAC->minfo->nfields))
continue;
vbits = _fmpz_vec_max_bits(v, ctxAC->minfo->nfields);
FLINT_ASSERT(vbits >= 0);
vbits = mpoly_fix_bits(vbits + 1, ctxAC->minfo);
fq_nmod_mpoly_fit_length_fit_bits(A, A->length + 1, vbits, ctxAC);
_n_fq_set(A->coeffs + d*A->length, Bcoeffs + d*i, d);
AN = mpoly_words_per_exp(A->bits, ctxAC->minfo);
mpoly_pack_vec_fmpz(A->exps + AN*A->length, v, A->bits, ctxAC->minfo->nfields, 1);
A->length++;
}
_fmpz_vec_clear(u, ctxB->minfo->nfields);
_fmpz_vec_clear(v, ctxAC->minfo->nfields + 1);
fq_nmod_mpoly_sort_terms(A, ctxAC);
fq_nmod_mpoly_combine_like_terms(A, ctxAC);
return;
}