#pragma once
#include <cute/config.hpp>
#include <cute/tensor_impl.hpp>
#include <cute/atom/copy_atom.hpp>
namespace cute
{
template <uint32_t NumThreads, uint32_t FetchBytes = 64,
class GEngine, class GLayout>
CUTE_HOST_DEVICE
void
cooperative_prefetch(uint32_t const& tid,
Tensor<GEngine, GLayout> const& src)
{
static_assert(is_gmem<GEngine>::value, "Expected global tensor for prefetch");
constexpr int V = decltype(max_common_vector(src, src))::value;
if constexpr (V > 1) {
using VecType = conditional_t<(V * sizeof_bits_v<typename GEngine::value_type>) < (FetchBytes * 8),
ArrayEngine<typename GEngine::value_type, V>,
uint8_t[FetchBytes] >;
Tensor src_v = recast<VecType const>(src);
CUTE_UNROLL
for (int i = tid; i < size(src_v); i += NumThreads) {
prefetch(raw_pointer_cast(&src_v(i)));
}
} else {
CUTE_UNROLL
for (int i = tid; i < size(src); i += NumThreads) {
prefetch(raw_pointer_cast(&src(i)));
}
}
}
template <class GEngine, class GLayout>
CUTE_HOST_DEVICE
void
prefetch(Tensor<GEngine, GLayout> const& src)
{
return cooperative_prefetch<1>(0, src);
}
namespace detail {
template <class CopyOp, class = void>
constexpr bool has_prefetch = false;
template <class CopyOp>
constexpr bool has_prefetch<CopyOp, void_t<typename CopyOp::PREFETCH>> = true;
}
template <class CopyOp, class... CT_Args, class... CA_Args,
class GEngine, class GLayout>
CUTE_HOST_DEVICE
void
prefetch(Copy_Atom<Copy_Traits<CopyOp, CT_Args...>, CA_Args...> const& atom,
Tensor<GEngine, GLayout> const& src)
{
if constexpr (detail::has_prefetch<CopyOp>) {
using Prefetch_Traits = Copy_Traits<typename CopyOp::PREFETCH, CT_Args...>;
using Prefetch_Atom = Copy_Atom<Prefetch_Traits, CA_Args...>;
Prefetch_Atom prefetch_atom{atom};
auto& dst = const_cast<Tensor<GEngine, GLayout>&>(src); return copy(prefetch_atom, src, dst);
} else {
return prefetch(src);
}
}
#if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED)
template <class... CT_Args,
class SrcEngine, class SrcLayout>
CUTE_HOST_DEVICE
void
prefetch(Copy_Traits<SM90_BULK_COPY_AUTO, CT_Args...> const& atom,
Tensor<SrcEngine, SrcLayout> const& src)
{
using SrcType = typename SrcEngine::value_type;
static_assert(is_gmem<SrcEngine>::value, "Expected global tensor for L2 prefetch");
auto tiler = max_common_layout(src, src);
constexpr int vec_elem = decltype(size(tiler))::value;
constexpr int vec_bits = vec_elem * sizeof_bits_v<SrcType>;
static_assert(vec_bits >= 128, "Expected at least 128-bits for BLKCP");
auto bulk_atom = Copy_Atom<Copy_Traits<SM90_BULK_COPY_G2S, Int<vec_bits>>, SrcType>{};
return prefetch(bulk_atom, logical_divide(src, tiler));
}
template <class... CT_Args, class... CA_Args,
class SrcEngine, class SrcLayout>
CUTE_HOST_DEVICE
void
prefetch(Copy_Atom<Copy_Traits<SM90_BULK_COPY_AUTO, CT_Args...>, CA_Args...> const& atom,
Tensor<SrcEngine, SrcLayout> const& src)
{
return prefetch(static_cast<Copy_Traits<SM90_BULK_COPY_AUTO, CT_Args...> const&>(atom), src);
}
#endif
}