#pragma once
#include <cute/config.hpp>
#include <cute/layout_composed.hpp>
#include <cute/pointer.hpp>
#include <cute/pointer_sparse.hpp>
#include <cute/pointer_swizzle.hpp>
#include <cute/arch/util.hpp>
#include <cute/numeric/integral_constant.hpp>
namespace cute
{
template <int Bits>
struct smem_ptr_flag_bits : Int<0> {};
using smem_ptr_flag = smem_ptr_flag_bits<1>;
template <class Iterator, class SwizzleFn, int B, class Layout>
CUTE_HOST_DEVICE constexpr
auto
make_tensor(Iterator const& ptr,
ComposedLayout<SwizzleFn,smem_ptr_flag_bits<B>,Layout> const& layout)
{
static_assert(is_smem<Iterator>::value, "Expected smem.");
static_assert(B == sizeof_bits<iter_value_t<Iterator>>::value, "Expected a B-bit pointer type.");
return make_tensor(make_smem_ptr(ptr.get(), layout.layout_a()),
layout.layout_b());
}
template <int N, class SwizzleFn, int B, class Layout>
CUTE_HOST_DEVICE constexpr
auto
upcast(ComposedLayout<SwizzleFn,smem_ptr_flag_bits<B>,Layout> const& layout)
{
return composition(layout.layout_a(), smem_ptr_flag_bits<B*N>{}, upcast<N>(layout.layout_b()));
}
template <int N, class SwizzleFn, int B, class Layout>
CUTE_HOST_DEVICE constexpr
auto
downcast(ComposedLayout<SwizzleFn,smem_ptr_flag_bits<B>,Layout> const& layout)
{
return composition(layout.layout_a(), smem_ptr_flag_bits<B/N>{}, downcast<N>(layout.layout_b()));
}
template <class SwizzleFn, int B, class Layout>
CUTE_HOST_DEVICE
auto
as_position_independent_swizzle_layout(ComposedLayout<SwizzleFn,smem_ptr_flag_bits<B>,Layout> const& layout)
{
return composition(recast_layout<uint8_t,uint_bit_t<B>>(layout.layout_a()), Int<0>{}, layout.layout_b());
}
template <class Tensor>
CUTE_HOST_DEVICE
auto
as_position_independent_swizzle_tensor(Tensor&& tensor)
{
static_assert(is_smem<remove_cvref_t<Tensor>>::value, "Expected smem tensor.");
using SwizzleFn = get_swizzle_t<remove_cvref_t<Tensor>>;
if constexpr (SwizzleFn::num_bits == 0) {
return tensor;
} else {
#if !defined(NDEBUG)
{
uint32_t address = cast_smem_ptr_to_uint(raw_pointer_cast(static_cast<Tensor&&>(tensor).data()));
uint32_t mask = ((uint32_t(1) << SwizzleFn::num_base) - 1) | SwizzleFn::swizzle_code;
assert((address & mask) == 0); }
#endif
using T = typename remove_cvref_t<Tensor>::value_type;
auto new_swizzle = recast_layout<uint8_t, T>(SwizzleFn{});
auto new_ptr = make_smem_ptr<T>(raw_pointer_cast(static_cast<Tensor&&>(tensor).data()));
return make_tensor(new_ptr, composition(new_swizzle, Int<0>{}, tensor.layout()));
}
CUTE_GCC_UNREACHABLE;
}
template <int Sparsity, int Bits>
struct smem_sparse_ptr_flag_bits : Int<0> {};
template <int Sparsity>
using smem_sparse_ptr_flag = smem_sparse_ptr_flag_bits<Sparsity, 1>;
template <class Iterator, class SwizzleFn, int S, int B, class Layout>
CUTE_HOST_DEVICE constexpr
auto
make_tensor(Iterator const& ptr,
ComposedLayout<SwizzleFn,smem_sparse_ptr_flag_bits<S,B>,Layout> const& layout)
{
static_assert(is_smem<Iterator>::value, "Expected smem.");
static_assert(is_sparse_ptr<Iterator>::value, "Expected sparse iter");
static_assert(is_sparse<iter_value_t<Iterator>>::value, "Expected sparse elem");
static_assert(S == iter_value_t<Iterator>::sparsity, "Expected sparsity S");
static_assert(B == sizeof_bits<typename iter_value_t<Iterator>::raw_type>::value, "Expected B-bit pointer type");
return make_tensor(make_swizzle_ptr(ptr, layout.layout_a()), layout.layout_b());
}
template <int N, class SwizzleFn, int S, int B, class Layout>
CUTE_HOST_DEVICE constexpr
auto
upcast(ComposedLayout<SwizzleFn,smem_sparse_ptr_flag_bits<S,B>,Layout> const& layout)
{
static_assert(dependent_false<SwizzleFn>, "Not implemented for safety");
}
template <int N, class SwizzleFn, int S, int B, class Layout>
CUTE_HOST_DEVICE constexpr
auto
downcast(ComposedLayout<SwizzleFn,smem_sparse_ptr_flag_bits<S,B>,Layout> const& layout)
{
static_assert(dependent_false<SwizzleFn>, "Not implemented for safety");
}
template <class SwizzleFn, int B, class Layout>
CUTE_HOST_DEVICE
void
print_layout(ComposedLayout<SwizzleFn,smem_ptr_flag_bits<B>,Layout> const& layout)
{
print_layout(as_position_independent_swizzle_layout(layout));
}
template <class SwizzleFn, int B, class Layout>
CUTE_HOST_DEVICE
void
print_latex(ComposedLayout<SwizzleFn,smem_ptr_flag_bits<B>,Layout> const& layout)
{
print_latex(as_position_independent_swizzle_layout(layout));
}
template <int B>
CUTE_HOST_DEVICE void print(smem_ptr_flag_bits<B> ptr)
{
printf("smem_ptr[%db](unset)", B);
}
template <int S, int B>
CUTE_HOST_DEVICE void print(smem_sparse_ptr_flag_bits<S,B>)
{
printf("smem_sparse<%d>_ptr[%db](unset)", S, B);
}
}