#pragma once
#include <cute/config.hpp>
#include <cute/util/type_traits.hpp>
#include <cute/algorithm/functional.hpp>
#include <cute/tensor_impl.hpp>
#include <cute/atom/mma_atom.hpp>
namespace cute
{
template <class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout>
CUTE_HOST_DEVICE
void
gemm(Tensor<TA, ALayout> const& A,
Tensor<TB, BLayout> const& B,
Tensor<TC, CLayout> & C)
{
return gemm(C, A, B, C);
}
template <class MMA,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout>
CUTE_HOST_DEVICE
void
gemm(MMA_Atom<MMA> const& mma,
Tensor<TA, ALayout> const& A,
Tensor<TB, BLayout> const& B,
Tensor<TC, CLayout> & C)
{
return gemm(mma, C, A, B, C);
}
template <class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout>
CUTE_HOST_DEVICE
void
gemm(Tensor<TA, ALayout> const& A,
Tensor<TB, BLayout> const& B,
Tensor<TC, CLayout> && C)
{
return gemm(C, A, B, C);
}
template <class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout>
CUTE_HOST_DEVICE
void
gemm(Tensor<TD, DLayout> && D,
Tensor<TA, ALayout> const& A,
Tensor<TB, BLayout> const& B,
Tensor<TC, CLayout> const& C)
{
return gemm(D, A, B, C);
}
template <class MMA,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout>
CUTE_HOST_DEVICE
void
gemm(MMA_Atom<MMA> const& mma,
Tensor<TA, ALayout> const& A,
Tensor<TB, BLayout> const& B,
Tensor<TC, CLayout> && C)
{
return gemm(mma, C, A, B, C);
}
template <class MMA,
class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout>
CUTE_HOST_DEVICE
void
gemm(MMA_Atom<MMA> const& mma,
Tensor<TD, DLayout> && D,
Tensor<TA, ALayout> const& A,
Tensor<TB, BLayout> const& B,
Tensor<TC, CLayout> const& C)
{
return gemm(mma, D, A, B, C);
}
template <class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout>
CUTE_HOST_DEVICE
void
gemm(Tensor<TD, DLayout> & D,
Tensor<TA, ALayout> const& A,
Tensor<TB, BLayout> const& B,
Tensor<TC, CLayout> const& C)
{
using MMA = MMA_Atom<UniversalFMA<typename Tensor<TD,DLayout>::value_type,
typename Tensor<TA,ALayout>::value_type,
typename Tensor<TB,BLayout>::value_type,
typename Tensor<TC,CLayout>::value_type>>;
return gemm(MMA{}, D, A, B, C);
}
template <class MMA,
class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout,
__CUTE_REQUIRES(DLayout::rank == 1 && is_rmem<TD>::value &&
ALayout::rank == 1 && is_rmem<TA>::value &&
BLayout::rank == 1 && is_rmem<TB>::value &&
CLayout::rank == 1 && is_rmem<TC>::value)>
CUTE_HOST_DEVICE
void
gemm(MMA_Atom<MMA> const& mma,
Tensor<TD, DLayout> & D, Tensor<TA, ALayout> const& A, Tensor<TB, BLayout> const& B, Tensor<TC, CLayout> const& C) {
mma.call(D, A, B, C);
}
template <class MMA,
class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout,
__CUTE_REQUIRES(DLayout::rank == 2 && is_rmem<TD>::value &&
ALayout::rank == 1 && is_rmem<TA>::value &&
BLayout::rank == 1 && is_rmem<TB>::value &&
CLayout::rank == 2 && is_rmem<TC>::value)>
CUTE_HOST_DEVICE
void
gemm(MMA_Atom<MMA> const& mma,
Tensor<TD, DLayout> & D, Tensor<TA, ALayout> const& A, Tensor<TB, BLayout> const& B, Tensor<TC, CLayout> const& C) {
CUTE_STATIC_ASSERT_V(size<0>(A) == size<0>(C)); CUTE_STATIC_ASSERT_V(size<0>(B) == size<1>(C)); CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D));
gemm(mma,
D, make_tensor(A.data(), append<2>(A.layout())), make_tensor(B.data(), append<2>(B.layout())), C); }
template <class MMA,
class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout,
__CUTE_REQUIRES(DLayout::rank == 2 && is_rmem<TD>::value &&
ALayout::rank == 2 && is_rmem<TA>::value &&
BLayout::rank == 2 && is_rmem<TB>::value &&
CLayout::rank == 2 && is_rmem<TC>::value)>
CUTE_HOST_DEVICE
void
gemm(MMA_Atom<MMA> const& mma,
Tensor<TD, DLayout> & D, Tensor<TA, ALayout> const& A, Tensor<TB, BLayout> const& B, Tensor<TC, CLayout> const& C) {
CUTE_STATIC_ASSERT_V(size<0>(A) == size<0>(C)); CUTE_STATIC_ASSERT_V(size<0>(B) == size<1>(C)); CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(B)); CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D));
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutC_TV{}) == Int<1>{});
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutA_TV{}) == Int<1>{});
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutB_TV{}) == Int<1>{});
gemm(mma,
make_tensor(D.data(), prepend<3>(D.layout())), make_tensor(A.data(), prepend<3>(A.layout())), make_tensor(B.data(), prepend<3>(B.layout())), make_tensor(C.data(), prepend<3>(C.layout()))); }
template <class MMA,
class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout,
__CUTE_REQUIRES(DLayout::rank == 3 && is_rmem<TD>::value &&
ALayout::rank == 2 && is_rmem<TA>::value &&
BLayout::rank == 2 && is_rmem<TB>::value &&
CLayout::rank == 3 && is_rmem<TC>::value)>
CUTE_HOST_DEVICE
void
gemm(MMA_Atom<MMA> const& mma,
Tensor<TD, DLayout> & D, Tensor<TA, ALayout> const& A, Tensor<TB, BLayout> const& B, Tensor<TC, CLayout> const& C) {
CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(C)); CUTE_STATIC_ASSERT_V(size<1>(B) == size<2>(C)); CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D) && size<2>(C) == size<2>(D));
auto M = size<1>(A);
auto N = size<1>(B);
if constexpr (decltype(size<0>(A))::value * sizeof(typename TA::value_type) == 8 &&
decltype(size<0>(B))::value * sizeof(typename TB::value_type) == 8)
{
#if 1 CUTE_UNROLL
for (int m = 0; m < M; ++m) {
CUTE_UNROLL
for (int n = 0; n < N; ++n) {
int ns = (m & 1) ? N-1-n : n; gemm(mma, D(_,m,ns), A(_,m), B(_,ns), C(_,m,ns));
}
}
#else#endif
} else
if constexpr (decltype(size<0>(A))::value * sizeof(typename TA::value_type) == 4 &&
decltype(size<0>(B))::value * sizeof(typename TB::value_type) == 4)
{
#if 1 CUTE_UNROLL
for (int m = 0; m < M; m += 2) {
CUTE_UNROLL
for (int n = 0; n < N; ++n) {
int ns = (m & 2) ? N-1-n : n;
gemm(mma, D(_,m+0,ns), A(_,m+0), B(_,ns), C(_,m+0,ns));
if (m+1 < M) {
gemm(mma, D(_,m+1,ns), A(_,m+1), B(_,ns), C(_,m+1,ns));
}
}
}
#else#endif
} else
if constexpr (decltype(size<0>(A))::value * sizeof(typename TA::value_type) == 8 &&
decltype(size<0>(B))::value * sizeof(typename TB::value_type) == 4) {
CUTE_UNROLL
for (int m = 0; m < M; ++m) {
CUTE_UNROLL
for (int n = 0; n < N; ++n) {
int ns = (m & 1) ? N-1-n : n; gemm(mma, D(_,m,ns), A(_,m), B(_,ns), C(_,m,ns));
}
}
} else
if constexpr (decltype(size<0>(A))::value * sizeof(typename TA::value_type) == 4 &&
decltype(size<0>(B))::value * sizeof(typename TB::value_type) == 8) {
CUTE_UNROLL
for (int n = 0; n < N; ++n) {
CUTE_UNROLL
for (int m = 0; m < M; ++m) {
int ms = (n & 1) ? M-1-m : m; gemm(mma, D(_,ms,n), A(_,ms), B(_,n), C(_,ms,n));
}
}
} else
{
CUTE_UNROLL
for (int n = 0; n < N; ++n) {
CUTE_UNROLL
for (int m = 0; m < M; ++m) {
int ms = (n & 1) ? M-1-m : m; gemm(mma, D(_,ms,n), A(_,ms), B(_,n), C(_,ms,n));
}
}
}
}
template <class MMA,
class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout,
__CUTE_REQUIRES(DLayout::rank == 3 && is_rmem<TD>::value &&
ALayout::rank == 3 && is_rmem<TA>::value &&
BLayout::rank == 3 && is_rmem<TB>::value &&
CLayout::rank == 3 && is_rmem<TC>::value)>
CUTE_HOST_DEVICE
void
gemm(MMA_Atom<MMA> const& mma,
Tensor<TD, DLayout> & D, Tensor<TA, ALayout> const& A, Tensor<TB, BLayout> const& B, Tensor<TC, CLayout> const& C) {
CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(C)); CUTE_STATIC_ASSERT_V(size<1>(B) == size<2>(C)); CUTE_STATIC_ASSERT_V(size<2>(A) == size<2>(B)); CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D) && size<2>(C) == size<2>(D));
auto K = size<2>(A);
CUTE_UNROLL
for (int k = 0; k < K; ++k) {
gemm(mma, D, A(_,_,k), B(_,_,k), C);
}
}
template <class MMA,
class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout,
__CUTE_REQUIRES(DLayout::rank == 2 && is_rmem<TD>::value &&
ALayout::rank == 2 && is_smem<TA>::value &&
BLayout::rank == 2 && is_smem<TB>::value &&
CLayout::rank == 2 && is_rmem<TC>::value)>
CUTE_HOST_DEVICE
void
gemm(MMA_Atom<MMA> const& mma,
Tensor<TD, DLayout> & D, Tensor<TA, ALayout> const& A, Tensor<TB, BLayout> const& B, Tensor<TC, CLayout> const& C) {
CUTE_STATIC_ASSERT_V(size<0>(A) == size<0>(C)); CUTE_STATIC_ASSERT_V(size<0>(B) == size<1>(C)); CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(B)); CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D));
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutC_TV{}) == Int<1>{});
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutA_TV{}) == Int<1>{});
CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom<MMA>::LayoutB_TV{}) == Int<1>{});
gemm(mma,
make_tensor(D.data(), prepend<3>(D.layout())), make_tensor(A.data(), prepend<3>(A.layout())), make_tensor(B.data(), prepend<3>(B.layout())), make_tensor(C.data(), prepend<3>(C.layout()))); }
template <class MMA,
class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout,
__CUTE_REQUIRES(DLayout::rank == 3 && is_rmem<TD>::value &&
ALayout::rank == 3 && is_smem<TA>::value &&
BLayout::rank == 3 && is_smem<TB>::value &&
CLayout::rank == 3 && is_rmem<TC>::value)>
CUTE_HOST_DEVICE
void
gemm(MMA_Atom<MMA> const& mma,
Tensor<TD, DLayout> & D, Tensor<TA, ALayout> const& A, Tensor<TB, BLayout> const& B, Tensor<TC, CLayout> const& C) {
CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(C)); CUTE_STATIC_ASSERT_V(size<1>(B) == size<2>(C)); CUTE_STATIC_ASSERT_V(size<2>(A) == size<2>(B)); CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D) && size<2>(C) == size<2>(D));
auto rA = MMA_Atom<MMA>::make_fragment_A(A);
auto rB = MMA_Atom<MMA>::make_fragment_B(B);
auto K = size<2>(A);
CUTE_UNROLL
for (int k = 0; k < K; ++k)
{
copy(A(_,_,k), rA(_,_,k));
copy(B(_,_,k), rB(_,_,k));
gemm(mma, D, rA(_,_,k), rB(_,_,k), C);
}
}
}