#pragma once
#include "cutlass/gemm/device/gemm.h"
#include "src/cuda/cutlass/library_internal.h"
namespace cutlass {
namespace library {
template <typename Operator>
struct split_k_mode {
template <typename T>
static char check(typename T::ReductionKernel*);
template <typename T>
static int check(...);
SplitKMode operator()() {
if (sizeof(check<Operator>(0)) == sizeof(char)) {
return SplitKMode::kParallel;
} else {
return SplitKMode::kNone;
}
}
};
template <typename Operator_>
class GemmOperationBase : public Operation {
public:
using Operator = Operator_;
using ElementA = typename Operator::ElementA;
using LayoutA = typename Operator::LayoutA;
using ElementB = typename Operator::ElementB;
using LayoutB = typename Operator::LayoutB;
using ElementC = typename Operator::ElementC;
using LayoutC = typename Operator::LayoutC;
using ElementAccumulator = typename Operator::ElementAccumulator;
GemmOperationBase(char const* name = "unknown_gemm") {
m_description.name = name;
m_description.provider = Provider::kCUTLASS;
m_description.kind = OperationKind::kGemm;
m_description.gemm_kind = GemmKind::kGemm;
m_description.tile_description.threadblock_shape = make_Coord(
Operator::ThreadblockShape::kM, Operator::ThreadblockShape::kN,
Operator::ThreadblockShape::kK);
m_description.tile_description.threadblock_stages = Operator::kStages;
m_description.tile_description.warp_count = make_Coord(
Operator::GemmKernel::WarpCount::kM,
Operator::GemmKernel::WarpCount::kN,
Operator::GemmKernel::WarpCount::kK);
m_description.tile_description.math_instruction.instruction_shape = make_Coord(
Operator::InstructionShape::kM, Operator::InstructionShape::kN,
Operator::InstructionShape::kK);
m_description.tile_description.math_instruction.element_accumulator =
NumericTypeMap<ElementAccumulator>::kId;
m_description.tile_description.math_instruction.opcode_class =
OpcodeClassMap<typename Operator::OperatorClass>::kId;
m_description.tile_description.math_instruction.math_operation =
MathOperationMap<typename Operator::Operator>::kId;
m_description.tile_description.minimum_compute_capability =
ArchMap<typename Operator::ArchTag,
typename Operator::OperatorClass>::kMin;
m_description.tile_description.maximum_compute_capability =
ArchMap<typename Operator::ArchTag,
typename Operator::OperatorClass>::kMax;
m_description.A =
make_TensorDescription<ElementA, LayoutA>(Operator::kAlignmentA);
m_description.B =
make_TensorDescription<ElementB, LayoutB>(Operator::kAlignmentB);
m_description.C =
make_TensorDescription<ElementC, LayoutC>(Operator::kAlignmentC);
m_description.stages = Operator::kStages;
split_k_mode<Operator> mode;
m_description.split_k_mode = mode();
}
virtual OperationDescription const& description() const { return m_description; }
protected:
GemmDescription m_description;
};
template <typename Operator_>
class GemmOperation : public GemmOperationBase<Operator_> {
public:
using Operator = Operator_;
using ElementA = typename Operator::ElementA;
using LayoutA = typename Operator::LayoutA;
using ElementB = typename Operator::ElementB;
using LayoutB = typename Operator::LayoutB;
using ElementC = typename Operator::ElementC;
using LayoutC = typename Operator::LayoutC;
using ElementAccumulator = typename Operator::ElementAccumulator;
using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute;
using OperatorArguments = typename Operator::Arguments;
GemmOperation(char const* name = "unknown_gemm")
: GemmOperationBase<Operator_>(name) {}
virtual Status run(
void const* arguments_ptr, void* device_workspace = nullptr,
cudaStream_t stream = nullptr) const {
GemmArguments const* gemm_args =
reinterpret_cast<GemmArguments const*>(arguments_ptr);
OperatorArguments args;
args.problem_size = gemm_args->problem_size;
args.ref_A = {static_cast<ElementA const*>(gemm_args->A), int(gemm_args->lda)};
args.ref_B = {static_cast<ElementB const*>(gemm_args->B), int(gemm_args->ldb)};
args.ref_C = {static_cast<ElementC const*>(gemm_args->C), int(gemm_args->ldc)};
args.ref_D = {static_cast<ElementC*>(gemm_args->D), int(gemm_args->ldd)};
args.split_k_slices = gemm_args->split_k_slices;
args.epilogue = {
*static_cast<ElementCompute const*>(gemm_args->alpha),
*static_cast<ElementCompute const*>(gemm_args->beta)};
Operator op;
Status status = op.initialize(args, device_workspace);
if (status != Status::kSuccess) {
return status;
}
return op.run(stream);
}
};
} }