#include "blis.h"
#ifdef BLIS_ENABLE_HPX
#include <hpx/synchronization/barrier.hpp>
extern "C" {
void hpx_barrier_init( hpx_barrier_t* barrier, dim_t n_threads )
{
if ( barrier == nullptr ) return;
barrier->handle = new hpx::barrier<>( n_threads );
}
void hpx_barrier_destroy( hpx_barrier_t* barrier )
{
if ( barrier == nullptr ) return;
auto* barrier_ = reinterpret_cast<hpx::barrier<>*>( barrier->handle );
barrier->handle = nullptr;
delete barrier_;
}
void hpx_barrier_arrive_and_wait( hpx_barrier_t* barrier )
{
if ( barrier == nullptr ) return;
auto* barrier_ = reinterpret_cast<hpx::barrier<>*>( barrier->handle );
if ( barrier_ == nullptr ) return;
barrier_->arrive_and_wait();
}
void bli_thrcomm_init_hpx( dim_t n_threads, thrcomm_t* comm )
{
if ( comm == nullptr ) return;
comm->sent_object = nullptr;
comm->n_threads = n_threads;
comm->ti = BLIS_HPX;
hpx_barrier_init( &comm->barrier, n_threads );
}
void bli_thrcomm_cleanup_hpx( thrcomm_t* comm )
{
if ( comm == nullptr ) return;
hpx_barrier_destroy( &comm->barrier );
}
void bli_thrcomm_barrier_hpx( dim_t t_id, thrcomm_t* comm )
{
hpx_barrier_arrive_and_wait( &comm->barrier );
}
}
#endif