#pragma once
#include <cute/config.hpp>
#include <cute/util/type_traits.hpp>
#include <cutlass/fast_math.h>
namespace cute
{
template <class T, class U,
__CUTE_REQUIRES(is_arithmetic<T>::value &&
is_arithmetic<U>::value)>
CUTE_HOST_DEVICE constexpr
auto
max(T const& t, U const& u) {
return t < u ? u : t;
}
template <class T, class U,
__CUTE_REQUIRES(is_arithmetic<T>::value &&
is_arithmetic<U>::value)>
CUTE_HOST_DEVICE constexpr
auto
min(T const& t, U const& u) {
return t < u ? t : u;
}
template <class T,
__CUTE_REQUIRES(is_arithmetic<T>::value)>
CUTE_HOST_DEVICE constexpr
auto
abs(T const& t) {
if constexpr (is_signed<T>::value) {
return t < T(0) ? -t : t;
} else {
return t;
}
CUTE_GCC_UNREACHABLE;
}
template <class T,
__CUTE_REQUIRES(is_arithmetic<T>::value)>
CUTE_HOST_DEVICE constexpr
int
signum(T const& x) {
if constexpr (is_signed<T>::value) {
return (T(0) < x) - (x < T(0));
} else {
return T(0) < x;
}
CUTE_GCC_UNREACHABLE;
}
template <class T, class U,
__CUTE_REQUIRES(is_std_integral<T>::value &&
is_std_integral<U>::value)>
CUTE_HOST_DEVICE constexpr
cute::common_type_t<T, U>
gcd(T t, U u) {
while (true) {
if (t == 0) { return u; }
u %= t;
if (u == 0) { return t; }
t %= u;
}
}
template <class T, class U,
__CUTE_REQUIRES(is_std_integral<T>::value &&
is_std_integral<U>::value)>
CUTE_HOST_DEVICE constexpr
cute::common_type_t<T, U>
lcm(T const& t, U const& u) {
return (t / gcd(t,u)) * u;
}
template <class T>
CUTE_HOST_DEVICE constexpr
bool
has_single_bit(T x) {
return x != 0 && (x & (x - 1)) == 0;
}
template <class T>
CUTE_HOST_DEVICE constexpr
int
bit_width(T x) {
static_assert(is_unsigned<T>::value, "Only to be used for unsigned types.");
constexpr int N = (numeric_limits<T>::digits == 64 ? 6 :
(numeric_limits<T>::digits == 32 ? 5 :
(numeric_limits<T>::digits == 16 ? 4 :
(numeric_limits<T>::digits == 8 ? 3 : (assert(false),0)))));
T r = 0;
for (int i = N - 1; i >= 0; --i) {
T shift = (x > ((T(1) << (T(1) << i))-1)) << i;
x >>= shift;
r |= shift;
}
return r + (x != 0);
}
template <class T>
CUTE_HOST_DEVICE constexpr
T
bit_ceil(T x) {
return x == 0 ? T(1) : (T(1) << bit_width(x - 1));
}
template <class T>
CUTE_HOST_DEVICE constexpr
T
bit_floor(T x) {
return x == 0 ? 0 : (T(1) << (bit_width(x) - 1));
}
template <class T>
CUTE_HOST_DEVICE constexpr T rotl(T x, int s);
template <class T>
CUTE_HOST_DEVICE constexpr T rotr(T x, int s);
template <class T>
CUTE_HOST_DEVICE constexpr
T
rotl(T x, int s) {
constexpr int N = numeric_limits<T>::digits;
return static_cast<T>(s == 0 ? x : s > 0 ? (x << s) | (x >> (N - s)) : rotr(x, -s));
}
template <class T>
CUTE_HOST_DEVICE constexpr
T
rotr(T x, int s) {
constexpr int N = numeric_limits<T>::digits;
return static_cast<T>(s == 0 ? x : s > 0 ? (x >> s) | (x << (N - s)) : rotl(x, -s));
}
template <class T>
CUTE_HOST_DEVICE constexpr
int
countl_zero(T x) {
return numeric_limits<T>::digits - bit_width(x);
}
template <class T>
CUTE_HOST_DEVICE constexpr
int
countl_one(T x) {
return countl_zero(~x);
}
template <class T>
CUTE_HOST_DEVICE constexpr
int
countr_zero(T x) {
return x == 0 ? numeric_limits<T>::digits : bit_width(T(x & T(-x))) - 1; }
template <class T>
CUTE_HOST_DEVICE constexpr
int
countr_one(T x) {
return countr_zero(~x);
}
template <class T>
CUTE_HOST_DEVICE constexpr
int
popcount(T x) {
int c = 0;
while (x) {
++c;
x &= x - 1; }
return c;
}
template <class T>
CUTE_HOST_DEVICE constexpr
auto
shiftl(T x, int s) {
return s >= 0 ? (x << s) : (x >> -s);
}
template <class T>
CUTE_HOST_DEVICE constexpr
auto
shiftr(T x, int s) {
return s >= 0 ? (x >> s) : (x << -s);
}
template <class T, class U,
__CUTE_REQUIRES(is_std_integral<T>::value &&
is_std_integral<U>::value)>
CUTE_HOST_DEVICE constexpr
auto
safe_div(T const& t, U const& u) {
return t / u;
}
template <class T>
CUTE_HOST_DEVICE constexpr
int32_t
log_2(T x) {
assert(x > 0);
static_assert(is_unsigned<T>::value, "Only to be used for unsigned integral types.");
return static_cast<int32_t>(bit_width(x)) - 1;
}
template <class IntDiv, class IntMod>
struct DivModReturnType {
IntDiv div_;
IntMod mod_;
CUTE_HOST_DEVICE constexpr
DivModReturnType(IntDiv const& div, IntMod const& mod) : div_(div), mod_(mod) {}
};
template <class CInt0, class CInt1>
CUTE_HOST_DEVICE constexpr
auto
divmod(CInt0 const& a, CInt1 const& b) {
return DivModReturnType{a / b, a % b};
}
template <class CInt>
CUTE_HOST_DEVICE constexpr
auto
divmod(CInt const& a, cutlass::FastDivmod const& b) {
using val_div_type = typename cutlass::FastDivmod::value_div_type;
using val_mod_type = typename cutlass::FastDivmod::value_mod_type;
val_div_type div = 0;
val_mod_type mod = 0;
b(div, mod, a);
return DivModReturnType{div, mod};
}
}