flint-sys 0.9.0

Bindings to the FLINT C library
Documentation
/*
    Copyright (C) 2024 Fredrik Johansson

    This file is part of FLINT.

    FLINT is free software: you can redistribute it and/or modify it under
    the terms of the GNU Lesser General Public License (LGPL) as published
    by the Free Software Foundation; either version 3 of the License, or
    (at your option) any later version.  See <https://www.gnu.org/licenses/>.
*/

#include "mpn_mod.h"

/* assumes res is initially zeroed */
/* assumes that we can write one (zeroed) limb too much */
/* assumes bits >= FLINT_BITS */
static void
_mpn_mod_poly_bit_pack(nn_ptr res, nn_srcptr x, slong len, flint_bitcnt_t bits, slong nlimbs)
{
    slong i, l, shift;

    for (i = 0; i < len; i++)
    {
        l = (bits * i) / FLINT_BITS;
        shift = (bits * i) % FLINT_BITS;

        if (shift == 0)
            flint_mpn_copyi(res + l, x + i * nlimbs, nlimbs);
        else
            res[l + nlimbs] = mpn_lshift(res + l, x + i * nlimbs, nlimbs, shift);
    }
}

static void
_mpn_mod_poly_bit_unpack(nn_ptr res, nn_srcptr x, slong nlo, slong nhi, flint_bitcnt_t bits, slong nlimbs, gr_ctx_t ctx)
{
    slong i, i1, i2, l1, shift, l2;
    ulong t[2 * MPN_MOD_MAX_LIMBS + 3];
    ulong mask;
    slong blimbs, tn;

    blimbs = (bits + FLINT_BITS - 1) / FLINT_BITS;

    if (bits % FLINT_BITS == 0)
        mask = ~UWORD(0);
    else
        mask = ((~UWORD(0)) >> (FLINT_BITS - (bits % FLINT_BITS)));

    for (i = nlo; i < nhi; i++)
    {
        /* read bits i1 <= j < i2 */
        i1 = bits * i;
        i2 = bits * (i + 1);

        /* read limbs l1 <= l < l2 */
        l1 = i1 / FLINT_BITS;
        l2 = (i2 + FLINT_BITS - 1) / FLINT_BITS;

        shift = i1 % FLINT_BITS;

        if (shift == 0)
            flint_mpn_copyi(t, x + l1, l2 - l1);
        else
            mpn_rshift(t, x + l1, l2 - l1, shift);

        /* mask off high bits */
        tn = blimbs;
        t[tn - 1] &= mask;
        MPN_NORM(t, tn);
        mpn_mod_set_mpn(res + (i - nlo) * nlimbs, t, tn, ctx);
    }
}

int
_mpn_mod_poly_mulmid_KS(nn_ptr res, nn_srcptr poly1, slong len1, nn_srcptr poly2, slong len2, slong nlo, slong nhi, gr_ctx_t ctx)
{
    slong bits, nbits, nlimbs, limbs1, limbs2;
    nn_ptr arr1, arr2, arr;
    int squaring;

    len1 = FLINT_MIN(len1, nhi);
    len2 = FLINT_MIN(len2, nhi);

    nlimbs = MPN_MOD_CTX_NLIMBS(ctx);
    nbits = MPN_MOD_CTX_MODULUS_BITS(ctx);

    if (nlo != 0)
    {
        slong nlo2 = (len1 + len2 - 1) - nlo;

        if (len1 > nlo2)
        {
            slong trunc = len1 - nlo2;
            poly1 += trunc * nlimbs;
            len1 -= trunc;
            nlo -= trunc;
            nhi -= trunc;
        }

        if (len2 > nlo2)
        {
            slong trunc = len2 - nlo2;
            poly2 += trunc * nlimbs;
            len2 -= trunc;
            nlo -= trunc;
            nhi -= trunc;
        }
    }

    squaring = (poly1 == poly2 && len1 == len2);

    bits = 2 * nbits + FLINT_BIT_COUNT(FLINT_MIN(len1, len2));

    limbs1 = (bits * len1 - 1) / FLINT_BITS + 1;
    limbs2 = (bits * len2 - 1) / FLINT_BITS + 1;

    FLINT_ASSERT(limbs1 >= (bits * (len1 - 1) / FLINT_BITS + nlimbs + 1));
    FLINT_ASSERT(limbs2 >= (bits * (len2 - 1) / FLINT_BITS + nlimbs + 1));

    arr1 = flint_calloc(squaring ? limbs1 : limbs1 + limbs2, sizeof(ulong));
    arr2 = squaring ? arr1 : arr1 + limbs1;

    _mpn_mod_poly_bit_pack(arr1, poly1, len1, bits, nlimbs);
    if (!squaring)
        _mpn_mod_poly_bit_pack(arr2, poly2, len2, bits, nlimbs);

    arr = flint_malloc((limbs1 + limbs2) * sizeof(ulong));

    if (squaring)
        flint_mpn_sqr(arr, arr1, limbs1);
    else if (limbs1 >= limbs2)
        flint_mpn_mul(arr, arr1, limbs1, arr2, limbs2);
    else
        flint_mpn_mul(arr, arr2, limbs2, arr1, limbs1);

    _mpn_mod_poly_bit_unpack(res, arr, nlo, nhi, bits, nlimbs, ctx);

    flint_free(arr1);
    flint_free(arr);

    return GR_SUCCESS;
}

int
_mpn_mod_poly_mullow_KS(nn_ptr res, nn_srcptr poly1, slong len1, nn_srcptr poly2, slong len2, slong len, gr_ctx_t ctx)
{
    return _mpn_mod_poly_mulmid_KS(res, poly1, len1, poly2, len2, 0, len, ctx);
}