#include "blis.h"
#ifdef BLIS_ENABLE_PTHREADS
typedef struct thread_data
{
dim_t tid;
thrcomm_t* gl_comm;
thread_func_t func;
const void* params;
} thread_data_t;
static void* bli_posix_thread_entry( void* data_void )
{
const thread_data_t* data = data_void;
const dim_t tid = data->tid;
thrcomm_t* gl_comm = data->gl_comm;
thread_func_t func = data->func;
const void* params = data->params;
func( gl_comm, tid, params );
return NULL;
}
void bli_thread_launch_pthreads( dim_t n_threads, thread_func_t func, const void* params )
{
err_t r_val;
const timpl_t ti = BLIS_POSIX;
pool_t* gl_comm_pool = NULL;
thrcomm_t* gl_comm = bli_thrcomm_create( ti, gl_comm_pool, n_threads );
#ifdef BLIS_ENABLE_MEM_TRACING
printf( "bli_l3_thread_decorator().pth: " );
#endif
bli_pthread_t* pthreads = bli_malloc_intl( sizeof( bli_pthread_t ) * n_threads, &r_val );
#ifdef BLIS_ENABLE_MEM_TRACING
printf( "bli_l3_thread_decorator().pth: " );
#endif
thread_data_t* datas = bli_malloc_intl( sizeof( thread_data_t ) * n_threads, &r_val );
for ( dim_t tid = n_threads - 1; 0 <= tid; tid-- )
{
datas[tid].tid = tid;
datas[tid].gl_comm = gl_comm;
datas[tid].func = func;
datas[tid].params = params;
if ( tid != 0 )
bli_pthread_create( &pthreads[tid], NULL, &bli_posix_thread_entry, &datas[tid] );
else
bli_posix_thread_entry( &datas[0] );
}
for ( dim_t tid = 1; tid < n_threads; tid++ )
{
bli_pthread_join( pthreads[tid], NULL );
}
bli_thrcomm_free( gl_comm_pool, gl_comm );
#ifdef BLIS_ENABLE_MEM_TRACING
printf( "bli_l3_thread_decorator().pth: " );
#endif
bli_free_intl( pthreads );
#ifdef BLIS_ENABLE_MEM_TRACING
printf( "bli_l3_thread_decorator().pth: " );
#endif
bli_free_intl( datas );
}
#endif