#pragma once
#include <cute/config.hpp>
#include <cute/util/type_traits.hpp>
#include <cute/numeric/integral_constant.hpp>
#include <cute/numeric/integer_sequence.hpp>
#include <cute/container/cuda_types.hpp>
#include <cute/container/type_list.hpp>
namespace cute
{
namespace detail
{
template <bool IsFirstEmpty, bool IsRestEmpty, class... T>
struct ESO;
template <class First, class... Rest>
static constexpr bool is_first_empty_v = cute::is_empty<First>::value;
template <class First, class... Rest>
static constexpr bool is_rest_empty_v = (cute::is_empty<Rest>::value && ...);
template <class... T>
using ESO_t = ESO<is_first_empty_v<T...>, is_rest_empty_v<T...>, T...>;
template <class First, class... Rest>
struct ESO<true, true, First, Rest...> {
CUTE_HOST_DEVICE constexpr
ESO() {}
CUTE_HOST_DEVICE constexpr
ESO(First const&, Rest const&...) {}
};
template <class First, class... Rest>
struct ESO<false, true, First, Rest...> {
CUTE_HOST_DEVICE constexpr
ESO() : first_{} {}
CUTE_HOST_DEVICE constexpr
ESO(First const& first, Rest const&...) : first_{first} {}
First first_;
};
template <class First, class... Rest>
struct ESO<true, false, First, Rest...> {
CUTE_HOST_DEVICE constexpr
ESO() : rest_{} {}
CUTE_HOST_DEVICE constexpr
ESO(First const&, Rest const&... rest) : rest_{rest...} {}
ESO_t<Rest...> rest_;
};
template <class First, class... Rest>
struct ESO<false, false, First, Rest...> {
CUTE_HOST_DEVICE constexpr
ESO() : first_{}, rest_{} {}
CUTE_HOST_DEVICE constexpr
ESO(First const& first, Rest const&... rest) : first_{first}, rest_{rest...} {}
First first_;
ESO_t<Rest...> rest_;
};
template <size_t N, bool F, bool R, class T, class... Rest>
CUTE_HOST_DEVICE constexpr
cute::enable_if_t<cute::is_empty<cute::tuple_element_t<N, cute::type_list<T, Rest...>>>::value,
cute::tuple_element_t<N, cute::type_list<T, Rest...>>>
getv(ESO<F, R, T, Rest...> const&)
{
return {};
}
template <size_t N, bool F, bool R, class T, class... Rest>
CUTE_HOST_DEVICE constexpr
cute::enable_if_t<not cute::is_empty<cute::tuple_element_t<N, cute::type_list<T, Rest...>>>::value,
cute::tuple_element_t<N, cute::type_list<T, Rest...>> const&>
getv(ESO<F, R, T, Rest...> const& s)
{
if constexpr (N == 0) {
return static_cast<T const&>(s.first_);
} else {
return getv<N-1>(s.rest_);
}
}
template <size_t N, bool F, bool R, class T, class... Rest>
CUTE_HOST_DEVICE constexpr
cute::enable_if_t<not cute::is_empty<cute::tuple_element_t<N, cute::type_list<T, Rest...>>>::value,
cute::tuple_element_t<N, cute::type_list<T, Rest...>> &>
getv(ESO<F, R, T, Rest...>& s)
{
if constexpr (N == 0) {
return static_cast<T&>(s.first_);
} else {
return getv<N-1>(s.rest_);
}
}
template <size_t N, bool F, bool R, class T, class... Rest>
CUTE_HOST_DEVICE constexpr
cute::enable_if_t<not cute::is_empty<cute::tuple_element_t<N, cute::type_list<T, Rest...>>>::value,
cute::tuple_element_t<N, cute::type_list<T, Rest...>> &&>
getv(ESO<F, R, T, Rest...>&& s)
{
if constexpr (N == 0) {
return static_cast<T&&>(s.first_);
} else {
return getv<N-1>(static_cast<ESO_t<Rest...>&&>(s.rest_));
}
}
template <class X, size_t N,
bool IsFirstEmpty, bool IsRestEmpty, class First, class... Rest>
CUTE_HOST_DEVICE constexpr
auto
findt(ESO<IsFirstEmpty, IsRestEmpty, First, Rest...> const& t) noexcept
{
if constexpr (cute::is_same_v<X, First>) {
return C<N>{};
} else
if constexpr (sizeof...(Rest) == 0) {
return C<N+1>{};
} else
if constexpr (IsRestEmpty) {
return cute::detail::findt<X, N+1>(ESO_t<Rest...>{});
} else {
return cute::detail::findt<X, N+1>(t.rest_);
}
}
}
template <class... T>
struct tuple : detail::ESO_t<T...>
{
CUTE_HOST_DEVICE constexpr
tuple() {}
CUTE_HOST_DEVICE constexpr
tuple(T const&... t) : detail::ESO_t<T...>(t...) {}
};
template <>
struct tuple<> {};
template <size_t I, class... T>
CUTE_HOST_DEVICE constexpr
decltype(auto)
get(tuple<T...> const& t) noexcept
{
static_assert(I < sizeof...(T), "Index out of range");
return detail::getv<I>(t);
}
template <size_t I, class... T>
CUTE_HOST_DEVICE constexpr
decltype(auto)
get(tuple<T...>& t) noexcept
{
static_assert(I < sizeof...(T), "Index out of range");
return detail::getv<I>(t);
}
template <size_t I, class... T>
CUTE_HOST_DEVICE constexpr
decltype(auto)
get(tuple<T...>&& t) noexcept
{
static_assert(I < sizeof...(T), "Index out of range");
return detail::getv<I>(static_cast<detail::ESO_t<T...>&&>(t));
}
template <class X, class... T>
CUTE_HOST_DEVICE constexpr
auto
find(tuple<T...> const& t) noexcept
{
return detail::findt<X, 0>(t);
}
namespace detail {
template <class T>
auto has_tuple_size( T*) -> bool_constant<(0 <= tuple_size<T>::value)>;
auto has_tuple_size(...) -> false_type;
}
template <class T>
struct is_tuple : decltype(detail::has_tuple_size((T*)0)) {};
template <class T>
constexpr bool is_tuple_v = cute::is_tuple<T>::value;
template <class... T>
CUTE_HOST_DEVICE constexpr
tuple<T...>
make_tuple(T const&... t)
{
return {t...};
}
#if 0#endif
#if 1
namespace detail {
template <class T0, class T1,
size_t... I0, size_t... I1>
CUTE_HOST_DEVICE constexpr
auto
tuple_cat(T0 const& t0, T1 const& t1,
index_sequence<I0...>, index_sequence<I1...>)
{
return cute::make_tuple(get<I0>(t0)..., get<I1>(t1)...);
}
template <class T0, class T1, class T2,
size_t... I0, size_t... I1, size_t... I2>
CUTE_HOST_DEVICE constexpr
auto
tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2,
index_sequence<I0...>, index_sequence<I1...>, index_sequence<I2...>)
{
return cute::make_tuple(get<I0>(t0)..., get<I1>(t1)..., get<I2>(t2)...);
}
template <class T0, class T1, class T2, class T3,
size_t... I0, size_t... I1, size_t... I2, size_t... I3>
CUTE_HOST_DEVICE constexpr
auto
tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3,
index_sequence<I0...>, index_sequence<I1...>, index_sequence<I2...>, index_sequence<I3...>)
{
return cute::make_tuple(get<I0>(t0)..., get<I1>(t1)..., get<I2>(t2)..., get<I3>(t3)...);
}
template <class T0, class T1, class T2, class T3, class T4,
size_t... I0, size_t... I1, size_t... I2, size_t... I3, size_t... I4>
CUTE_HOST_DEVICE constexpr
auto
tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, T4 const& t4,
index_sequence<I0...>, index_sequence<I1...>, index_sequence<I2...>, index_sequence<I3...>, index_sequence<I4...>)
{
return cute::make_tuple(get<I0>(t0)..., get<I1>(t1)..., get<I2>(t2)..., get<I3>(t3)..., get<I4>(t4)...);
}
template <class T0, class T1>
struct tuple_cat_static;
template <class... T0s, class... T1s>
struct tuple_cat_static<tuple<T0s...>, tuple<T1s...>> {
using type = tuple<T0s..., T1s...>;
};
}
CUTE_HOST_DEVICE constexpr
tuple<>
tuple_cat()
{
return {};
}
template <class Tuple,
__CUTE_REQUIRES(is_tuple<Tuple>::value)>
CUTE_HOST_DEVICE constexpr
Tuple const&
tuple_cat(Tuple const& t)
{
return t;
}
template <class T0, class T1>
CUTE_HOST_DEVICE constexpr
auto
tuple_cat(T0 const& t0, T1 const& t1)
{
if constexpr (is_static<T0>::value && is_static<T1>::value &&
is_tuple<T0>::value && is_tuple<T1>::value) {
return typename detail::tuple_cat_static<T0, T1>::type{};
} else {
return detail::tuple_cat(t0, t1,
make_index_sequence<tuple_size<T0>::value>{},
make_index_sequence<tuple_size<T1>::value>{});
}
CUTE_GCC_UNREACHABLE;
}
template <class T0, class T1, class T2>
CUTE_HOST_DEVICE constexpr
auto
tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2)
{
return detail::tuple_cat(t0, t1, t2,
make_index_sequence<tuple_size<T0>::value>{},
make_index_sequence<tuple_size<T1>::value>{},
make_index_sequence<tuple_size<T2>::value>{});
}
template <class T0, class T1, class T2, class T3>
CUTE_HOST_DEVICE constexpr
auto
tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3)
{
return detail::tuple_cat(t0, t1, t2, t3,
make_index_sequence<tuple_size<T0>::value>{},
make_index_sequence<tuple_size<T1>::value>{},
make_index_sequence<tuple_size<T2>::value>{},
make_index_sequence<tuple_size<T3>::value>{});
}
template <class T0, class T1, class T2, class T3, class T4>
CUTE_HOST_DEVICE constexpr
auto
tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, T4 const& t4)
{
return detail::tuple_cat(t0, t1, t2, t3, t4,
make_index_sequence<tuple_size<T0>::value>{},
make_index_sequence<tuple_size<T1>::value>{},
make_index_sequence<tuple_size<T2>::value>{},
make_index_sequence<tuple_size<T3>::value>{},
make_index_sequence<tuple_size<T4>::value>{});
}
template <class T0, class T1, class T2, class T3, class T4, class T5, class... Ts>
CUTE_HOST_DEVICE constexpr
auto
tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, T4 const& t4, T5 const& t5, Ts const&... ts)
{
return cute::tuple_cat(cute::tuple_cat(t0,t1,t2,t3,t4), cute::tuple_cat(t5, ts...));
}
#endif
#if 0#endif
namespace detail {
template <class TupleA, class TupleB, size_t... I>
CUTE_HOST_DEVICE constexpr
auto
equal_impl(TupleA const& a, TupleB const& b, index_sequence<I...>)
{
return (cute::true_type{} && ... && (get<I>(a) == get<I>(b)));
}
}
template <class TupleT, class TupleU,
__CUTE_REQUIRES(is_tuple<TupleT>::value && is_tuple<TupleU>::value)>
CUTE_HOST_DEVICE constexpr
auto
operator==(TupleT const& t, TupleU const& u)
{
if constexpr (tuple_size<TupleT>::value == tuple_size<TupleU>::value) {
return detail::equal_impl(t, u, make_index_sequence<tuple_size<TupleT>::value>{});
} else {
return cute::false_type{};
}
CUTE_GCC_UNREACHABLE;
}
template <class TupleT, class TupleU,
__CUTE_REQUIRES(is_tuple<TupleT>::value ^ is_tuple<TupleU>::value)>
CUTE_HOST_DEVICE constexpr
auto
operator==(TupleT const& t, TupleU const& u)
{
return cute::false_type{};
}
template <class TupleT, class TupleU,
__CUTE_REQUIRES(is_tuple<TupleT>::value && is_tuple<TupleU>::value)>
CUTE_HOST_DEVICE constexpr
auto
operator!=(TupleT const& t, TupleU const& u)
{
return !(t == u);
}
template <class TupleT, class TupleU,
__CUTE_REQUIRES(is_tuple<TupleT>::value ^ is_tuple<TupleU>::value)>
CUTE_HOST_DEVICE constexpr
auto
operator!=(TupleT const& t, TupleU const& u)
{
return cute::true_type{};
}
namespace detail {
template <class Tuple, size_t... Is>
CUTE_HOST_DEVICE void print_tuple(Tuple const& t, index_sequence<Is...>, char s = '(', char e = ')')
{
using cute::print;
if (sizeof...(Is) == 0) {
print(s);
} else {
((void(print(Is == 0 ? s : ',')), void(print(get<Is>(t)))), ...);
}
print(e);
}
#if !defined(__CUDACC_RTC__)
template <class Tuple, std::size_t... Is>
CUTE_HOST std::ostream& print_tuple_os(std::ostream& os, Tuple const& t, index_sequence<Is...>, char s = '(', char e = ')')
{
if (sizeof...(Is) == 0) {
os << s;
} else {
(void(os << (Is == 0 ? s : ',') << get<Is>(t)), ...);
}
return os << e;
}
#endif
}
template <class Tuple,
__CUTE_REQUIRES(is_tuple<Tuple>::value)>
CUTE_HOST_DEVICE void print(Tuple const& t)
{
return detail::print_tuple(t, make_index_sequence<tuple_size<Tuple>::value>{});
}
#if !defined(__CUDACC_RTC__)
template <class Tuple,
__CUTE_REQUIRES(is_tuple<Tuple>::value)>
CUTE_HOST std::ostream& operator<<(std::ostream& os, Tuple const& t)
{
return detail::print_tuple_os(os, t, make_index_sequence<tuple_size<Tuple>::value>{});
}
#endif
}
namespace CUTE_STL_NAMESPACE
{
template <class... T>
struct tuple_size<cute::tuple<T...>>
: CUTE_STL_NAMESPACE::integral_constant<size_t, sizeof...(T)>
{};
template <size_t I, class... T>
struct tuple_element<I, cute::tuple<T...>>
: CUTE_STL_NAMESPACE::tuple_element<I, CUTE_STL_NAMESPACE::tuple<T...>>
{};
}
#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD
namespace std
{
#if defined(__CUDACC_RTC__)
template <class... _Tp>
struct tuple_size;
template <size_t _Ip, class... _Tp>
struct tuple_element;
#endif
template <class... T>
struct tuple_size<cute::tuple<T...>>
: CUTE_STL_NAMESPACE::integral_constant<size_t, sizeof...(T)>
{};
template <size_t I, class... T>
struct tuple_element<I, cute::tuple<T...>>
: CUTE_STL_NAMESPACE::tuple_element<I, CUTE_STL_NAMESPACE::tuple<T...>>
{};
} #endif