flint-sys 0.9.0

Bindings to the FLINT C library
Documentation
/*
    Copyright (C) 2012 William Hart
    Copyright (C) 2025 Fredrik Johansson

    This file is part of FLINT.

    FLINT is free software: you can redistribute it and/or modify it under
    the terms of the GNU Lesser General Public License (LGPL) as published
    by the Free Software Foundation; either version 3 of the License, or
    (at your option) any later version.  See <https://www.gnu.org/licenses/>.
*/

#include "mpn_extras.h"

/*
    Hack: flint_mpn_mod_preinvn is currently slow for divisions of
    size (n+2,n) or (n+1,n), so we use the following code adapted from
    flint_mpn_mod_preinv1. It would be better to improve flint_mpn_mod_preinvn,
    but this is not straightforward; flint_mpn_mod_preinvn is only
    allowed to overwrite n limbs while the following overwrites m limbs
    (which is fine for the use local to this file).

    Note: flint_mpn_divrem21_preinv is documented as requiring the
    precomputed inverse generated by flint_mpn_preinv1, but it turns out
    that it works to take the top limb of an inverse computed by
    flint_mpn_preinvn.
*/
static void flint_mpn_mod_preinv1(mp_ptr a, mp_size_t m,
                                  mp_srcptr b, mp_size_t n, mp_limb_t dinv)
{
    mp_size_t i;
    mp_limb_t q;

    if (mpn_cmp(a + m - n, b, n) >= 0)
        mpn_sub_n(a + m - n, a + m - n, b, n);

    for (i = m - 1; i >= n; i--)
    {
        flint_mpn_divrem21_preinv(q, a[i], a[i - 1], dinv);
        a[i] -= mpn_submul_1(a + i - n, b, n, q);

        if (mpn_cmp(a + i - n, b, n) >= 0 || a[i] != 0)
            a[i] -= mpn_sub_n(a + i - n, a + i - n, b, n);
    }
}

mp_size_t flint_mpn_mulmod_precond_matrix_alloc(mp_size_t n)
{
    /* We only need n^2 limbs for the result, but allocate one extra limb
       which flint_mpn_mulmod_precond_precompute can use as scratch space
       to save a copy. */
    return n * n + 1;
}

void
flint_mpn_mulmod_precond_matrix_precompute(mp_ptr apre, mp_srcptr a, mp_size_t n, mp_srcptr d, mp_srcptr dinv, ulong norm)
{
    slong i;

    FLINT_ASSERT(n >= 2);

    if (norm == 0)
        flint_mpn_copyi(apre, a, n);
    else
        mpn_lshift(apre, a, n, norm);

    for (i = 1; i < n; i++)
    {
        apre[i * n] = 0;
        flint_mpn_copyi(apre + i * n + 1, apre + (i - 1) * n, n);
#if 0
        flint_mpn_mod_preinvn(apre + i * n, apre + i * n, n + 1, d, n, dinv);
#else
        flint_mpn_mod_preinv1(apre + i * n, n + 1, d, n, dinv[n - 1]);
#endif
    }
}

/* p-mulmod_precond */
int
flint_mpn_mulmod_want_precond(mp_size_t n, slong num, ulong norm)
{
    if (num < 4 || (n == 2 && norm == 0))
        return MPN_MULMOD_PRECOND_NONE;
    if (n <= 10 || (n <= 12 && num <= 12))
        return MPN_MULMOD_PRECOND_SHOUP;
    if (n <= 64 || (n <= 128 && num >= 6) || (n <= 192 && num >= 20))
        return MPN_MULMOD_PRECOND_MATRIX;
    if ((n <= 320 && num >= 9) || (n <= 768 && num >= 20))
        return MPN_MULMOD_PRECOND_SHOUP;
    return MPN_MULMOD_PRECOND_NONE;
}

