#include "test/common/matrix_mul.h"
#include "test/common/checker.h"
#include "test/common/rng.h"
#include "test/common/task_record_check.h"
#include "test/fallback/fixture.h"
namespace megdnn {
namespace test {
TEST_F(FALLBACK, MATRIX_MUL) {
Checker<MatrixMul> checker(handle());
using Param = MatrixMul::Param;
auto args = matrix_mul::get_matmul_args();
for (auto arg : args) {
auto m = arg.m, n = arg.n, k = arg.k;
auto mask = arg.mask;
Param param;
param.transposeA = mask & 1;
param.transposeB = mask & 2;
TensorShape AS, BS, CS;
if (param.transposeA)
AS = TensorShape{k, m};
else
AS = TensorShape{m, k};
if (param.transposeB)
BS = TensorShape{n, k};
else
BS = TensorShape{k, n};
CS = TensorShape{m, n};
TensorLayout AL, BL, CL;
AL = TensorLayout(AS, dtype::Float32());
BL = TensorLayout(BS, dtype::Float32());
CL = TensorLayout(CS, dtype::Float32());
checker.set_param(param);
checker.execl({AL, BL, CL});
}
}
TEST_F(FALLBACK, MATRIX_MUL_RECORD) {
TaskRecordChecker<MatrixMul> checker(1);
using Param = MatrixMul::Param;
auto args = matrix_mul::get_matmul_args();
for (auto arg : args) {
auto m = arg.m, n = arg.n, k = arg.k;
auto mask = arg.mask;
Param param;
param.transposeA = mask & 1;
param.transposeB = mask & 2;
TensorShape AS, BS, CS;
if (param.transposeA)
AS = TensorShape{k, m};
else
AS = TensorShape{m, k};
if (param.transposeB)
BS = TensorShape{n, k};
else
BS = TensorShape{k, n};
CS = TensorShape{m, n};
TensorLayout AL, BL, CL;
AL = TensorLayout(AS, dtype::Float32());
BL = TensorLayout(BS, dtype::Float32());
CL = TensorLayout(CS, dtype::Float32());
checker.set_param(param);
checker.execl({AL, BL, CL});
}
}
TEST_F(FALLBACK, MATRIX_MUL_NAIVE_MK4) {
matrix_mul::check_matrix_mul(
dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), "FB_NAIVE",
param::MatrixMul::Format::MK4, 1);
}
TEST_F(FALLBACK, MATRIX_MUL_NAIVE_MK8) {
matrix_mul::check_matrix_mul(
dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), "FB_NAIVE",
param::MatrixMul::Format::MK8, 1);
}
TEST_F(FALLBACK, MATRIX_MUL_NAIVE_MK4_DOT) {
matrix_mul::check_matrix_mul(
dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), "FB_NAIVE",
param::MatrixMul::Format::MK4_DOT, 1);
}
TEST_F(FALLBACK, MATRIX_MUL_NAIVE) {
Checker<MatrixMul> checker(handle());
checker.set_before_exec_callback(AlgoChecker<MatrixMul>("FB_NAIVE"));
using Param = MatrixMul::Param;
auto args = matrix_mul::get_matmul_args();
for (auto arg : args) {
auto m = arg.m, n = arg.n, k = arg.k;
auto mask = arg.mask;
Param param;
param.transposeA = mask & 1;
param.transposeB = mask & 2;
TensorShape AS, BS, CS;
if (param.transposeA)
AS = TensorShape{k, m};
else
AS = TensorShape{m, k};
if (param.transposeB)
BS = TensorShape{n, k};
else
BS = TensorShape{k, n};
CS = TensorShape{m, n};
TensorLayout AL, BL, CL;
AL = TensorLayout(AS, dtype::Float32());
BL = TensorLayout(BS, dtype::Float32());
CL = TensorLayout(CS, dtype::Float32());
checker.set_param(param);
checker.execl({AL, BL, CL});
}
}
TEST_F(FALLBACK, BATCHED_MATRIX_MUL) {
Checker<BatchedMatrixMul> checker(handle());
using Param = MatrixMul::Param;
auto args = matrix_mul::get_batched_matmul_args();
for (auto arg : args) {
auto b = arg.b, m = arg.m, n = arg.n, k = arg.k;
auto mask = arg.mask;
Param param;
param.transposeA = mask & 1;
param.transposeB = mask & 2;
TensorShape AS, BS, CS;
if (param.transposeA)
AS = TensorShape{b, k, m};
else
AS = TensorShape{b, m, k};
if (param.transposeB)
BS = TensorShape{b, n, k};
else
BS = TensorShape{b, k, n};
TensorLayout AL, BL;
AL = TensorLayout(AS, dtype::Float32());
BL = TensorLayout(BS, dtype::Float32());
checker.set_param(param);
checker.execs({AL, BL, {}});
}
}
} }