#ifndef EXAMPLE_GROW_ARRAY
#define EXAMPLE_GROW_ARRAY
#include <cstdlib>
#include <Kokkos_Core.hpp>
#include <algorithm>
#if defined(KOKKOS_ENABLE_CUDA)
#include <thrust/device_ptr.h>
#include <thrust/sort.h>
#endif
namespace Example {
template< class ExecSpace >
struct SortView {
template< typename ValueType >
SortView( const Kokkos::View<ValueType*,ExecSpace> v , int begin , int end )
{
std::sort( v.ptr_on_device() + begin , v.ptr_on_device() + end );
}
};
#if defined(KOKKOS_ENABLE_CUDA)
template<>
struct SortView< Kokkos::Cuda > {
template< typename ValueType >
SortView( const Kokkos::View<ValueType*,Kokkos::Cuda> v , int begin , int end )
{
thrust::sort( thrust::device_ptr<ValueType>( v.ptr_on_device() + begin )
, thrust::device_ptr<ValueType>( v.ptr_on_device() + end ) );
}
};
#endif
template< class ExecSpace >
struct GrowArrayFunctor {
typedef ExecSpace execution_space ;
enum { SHIFT = sizeof(int) == 8 ? 6 : 5 }; enum { MASK = ( 1 << SHIFT ) - 1 };
const Kokkos::View<int*,ExecSpace> m_search_flags ; const Kokkos::View<int*,ExecSpace> m_search_array ; const Kokkos::View<int,ExecSpace> m_search_count ; const int m_search_total ;
const int m_search_team_chunk ;
GrowArrayFunctor( int array_length , int search_length , int print = 1 )
: m_search_flags( "flags" , ( search_length + MASK ) >> SHIFT ) , m_search_array( "array" , array_length )
, m_search_count( "count" )
, m_search_total( search_length )
, m_search_team_chunk( 2048 )
{}
KOKKOS_INLINE_FUNCTION
bool flag_is_set( const int index ) const
{
const int j = index >> SHIFT ; const int k = 1 << ( index & MASK ); const int s = ( j < int(m_search_flags.dimension_0()) ) && ( 0 != ( m_search_flags(j) & k ) );
return s ;
}
typedef typename Kokkos::TeamPolicy<ExecSpace>::member_type team_member ;
KOKKOS_INLINE_FUNCTION
void operator()( const team_member & member ) const
{
enum { LOCAL_BUFFER_LENGTH = 16 };
int local_buffer[ LOCAL_BUFFER_LENGTH ] ;
int local_count = 0 ;
int search_team_begin = member.league_rank() * m_search_team_chunk ;
const int search_team_end = search_team_begin + m_search_team_chunk ;
int k = 0 ;
while ( search_team_begin < search_team_end ) {
const int thread_search_index = search_team_begin + member.team_rank();
if ( thread_search_index < m_search_total && flag_is_set(thread_search_index) ) {
local_buffer[ local_count ] = thread_search_index ;
++local_count ;
}
search_team_begin += member.team_size();
++k ;
if ( k == LOCAL_BUFFER_LENGTH ||
! ( search_team_begin < search_team_end ) ) {
const int team_offset = member.team_scan( local_count , & *m_search_count );
for ( int i = 0 ; i < local_count ; ++i ) {
m_search_array( team_offset + i ) = local_buffer[i] ;
}
k = 0 ;
local_count = 0 ;
}
}
}
};
template< class ExecSpace >
void grow_array( int array_length , int search_length , int print = 1 )
{
typedef GrowArrayFunctor< ExecSpace > FunctorType ;
FunctorType functor( array_length , search_length , print );
typename Kokkos::View<int,ExecSpace>::HostMirror count = Kokkos::create_mirror_view( functor.m_search_count );
typename Kokkos::View<int*,ExecSpace>::HostMirror flags = Kokkos::create_mirror_view( functor.m_search_flags );
for ( int i = 0 ; i < array_length ; ++i ) {
const long int index = ( lrand48() * search_length ) >> 31 ;
flags( index >> FunctorType::SHIFT ) |= ( 1 << ( index & FunctorType::MASK ) );
}
Kokkos::deep_copy( functor.m_search_flags , flags );
Kokkos::TeamPolicy< ExecSpace >
work( ( search_length + functor.m_search_team_chunk - 1 ) / functor.m_search_team_chunk
, Kokkos::TeamPolicy< ExecSpace >::team_size_max( functor ) );
Kokkos::parallel_for( work , functor );
Kokkos::deep_copy( count , functor.m_search_count );
SortView< ExecSpace >( functor.m_search_array , 0 , *count );
typename Kokkos::View<int*,ExecSpace>::HostMirror results = Kokkos::create_mirror_view( functor.m_search_array );
Kokkos::deep_copy( results , functor.m_search_array );
int result_error_count = 0 ;
int flags_error_count = 0 ;
for ( int i = 0 ; i < *count ; ++i ) {
const int index = results(i);
const int entry = index >> FunctorType::SHIFT ;
const int bit = 1 << ( index & FunctorType::MASK );
const bool flag = 0 != ( flags( entry ) & bit );
if ( ! flag ) {
if ( print ) std::cerr << "result( " << i << " : " << index << " )";
++result_error_count ;
}
flags( entry ) &= ~bit ; }
for ( int i = 0 ; i < int(flags.dimension_0()) ; ++i ) {
if ( flags(i) ) {
if ( print ) std::cerr << "flags( " << i << " : " << flags(i) << " )" ;
++flags_error_count ;
}
}
if ( result_error_count || flags_error_count ) {
std::cerr << std::endl << "Example::GrowArrayFunctor( " << array_length
<< " , " << search_length
<< " ) result_error_count( " << result_error_count << " )"
<< " ) flags_error_count( " << flags_error_count << " )"
<< std::endl ;
}
}
}
#endif