#pragma once
#include <cute/config.hpp>
#include <cute/pointer_base.hpp>
#include <cute/swizzle.hpp>
#include <cute/util/type_traits.hpp>
#include <cute/container/array_subbyte.hpp>
namespace cute
{
template <class SwizzleFn, class Iterator>
struct swizzle_ptr : iter_adaptor<Iterator,swizzle_ptr<SwizzleFn,Iterator>>
{
using iterator = Iterator;
using reference = typename iterator_traits<iterator>::reference;
using element_type = typename iterator_traits<iterator>::element_type;
using value_type = typename iterator_traits<iterator>::value_type;
using iter_adaptor<Iterator,swizzle_ptr<SwizzleFn,Iterator>>::iter_adaptor;
template <class Iter>
CUTE_HOST_DEVICE constexpr static
Iter apply_swizzle(Iter ptr) {
return {apply_swizzle(ptr.get())};
}
template <class T>
CUTE_HOST_DEVICE constexpr static
T* apply_swizzle(T* ptr) {
return reinterpret_cast<T*>(SwizzleFn::apply(reinterpret_cast<uintptr_t>(ptr)));
}
template <class T>
CUTE_HOST_DEVICE constexpr static
subbyte_iterator<T> apply_swizzle(subbyte_iterator<T> ptr) {
return {apply_swizzle(ptr.ptr_), ptr.idx_};
}
CUTE_HOST_DEVICE constexpr
reference operator*() const {
return *apply_swizzle(this->get());
}
template <class Int>
CUTE_HOST_DEVICE constexpr
reference operator[](Int const& i) const {
return *apply_swizzle(this->get() + i);
}
};
template <class SwizzleFn, class P> struct get_swizzle<swizzle_ptr<SwizzleFn,P>> { using type = SwizzleFn; };
template <class T> struct get_swizzle<T, void_t<typename T::iterator>> : get_swizzle<typename T::iterator> {};
template <class Iterator, class SwizzleFn>
CUTE_HOST_DEVICE constexpr
swizzle_ptr<SwizzleFn,Iterator>
make_swizzle_ptr(Iterator ptr, SwizzleFn) {
return {ptr};
}
template <class Iterator, int M, int S>
CUTE_HOST_DEVICE constexpr
Iterator
make_swizzle_ptr(Iterator ptr, Swizzle<0,M,S>) {
return ptr;
}
template <class SwizzleFn, class P>
CUTE_HOST_DEVICE constexpr
auto
raw_pointer_cast(swizzle_ptr<SwizzleFn,P> const& ptr) {
return raw_pointer_cast(ptr.get());
}
template <class NewT, class SwizzleFn, class P>
CUTE_HOST_DEVICE constexpr
auto
recast_ptr(swizzle_ptr<SwizzleFn,P> const& ptr) {
return make_swizzle_ptr(recast_ptr<NewT>(ptr.get()), SwizzleFn{});
}
template <class SwizzleFn, class P>
CUTE_HOST_DEVICE constexpr
auto
max_alignment(swizzle_ptr<SwizzleFn,P> const&) {
return Int<8>{} * max_alignment(SwizzleFn{});
}
template <class SwizzleFn, class P>
CUTE_HOST_DEVICE void print(swizzle_ptr<SwizzleFn,P> ptr)
{
print(SwizzleFn{}); printf("_"); print(ptr.get());
}
#if !defined(__CUDACC_RTC__)
template <class SwizzleFn, class P>
CUTE_HOST std::ostream& operator<<(std::ostream& os, swizzle_ptr<SwizzleFn,P> ptr)
{
return os << SwizzleFn{} << "_" << ptr.get();
}
#endif
}