flint-sys 0.9.0

Bindings to the FLINT C library
Documentation
/*
    Copyright (C) 2019 Daniel Schultz

    This file is part of FLINT.

    FLINT is free software: you can redistribute it and/or modify it under
    the terms of the GNU Lesser General Public License (LGPL) as published
    by the Free Software Foundation; either version 3 of the License, or
    (at your option) any later version.  See <https://www.gnu.org/licenses/>.
*/

#include "nmod.h"
#include "fmpz.h"
#include "fmpz_mod.h"

void _fmpz_mod_mul1(fmpz_t a, const fmpz_t b, const fmpz_t c,
                                                     const fmpz_mod_ctx_t ctx)
{
    ulong a0, b0, c0;

    FLINT_ASSERT(fmpz_mod_is_canonical(b, ctx));
    FLINT_ASSERT(fmpz_mod_is_canonical(c, ctx));

    b0 = fmpz_get_ui(b);
    c0 = fmpz_get_ui(c);
    a0 = nmod_mul(b0, c0, ctx->mod);
    fmpz_set_ui(a, a0);

    FLINT_ASSERT(fmpz_mod_is_canonical(a, ctx));
}

/*
    Multiplication modulo 2^FLINT_BITS is easy.
*/
void _fmpz_mod_mul2s(fmpz_t a, const fmpz_t b, const fmpz_t c,
                                                     const fmpz_mod_ctx_t FLINT_UNUSED(ctx))
{
    ulong a0, b0, c0;

    FLINT_ASSERT(fmpz_mod_is_canonical(b, ctx));
    FLINT_ASSERT(fmpz_mod_is_canonical(c, ctx));

    b0 = fmpz_get_ui(b);
    c0 = fmpz_get_ui(c);
    a0 = b0*c0;
    fmpz_set_ui(a, a0);

    FLINT_ASSERT(fmpz_mod_is_canonical(a, ctx));
}

/*
    Standard Barrett reduction: (set r = FLINT_BITS)

    We have n fits into 2 words and 2^r < n < 2^(2r). Therefore
    2^(3r) > 2^(4r) / n > 2^(2r) and the precomputed number
    ninv = floor(2^(4r) / n) fits into 3 words.
    The inputs b and c are < n and therefore fit into 2 words.

    The computation of a = b*c mod n is:

        x = b*c             x < n^2 and therefore fits into 4 words
        z = (x >> r)*ninv   z <= n*2^(3*r) and therefore fits into 5 words
        q = (z >> (3r))*n   q fits into 4 words
        x = x - q           x fits into 3 words after the subtraction
        at this point the canonical reduction in the range [0, n) is one of
            a = x, a = x - n, or a = x - 2n
*/
void _fmpz_mod_mul2(fmpz_t a, const fmpz_t b, const fmpz_t c,
                                                     const fmpz_mod_ctx_t ctx)
{
    ulong a1, a0, b1, b0, c1, c0;
    ulong x3, x2, x1, x0;
    ulong q2, q1, q0;
    ulong z4, z3, z2, z1, z0;
    ulong t4, t3, t2, t1;
    ulong s3, s2, s1;
    ulong u4, u3, u2, u1;
    ulong v4, v3, v2, v1;

    FLINT_ASSERT(fmpz_mod_is_canonical(b, ctx));
    FLINT_ASSERT(fmpz_mod_is_canonical(c, ctx));

    fmpz_get_uiui(&b1, &b0, b);
    fmpz_get_uiui(&c1, &c0, c);

    /* x[3:0] = b[1:0]*c[1:0] */
    umul_ppmm(t2, t1, b0, c1);
    umul_ppmm(s2, s1, b1, c0);
    umul_ppmm(x3, x2, b1, c1);
    umul_ppmm(x1, x0, b0, c0);
    t3 = 0;
    add_sssaaaaaa(t3, t2, t1, t3, t2, t1, 0, s2, s1);
    add_sssaaaaaa(x3, x2, x1, x3, x2, x1, t3, t2, t1);

    /* z[5:0] = x[3:1] * ninv[2:0], z[5] should end up zero */
    umul_ppmm(z1, z0, x1, ctx->ninv_limbs[0]);
    umul_ppmm(z3, z2, x2, ctx->ninv_limbs[1]);
    z4 = x3 * ctx->ninv_limbs[2];
    umul_ppmm(t3, t2, x3, ctx->ninv_limbs[0]);
    umul_ppmm(s3, s2, x1, ctx->ninv_limbs[2]);
    t4 = 0;
    add_sssaaaaaa(t4, t3, t2, t4, t3, t2, 0, s3, s2);
    umul_ppmm(u2, u1, x2, ctx->ninv_limbs[0]);
    umul_ppmm(u4, u3, x3, ctx->ninv_limbs[1]);
    add_sssaaaaaa(z4, z3, z2, z4, z3, z2, t4, t3, t2);
    umul_ppmm(v2, v1, x1, ctx->ninv_limbs[1]);
    umul_ppmm(v4, v3, x2, ctx->ninv_limbs[2]);
    add_ssssaaaaaaaa(z4, z3, z2, z1, z4, z3, z2, z1, u4, u3, u2, u1);
    add_ssssaaaaaaaa(z4, z3, z2, z1, z4, z3, z2, z1, v4, v3, v2, v1);

    /* q[3:0] = z[4:3] * n[1:0], q[3] is not needed */
    /* x[3:0] -= q[3:0], x[3] should end up zero */
    umul_ppmm(t2, t1, z3, ctx->n_limbs[1]);
    umul_ppmm(s2, s1, z4, ctx->n_limbs[0]);
    umul_ppmm(q1, q0, z3, ctx->n_limbs[0]);
    sub_ddmmss(x2, x1, x2, x1, t2, t1);
    q2 = z4 * ctx->n_limbs[1];
    sub_ddmmss(x2, x1, x2, x1, s2, s1);
    sub_dddmmmsss(x2, x1, x0, x2, x1, x0, q2, q1, q0);

    /* at most two subtractions of n, use q as temp space */
    sub_dddmmmsss(q2, q1, q0, x2, x1, x0, 0, ctx->n_limbs[1], ctx->n_limbs[0]);
    if ((slong)(q2) >= 0)
    {
        sub_dddmmmsss(x2, x1, x0, q2, q1, q0, 0, ctx->n_limbs[1], ctx->n_limbs[0]);
        if ((slong)(x2) >= 0)
        {
            a1 = x1;
            a0 = x0;
        }
        else
        {
            a1 = q1;
            a0 = q0;
        }
    }
    else
    {
        a1 = x1;
        a0 = x0;
    }

    fmpz_set_uiui(a, a1, a0);

    FLINT_ASSERT(fmpz_mod_is_canonical(a, ctx));

}

