#pragma once
#include <cute/config.hpp>
#include <cute/container/tuple.hpp>
#include <cute/numeric/integral_constant.hpp>
#include <cute/algorithm/functional.hpp>
#include <cute/algorithm/tuple_algorithms.hpp>
#include <cute/util/type_traits.hpp>
namespace cute
{
template <class... T>
struct ArithmeticTuple : tuple<T...>
{
template <class... U>
CUTE_HOST_DEVICE constexpr
ArithmeticTuple(ArithmeticTuple<U...> const& u)
: tuple<T...>(static_cast<tuple<U...> const&>(u)) {}
template <class... U>
CUTE_HOST_DEVICE constexpr
ArithmeticTuple(tuple<U...> const& u)
: tuple<T...>(u) {}
template <class... U>
CUTE_HOST_DEVICE constexpr
ArithmeticTuple(U const&... u)
: tuple<T...>(u...) {}
};
template <class... T>
struct is_tuple<ArithmeticTuple<T...>> : true_type {};
template <class... Ts>
struct is_flat<ArithmeticTuple<Ts...>> : is_flat<tuple<Ts...>> {};
template <class... T>
CUTE_HOST_DEVICE constexpr
auto
make_arithmetic_tuple(T const&... t) {
return ArithmeticTuple<T...>(t...);
}
template <class T>
CUTE_HOST_DEVICE constexpr
auto
as_arithmetic_tuple(T const& t) {
if constexpr (is_tuple<T>::value) {
return detail::tapply(t, [](auto const& x){ return as_arithmetic_tuple(x); },
[](auto const&... a){ return make_arithmetic_tuple(a...); },
tuple_seq<T>{});
} else {
return t;
}
}
template <class... T, class... U>
CUTE_HOST_DEVICE constexpr
auto
operator+(ArithmeticTuple<T...> const& t, ArithmeticTuple<U...> const& u) {
constexpr int R = cute::max(int(sizeof...(T)), int(sizeof...(U)));
return transform_apply(append<R>(t,Int<0>{}), append<R>(u,Int<0>{}), plus{}, [](auto const&... a){ return make_arithmetic_tuple(a...); });
}
template <class... T, class... U>
CUTE_HOST_DEVICE constexpr
auto
operator+(ArithmeticTuple<T...> const& t, tuple<U...> const& u) {
return t + ArithmeticTuple<U...>(u);
}
template <class... T, class... U>
CUTE_HOST_DEVICE constexpr
auto
operator+(tuple<T...> const& t, ArithmeticTuple<U...> const& u) {
return ArithmeticTuple<T...>(t) + u;
}
template <class... T, class... U>
CUTE_HOST_DEVICE constexpr
auto
operator-(ArithmeticTuple<T...> const& t, ArithmeticTuple<U...> const& u) {
constexpr int R = cute::max(int(sizeof...(T)), int(sizeof...(U)));
return transform_apply(append<R>(t,Int<0>{}), append<R>(u,Int<0>{}), minus{}, [](auto const&... a){ return make_arithmetic_tuple(a...); });
}
template <class... T, class... U>
CUTE_HOST_DEVICE constexpr
auto
operator-(ArithmeticTuple<T...> const& t, tuple<U...> const& u) {
return t - ArithmeticTuple<U...>(u);
}
template <class... T, class... U>
CUTE_HOST_DEVICE constexpr
auto
operator-(tuple<T...> const& t, ArithmeticTuple<U...> const& u) {
return ArithmeticTuple<T...>(t) - u;
}
template <class... T>
CUTE_HOST_DEVICE constexpr
auto
operator-(ArithmeticTuple<T...> const& t) {
return transform_apply(t, negate{}, [](auto const&... a){ return make_arithmetic_tuple(a...); });
}
template <auto t, class... U>
CUTE_HOST_DEVICE constexpr
ArithmeticTuple<U...> const&
operator+(C<t>, ArithmeticTuple<U...> const& u) {
static_assert(t == 0, "Arithmetic tuple op+ error!");
return u;
}
template <class... T, auto u>
CUTE_HOST_DEVICE constexpr
ArithmeticTuple<T...> const&
operator+(ArithmeticTuple<T...> const& t, C<u>) {
static_assert(u == 0, "Arithmetic tuple op+ error!");
return t;
}
template <auto t, class... U>
CUTE_HOST_DEVICE constexpr
ArithmeticTuple<U...> const&
operator-(C<t>, ArithmeticTuple<U...> const& u) {
static_assert(t == 0, "Arithmetic tuple op- error!");
return -u;
}
template <class... T, auto u>
CUTE_HOST_DEVICE constexpr
ArithmeticTuple<T...> const&
operator-(ArithmeticTuple<T...> const& t, C<u>) {
static_assert(u == 0, "Arithmetic tuple op- error!");
return t;
}
template <class ArithTuple>
struct ArithmeticTupleIterator
{
using value_type = ArithTuple;
using element_type = ArithTuple;
using reference = ArithTuple;
ArithTuple coord_;
CUTE_HOST_DEVICE constexpr
ArithmeticTupleIterator(ArithTuple const& coord = {}) : coord_(coord) {}
CUTE_HOST_DEVICE constexpr
ArithTuple operator*() const { return coord_; }
template <class Coord>
CUTE_HOST_DEVICE constexpr
auto operator[](Coord const& c) const { return *(*this + c); }
template <class Coord>
CUTE_HOST_DEVICE constexpr
auto operator+(Coord const& c) const {
return ArithmeticTupleIterator<remove_cvref_t<decltype(coord_ + c)>>(coord_ + c);
}
};
template <class Tuple>
CUTE_HOST_DEVICE constexpr
auto
make_inttuple_iter(Tuple const& t) {
return ArithmeticTupleIterator(as_arithmetic_tuple(t));
}
template <class T0, class T1, class... Ts>
CUTE_HOST_DEVICE constexpr
auto
make_inttuple_iter(T0 const& t0, T1 const& t1, Ts const&... ts) {
return make_inttuple_iter(cute::make_tuple(t0, t1, ts...));
}
template <class T, int N>
struct ScaledBasis : private tuple<T>
{
CUTE_HOST_DEVICE constexpr
ScaledBasis(T const& t = {}) : tuple<T>(t) {}
CUTE_HOST_DEVICE constexpr
decltype(auto) value() { return get<0>(static_cast<tuple<T> &>(*this)); }
CUTE_HOST_DEVICE constexpr
decltype(auto) value() const { return get<0>(static_cast<tuple<T> const&>(*this)); }
CUTE_HOST_DEVICE static constexpr
auto mode() { return Int<N>{}; }
};
template <class T>
struct is_scaled_basis : false_type {};
template <class T, int N>
struct is_scaled_basis<ScaledBasis<T,N>> : true_type {};
template <class T, int N>
struct is_integral<ScaledBasis<T,N>> : true_type {};
template <class SB>
CUTE_HOST_DEVICE constexpr auto
basis_value(SB const& e)
{
if constexpr (is_scaled_basis<SB>::value) {
return basis_value(e.value());
} else {
return e;
}
CUTE_GCC_UNREACHABLE;
}
template <class SB, class Tuple>
CUTE_HOST_DEVICE decltype(auto)
basis_get(SB const& e, Tuple&& t)
{
if constexpr (is_scaled_basis<SB>::value) {
return basis_get(e.value(), get<SB::mode()>(static_cast<Tuple&&>(t)));
} else {
return static_cast<Tuple&&>(t);
}
CUTE_GCC_UNREACHABLE;
}
namespace detail {
template <class T, int... I>
CUTE_HOST_DEVICE constexpr
auto
to_atuple_i(T const& t, seq<I...>) {
return make_arithmetic_tuple((void(I),Int<0>{})..., t);
}
}
template <class T, int N>
CUTE_HOST_DEVICE constexpr
auto
as_arithmetic_tuple(ScaledBasis<T,N> const& t) {
return detail::to_atuple_i(as_arithmetic_tuple(t.value()), make_seq<N>{});
}
namespace detail {
template <int... Ns>
struct Basis;
template <>
struct Basis<> {
using type = Int<1>;
};
template <int N, int... Ns>
struct Basis<N,Ns...> {
using type = ScaledBasis<typename Basis<Ns...>::type, N>;
};
}
template <int... N>
using E = typename detail::Basis<N...>::type;
template <class Shape>
CUTE_HOST_DEVICE constexpr
auto
make_basis_like(Shape const& shape)
{
if constexpr (is_integral<Shape>::value) {
return Int<1>{};
} else {
return transform(tuple_seq<Shape>{}, shape, [](auto I, auto si) {
using I_type = decltype(I);
return transform_leaf(make_basis_like(si), [](auto e) {
constexpr int i = I_type::value;
return ScaledBasis<decltype(e), i>{};
});
});
}
CUTE_GCC_UNREACHABLE;
}
template <class T, int M, class U>
CUTE_HOST_DEVICE constexpr
auto
safe_div(ScaledBasis<T,M> const& b, U const& u)
{
auto t = safe_div(b.value(), u);
return ScaledBasis<decltype(t),M>{t};
}
template <class T, int M, class U>
CUTE_HOST_DEVICE constexpr
auto
ceil_div(ScaledBasis<T,M> const& b, U const& u)
{
auto t = ceil_div(b.value(), u);
return ScaledBasis<decltype(t),M>{t};
}
template <class T, int N>
CUTE_HOST_DEVICE constexpr
auto
abs(ScaledBasis<T,N> const& e)
{
auto t = abs(e.value());
return ScaledBasis<decltype(t),N>{t};
}
template <class T, int N, class U, int M>
CUTE_HOST_DEVICE constexpr
auto
operator==(ScaledBasis<T,N> const& t, ScaledBasis<U,M> const& u) {
return bool_constant<M == N>{} && t.value() == u.value();
}
template <class T, int N, class U>
CUTE_HOST_DEVICE constexpr
false_type
operator==(ScaledBasis<T,N> const&, U const&) {
return {};
}
template <class T, class U, int M>
CUTE_HOST_DEVICE constexpr
false_type
operator==(T const&, ScaledBasis<U,M> const&) {
return {};
}
template <class A, class T, int N>
CUTE_HOST_DEVICE constexpr
auto
operator*(A const& a, ScaledBasis<T,N> const& e) {
auto r = a * e.value();
return ScaledBasis<decltype(r),N>{r};
}
template <class T, int N, class B>
CUTE_HOST_DEVICE constexpr
auto
operator*(ScaledBasis<T,N> const& e, B const& b) {
auto r = e.value() * b;
return ScaledBasis<decltype(r),N>{r};
}
template <class T, int N, class U, int M>
CUTE_HOST_DEVICE constexpr
auto
operator+(ScaledBasis<T,N> const& t, ScaledBasis<U,M> const& u) {
return as_arithmetic_tuple(t) + as_arithmetic_tuple(u);
}
template <class T, int N, class... U>
CUTE_HOST_DEVICE constexpr
auto
operator+(ScaledBasis<T,N> const& t, ArithmeticTuple<U...> const& u) {
return as_arithmetic_tuple(t) + u;
}
template <class... T, class U, int M>
CUTE_HOST_DEVICE constexpr
auto
operator+(ArithmeticTuple<T...> const& t, ScaledBasis<U,M> const& u) {
return t + as_arithmetic_tuple(u);
}
template <auto t, class U, int M>
CUTE_HOST_DEVICE constexpr
auto
operator+(C<t>, ScaledBasis<U,M> const& u) {
static_assert(t == 0, "ScaledBasis op+ error!");
return u;
}
template <class T, int N, auto u>
CUTE_HOST_DEVICE constexpr
auto
operator+(ScaledBasis<T,N> const& t, C<u>) {
static_assert(u == 0, "ScaledBasis op+ error!");
return t;
}
template <class ArithTuple>
CUTE_HOST_DEVICE void print(ArithmeticTupleIterator<ArithTuple> const& iter)
{
printf("ArithTuple"); print(iter.coord_);
}
template <class T, int N>
CUTE_HOST_DEVICE void print(ScaledBasis<T,N> const& e)
{
print(e.value()); printf("@%d", N);
}
#if !defined(__CUDACC_RTC__)
template <class ArithTuple>
CUTE_HOST std::ostream& operator<<(std::ostream& os, ArithmeticTupleIterator<ArithTuple> const& iter)
{
return os << "ArithTuple" << iter.coord_;
}
template <class T, int N>
CUTE_HOST std::ostream& operator<<(std::ostream& os, ScaledBasis<T,N> const& e)
{
return os << e.value() << "@" << N;
}
#endif
}
namespace CUTE_STL_NAMESPACE
{
template <class... T>
struct tuple_size<cute::ArithmeticTuple<T...>>
: CUTE_STL_NAMESPACE::integral_constant<size_t, sizeof...(T)>
{};
template <size_t I, class... T>
struct tuple_element<I, cute::ArithmeticTuple<T...>>
: CUTE_STL_NAMESPACE::tuple_element<I, CUTE_STL_NAMESPACE::tuple<T...>>
{};
}
#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD
namespace std
{
#if defined(__CUDACC_RTC__)
template <class... _Tp>
struct tuple_size;
template <size_t _Ip, class... _Tp>
struct tuple_element;
#endif
template <class... T>
struct tuple_size<cute::ArithmeticTuple<T...>>
: CUTE_STL_NAMESPACE::integral_constant<size_t, sizeof...(T)>
{};
template <size_t I, class... T>
struct tuple_element<I, cute::ArithmeticTuple<T...>>
: CUTE_STL_NAMESPACE::tuple_element<I, CUTE_STL_NAMESPACE::tuple<T...>>
{};
} #endif