#include "src/cuda/cutlass/operation_table.h"
#include "src/common/utils.h"
namespace cutlass {
namespace library {
GemmKey get_gemm_key_from_desc(const GemmDescription& desc) {
GemmKey key;
key.element_A = desc.A.element;
key.layout_A = desc.A.layout;
key.element_B = desc.B.element;
key.layout_B = desc.B.layout;
key.element_C = desc.C.element;
key.layout_C = desc.C.layout;
key.element_accumulator =
desc.tile_description.math_instruction.element_accumulator;
key.threadblock_shape_m = desc.tile_description.threadblock_shape.m();
key.threadblock_shape_n = desc.tile_description.threadblock_shape.n();
key.threadblock_shape_k = desc.tile_description.threadblock_shape.k();
key.warp_shape_m = desc.tile_description.threadblock_shape.m() /
desc.tile_description.warp_count.m();
key.warp_shape_n = desc.tile_description.threadblock_shape.n() /
desc.tile_description.warp_count.n();
key.warp_shape_k = desc.tile_description.threadblock_shape.k() /
desc.tile_description.warp_count.k();
key.instruction_shape_m =
desc.tile_description.math_instruction.instruction_shape.m();
key.instruction_shape_n =
desc.tile_description.math_instruction.instruction_shape.n();
key.instruction_shape_k =
desc.tile_description.math_instruction.instruction_shape.k();
key.stages = desc.stages;
key.alignment_A = desc.A.alignment;
key.alignment_B = desc.B.alignment;
key.split_k_mode = desc.split_k_mode;
return key;
}
ConvolutionKey get_convolution_key_from_desc(const ConvolutionDescription& desc) {
ConvolutionKey key;
key.conv_op = desc.conv_op;
key.element_src = desc.src.element;
key.layout_src = desc.src.layout;
key.element_filter = desc.filter.element;
key.layout_filter = desc.filter.layout;
key.element_dst = desc.dst.element;
key.layout_dst = desc.dst.layout;
key.element_bias = desc.bias.element;
key.layout_bias = desc.bias.layout;
key.element_accumulator =
desc.tile_description.math_instruction.element_accumulator;
key.convolution_type = desc.convolution_type;
key.threadblock_shape_m = desc.tile_description.threadblock_shape.m();
key.threadblock_shape_n = desc.tile_description.threadblock_shape.n();
key.threadblock_shape_k = desc.tile_description.threadblock_shape.k();
key.warp_shape_m = desc.tile_description.threadblock_shape.m() /
desc.tile_description.warp_count.m();
key.warp_shape_n = desc.tile_description.threadblock_shape.n() /
desc.tile_description.warp_count.n();
key.warp_shape_k = desc.tile_description.threadblock_shape.k() /
desc.tile_description.warp_count.k();
key.instruction_shape_m =
desc.tile_description.math_instruction.instruction_shape.m();
key.instruction_shape_n =
desc.tile_description.math_instruction.instruction_shape.n();
key.instruction_shape_k =
desc.tile_description.math_instruction.instruction_shape.k();
key.epilogue_type = desc.epilogue_type;
key.stages = desc.tile_description.threadblock_stages;
key.special_optimization = desc.special_optimization;
key.alignment_src = desc.src.alignment;
key.alignment_filter = desc.filter.alignment;
key.without_shared_load = desc.without_shared_load;
return key;
}
void OperationTable::append(Manifest const& manifest) {
for (auto const& operation : manifest) {
OperationDescription const& desc = operation->description();
if (desc.kind == OperationKind::kGemm) {
GemmKey key =
get_gemm_key_from_desc(static_cast<GemmDescription const&>(desc));
gemm_operations[key].push_back(operation.get());
}
if (desc.kind == OperationKind::kConvolution) {
ConvolutionKey key = get_convolution_key_from_desc(
static_cast<ConvolutionDescription const&>(desc));
convolution_operations[key].push_back(operation.get());
}
}
}
Operation const* OperationTable::find_op(GemmKey const& key) const {
if (gemm_operations.count(key)) {
auto const& ops = gemm_operations.at(key);
megdnn_assert(
ops.size() == 1, "exactly one kernel expected, got %zu", ops.size());
return ops[0];
}
return nullptr;
}
Operation const* OperationTable::find_op(ConvolutionKey const& key) const {
if (convolution_operations.count(key) > 0) {
auto const& ops = convolution_operations.at(key);
megdnn_assert(
ops.size() == 1, "exactly one kernel expected, got %zu", ops.size());
return ops[0];
}
return nullptr;
}
} }