#include "src/cuda/cutlass/singleton.h"
#include "src/cuda/handle.h"
#include "src/cuda/matrix_mul/algos.h"
#include "src/cuda/utils.h"
#if CUDA_VERSION >= 10010
using namespace megdnn;
using namespace cuda;
bool MatrixMulForwardImpl::AlgoFloat16TensorOp::is_available(
const SizeArgs& args) const {
bool available = args.opr->param().format == param::MatrixMul::Format::DEFAULT &&
args.layout_b.dtype == dtype::Float16() &&
args.layout_c.dtype == dtype::Float16();
int n = args.layout_c.shape[1];
auto&& device_prop = cuda::current_device_prop();
int y_grid_limit = device_prop.maxGridSize[1];
available &=
((n + m_algo_param.threadblock_n - 1) / m_algo_param.threadblock_n <=
y_grid_limit);
if (m_algo_param.instruction_m == 8 && m_algo_param.instruction_n == 8 &&
m_algo_param.instruction_k == 4) {
available &= is_compute_capability_required(7, 0);
} else {
megdnn_assert(
m_algo_param.instruction_m == 16 && m_algo_param.instruction_n == 8 &&
m_algo_param.instruction_k == 8);
available &= is_compute_capability_required(7, 5);
}
return available;
}
size_t MatrixMulForwardImpl::AlgoFloat16TensorOp::get_workspace_in_bytes(
const SizeArgs& args) const {
auto aligned = construct_aligned_layouts(args);
if (!aligned.first)
return 0_z;
const auto& layouts = aligned.second;
size_t ws_size = 0;
for (auto&& ly : layouts) {
ws_size += ly.span().dist_byte();
}
return ws_size;
}
void MatrixMulForwardImpl::AlgoFloat16TensorOp::do_exec(const ExecArgs& args) const {
int64_t lda = args.tensor_a.layout.stride[0], ldb = args.tensor_b.layout.stride[0],
ldc = args.tensor_c.layout.stride[0];
int alignment = max_alignment(args);
int min_alignment = min_alignment_requirement();
auto&& param = args.opr->param();
int m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1],
k = args.tensor_a.layout.shape[param.transposeA ? 0 : 1];
megdnn_assert(
lda % alignment == 0 && ldb % alignment == 0 && ldc % alignment == 0 &&
m % alignment == 0 && n % alignment == 0 && k % alignment == 0 &&
alignment >= min_alignment);
cutlass::gemm::GemmCoord problem_size{m, n, k};
auto&& stream = cuda_stream(args.opr->handle());
int* workspace = reinterpret_cast<int*>(args.workspace.raw_ptr);
float one = 1.f, zero = 0.f;
dt_float16 one_f16 = static_cast<dt_float16>(one),
zero_f16 = static_cast<dt_float16>(zero);
using namespace cutlass::library;
auto layoutA =
param.transposeA ? LayoutTypeID::kColumnMajor : LayoutTypeID::kRowMajor;
auto layoutB =
param.transposeB ? LayoutTypeID::kColumnMajor : LayoutTypeID::kRowMajor;
void *host_one, *host_zero;
NumericTypeID element_accumulator;
if (param.compute_mode == param::MatrixMul::ComputeMode::DEFAULT) {
element_accumulator = NumericTypeID::kF16;
host_one = &one_f16;
host_zero = &zero_f16;
} else {
megdnn_assert(param.compute_mode == param::MatrixMul::ComputeMode::FLOAT32);
element_accumulator = NumericTypeID::kF32;
host_one = &one;
host_zero = &zero;
}
GemmKey key{
NumericTypeID::kF16,
layoutA,
NumericTypeID::kF16,
layoutB,
NumericTypeID::kF16,
LayoutTypeID::kRowMajor,
element_accumulator,
m_algo_param.threadblock_m,
m_algo_param.threadblock_n,
m_algo_param.threadblock_k,
m_algo_param.warp_m,
m_algo_param.warp_n,
m_algo_param.warp_k,
m_algo_param.instruction_m,
m_algo_param.instruction_n,
m_algo_param.instruction_k,
2,
alignment,
alignment,
SplitKMode::kNone};
const auto& table = Singleton::get().operation_table;
megdnn_assert(
table.gemm_operations.count(key) > 0,
"key not found in cutlass operation table");
const auto& ops = table.gemm_operations.at(key);
megdnn_assert(ops.size() == 1, "exactly one kernel expected, got %zu", ops.size());
GemmArguments gemm_args{
problem_size,
args.tensor_a.raw_ptr(),
args.tensor_b.raw_ptr(),
args.tensor_c.raw_ptr(),
args.tensor_c.raw_ptr(),
lda,
ldb,
ldc,
ldc,
1,
host_one,
host_zero};
cutlass_check(ops[0]->run(&gemm_args, workspace, stream));
}
#endif