void
flint_mpn_mulmod_precond_matrix(mp_ptr rp, mp_srcptr apre, mp_srcptr b, mp_size_t n, mp_srcptr d, mp_srcptr dinv, ulong norm)
{
    /*
    Note: it is possible to add a special case for n = 2.
    For example, we can do something like

    FLINT_MPN_MUL_2X1(t[2], t[1], t[0], apre[1], apre[0], b[0]);
    FLINT_MPN_MUL_2X1(u[2], u[1], u[0], apre[3], apre[2], b[1]);
    add_ssssaaaaaaaa(t[3], t[2], t[1], t[0], 0, t[2], t[1], t[0], 0, u[2], u[1], u[0]);

    and then reduce mod d using the same operation sequence as in
    flint_mpn_mulmod_preinvn_2.

    We omit this special case as the resulting code does not run
    appreciable faster than flint_mpn_mulmod_preinvn_2.
    */

    if (n == 2)
    {
        mp_limb_t cy, r0, r1;
        mp_limb_t t[10];
        mp_limb_t u[3];

        /* mpn_mul_n(t, a, b, n) */
        FLINT_MPN_MUL_2X1(t[2], t[1], t[0], apre[1], apre[0], b[0]);
        FLINT_MPN_MUL_2X1(u[2], u[1], u[0], apre[3], apre[2], b[1]);
        add_ssssaaaaaaaa(t[3], t[2], t[1], t[0], 0, t[2], t[1], t[0], 0, u[2], u[1], u[0]);

        /* mpn_mul_n(t + 3*n, t + n, dinv, n) */
        FLINT_MPN_MUL_2X2(t[9], t[8], t[7], t[6], t[3], t[2], dinv[1], dinv[0]);

        /* mpn_add_n(t + 4*n, t + 4*n, t + n, n) */
        add_ssaaaa(t[9], t[8], t[9], t[8], t[3], t[2]);

        /* mpn_mul_n(t + 2*n, t + 4*n, d, n) */
        FLINT_MPN_MUL_3P2X2(t[6], t[5], t[4], t[9], t[8], d[1], d[0]);

        /* cy = t[n] - t[3*n] - mpn_sub_n(r, t, t + 2*n, n) */
        sub_dddmmmsss(cy, r1, r0, t[2], t[1], t[0], t[6], t[5], t[4]);

        while (cy > 0)
        {
            /* cy -= mpn_sub_n(r, r, d, n) */
            sub_dddmmmsss(cy, r1, r0, cy, r1, r0, 0, d[1], d[0]);
        }

        if ((r1 > d[1]) || (r1 == d[1] && r0 >= d[0]))
        {
            /* mpn_sub_n(r, r, d, n) */
            sub_ddmmss(r1, r0, r1, r0, d[1], d[0]);
        }

        if (norm)
        {
            rp[0] = (r0 >> norm) | (r1 << (FLINT_BITS - norm));
            rp[1] = (r1 >> norm);
        }
        else
        {
            rp[0] = r0;
            rp[1] = r1;
        }

        return;
    }

    mp_ptr tmp;
    mp_limb_t cy, cy1, cy2;
    slong i, rn;
    TMP_INIT;

    TMP_START;
    tmp = TMP_ALLOC((n + 2) * sizeof(mp_limb_t));

    cy1 = mpn_mul_1(tmp, apre, n, b[0]);
    cy2 = 0;
    for (i = 1; i < n; i++)
    {
        cy = mpn_addmul_1(tmp, apre + i * n, n, b[i]);
        add_ssaaaa(cy2, cy1, cy2, cy1, 0, cy);
    }

    tmp[n] = cy1;
    tmp[n + 1] = cy2;
    rn = (n + 2) - (tmp[n + 1] == 0);

#if 0
    flint_mpn_mod_preinvn(tmp, tmp, rn, d, n, dinv);
#else
    flint_mpn_mod_preinv1(tmp, rn, d, n, dinv[n - 1]);
#endif

    if (norm == 0)
        flint_mpn_copyi(rp, tmp, n);
    else
        mpn_rshift(rp, tmp, n, norm);

    TMP_END;
}

void
flint_mpn_fmmamod_precond_matrix(mp_ptr rp, mp_srcptr apre1, mp_srcptr b1, mp_srcptr apre2, mp_srcptr b2, mp_size_t n, mp_srcptr d, mp_srcptr dinv, ulong norm)
{
    mp_ptr tmp;
    mp_limb_t cy, cy1, cy2;
    slong i, rn;
    TMP_INIT;

    /* Something like this if we want a special case for n = 2 */
    /*
    if (n == 2)
    {
        ulong tmp[4];
        ulong ump[4];

        FLINT_MPN_MUL_2X1(tmp[2], tmp[1], tmp[0], apre1[1], apre1[0], b1[0]);
        FLINT_MPN_MUL_2X1(ump[2], ump[1], ump[0], apre1[3], apre1[2], b1[1]);
        add_ssssaaaaaaaa(tmp[3], tmp[2], tmp[1], tmp[0], tmp[3], tmp[2], tmp[1], tmp[0], 0, ump[2], ump[1], ump[0]);
        FLINT_MPN_MUL_2X1(ump[2], ump[1], ump[0], apre2[3], apre2[2], b2[0]);
        add_ssssaaaaaaaa(tmp[3], tmp[2], tmp[1], tmp[0], tmp[3], tmp[2], tmp[1], tmp[0], 0, ump[2], ump[1], ump[0]);
        FLINT_MPN_MUL_2X1(ump[2], ump[1], ump[0], apre2[3], apre2[2], b2[1]);
        add_ssssaaaaaaaa(tmp[3], tmp[2], tmp[1], tmp[0], tmp[3], tmp[2], tmp[1], tmp[0], 0, ump[2], ump[1], ump[0]);

        rn = (n + 2) - (tmp[n + 1] == 0);
        flint_mpn_mod_preinv1(tmp, rn, d, n, dinv[n - 1]);

        if (norm)
        {
            rp[0] = (tmp[0] >> norm) | (tmp[1] << (FLINT_BITS - norm));
            rp[1] = (tmp[1] >> norm);
        }
        else
        {
            rp[0] = tmp[0];
            rp[1] = tmp[1];
        }

        return;
    }
    */

    TMP_START;
    tmp = TMP_ALLOC((n + 2) * sizeof(mp_limb_t));

    cy1 = mpn_mul_1(tmp, apre1, n, b1[0]);
    cy2 = 0;
    for (i = 1; i < n; i++)
    {
        cy = mpn_addmul_1(tmp, apre1 + i * n, n, b1[i]);
        add_ssaaaa(cy2, cy1, cy2, cy1, 0, cy);
    }
    for (i = 0; i < n; i++)
    {
        cy = mpn_addmul_1(tmp, apre2 + i * n, n, b2[i]);
        add_ssaaaa(cy2, cy1, cy2, cy1, 0, cy);
    }

    tmp[n] = cy1;
    tmp[n + 1] = cy2;
    rn = (n + 2) - (tmp[n + 1] == 0);

#if 0
    flint_mpn_mod_preinvn(tmp, tmp, rn, d, n, dinv);
#else
    flint_mpn_mod_preinv1(tmp, rn, d, n, dinv[n - 1]);
#endif

    if (norm == 0)
        flint_mpn_copyi(rp, tmp, n);
    else
        mpn_rshift(rp, tmp, n, norm);

    TMP_END;
}