#include "test_helpers.h"
#include "thread_support.h"
#include "fmpz.h"
typedef struct
{
fmpz r;
}
product_res_t;
typedef struct
{
nn_srcptr factors;
int left_inplace;
}
product_args_t;
static void
product_init(product_res_t * x, product_args_t * args)
{
fmpz_init(&x->r);
}
static void
product_clear(product_res_t * x, product_args_t * args)
{
fmpz_clear(&x->r);
}
static void
product_combine(product_res_t * res, product_res_t * left, product_res_t * right, product_args_t * args)
{
if (((res == left) != args->left_inplace) || res == right)
{
flint_abort();
}
fmpz_mul(&res->r, &left->r, &right->r);
}
static void
product_basecase(product_res_t * res, slong a, slong b, product_args_t * args)
{
slong i;
fmpz_one(&res->r);
for (i = a; i < b; i++)
fmpz_mul_ui(&res->r, &res->r, args->factors[i]);
}
static void
bsplit_product(fmpz_t r, nn_srcptr factors, slong len, slong thread_limit, int flags)
{
product_res_t res;
product_args_t args;
res.r = *r;
args.factors = factors;
args.left_inplace = (flags & FLINT_PARALLEL_BSPLIT_LEFT_INPLACE) ? 1 : 0;
flint_parallel_binary_splitting(&res,
(bsplit_basecase_func_t) product_basecase,
(bsplit_merge_func_t) product_combine,
sizeof(product_res_t),
(bsplit_init_func_t) product_init,
(bsplit_clear_func_t) product_clear,
&args, 0, len, 4, thread_limit, flags);
*r = res.r;
}
TEST_FUNCTION_START(thread_support_parallel_binary_splitting, state)
{
slong iter;
for (iter = 0; iter < 100 * flint_test_multiplier(); iter++)
{
fmpz_t r, s;
nn_ptr factors;
slong i, n;
int flags;
n = n_randint(state, 100);
flint_set_num_threads(n_randint(state, 10) + 1);
factors = flint_malloc(n * sizeof(ulong));
fmpz_init(r);
fmpz_init(s);
for (i = 0; i < n; i++)
factors[i] = n_randint(state, 300);
flags = 0;
if (n_randint(state, 2))
flags = FLINT_PARALLEL_BSPLIT_LEFT_INPLACE;
bsplit_product(r, factors, n, n_randint(state, 5), flags);
fmpz_one(s);
for (i = 0; i < n; i++)
fmpz_mul_ui(s, s, factors[i]);
if (!fmpz_equal(r, s))
TEST_FUNCTION_FAIL(
"num_threads = %wd, i = %wd/%wd\n",
flint_get_num_threads(), i, n);
flint_free(factors);
fmpz_clear(r);
fmpz_clear(s);
}
TEST_FUNCTION_END(state);
}