#pragma once
#include <cstdint>
#include <concepts>
#include <utility>
#if defined(__CUDACC__)
#define WT_DEVICE __device__ __forceinline__
#define WT_HOST_DEVICE __host__ __device__ __forceinline__
#define WT_SHFL_XOR(val, mask) __shfl_xor_sync(0xFFFFFFFFu, (val), (mask))
#define WT_SHFL_DOWN(val, delta) __shfl_down_sync(0xFFFFFFFFu, (val), (delta))
#define WT_BALLOT(pred) __ballot_sync(0xFFFFFFFFu, (pred))
#define WT_WARP_SIZE 32
#elif defined(__HIPCC__)
#define WT_DEVICE __device__ __attribute__((always_inline))
#define WT_HOST_DEVICE __host__ __device__ __attribute__((always_inline))
#define WT_SHFL_XOR(val, mask) __shfl_xor((val), (mask))
#define WT_SHFL_DOWN(val, delta) __shfl_down((val), (delta))
#define WT_BALLOT(pred) __ballot((pred))
#if defined(__gfx9__) || defined(__gfx10__) || defined(__gfx11__)
#define WT_WARP_SIZE 64
#else
#define WT_WARP_SIZE 32
#endif
#else
#define WT_DEVICE inline
#define WT_HOST_DEVICE inline
#define WT_SHFL_XOR(val, mask) (val)
#define WT_SHFL_DOWN(val, delta) (val)
#define WT_BALLOT(pred) static_cast<uint32_t>((pred) ? 1u : 0u)
#define WT_WARP_SIZE 32
#endif
namespace warp_types {
template<typename S>
concept ActiveSet = requires {
{ S::MASK } -> std::convertible_to<uint64_t>;
{ S::NAME } -> std::convertible_to<const char*>;
};
struct All { static constexpr uint64_t MASK = 0xFFFF'FFFF;
static constexpr const char* NAME = "All"; };
struct Even { static constexpr uint64_t MASK = 0x5555'5555;
static constexpr const char* NAME = "Even"; };
struct Odd { static constexpr uint64_t MASK = 0xAAAA'AAAA;
static constexpr const char* NAME = "Odd"; };
struct LowHalf { static constexpr uint64_t MASK = 0x0000'FFFF;
static constexpr const char* NAME = "LowHalf"; };
struct HighHalf { static constexpr uint64_t MASK = 0xFFFF'0000;
static constexpr const char* NAME = "HighHalf"; };
struct Lane0 { static constexpr uint64_t MASK = 0x0000'0001;
static constexpr const char* NAME = "Lane0"; };
struct NotLane0 { static constexpr uint64_t MASK = 0xFFFF'FFFE;
static constexpr const char* NAME = "NotLane0"; };
struct EvenLow { static constexpr uint64_t MASK = 0x0000'5555;
static constexpr const char* NAME = "EvenLow"; };
struct EvenHigh { static constexpr uint64_t MASK = 0x5555'0000;
static constexpr const char* NAME = "EvenHigh"; };
struct OddLow { static constexpr uint64_t MASK = 0x0000'AAAA;
static constexpr const char* NAME = "OddLow"; };
struct OddHigh { static constexpr uint64_t MASK = 0xAAAA'0000;
static constexpr const char* NAME = "OddHigh"; };
struct Empty { static constexpr uint64_t MASK = 0x0000'0000;
static constexpr const char* NAME = "Empty"; };
struct All64 { static constexpr uint64_t MASK = 0xFFFF'FFFF'FFFF'FFFF;
static constexpr const char* NAME = "All64"; };
struct Even64 { static constexpr uint64_t MASK = 0x5555'5555'5555'5555;
static constexpr const char* NAME = "Even64"; };
struct Odd64 { static constexpr uint64_t MASK = 0xAAAA'AAAA'AAAA'AAAA;
static constexpr const char* NAME = "Odd64"; };
// ============================================================================
// Complement relationship (compile-time)
// ============================================================================
/// Two sets are complements within a parent if they cover it and don't overlap.
template<typename S1, typename S2, typename Parent>
concept ComplementWithin = ActiveSet<S1> && ActiveSet<S2> && ActiveSet<Parent>
&& (S1::MASK | S2::MASK) == Parent::MASK
&& (S1::MASK & S2::MASK) == 0;
template<typename S1, typename S2>
concept ComplementOf = ComplementWithin<S1, S2, All>;
template<typename S1, typename S2>
concept ComplementOf64 = ComplementWithin<S1, S2, All64>;
template<typename T>
struct Uniform {
T value;
WT_HOST_DEVICE constexpr T get() const { return value; }
WT_HOST_DEVICE static constexpr Uniform from_const(T v) { return {v}; }
};
template<typename T>
struct PerLane {
T value;
WT_HOST_DEVICE constexpr T get() const { return value; }
WT_HOST_DEVICE static constexpr PerLane from(T v) { return {v}; }
};
template<ActiveSet S>
class Warp {
public:
static constexpr uint64_t MASK = S::MASK;
WT_HOST_DEVICE static constexpr Warp kernel_entry()
requires std::same_as<S, All>
{
return {};
}
template<typename T>
WT_DEVICE PerLane<T> shuffle_xor(PerLane<T> data, uint32_t mask) const
requires std::same_as<S, All>
{
return {WT_SHFL_XOR(data.value, mask)};
}
template<typename T>
WT_DEVICE PerLane<T> shuffle_down(PerLane<T> data, uint32_t delta) const
requires std::same_as<S, All>
{
return {WT_SHFL_DOWN(data.value, delta)};
}
template<typename T>
WT_DEVICE Uniform<T> reduce_sum(PerLane<T> data) const
requires std::same_as<S, All>
{
T val = data.value;
#if WT_WARP_SIZE == 64
val += WT_SHFL_XOR(val, 32);
#endif
val += WT_SHFL_XOR(val, 16);
val += WT_SHFL_XOR(val, 8);
val += WT_SHFL_XOR(val, 4);
val += WT_SHFL_XOR(val, 2);
val += WT_SHFL_XOR(val, 1);
return {val};
}
WT_DEVICE Uniform<uint32_t> ballot(bool predicate) const
requires std::same_as<S, All>
{
return {WT_BALLOT(predicate)};
}
template<typename T>
WT_HOST_DEVICE constexpr PerLane<T> broadcast(T value) const
requires std::same_as<S, All>
{
return {value};
}
WT_HOST_DEVICE constexpr std::pair<Warp<Even>, Warp<Odd>>
diverge_even_odd() const
requires std::same_as<S, All>
{
return {{}, {}};
}
WT_HOST_DEVICE constexpr std::pair<Warp<LowHalf>, Warp<HighHalf>>
diverge_low_high() const
requires std::same_as<S, All>
{
return {{}, {}};
}
};
template<ActiveSet S1, ActiveSet S2>
requires ComplementOf<S1, S2>
WT_HOST_DEVICE constexpr Warp<All> merge(Warp<S1>, Warp<S2>) {
return {};
}
template<ActiveSet S1, ActiveSet S2>
requires ComplementOf64<S1, S2>
WT_HOST_DEVICE constexpr Warp<All64> merge64(Warp<S1>, Warp<S2>) {
return {};
}
}
#ifndef WT_KEEP_MACROS
#undef WT_DEVICE
#undef WT_HOST_DEVICE
#undef WT_SHFL_XOR
#undef WT_SHFL_DOWN
#undef WT_BALLOT
#endif