#include "src/cuda/handle.h"
#include "src/cuda/matrix_mul/algos.h"
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh"
#include "src/cuda/utils.h"
#if CUDA_VERSION >= 9020
using namespace megdnn;
using namespace cuda;
std::string MatrixMulForwardImpl::AlgoCutlassMatrixMulBase::AlgoParam::to_string()
const {
return ssprintf(
"%dX%dX%d_%dX%dX%d", threadblock_m, threadblock_n, threadblock_k, warp_m,
warp_n, warp_k);
}
std::pair<bool, TensorLayoutArray> MatrixMulForwardImpl::AlgoCutlassMatrixMulBase::
construct_aligned_layouts(const SizeArgs& args) const {
int alignment = max_alignment(args);
int min_alignment = min_alignment_requirement();
bool aligned = alignment >= min_alignment;
if (aligned)
return std::make_pair(!aligned, TensorLayoutArray{{}});
auto&& param = args.opr->param();
int m = args.layout_c.shape[0], n = args.layout_c.shape[1],
k = args.layout_a.shape[param.transposeA ? 0 : 1];
size_t align_m = get_aligned_power2(m, min_alignment);
size_t align_n = get_aligned_power2(n, min_alignment);
size_t align_k = get_aligned_power2(k, min_alignment);
TensorLayoutArray layouts;
layouts.emplace_back(TensorLayout{{align_m, align_k}, args.layout_a.dtype});
layouts.emplace_back(TensorLayout{{align_k, align_n}, args.layout_b.dtype});
layouts.emplace_back(TensorLayout{{align_m, align_n}, args.layout_c.dtype});
return std::make_pair(!aligned, std::move(layouts));
}
void MatrixMulForwardImpl::AlgoCutlassMatrixMulBase::exec(const ExecArgs& args) const {
auto aligned = construct_aligned_layouts(args);
if (!aligned.first)
return do_exec(args);
const auto& layouts = aligned.second;
size_t copy_size = 0;
for (const auto& ly : layouts)
copy_size += ly.span().dist_byte();
auto&& param = args.opr->param();
auto&& stream = cuda_stream(args.opr->handle());
auto workspace = args.workspace;
cuda_check(cudaMemsetAsync(workspace.raw_ptr, 0, copy_size, stream));
auto&& relayout = args.opr->handle()->create_operator<RelayoutForward>();
auto copy_stride = [](const TensorLayout& src, TensorLayout& dst, bool trans) {
dst.stride[0] = src.stride[0], dst.stride[1] = src.stride[1];
if (trans)
std::swap(dst.stride[0], dst.stride[1]);
};
TensorND tensor_a{workspace.raw_ptr, args.tensor_a.layout};
copy_stride(layouts[0], tensor_a.layout, param.transposeA);
relayout->exec(args.tensor_a, tensor_a);
workspace.raw_ptr += layouts[0].span().dist_byte();
workspace.size -= layouts[0].span().dist_byte();
TensorND tensor_b{workspace.raw_ptr, args.tensor_b.layout};
copy_stride(layouts[1], tensor_b.layout, param.transposeB);
relayout->exec(args.tensor_b, tensor_b);
workspace.raw_ptr += layouts[1].span().dist_byte();
workspace.size -= layouts[1].span().dist_byte();
decltype(tensor_a) tensor_c{workspace.raw_ptr, layouts[2]};
workspace.raw_ptr += layouts[2].span().dist_byte();
workspace.size -= layouts[2].span().dist_byte();
auto&& matmul = args.opr->handle()->create_operator<MatrixMulForward>();
matmul->param().transposeA = false;
matmul->param().transposeB = false;
matmul->param().compute_mode = args.opr->param().compute_mode;
tensor_a.layout = layouts[0];
tensor_b.layout = layouts[1];
ExecArgs args_{
static_cast<MatrixMulForwardImpl*>(matmul.get()), tensor_a, tensor_b,
tensor_c, workspace};
do_exec(args_);
tensor_c.layout.TensorShape::operator=(args.layout_c);
relayout->exec(tensor_c, args.tensor_c);
}
int MatrixMulForwardImpl::AlgoCutlassMatrixMulBase::max_alignment(
const SizeArgs& args) const {
auto&& dtype_a = args.layout_a.dtype;
auto&& dtype_b = args.layout_b.dtype;
auto&& dtype_c = args.layout_c.dtype;
auto get_alignment = [](const DType& dt, int len) {
int size_bits = dt.size(1) * 8;
int align = 128;
while (align > 1) {
if ((len * size_bits) % align == 0)
break;
align = align / 2;
}
return align / size_bits;
};
int lda = args.layout_a.stride[0], ldb = args.layout_b.stride[0],
ldc = args.layout_c.stride[0];
auto&& param = args.opr->param();
int m = args.layout_c.shape[0], n = args.layout_c.shape[1],
k = args.layout_a.shape[param.transposeA ? 0 : 1];
int max_align = get_alignment(dtype_a, lda);
max_align = std::min(get_alignment(dtype_a, m), max_align);
max_align = std::min(get_alignment(dtype_a, n), max_align);
max_align = std::min(get_alignment(dtype_a, k), max_align);
max_align = std::min(get_alignment(dtype_a, lda), max_align);
max_align = std::min(get_alignment(dtype_b, ldb), max_align);
max_align = std::min(get_alignment(dtype_c, ldc), max_align);
return max_align;
}
#endif