#include "blis.h"
void bli_thread_range_sub
(
const thrinfo_t* thread,
dim_t n,
dim_t bf,
bool handle_edge_low,
dim_t* start,
dim_t* end
)
{
dim_t n_way = bli_thrinfo_n_way( thread );
if ( n_way == 1 ) { *start = 0; *end = n; return; }
dim_t work_id = bli_thrinfo_work_id( thread );
dim_t all_start = 0;
dim_t all_end = n;
dim_t size = all_end - all_start;
dim_t n_bf_whole = size / bf;
dim_t n_bf_left = size % bf;
dim_t n_bf_lo = n_bf_whole / n_way;
dim_t n_bf_hi = n_bf_whole / n_way;
if ( handle_edge_low == FALSE )
{
dim_t n_th_lo = n_bf_whole % n_way;
if ( n_th_lo != 0 ) n_bf_lo += 1;
dim_t size_lo = n_bf_lo * bf;
dim_t size_hi = n_bf_hi * bf;
dim_t lo_start = all_start;
dim_t hi_start = all_start + n_th_lo * size_lo;
if ( work_id < n_th_lo )
{
*start = lo_start + (work_id ) * size_lo;
*end = lo_start + (work_id+1) * size_lo;
}
else {
*start = hi_start + (work_id-n_th_lo ) * size_hi;
*end = hi_start + (work_id-n_th_lo+1) * size_hi;
if ( work_id == n_way - 1 ) *end += n_bf_left;
}
}
else {
dim_t n_th_hi = n_bf_whole % n_way;
dim_t n_th_lo = n_way - n_th_hi;
if ( n_th_hi != 0 ) n_bf_hi += 1;
dim_t size_lo = n_bf_lo * bf;
dim_t size_hi = n_bf_hi * bf;
dim_t lo_start = all_start;
dim_t hi_start = all_start + n_th_lo * size_lo
+ n_bf_left;
if ( work_id < n_th_lo )
{
*start = lo_start + (work_id ) * size_lo;
*end = lo_start + (work_id+1) * size_lo;
if ( work_id == 0 ) *end += n_bf_left;
else { *start += n_bf_left;
*end += n_bf_left; }
}
else {
*start = hi_start + (work_id-n_th_lo ) * size_hi;
*end = hi_start + (work_id-n_th_lo+1) * size_hi;
}
}
}
siz_t bli_thread_range_l2r
(
const thrinfo_t* thr,
const obj_t* a,
const blksz_t* bmult,
dim_t* start,
dim_t* end
)
{
num_t dt = bli_obj_dt( a );
dim_t m = bli_obj_length_after_trans( a );
dim_t n = bli_obj_width_after_trans( a );
dim_t bf = bli_blksz_get_def( dt, bmult );
bli_thread_range_sub( thr, n, bf,
FALSE, start, end );
return m * ( *end - *start );
}
siz_t bli_thread_range_r2l
(
const thrinfo_t* thr,
const obj_t* a,
const blksz_t* bmult,
dim_t* start,
dim_t* end
)
{
num_t dt = bli_obj_dt( a );
dim_t m = bli_obj_length_after_trans( a );
dim_t n = bli_obj_width_after_trans( a );
dim_t bf = bli_blksz_get_def( dt, bmult );
bli_thread_range_sub( thr, n, bf,
TRUE, start, end );
return m * ( *end - *start );
}
siz_t bli_thread_range_t2b
(
const thrinfo_t* thr,
const obj_t* a,
const blksz_t* bmult,
dim_t* start,
dim_t* end
)
{
num_t dt = bli_obj_dt( a );
dim_t m = bli_obj_length_after_trans( a );
dim_t n = bli_obj_width_after_trans( a );
dim_t bf = bli_blksz_get_def( dt, bmult );
bli_thread_range_sub( thr, m, bf,
FALSE, start, end );
return n * ( *end - *start );
}
siz_t bli_thread_range_b2t
(
const thrinfo_t* thr,
const obj_t* a,
const blksz_t* bmult,
dim_t* start,
dim_t* end
)
{
num_t dt = bli_obj_dt( a );
dim_t m = bli_obj_length_after_trans( a );
dim_t n = bli_obj_width_after_trans( a );
dim_t bf = bli_blksz_get_def( dt, bmult );
bli_thread_range_sub( thr, m, bf,
TRUE, start, end );
return n * ( *end - *start );
}
dim_t bli_thread_range_width_l
(
doff_t diagoff_j,
dim_t m,
dim_t n_j,
dim_t j,
dim_t n_way,
dim_t bf,
dim_t bf_left,
double area_per_thr,
bool handle_edge_low
)
{
dim_t width;
if ( j == n_way - 1 ) return n_j;
{
dim_t n_j_bf = n_j / bf + ( bf_left > 0 ? 1 : 0 );
if ( n_j_bf <= n_way - j )
{
if ( j == 0 && handle_edge_low )
width = ( bf_left > 0 ? bf_left : bf );
else
width = bf;
if ( width > n_j ) width = n_j;
return width;
}
}
{
width = ( dim_t )bli_round( ( double )area_per_thr / ( double )m );
if ( j == 0 && handle_edge_low )
{
if ( width % bf != bf_left ) width += bf_left - ( width % bf );
}
else {
if ( width % bf != 0 ) width = bli_round_to_mult( width, bf );
}
}
if ( diagoff_j < width )
{
dim_t offm_inc, offn_inc;
bli_prune_unstored_region_top_l( &diagoff_j, &m, &n_j, &offm_inc );
( void )offm_inc;
( void )offn_inc;
const double a = -0.5;
const double b = ( double )m + ( double )diagoff_j + 0.5;
const double c = -0.5 * ( ( double )diagoff_j *
( ( double )diagoff_j + 1.0 )
) - area_per_thr;
const double r = b * b - 4.0 * a * c;
if ( r >= 0.0 )
{
const double x = ( -b + sqrt( r ) ) / ( 2.0 * a );
width = ( dim_t )bli_round( x );
if ( width == 0 ) width = 1;
}
if ( j == 0 && handle_edge_low )
{
if ( width % bf != bf_left ) width += bf_left - ( width % bf );
}
else {
if ( width % bf != 0 ) width = bli_round_to_mult( width, bf );
}
}
if ( width > n_j ) width = n_j;
return width;
}
siz_t bli_find_area_trap_l
(
doff_t diagoff,
dim_t m,
dim_t n,
dim_t bf
)
{
dim_t offm_inc = 0;
dim_t offn_inc = 0;
double utri_area;
double blktri_area;
bli_prune_unstored_region_top_l( &diagoff, &m, &n, &offm_inc );
( void )offm_inc;
( void )offn_inc;
if ( bli_intersects_diag_n( diagoff, m, n ) )
{
double tri_dim = ( double )( n - diagoff - 1 );
tri_dim = bli_min( tri_dim, m - 1 );
utri_area = tri_dim * ( tri_dim + 1.0 ) / 2.0;
blktri_area = tri_dim * ( bf - 1.0 ) / 2.0;
}
else
{
utri_area = 0.0;
blktri_area = 0.0;
}
double area = ( double )m * ( double )n - utri_area + blktri_area;
return ( siz_t )area;
}
siz_t bli_thread_range_weighted_sub
(
const thrinfo_t* thread,
doff_t diagoff,
uplo_t uplo,
uplo_t uplo_orig,
dim_t m,
dim_t n,
dim_t bf,
bool handle_edge_low,
dim_t* j_start_thr,
dim_t* j_end_thr
)
{
dim_t n_way = bli_thrinfo_n_way( thread );
dim_t my_id = bli_thrinfo_work_id( thread );
dim_t bf_left = n % bf;
dim_t offm_inc, offn_inc;
siz_t area = 0;
if ( bli_is_lower( uplo ) )
{
#if 0 #endif
bli_prune_unstored_region_top_l( &diagoff, &m, &n, &offm_inc );
if ( !handle_edge_low )
{
bli_prune_unstored_region_right_l( &diagoff, &m, &n, &offn_inc );
}
else {
const dim_t n_inner = ( diagoff + bli_min( m, n - diagoff ) - bf_left );
const dim_t n_bf_iter_br = n_inner / bf;
const dim_t n_bf_left_br = n_inner % bf;
const dim_t n_bf_br = ( bf_left > 0 ? 1 : 0 ) +
n_bf_iter_br +
( n_bf_left_br > 0 ? 1 : 0 );
const dim_t bf_extra = ( bf_left > 0 ? bf - bf_left : 0 );
n = n_bf_br * bf - bf_extra;
#if 0 #endif
}
( void )offm_inc;
( void )offn_inc;
double tri_dim = ( double )( n - diagoff - 1 );
tri_dim = bli_min( tri_dim, m - 1 );
double utri_area = tri_dim * ( tri_dim + 1.0 ) / 2.0;
double blktri_area = tri_dim * ( bf - 1.0 ) / 2.0;
double beyondtri_dim = n - diagoff - m;
double beyondtri_area;
if ( 0 < beyondtri_dim ) beyondtri_area = beyondtri_dim * m;
else beyondtri_area = 0.0;
blktri_area *= 1.5;
double area_total = ( double )m * ( double )n - utri_area + blktri_area
- beyondtri_area;
double area_per_thr = area_total / ( double )n_way;
dim_t off_j = 0;
doff_t diagoff_j = diagoff;
dim_t n_left = n;
#if 0 #endif
for ( dim_t j = 0; j < n_way; ++j )
{
dim_t width_j
=
bli_thread_range_width_l
(
diagoff_j, m, n_left,
j, n_way,
bf, bf_left,
area_per_thr,
handle_edge_low
);
#if 0 #endif
#if 0 #else
bool is_my_range = ( j == my_id );
#endif
if ( is_my_range )
{
*j_start_thr = off_j;
*j_end_thr = off_j + width_j;
#if 0 #endif
area = bli_find_area_trap_l( diagoff_j, m, width_j, bf );
break;
}
off_j += width_j;
diagoff_j -= width_j;
n_left -= width_j;
}
}
else {
#if 0 #endif
bli_rotate180_trapezoid( &diagoff, &uplo, &m, &n );
bli_toggle_bool( &handle_edge_low );
area = bli_thread_range_weighted_sub
(
thread, diagoff, uplo, uplo_orig, m, n, bf,
handle_edge_low,
j_start_thr, j_end_thr
);
bli_reverse_index_direction( n, j_start_thr, j_end_thr );
}
return area;
}
siz_t bli_thread_range_mdim
(
dir_t direct,
const thrinfo_t* thr,
const obj_t* a,
const obj_t* b,
const obj_t* c,
const cntl_t* cntl,
const cntx_t* cntx,
dim_t* start,
dim_t* end
)
{
bszid_t bszid = bli_cntl_bszid( cntl );
opid_t family = bli_cntl_family( cntl );
if ( family == BLIS_TRSM )
{
if ( bli_obj_root_is_triangular( a ) ) bszid = BLIS_MR;
else bszid = BLIS_NR;
}
const blksz_t* bmult = bli_cntx_get_bmult( bszid, cntx );
const obj_t* x;
bool use_weighted;
if ( family == BLIS_GEMM ) { x = a; use_weighted = FALSE; }
else if ( family == BLIS_GEMMT ) { x = c; use_weighted = TRUE; }
else if ( family == BLIS_TRMM ) { x = a; use_weighted = TRUE; }
else { x = a; use_weighted = FALSE; }
if ( use_weighted )
{
if ( direct == BLIS_FWD )
return bli_thread_range_weighted_t2b( thr, x, bmult, start, end );
else
return bli_thread_range_weighted_b2t( thr, x, bmult, start, end );
}
else
{
if ( direct == BLIS_FWD )
return bli_thread_range_t2b( thr, x, bmult, start, end );
else
return bli_thread_range_b2t( thr, x, bmult, start, end );
}
}
siz_t bli_thread_range_ndim
(
dir_t direct,
const thrinfo_t* thr,
const obj_t* a,
const obj_t* b,
const obj_t* c,
const cntl_t* cntl,
const cntx_t* cntx,
dim_t* start,
dim_t* end
)
{
bszid_t bszid = bli_cntl_bszid( cntl );
opid_t family = bli_cntl_family( cntl );
if ( family == BLIS_TRSM )
{
if ( bli_obj_root_is_triangular( b ) ) bszid = BLIS_MR;
else bszid = BLIS_NR;
}
const blksz_t* bmult = bli_cntx_get_bmult( bszid, cntx );
const obj_t* x;
bool use_weighted;
if ( family == BLIS_GEMM ) { x = b; use_weighted = FALSE; }
else if ( family == BLIS_GEMMT ) { x = c; use_weighted = TRUE; }
else if ( family == BLIS_TRMM ) { x = b; use_weighted = TRUE; }
else { x = b; use_weighted = FALSE; }
if ( use_weighted )
{
if ( direct == BLIS_FWD )
return bli_thread_range_weighted_l2r( thr, x, bmult, start, end );
else
return bli_thread_range_weighted_r2l( thr, x, bmult, start, end );
}
else
{
if ( direct == BLIS_FWD )
return bli_thread_range_l2r( thr, x, bmult, start, end );
else
return bli_thread_range_r2l( thr, x, bmult, start, end );
}
}
siz_t bli_thread_range_weighted_l2r
(
const thrinfo_t* thr,
const obj_t* a,
const blksz_t* bmult,
dim_t* start,
dim_t* end
)
{
siz_t area;
if ( bli_obj_intersects_diag( a ) &&
bli_obj_is_upper_or_lower( a ) )
{
num_t dt = bli_obj_dt( a );
doff_t diagoff = bli_obj_diag_offset( a );
uplo_t uplo = bli_obj_uplo( a );
dim_t m = bli_obj_length( a );
dim_t n = bli_obj_width( a );
dim_t bf = bli_blksz_get_def( dt, bmult );
if ( bli_obj_has_trans( a ) )
{
bli_reflect_about_diag( &diagoff, &uplo, &m, &n );
}
area =
bli_thread_range_weighted_sub
(
thr, diagoff, uplo, uplo, m, n, bf,
FALSE, start, end
);
}
else {
area = bli_thread_range_l2r
(
thr, a, bmult,
start, end
);
}
return area;
}
siz_t bli_thread_range_weighted_r2l
(
const thrinfo_t* thr,
const obj_t* a,
const blksz_t* bmult,
dim_t* start,
dim_t* end
)
{
siz_t area;
if ( bli_obj_intersects_diag( a ) &&
bli_obj_is_upper_or_lower( a ) )
{
num_t dt = bli_obj_dt( a );
doff_t diagoff = bli_obj_diag_offset( a );
uplo_t uplo = bli_obj_uplo( a );
dim_t m = bli_obj_length( a );
dim_t n = bli_obj_width( a );
dim_t bf = bli_blksz_get_def( dt, bmult );
if ( bli_obj_has_trans( a ) )
{
bli_reflect_about_diag( &diagoff, &uplo, &m, &n );
}
bli_rotate180_trapezoid( &diagoff, &uplo, &m, &n );
area =
bli_thread_range_weighted_sub
(
thr, diagoff, uplo, uplo, m, n, bf,
TRUE, start, end
);
}
else {
area = bli_thread_range_r2l
(
thr, a, bmult,
start, end
);
}
return area;
}
siz_t bli_thread_range_weighted_t2b
(
const thrinfo_t* thr,
const obj_t* a,
const blksz_t* bmult,
dim_t* start,
dim_t* end
)
{
siz_t area;
if ( bli_obj_intersects_diag( a ) &&
bli_obj_is_upper_or_lower( a ) )
{
num_t dt = bli_obj_dt( a );
doff_t diagoff = bli_obj_diag_offset( a );
uplo_t uplo = bli_obj_uplo( a );
dim_t m = bli_obj_length( a );
dim_t n = bli_obj_width( a );
dim_t bf = bli_blksz_get_def( dt, bmult );
if ( bli_obj_has_trans( a ) )
{
bli_reflect_about_diag( &diagoff, &uplo, &m, &n );
}
bli_reflect_about_diag( &diagoff, &uplo, &m, &n );
area =
bli_thread_range_weighted_sub
(
thr, diagoff, uplo, uplo, m, n, bf,
FALSE, start, end
);
}
else {
area = bli_thread_range_t2b
(
thr, a, bmult,
start, end
);
}
return area;
}
siz_t bli_thread_range_weighted_b2t
(
const thrinfo_t* thr,
const obj_t* a,
const blksz_t* bmult,
dim_t* start,
dim_t* end
)
{
siz_t area;
if ( bli_obj_intersects_diag( a ) &&
bli_obj_is_upper_or_lower( a ) )
{
num_t dt = bli_obj_dt( a );
doff_t diagoff = bli_obj_diag_offset( a );
uplo_t uplo = bli_obj_uplo( a );
dim_t m = bli_obj_length( a );
dim_t n = bli_obj_width( a );
dim_t bf = bli_blksz_get_def( dt, bmult );
if ( bli_obj_has_trans( a ) )
{
bli_reflect_about_diag( &diagoff, &uplo, &m, &n );
}
bli_reflect_about_diag( &diagoff, &uplo, &m, &n );
bli_rotate180_trapezoid( &diagoff, &uplo, &m, &n );
area = bli_thread_range_weighted_sub
(
thr, diagoff, uplo, uplo, m, n, bf,
TRUE, start, end
);
}
else {
area = bli_thread_range_b2t
(
thr, a, bmult,
start, end
);
}
return area;
}