flint-sys 0.9.0

Bindings to the FLINT C library
Documentation
/*
    Copyright (C) 2014 Fredrik Johansson
    Copyright (C) 2020 William Hart

    This file is part of FLINT.

    FLINT is free software: you can redistribute it and/or modify it under
    the terms of the GNU Lesser General Public License (LGPL) as published
    by the Free Software Foundation; either version 3 of the License, or
    (at your option) any later version.  See <https://www.gnu.org/licenses/>.
*/

#include "thread_pool.h"
#include "thread_support.h"
#include "nmod.h"
#include "nmod_poly.h"
#include "fmpz.h"
#include "fmpz_vec.h"
#include "fmpz_poly.h"

typedef struct
{
    fmpz * vec;
    nn_ptr * residues;
    slong n0;
    slong n1;
    nn_srcptr primes;
    slong num_primes;
    int crt;  /* reduce if 0, lift if 1 */
}
mod_ui_arg_t;

static void
_fmpz_vec_multi_mod_ui_worker(void * arg_ptr)
{
    mod_ui_arg_t arg = *((mod_ui_arg_t *) arg_ptr);
    nn_ptr tmp;
    slong i, j;

    fmpz_comb_t comb;
    fmpz_comb_temp_t comb_temp;

    tmp = flint_malloc(sizeof(ulong) * arg.num_primes);
    fmpz_comb_init(comb, arg.primes, arg.num_primes);
    fmpz_comb_temp_init(comb_temp, comb);

    for (i = arg.n0; i < arg.n1; i++)
    {
        if (arg.crt)
        {
            for (j = 0; j < arg.num_primes; j++)
                tmp[j] = arg.residues[j][i];
            fmpz_multi_CRT_ui(arg.vec + i, tmp, comb, comb_temp, 1);
        }
        else
        {
            fmpz_multi_mod_ui(tmp, arg.vec + i, comb, comb_temp);
            for (j = 0; j < arg.num_primes; j++)
                arg.residues[j][i] = tmp[j];
        }
    }

    flint_free(tmp);
    fmpz_comb_clear(comb);
    fmpz_comb_temp_clear(comb_temp);
}

static void
_fmpz_vec_multi_mod_ui_threaded(nn_ptr * residues, fmpz * vec, slong len,
    nn_srcptr primes, slong num_primes, int crt)
{
    mod_ui_arg_t * args;
    slong i, num_threads;
    thread_pool_handle * threads;

    num_threads = flint_request_threads(&threads, flint_get_num_threads());

    args = (mod_ui_arg_t *)
                          flint_malloc(sizeof(mod_ui_arg_t)*(num_threads + 1));

    for (i = 0; i < num_threads + 1; i++)
    {
        args[i].vec = vec;
        args[i].residues = residues;
        args[i].n0 = (len * i) / (num_threads + 1);
        args[i].n1 = (len * (i + 1)) / (num_threads + 1);
        args[i].primes = (nn_ptr) primes;
        args[i].num_primes = num_primes;
        args[i].crt = crt;
    }

    for (i = 0; i < num_threads; i++)
        thread_pool_wake(global_thread_pool, threads[i], 0,
                               _fmpz_vec_multi_mod_ui_worker, &args[i]);

    _fmpz_vec_multi_mod_ui_worker(&args[num_threads]);

    for (i = 0; i < num_threads; i++)
        thread_pool_wait(global_thread_pool, threads[i]);

    flint_give_back_threads(threads, num_threads);

    flint_free(args);
}

typedef struct
{
    nn_ptr * residues;
    slong len;
    nn_srcptr primes;
    slong num_primes;
    slong p0;
    slong p1;
    fmpz * c;
}
taylor_shift_arg_t;

static void
_fmpz_poly_multi_taylor_shift_worker(void * arg_ptr)
{
    taylor_shift_arg_t arg = *((taylor_shift_arg_t *) arg_ptr);
    slong i;

    for (i = arg.p0; i < arg.p1; i++)
    {
        nmod_t mod;
        ulong p, cm;

        p = arg.primes[i];
        nmod_init(&mod, p);
        cm = fmpz_fdiv_ui(arg.c, p);
        _nmod_poly_taylor_shift(arg.residues[i], cm, arg.len, mod);
    }
}

static void
_fmpz_poly_multi_taylor_shift_threaded(nn_ptr * residues, slong len,
        const fmpz_t c, nn_srcptr primes, slong num_primes)
{
    taylor_shift_arg_t * args;
    slong i, num_threads;
    thread_pool_handle * threads;

    num_threads = flint_request_threads(&threads, flint_get_num_threads());

    args = (taylor_shift_arg_t *)
                    flint_malloc(sizeof(taylor_shift_arg_t)*(num_threads + 1));

    for (i = 0; i < num_threads + 1; i++)
    {
        args[i].residues = residues;
        args[i].len = len;
        args[i].p0 = (num_primes * i) / (num_threads + 1);
        args[i].p1 = (num_primes * (i + 1)) / (num_threads + 1);
        args[i].primes = (nn_ptr) primes;
        args[i].num_primes = num_primes;
        args[i].c = (fmpz *) c;
    }

    for (i = 0; i < num_threads; i++)
        thread_pool_wake(global_thread_pool, threads[i], 0,
                               _fmpz_poly_multi_taylor_shift_worker, &args[i]);

    _fmpz_poly_multi_taylor_shift_worker(&args[num_threads]);

    for (i = 0; i < num_threads; i++)
        thread_pool_wait(global_thread_pool, threads[i]);

    flint_give_back_threads(threads, num_threads);

    flint_free(args);
}

void
_fmpz_poly_taylor_shift_multi_mod(fmpz * poly, const fmpz_t c, slong len)
{
    slong xbits, ybits, num_primes, i;
    nn_ptr primes;
    nn_ptr * residues;

    if (len <= 1 || fmpz_is_zero(c))
        return;

    xbits = _fmpz_vec_max_bits(poly, len);

    if (xbits == 0)
        return;

    /* If poly has degree D and coefficients at most |C|, the
       output has coefficient at most D * |C| * 2^D * c^D */
    xbits = FLINT_ABS(xbits) + 1;
    ybits = xbits + len + FLINT_BIT_COUNT(len);

    if (!fmpz_is_pm1(c))
    {
        fmpz_t t;
        fmpz_init(t);
        fmpz_pow_ui(t, c, len);
        ybits += fmpz_bits(t);
        fmpz_clear(t);
    }

    /* Use primes greater than 2^(FLINT_BITS-1) */
    num_primes = (ybits + (FLINT_BITS - 1) - 1) / (FLINT_BITS - 1);
    primes = flint_malloc(sizeof(ulong) * num_primes);
    primes[0] = n_nextprime(UWORD(1) << (FLINT_BITS - 1), 1);
    for (i = 1; i < num_primes; i++)
        primes[i] = n_nextprime(primes[i-1], 1);

    /* Space for poly reduced modulo the primes */
    residues = flint_malloc(sizeof(nn_ptr) * num_primes);
    for (i = 0; i < num_primes; i++)
        residues[i] = flint_malloc(sizeof(ulong) * len);

    _fmpz_vec_multi_mod_ui_threaded(residues, poly, len, primes,
                                                                num_primes, 0);
    _fmpz_poly_multi_taylor_shift_threaded(residues, len, c,
                                                           primes, num_primes);
    _fmpz_vec_multi_mod_ui_threaded(residues, poly, len, primes,
		                                                num_primes, 1);

    for (i = 0; i < num_primes; i++)
        flint_free(residues[i]);
    flint_free(residues);
    flint_free(primes);
}