#pragma once
#include <cuda.h>
#include "./algos.h"
#include "megdnn/basic_types.h"
#include "megdnn/oprs/nn.h"
#include "src/common/utils.h"
#include "src/cuda/utils.h"
#if CUDA_VERSION >= 10010
#include <cublasLt.h>
namespace megdnn {
namespace cuda {
struct CUBLASLTMatmulDesc {
struct SizeArgs {
using MMSizeArgs = MatrixMulForwardImpl::AlgoBase::SizeArgs;
HandleImpl* handle;
bool transposeA, transposeB;
TensorLayout layout_a, layout_b, layout_c;
std::string to_string() const;
SizeArgs(
HandleImpl* handle, bool transposeA, bool transposeB,
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C)
: handle(handle),
transposeA(transposeA),
transposeB(transposeB),
layout_a(A),
layout_b(B),
layout_c(C){};
explicit SizeArgs(const MMSizeArgs& args)
: layout_a(args.layout_a),
layout_b(args.layout_b),
layout_c(args.layout_c) {
handle = concrete_handle(args.opr->handle());
auto&& param = args.opr->param();
transposeA = param.transposeA;
transposeB = param.transposeB;
};
};
bool is_batched;
cublasLtMatmulDesc_t matmul_desc;
cudaDataType_t dt_a, dt_b, dt_c;
#if CUDA_VERSION >= 11000
cublasComputeType_t dt_compute;
#else
cudaDataType_t dt_compute;
#endif
cublasLtMatrixLayout_t layout_a, layout_b, layout_c;
cublasLtMatrixLayout_t layout_trans_a, layout_trans_b, layout_trans_c;
size_t workspace_a, workspace_b, workspace_c;
CUBLASLTMatmulDesc(const SizeArgs& args, bool batched = false)
: matmul_desc(nullptr),
layout_a(nullptr),
layout_b(nullptr),
layout_c(nullptr),
layout_trans_a(nullptr),
layout_trans_b(nullptr),
layout_trans_c(nullptr),
workspace_a(0),
workspace_b(0),
workspace_c(0) {
is_batched = batched;
set(args, batched);
}
~CUBLASLTMatmulDesc();
void set(const SizeArgs& args, bool batched = false);
void reset();
bool get_algorithm_heuristic(
const SizeArgs& args, size_t ws_limit, cublasLtMatmulAlgo_t& algo);
WorkspaceBundle get_workspace_bundle(
const SizeArgs& args, const cublasLtMatmulAlgo_t& algo);
bool is_available(const SizeArgs& args, size_t ws_limit);
};
} } #endif