#include "blis.h"
thrcomm_t BLIS_SINGLE_COMM = {};
typedef void (*thread_launch_t)
(
dim_t nt,
thread_func_t func,
const void* params
);
static thread_launch_t thread_launch_fpa[ BLIS_NUM_THREAD_IMPLS ] =
{
[BLIS_SINGLE] = bli_thread_launch_single,
[BLIS_OPENMP] =
#if defined(BLIS_ENABLE_OPENMP)
bli_thread_launch_openmp,
#else
NULL,
#endif
[BLIS_POSIX] =
#if defined(BLIS_ENABLE_PTHREADS)
bli_thread_launch_pthreads,
#else
NULL,
#endif
[BLIS_HPX] =
#if defined(BLIS_ENABLE_HPX)
bli_thread_launch_hpx,
#else
NULL,
#endif
};
int bli_thread_init( void )
{
bli_thrcomm_init( BLIS_SINGLE, 1, &BLIS_SINGLE_COMM );
return 0;
}
int bli_thread_finalize( void )
{
bli_thrcomm_cleanup( &BLIS_SINGLE_COMM );
return 0;
}
void bli_thread_launch
(
timpl_t ti,
dim_t nt,
thread_func_t func,
const void* params
)
{
thread_launch_fpa[ti]( nt, func, params );
}
void bli_prime_factorization( dim_t n, bli_prime_factors_t* factors )
{
factors->n = n;
factors->sqrt_n = ( dim_t )sqrt( ( double )n );
factors->f = 2;
}
dim_t bli_next_prime_factor( bli_prime_factors_t* factors )
{
while ( factors->f <= factors->sqrt_n )
{
if ( factors->f == 2 )
{
if ( factors->n % 2 == 0 )
{
factors->n /= 2;
return 2;
}
factors->f = 3;
}
else if ( factors->f == 3 )
{
if ( factors->n % 3 == 0 )
{
factors->n /= 3;
return 3;
}
factors->f = 5;
}
else if ( factors->f == 5 )
{
if ( factors->n % 5 == 0 )
{
factors->n /= 5;
return 5;
}
factors->f = 7;
}
else if ( factors->f == 7 )
{
if ( factors->n % 7 == 0 )
{
factors->n /= 7;
return 7;
}
factors->f = 11;
}
else
{
if ( factors->n % factors->f == 0 )
{
factors->n /= factors->f;
return factors->f;
}
factors->f++;
}
}
dim_t tmp = factors->n;
factors->n = 1;
return tmp;
}
bool bli_is_prime( dim_t n )
{
bli_prime_factors_t factors;
bli_prime_factorization( n, &factors );
dim_t f = bli_next_prime_factor( &factors );
if ( f == n ) return TRUE;
else return FALSE;
}
void bli_thread_partition_2x2
(
dim_t n_thread,
dim_t work1,
dim_t work2,
dim_t* restrict nt1,
dim_t* restrict nt2
)
{
if ( n_thread < 4 )
{
*nt1 = ( work1 >= work2 ? n_thread : 1 );
*nt2 = ( work1 < work2 ? n_thread : 1 );
return;
}
#if 1
bli_thread_partition_2x2_fast( n_thread, work1, work2, nt1, nt2 );
#else#endif
}
void bli_thread_partition_2x2_fast
(
dim_t n_thread,
dim_t work1,
dim_t work2,
dim_t* restrict nt1,
dim_t* restrict nt2
)
{
dim_t tn1 = 1;
dim_t tn2 = 1;
bli_prime_factors_t factors;
bli_prime_factorization( n_thread, &factors );
#ifdef PRINT_FACTORS
printf( "w1 w2 = %d %d (initial)\n", (int)work1, (int)work2 );
#endif
dim_t f;
while ( ( f = bli_next_prime_factor( &factors ) ) > 1 )
{
#ifdef PRINT_FACTORS
printf( "w1 w2 = %4d %4d nt1 nt2 = %d %d ... f = %d\n",
(int)work1, (int)work2, (int)tn1, (int)tn2, (int)f );
#endif
if ( work1 > work2 ) { work1 /= f; tn1 *= f; }
else { work2 /= f; tn2 *= f; }
}
#ifdef PRINT_FACTORS
printf( "w1 w2 = %4d %4d nt1 nt2 = %d %d\n",
(int)work1, (int)work2, (int)tn1, (int)tn2 );
#endif
if ( work1 > work2 )
{
if ( tn2 % 2 == 0 )
{
dim_t diff = work1 - work2;
dim_t diff_mod = bli_abs( work1/2 - work2*2 );
if ( diff_mod < diff ) { tn1 *= 2; tn2 /= 2; }
}
}
else if ( work1 < work2 )
{
if ( tn1 % 2 == 0 )
{
dim_t diff = work2 - work1;
dim_t diff_mod = bli_abs( work2/2 - work1*2 );
if ( diff_mod < diff ) { tn1 /= 2; tn2 *= 2; }
}
}
#ifdef PRINT_FACTORS
printf( "w1 w2 = %4d %4d nt1 nt2 = %d %d (final)\n",
(int)work1, (int)work2, (int)tn1, (int)tn2 );
#endif
*nt1 = tn1;
*nt2 = tn2;
}
#include "limits.h"
void bli_thread_partition_2x2_slow
(
dim_t n_thread,
dim_t work1,
dim_t work2,
dim_t* restrict nt1,
dim_t* restrict nt2
)
{
dim_t tn1 = 1;
dim_t tn2 = 1;
bli_prime_factors_t factors;
bli_prime_factorization( n_thread, &factors );
dim_t fact[8];
dim_t mult[8];
dim_t nfact = 1;
fact[0] = bli_next_prime_factor( &factors );
mult[0] = 1;
dim_t f;
while ( ( f = bli_next_prime_factor( &factors ) ) > 1 )
{
if ( f == fact[nfact-1] )
{
mult[nfact-1]++;
}
else
{
nfact++;
fact[nfact-1] = f;
mult[nfact-1] = 1;
}
}
dim_t ntake[8] = {0};
dim_t min_diff = INT_MAX;
bool done = FALSE;
while ( !done )
{
dim_t x = 1;
dim_t y = 1;
for ( dim_t i = 0 ; i < nfact ; i++ )
{
x *= bli_ipow( fact[i], ntake[i] );
y *= bli_ipow( fact[i], mult[i]-ntake[i] );
}
dim_t diff = llabs( x*work2 - y*work1 );
if ( diff < min_diff )
{
min_diff = diff;
tn1 = x;
tn2 = y;
}
for ( dim_t i = 0 ; i < nfact ; i++ )
{
if ( ++ntake[i] > mult[i] )
{
ntake[i] = 0;
if ( i == nfact-1 ) done = TRUE;
else continue;
}
break;
}
}
*nt1 = tn1;
*nt2 = tn2;
}
#if 0#endif
dim_t bli_gcd( dim_t x, dim_t y )
{
while ( y != 0 )
{
dim_t t = y;
y = x % y;
x = t;
}
return x;
}
dim_t bli_lcm( dim_t x, dim_t y)
{
return x * y / bli_gcd( x, y );
}
dim_t bli_ipow( dim_t base, dim_t power )
{
dim_t p = 1;
for ( dim_t mask = 0x1 ; mask <= power ; mask <<= 1 )
{
if ( power & mask ) p *= base;
base *= base;
}
return p;
}
dim_t bli_thread_get_jc_nt( void )
{
bli_init_once();
return bli_rntm_jc_ways( bli_global_rntm() );
}
dim_t bli_thread_get_pc_nt( void )
{
bli_init_once();
return bli_rntm_pc_ways( bli_global_rntm() );
}
dim_t bli_thread_get_ic_nt( void )
{
bli_init_once();
return bli_rntm_ic_ways( bli_global_rntm() );
}
dim_t bli_thread_get_jr_nt( void )
{
bli_init_once();
return bli_rntm_jr_ways( bli_global_rntm() );
}
dim_t bli_thread_get_ir_nt( void )
{
bli_init_once();
return bli_rntm_ir_ways( bli_global_rntm() );
}
dim_t bli_thread_get_num_threads( void )
{
bli_init_once();
return bli_rntm_num_threads( bli_global_rntm() );
}
timpl_t bli_thread_get_thread_impl( void )
{
bli_init_once();
return bli_rntm_thread_impl( bli_global_rntm() );
}
static const char* bli_timpl_string[BLIS_NUM_THREAD_IMPLS] =
{
[BLIS_SINGLE] = "single",
[BLIS_OPENMP] = "openmp",
[BLIS_POSIX] = "pthreads",
[BLIS_HPX] = "hpx",
};
const char* bli_thread_get_thread_impl_str( timpl_t ti )
{
return bli_timpl_string[ti];
}
void bli_thread_set_ways( dim_t jc, dim_t pc, dim_t ic, dim_t jr, dim_t ir )
{
bli_init_once();
#ifdef BLIS_ENABLE_MULTITHREADING
#ifdef BLIS_DISABLE_TLS
bli_pthread_mutex_lock( bli_global_rntm_mutex() );
#endif
bli_rntm_set_ways_only( jc, 1, ic, jr, ir, bli_global_rntm() );
bli_rntm_sanitize( bli_global_rntm() );
#ifdef BLIS_DISABLE_TLS
bli_pthread_mutex_unlock( bli_global_rntm_mutex() );
#endif
#else
#endif
}
void bli_thread_set_num_threads( dim_t n_threads )
{
bli_init_once();
#ifdef BLIS_ENABLE_MULTITHREADING
#ifdef BLIS_DISABLE_TLS
bli_pthread_mutex_lock( bli_global_rntm_mutex() );
#endif
bli_rntm_set_num_threads_only( n_threads, bli_global_rntm() );
bli_rntm_sanitize( bli_global_rntm() );
#ifdef BLIS_DISABLE_TLS
bli_pthread_mutex_unlock( bli_global_rntm_mutex() );
#endif
#else
#endif
}
void bli_thread_set_thread_impl( timpl_t ti )
{
bli_init_once();
#ifdef BLIS_DISABLE_TLS
bli_pthread_mutex_lock( bli_global_rntm_mutex() );
#endif
bli_rntm_set_thread_impl_only( ti, bli_global_rntm() );
#ifdef BLIS_DISABLE_TLS
bli_pthread_mutex_unlock( bli_global_rntm_mutex() );
#endif
}
void bli_thread_reset( void )
{
bli_init_once();
#ifdef BLIS_DISABLE_TLS
bli_pthread_mutex_lock( bli_global_rntm_mutex() );
#endif
rntm_t* src = bli_global_rntm_at_init();
rntm_t* dst = bli_global_rntm();
timpl_t ti = bli_rntm_thread_impl( src );
bool af = bli_rntm_auto_factor( src );
dim_t nt = bli_rntm_num_threads( src );
bli_rntm_set_thread_impl_only( ti, dst );
bli_rntm_set_auto_factor_only( af, dst );
bli_rntm_set_num_threads_only( nt, dst );
dim_t jc = bli_rntm_jc_ways( src );
dim_t pc = bli_rntm_pc_ways( src );
dim_t ic = bli_rntm_ic_ways( src );
dim_t jr = bli_rntm_jr_ways( src );
dim_t ir = bli_rntm_ir_ways( src );
bli_rntm_set_ways_only( jc, pc, ic, jr, ir, dst );
#ifdef BLIS_DISABLE_TLS
bli_pthread_mutex_unlock( bli_global_rntm_mutex() );
#endif
}