#pragma once
#include <cuda_runtime.h>
#include <cstdint>
#include <stdexcept>
#include <string>
#include <vector>
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wreorder"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#pragma GCC diagnostic ignored "-Wunused-parameter"
#include "cutlass/cutlass.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/tensor_coord.h"
#include "cutlass/conv/conv2d_problem_size.h"
#include "cutlass/conv/convolution.h"
#include "cutlass/epilogue/epilogue.h"
#include "cutlass/gemm/gemm.h"
#pragma GCC diagnostic pop
namespace cutlass {
namespace library {
enum class LayoutTypeID {
kUnknown,
kColumnMajor,
kRowMajor,
kColumnMajorInterleavedK2,
kRowMajorInterleavedK2,
kColumnMajorInterleavedK4,
kRowMajorInterleavedK4,
kColumnMajorInterleavedK16,
kRowMajorInterleavedK16,
kColumnMajorInterleavedK32,
kRowMajorInterleavedK32,
kColumnMajorInterleavedK64,
kRowMajorInterleavedK64,
kTensorNCHW,
kTensorNCDHW,
kTensorNHWC,
kTensorNDHWC,
kTensorNC4HW4,
kTensorC4RSK4,
kTensorNC8HW8,
kTensorC8RSK8,
kTensorNC16HW16,
kTensorC16RSK16,
kTensorNC32HW32,
kTensorC32RSK32,
kTensorNC64HW64,
kTensorC64RSK64,
kTensorK4RSC4,
kTensorCK4RS4,
kTensorCK8RS8,
kTensorCK16RS16,
kInvalid
};
enum class NumericTypeID {
kUnknown,
kVoid,
kB1,
kU2,
kU4,
kU8,
kU16,
kU32,
kU64,
kS2,
kS4,
kS8,
kS16,
kS32,
kS64,
kF16,
kBF16,
kTF32,
kF32,
kF64,
kCF16,
kCBF16,
kCF32,
kCTF32,
kCF64,
kCS2,
kCS4,
kCS8,
kCS16,
kCS32,
kCS64,
kCU2,
kCU4,
kCU8,
kCU16,
kCU32,
kCU64,
kInvalid
};
enum class ComplexTransform { kNone, kConjugate, kInvalid };
enum class Provider {
kNone,
kCUTLASS,
kReferenceHost,
kReferenceDevice,
kCUBLAS,
kCUDNN,
kInvalid
};
enum class OperationKind {
kGemm,
kConv2d,
kConv3d,
kConvolution,
kEqGemm,
kSparseGemm,
kReduction,
kInvalid
};
enum class ScalarPointerMode { kHost, kDevice, kInvalid };
enum class SplitKMode { kNone, kSerial, kParallel, kParallelSerial, kInvalid };
enum class OpcodeClassID { kSimt, kTensorOp, kWmmaTensorOp, kSparseTensorOp, kInvalid };
enum class ArchTagID {
kSm50,
kSm60,
kSm61,
kSm70,
kSm72,
kSm75,
kSm80,
kSm86,
kInvalid
};
enum class MathOperationID {
kAdd,
kMultiplyAdd,
kMultiplyAddSaturate,
kMultiplyAddFastBF16,
kMultiplyAddFastF16,
kMultiplyAddComplex,
kMultiplyAddGaussianComplex,
kXorPopc,
kInvalid
};
enum class ThreadblockSwizzleID {
kGemmIdentity,
kGemmHorizontal,
kGemmBatchedIdentity,
kGemmSplitKIdentity,
kGemmSplitKHorizontal,
kGemvBatchedStridedDefault,
kGemvBatchedStridedReduction,
kConvolutionFpropCxRSKx,
kConvolutionDgradCxRSKx,
kConvolutionFpropNCxHWx,
kConvolutionFpropTrans,
kConvolutionDgradNCxHWx,
kConvolutionDgradTrans,
kDepthwiseConvolutionFprop,
kDepthwiseConvolutionDgrad,
kDepthwiseConvolutionWgrad,
kInvalid
};
enum class GemmKind {
kGemm,
kSparse,
kUniversal,
kPlanarComplex,
kPlanarComplexArray,
kInvalid
};
using GemmUniversalMode = cutlass::gemm::GemmUniversalMode;
enum class ConvKind { kUnknown, kFprop, kDgrad, kWgrad, kInvalid };
enum class ConvModeID { kCrossCorrelation, kConvolution, kInvalid };
enum class IteratorAlgorithmID { kNone, kAnalytic, kOptimized, kInvalid };
enum class EpilogueKind {
kUnknown,
kBiasAddLinearCombination,
kBiasAddLinearCombinationClamp,
kBiasAddLInearCombinationHSwish,
kBiasAddLInearCombinationHSwishClamp,
kBiasAddLInearCombinationRelu,
kBiasAddLInearCombinationReluClamp,
kConversion,
kLinearCombination,
kLinearCombinationClamp,
kLinearCombinationPlanarComplex,
kLinearCombinationRelu,
kLinearCombinationSigmoid,
kInvalid
};
struct MathInstructionDescription {
cutlass::gemm::GemmCoord instruction_shape;
NumericTypeID element_accumulator;
OpcodeClassID opcode_class;
MathOperationID math_operation;
MathInstructionDescription(
cutlass::gemm::GemmCoord instruction_shape = cutlass::gemm::GemmCoord(),
NumericTypeID element_accumulator = NumericTypeID::kInvalid,
OpcodeClassID opcode_class = OpcodeClassID::kInvalid,
MathOperationID math_operation = MathOperationID::kMultiplyAdd)
: instruction_shape(instruction_shape),
element_accumulator(element_accumulator),
opcode_class(opcode_class),
math_operation(math_operation) {}
inline bool operator==(MathInstructionDescription const& rhs) const {
return ((instruction_shape == rhs.instruction_shape) &&
(element_accumulator == rhs.element_accumulator) &&
(opcode_class == rhs.opcode_class) &&
(math_operation == rhs.math_operation));
}
inline bool operator!=(MathInstructionDescription const& rhs) const {
return !(*this == rhs);
}
};
struct TileDescription {
cutlass::gemm::GemmCoord threadblock_shape;
int threadblock_stages;
cutlass::gemm::GemmCoord warp_count;
MathInstructionDescription math_instruction;
int minimum_compute_capability;
int maximum_compute_capability;
TileDescription(
cutlass::gemm::GemmCoord threadblock_shape = cutlass::gemm::GemmCoord(),
int threadblock_stages = 0,
cutlass::gemm::GemmCoord warp_count = cutlass::gemm::GemmCoord(),
MathInstructionDescription math_instruction = MathInstructionDescription(),
int minimum_compute_capability = 0, int maximum_compute_capability = 0)
: threadblock_shape(threadblock_shape),
threadblock_stages(threadblock_stages),
warp_count(warp_count),
math_instruction(math_instruction),
minimum_compute_capability(minimum_compute_capability),
maximum_compute_capability(maximum_compute_capability) {}
inline bool operator==(TileDescription const& rhs) const {
return ((threadblock_shape == rhs.threadblock_shape) &&
(threadblock_stages == rhs.threadblock_stages) &&
(warp_count == rhs.warp_count) &&
(math_instruction == rhs.math_instruction) &&
(minimum_compute_capability == rhs.minimum_compute_capability) &&
(maximum_compute_capability == rhs.maximum_compute_capability));
}
inline bool operator!=(TileDescription const& rhs) const { return !(*this == rhs); }
};
struct OperationDescription {
char const* name;
Provider provider;
OperationKind kind;
TileDescription tile_description;
OperationDescription(
char const* name = "unknown", OperationKind kind = OperationKind::kInvalid,
TileDescription const& tile_description = TileDescription())
: name(name), kind(kind), tile_description(tile_description) {}
};
struct TensorDescription {
NumericTypeID element;
LayoutTypeID layout;
int alignment;
int log_extent_range;
int log_stride_range;
TensorDescription(
NumericTypeID element = NumericTypeID::kInvalid,
LayoutTypeID layout = LayoutTypeID::kInvalid, int alignment = 1,
int log_extent_range = 24, int log_stride_range = 24)
: element(element),
layout(layout),
alignment(alignment),
log_extent_range(log_extent_range),
log_stride_range(log_stride_range) {}
};
struct GemmDescription : public OperationDescription {
GemmKind gemm_kind;
TensorDescription A;
TensorDescription B;
TensorDescription C;
int stages;
SplitKMode split_k_mode;
};
struct GemmArguments {
gemm::GemmCoord problem_size;
void const* A;
void const* B;
void const* C;
void* D;
int64_t lda;
int64_t ldb;
int64_t ldc;
int64_t ldd;
int split_k_slices;
void const* alpha;
void const* beta;
};
struct ConvolutionDescription : public OperationDescription {
conv::Operator conv_op;
TensorDescription src;
TensorDescription filter;
TensorDescription dst;
TensorDescription bias;
conv::ConvType convolution_type;
ArchTagID arch_tag;
epilogue::EpilogueType epilogue_type;
int epilogue_count;
ThreadblockSwizzleID threadblock_swizzle;
conv::SpecialOptimizeDesc special_optimization;
conv::ImplicitGemmMode gemm_mode;
bool without_shared_load;
};
struct ConvolutionArguments {
conv::Conv2dProblemSize problem_size;
void const* src;
void const* filter;
void const* bias;
void const* z;
void* dst;
void const* alpha;
void const* beta;
void const* gamma;
void const* delta;
void const* theta;
void const* threshold;
void const* scale;
void const* extra_param;
};
class Operation {
public:
virtual ~Operation() {}
virtual OperationDescription const& description() const = 0;
virtual Status run(
void const* arguments, void* device_workspace = nullptr,
cudaStream_t stream = nullptr) const = 0;
};
} }