#include "test_helpers.h"
#include "thread_pool.h"
#include "fmpz.h"
typedef struct
{
ulong modulus;
ulong residue;
ulong n;
fmpz_t ans;
}
worker1_arg_struct;
void worker1(void * varg)
{
worker1_arg_struct * arg = (worker1_arg_struct *) varg;
ulong i;
fmpz_one(arg->ans);
for (i = arg->residue; i <= arg->n; i += arg->modulus)
{
fmpz_mul_ui(arg->ans, arg->ans, i);
}
}
void test1(fmpz_t x, ulong n)
{
ulong i, modulus;
slong k, req, num_workers;
worker1_arg_struct * args;
thread_pool_handle * handles;
req = thread_pool_get_size(global_thread_pool);
handles = (thread_pool_handle *) flint_malloc(req
*sizeof(thread_pool_handle));
num_workers = thread_pool_request(global_thread_pool, handles, req);
args = (worker1_arg_struct *) flint_malloc(FLINT_MAX(num_workers, 1)
*sizeof(worker1_arg_struct));
modulus = num_workers + 1;
for (k = 0; k < num_workers; k++)
{
args[k].residue = k + 1;
args[k].modulus = modulus;
args[k].n = n;
fmpz_init(args[k].ans);
thread_pool_wake(global_thread_pool, handles[k], 0, worker1, &args[k]);
}
fmpz_one(x);
for (i = modulus; i <= n; i += modulus)
{
fmpz_mul_ui(x, x, i);
}
for (k = 0; k < num_workers; k++)
{
thread_pool_wait(global_thread_pool, handles[k]);
fmpz_mul(x, x, args[k].ans);
fmpz_clear(args[k].ans);
thread_pool_give_back(global_thread_pool, handles[k]);
}
flint_free(args);
flint_free(handles);
}
typedef struct
{
ulong min;
ulong max;
fmpz_t ans;
}
worker2_arg_struct;
void test2_helper(fmpz_t x, ulong min, ulong max);
void worker2(void * varg)
{
worker2_arg_struct * arg = (worker2_arg_struct *) varg;
test2_helper(arg->ans, arg->min, arg->max);
}
void test2_helper(fmpz_t x, ulong min, ulong max)
{
ulong i, mid;
slong num_workers;
thread_pool_handle handles[1];
worker2_arg_struct args[1];
FLINT_ASSERT(max >= min);
if (max - min > UWORD(20)
&& ((num_workers = thread_pool_request(global_thread_pool, handles, 1)) != 0))
{
FLINT_ASSERT(num_workers == 1);
mid = min + ((max - min)/UWORD(2));
args[0].min = min;
args[0].max = mid;
fmpz_init(args[0].ans);
thread_pool_wake(global_thread_pool, handles[0], 0, worker2, &args[0]);
test2_helper(x, mid, max);
thread_pool_wait(global_thread_pool, handles[0]);
fmpz_mul(x, x, args[0].ans);
fmpz_clear(args[0].ans);
thread_pool_give_back(global_thread_pool, handles[0]);
}
else
{
fmpz_one(x);
for (i = max; i > min; i--)
{
fmpz_mul_ui(x, x, i);
}
}
}
void test2(fmpz_t x, ulong n)
{
test2_helper(x, 0, n);
}
TEST_FUNCTION_START(thread_pool, state)
{
slong i, j;
for (i = 0; i < 10*flint_test_multiplier(); i++)
{
fmpz_t x, y;
fmpz_init(x);
fmpz_init(y);
flint_set_num_threads(n_randint(state, 10) + 1);
for (j = 0; j < 10; j++)
{
ulong n = n_randint(state, 1000);
fmpz_fac_ui(y, n);
test1(x, n);
if (!fmpz_equal(x, y))
TEST_FUNCTION_FAIL(
"Test 1 failed\n"
"n: %wu\n"
"x: %{fmpz}\n"
"y: %{fmpz}\n",
n, x, y);
test2(x, n);
if (!fmpz_equal(x, y))
TEST_FUNCTION_FAIL(
"Test 2 failed\n"
"n: %wu\n"
"x: %{fmpz}\n"
"y: %{fmpz}\n",
n, x, y);
}
fmpz_clear(y);
fmpz_clear(x);
}
TEST_FUNCTION_END(state);
}