#pragma once
#include <cute/config.hpp>
#include <cute/pointer_base.hpp>
#include <cute/pointer_sparse.hpp>
#include <cute/container/array_subbyte.hpp>
#include <cute/numeric/integral_constant.hpp>
#include <cute/numeric/numeric_types.hpp>
namespace cute
{
template <class NewT>
CUTE_HOST_DEVICE constexpr
auto
recast_ptr(void* ptr)
{
if constexpr (is_sparse<NewT>::value) {
constexpr int sparsity = NewT::sparsity;
NewT* p = reinterpret_cast<NewT*>(ptr);
return make_sparse_ptr<sparsity>(p);
} else
if constexpr (cute::is_subbyte_v<NewT>) {
return subbyte_iterator<NewT>(ptr);
} else {
return reinterpret_cast<NewT*>(ptr);
}
CUTE_GCC_UNREACHABLE;
}
template <class NewT>
CUTE_HOST_DEVICE constexpr
auto
recast_ptr(void const* ptr)
{
if constexpr (is_sparse<NewT>::value) {
constexpr int sparsity = NewT::sparsity;
NewT const* p = reinterpret_cast<NewT const*>(ptr);
return make_sparse_ptr<sparsity>(p);
} else
if constexpr (cute::is_subbyte_v<NewT>) {
return subbyte_iterator<NewT const>(ptr);
} else {
return reinterpret_cast<NewT const*>(ptr);
}
CUTE_GCC_UNREACHABLE;
}
template <class NewT>
CUTE_HOST_DEVICE constexpr
auto
recast_ptr(decltype(nullptr)) { return recast_ptr<NewT>(static_cast<NewT*>(nullptr));
}
template <class P>
struct gmem_ptr : iter_adaptor<P, gmem_ptr<P>> {
using iter_adaptor<P, gmem_ptr<P>>::iter_adaptor;
};
template <class T, class = void>
struct is_gmem : false_type {};
template <class P> struct is_gmem<gmem_ptr<P>> : true_type {};
template <class P> struct is_gmem<P, void_t<typename P::iterator>> : is_gmem<typename P::iterator> {};
template <class P>
constexpr bool is_gmem_v = is_gmem<P>::value;
template <class Iterator>
CUTE_HOST_DEVICE constexpr
auto
make_gmem_ptr(Iterator iter) {
if constexpr (is_gmem<Iterator>::value) {
return iter;
} else {
return gmem_ptr<Iterator>{iter};
}
CUTE_GCC_UNREACHABLE;
}
template <class T>
CUTE_HOST_DEVICE constexpr
auto
make_gmem_ptr(void* ptr) {
return make_gmem_ptr(recast_ptr<T>(ptr));
}
template <class T>
CUTE_HOST_DEVICE constexpr
auto
make_gmem_ptr(void const* ptr) {
return make_gmem_ptr(recast_ptr<T const>(ptr));
}
template <class T>
CUTE_HOST_DEVICE constexpr
auto
make_gmem_ptr(decltype(nullptr)) { return make_gmem_ptr(recast_ptr<T>(nullptr));
}
template <class NewT, class P>
CUTE_HOST_DEVICE constexpr
auto
recast_ptr(gmem_ptr<P> const& ptr) {
return make_gmem_ptr(recast_ptr<NewT>(ptr.get()));
}
template <class P>
struct smem_ptr : iter_adaptor<P, smem_ptr<P>> {
using iter_adaptor<P, smem_ptr<P>>::iter_adaptor;
};
template <class T, class = void>
struct is_smem : false_type {};
template <class P> struct is_smem<smem_ptr<P>> : true_type {};
template <class P> struct is_smem<P, void_t<typename P::iterator>> : is_smem<typename P::iterator> {};
template <class P>
constexpr bool is_smem_v = is_smem<P>::value;
template <class Iterator>
CUTE_HOST_DEVICE constexpr
auto
make_smem_ptr(Iterator iter) {
if constexpr (is_smem<Iterator>::value) {
return iter;
} else {
return smem_ptr<Iterator>{iter};
}
CUTE_GCC_UNREACHABLE;
}
template <class Iterator, class Swizzle>
CUTE_HOST_DEVICE constexpr
auto
make_smem_ptr(Iterator ptr, Swizzle sw)
{
return make_swizzle_ptr(make_smem_ptr(ptr), sw);
}
template <class T>
CUTE_HOST_DEVICE constexpr
auto
make_smem_ptr(void* ptr) {
return make_smem_ptr(recast_ptr<T>(ptr));
}
template <class T>
CUTE_HOST_DEVICE constexpr
auto
make_smem_ptr(void const* ptr) {
return make_smem_ptr(recast_ptr<T const>(ptr));
}
template <class T>
CUTE_HOST_DEVICE constexpr
auto
make_smem_ptr(decltype(nullptr)) { return make_smem_ptr(recast_ptr<T>(nullptr));
}
template <class NewT, class P>
CUTE_HOST_DEVICE constexpr
auto
recast_ptr(smem_ptr<P> const& ptr) {
return make_smem_ptr(recast_ptr<NewT>(ptr.get()));
}
template <class P>
struct rmem_ptr : iter_adaptor<P, rmem_ptr<P>> {
using iter_adaptor<P, rmem_ptr<P>>::iter_adaptor;
};
template <class T, class = void>
struct is_rmem : bool_constant<not (is_gmem<T>::value || is_smem<T>::value)> {};
template <class P>
struct is_rmem<rmem_ptr<P>> : true_type {};
template <class P>
constexpr bool is_rmem_v = is_rmem<P>::value;
template <class Iterator>
CUTE_HOST_DEVICE constexpr
auto
make_rmem_ptr(Iterator iter) {
if constexpr (is_rmem<Iterator>::value) {
return iter;
} else {
return rmem_ptr<Iterator>{iter};
}
CUTE_GCC_UNREACHABLE;
}
template <class T>
CUTE_HOST_DEVICE constexpr
auto
make_rmem_ptr(void* ptr) {
return make_rmem_ptr(recast_ptr<T>(ptr));
}
template <class T>
CUTE_HOST_DEVICE constexpr
auto
make_rmem_ptr(void const* ptr) {
return make_rmem_ptr(recast_ptr<T const>(ptr));
}
template <class NewT, class P>
CUTE_HOST_DEVICE constexpr
auto
recast_ptr(rmem_ptr<P> const& ptr) {
return make_rmem_ptr(recast_ptr<NewT>(ptr.get()));
}
template <class T>
struct tmem_ptr
{
using value_type = remove_cv_t<T>;
using element_type = T;
using reference = T;
static constexpr int32_t OffsetShift = log_2(trait_ratio(sizeof_bits<uint32_t>{}, sizeof_bits<T>{}));
CUTE_HOST_DEVICE constexpr
tmem_ptr(uint32_t addr = 0) : addr_(addr) {}
CUTE_HOST_DEVICE constexpr
uint32_t const& get() const {
return addr_;
}
CUTE_HOST_DEVICE constexpr
uint32_t& get() {
return addr_;
}
template <class T_ = T>
CUTE_HOST_DEVICE constexpr
value_type operator*() const {
static_assert(dependent_false<T_>, "Attempting to dereference a tmem_ptr, want raw_pointer_cast() for address instead?");
return value_type{};
}
CUTE_HOST_DEVICE constexpr
reference operator[](uint32_t const& i) const { return *(*this + i); }
CUTE_HOST_DEVICE constexpr
tmem_ptr operator+(uint32_t const& i) const {
return {addr_ + rotr(i, OffsetShift)}; }
union {
uint32_t addr_;
struct {
uint16_t col_;
uint8_t dp_;
uint8_t idx_; };
};
};
template <class T, class = void>
struct is_tmem : false_type {};
template <class T> struct is_tmem<tmem_ptr<T>> : true_type {};
template <class P> struct is_tmem<P, void_t<typename P::iterator>> : is_tmem<typename P::iterator> {};
template <class P>
constexpr bool is_tmem_v = is_tmem<P>::value;
template <class T>
CUTE_HOST_DEVICE constexpr
tmem_ptr<T>
make_tmem_ptr(uint32_t addr = 0) {
return tmem_ptr<T>(addr);
}
template <class T>
CUTE_HOST_DEVICE constexpr
uint32_t
raw_pointer_cast(tmem_ptr<T> const& ptr) {
return ptr.get();
}
template <class NewT, class T>
CUTE_HOST_DEVICE constexpr
auto
recast_ptr(tmem_ptr<T> const& ptr) {
return tmem_ptr<NewT>{ptr.addr_};
}
template <class T>
CUTE_HOST_DEVICE void print(gmem_ptr<T> ptr)
{
printf("gmem_"); print(ptr.get());
}
template <class T>
CUTE_HOST_DEVICE void print(smem_ptr<T> ptr)
{
printf("smem_"); print(ptr.get());
}
template <class T>
CUTE_HOST_DEVICE void print(rmem_ptr<T> ptr)
{
printf("rmem_"); print(ptr.get());
}
template <class T>
CUTE_HOST_DEVICE void print(tmem_ptr<T> ptr)
{
printf("tmem_["); print(sizeof_bits<T>::value); printf("b](0x%04x.%04x)", ptr.addr_ >> 16, ptr.addr_ & 0xFFFF);
}
#if !defined(__CUDACC_RTC__)
template <class T>
CUTE_HOST std::ostream& operator<<(std::ostream& os, gmem_ptr<T> ptr)
{
return os << "gmem_[" << int(sizeof_bits<iter_value_t<T>>::value) << "b]";
}
template <class T>
CUTE_HOST std::ostream& operator<<(std::ostream& os, smem_ptr<T> ptr)
{
return os << "smem_[" << int(sizeof_bits<iter_value_t<T>>::value) << "b]";
}
template <class T>
CUTE_HOST std::ostream& operator<<(std::ostream& os, rmem_ptr<T> ptr)
{
return os << "rmem_[" << int(sizeof_bits<iter_value_t<T>>::value) << "b]";
}
template <class T>
CUTE_HOST std::ostream& operator<<(std::ostream& os, tmem_ptr<T> ptr)
{
return os << "tmem_[" << int(sizeof_bits<T>::value) << "b](" << ptr.addr_ << ")";
}
#endif
}