void _fmpz_mod_mulN(fmpz_t a, const fmpz_t b, const fmpz_t c,
                                                     const fmpz_mod_ctx_t ctx)
{
    FLINT_ASSERT(fmpz_mod_is_canonical(b, ctx));
    FLINT_ASSERT(fmpz_mod_is_canonical(c, ctx));

    fmpz_mul(a, b, c);

    if (ctx->ninv_huge == NULL)
        fmpz_mod(a, a, ctx->n);
    else
        fmpz_fdiv_r_preinvn(a, a, ctx->n, ctx->ninv_huge);

    FLINT_ASSERT(fmpz_mod_is_canonical(a, ctx));
}

void fmpz_mod_mul_fmpz(fmpz_t a, const fmpz_t b, const fmpz_t c,
                                                      const fmpz_mod_ctx_t ctx)
{
    FLINT_ASSERT(fmpz_mod_is_canonical(b, ctx));

    fmpz_mul(a, b, c);

    if (ctx->ninv_huge == NULL)
        fmpz_mod(a, a, ctx->n);
    else
        fmpz_fdiv_r_preinvn(a, a, ctx->n, ctx->ninv_huge);

    FLINT_ASSERT(fmpz_mod_is_canonical(a, ctx));
}

void fmpz_mod_mul_ui(fmpz_t a, const fmpz_t b, ulong c,
                                                      const fmpz_mod_ctx_t ctx)
{
    FLINT_ASSERT(fmpz_mod_is_canonical(b, ctx));

    fmpz_mul_ui(a, b, c);
    fmpz_mod(a, a, ctx->n);

    FLINT_ASSERT(fmpz_mod_is_canonical(a, ctx));
}

void fmpz_mod_mul_si(fmpz_t a, const fmpz_t b, slong c,
                                                      const fmpz_mod_ctx_t ctx)
{
    FLINT_ASSERT(fmpz_mod_is_canonical(b, ctx));

    fmpz_mul_si(a, b, c);
    fmpz_mod(a, a, ctx->n);

    FLINT_ASSERT(fmpz_mod_is_canonical(a, ctx));
}