#include "src/naive/matrix_mul/opr_impl.h"
#include "./matrix_mul_helper.h"
#include "src/common/utils.h"
#include "src/naive/handle.h"
#include "midout.h"
MIDOUT_DECL(megdnn_naive_matmul)
namespace megdnn {
namespace naive {
template <bool TA, bool TB>
void dispatch_ta_tb(
_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
_megdnn_workspace workspace, const MatrixMul::Param& param) {
auto M = C.layout.shape[0], N = C.layout.shape[1];
auto K = A.layout.shape[param.transposeA ? 0 : 1];
auto LDA = A.layout.stride[0], LDB = B.layout.stride[0], LDC = C.layout.stride[0];
dispatch_ta_tb<TA, TB>(
A.raw_ptr(), B.raw_ptr(), C.raw_ptr(), workspace.raw_ptr, M, N, K, LDA, LDB,
LDC, A.layout.dtype, B.layout.dtype, C.layout.dtype, param.format,
param.compute_mode);
}
void exec_internal(
_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
_megdnn_workspace workspace, const MatrixMulForward::Param& param) {
#define DISPATCH(TA, TB) \
if (param.transposeA == TA && param.transposeB == TB) { \
dispatch_ta_tb<TA, TB>(A, B, C, workspace, param); \
return; \
}
DISPATCH(true, true);
DISPATCH(true, false);
DISPATCH(false, true);
DISPATCH(false, false);
#undef DISPATCH
megdnn_assert_internal(0);
}
void MatrixMulForwardImpl::exec(
_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
_megdnn_workspace workspace) {
MIDOUT_BEGIN(megdnn_naive_matmul, midout_iv("MatrixMulForwardImpl::exec"_hash)) {
check_exec(A.layout, B.layout, C.layout, workspace.size);
auto p = param();
MEGDNN_DISPATCH_CPU_KERN_OPR(exec_internal(A, B, C, workspace, p));
}
MIDOUT_END();
}
size_t MatrixMulForwardImpl::get_workspace_in_bytes(
const TensorLayout& A, const TensorLayout& B, const TensorLayout&) {
MIDOUT_BEGIN(
megdnn_naive_matmul,
midout_iv("MatrixMulForwardImpl::get_workspace_in_bytes"_hash)) {
if (A.dtype.enumv() == DTypeEnum::Quantized4Asymm ||
A.dtype.enumv() == DTypeEnum::QuantizedS4) {
return (A.span().dist_elem() + B.span().dist_elem()) * sizeof(uint8_t);
}
return 0;
}
MIDOUT_END();
}
std::vector<MatrixMulForward::Algorithm*> MatrixMulForwardImpl::get_all_algorithms(
const TensorLayout& , const TensorLayout& ,
const TensorLayout& ) {
return {static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo()};
}
std::vector<MatrixMulForward::Algorithm*> MatrixMulForwardImpl::get_all_algorithms_safe(
const TensorLayout& , const TensorLayout& ,
const TensorLayout& ) {
return {static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo()};
}
MatrixMulForward::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic(
const TensorLayout& , const TensorLayout& , const TensorLayout& ,
size_t , const AlgoAttribute& ,
const AlgoAttribute& ) {
return static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo();
}
MatrixMulForward::Algorithm* MatrixMulForwardImpl::get_algorithm_from_desc(
const AlgorithmDesc& desc) {
Algorithm* ret = static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo();
megdnn_assert(desc == ret->info().desc);
return ret;
}
} }