#include "test/fallback/fixture.h"
#include <ctime>
#include "test/common/checker.h"
#include "test/common/elemwise.h"
#include "test/common/task_record_check.h"
#include "test/common/tensor.h"
using namespace megdnn;
using namespace test;
template <typename tag>
class FALLBACK_ELEMWISE : public FALLBACK {};
TYPED_TEST_CASE(FALLBACK_ELEMWISE, elemwise::test_types);
TYPED_TEST(FALLBACK_ELEMWISE, run) {
elemwise::run_test<TypeParam>(this->handle());
}
TEST_F(FALLBACK, ELEMWISE_RECORD) {
TaskRecordChecker<Elemwise> checker{1};
checker.set_param({Elemwise::Mode::ADD});
checker.set_dtype(0, dtype::Float32());
checker.set_dtype(1, dtype::Float32());
checker.set_dtype(2, dtype::Float32());
UniformIntRNG rng{-100, 100};
checker.set_rng(0, &rng);
checker.set_rng(1, &rng);
checker.set_rng(2, &rng);
checker.execs({{10, 10, 32}, {10, 10, 32}, {}});
}
#if MEGDNN_WITH_BENCHMARK
TEST_F(FALLBACK, BENCHMARK_ELEMWISE) {
auto naive_handle = create_cpu_handle(2);
auto run = [&](const TensorShape& shp0, const TensorShape& shp1) {
TensorShape shpo;
Elemwise::deduce_shape({shp0, shp1}, shpo);
Tensor<> op0(handle(), {shp0, dtype::Float32()}),
op1(handle(), {shp1, dtype::Float32()}),
out(handle(), {shpo, dtype::Float32()});
auto opr_cur = handle()->create_operator<Elemwise>();
auto opr_naive = naive_handle->create_operator<Elemwise>();
opr_cur->param() = {Elemwise::Mode::ADD};
opr_naive->param() = {Elemwise::Mode::ADD};
auto timeit = [&](Elemwise* opr) {
opr->exec({op0.tensornd(), op1.tensornd()}, out.tensornd());
auto start = clock();
opr->exec({op0.tensornd(), op1.tensornd()}, out.tensornd());
auto stop = clock();
return (stop - start) * 1e3 / CLOCKS_PER_SEC;
};
auto t0 = timeit(opr_cur.get()), t1 = timeit(opr_naive.get());
double tot_size_gb_ms =
(op0.layout().span().dist_byte() + op1.layout().span().dist_byte() +
out.layout().span().dist_byte()) /
1024.0 / 1024.0 / 1024.0 * 1e3;
printf("%15s+%-15s: fallback=%7.3fms,%5.2fGiB/s "
"naive=%7.3fms,%5.2fGiB/s\n",
shp0.to_string().c_str(), shp1.to_string().c_str(), t0,
tot_size_gb_ms / t0, t1, tot_size_gb_ms / t1);
};
run({1024, 1024, 32}, {1024, 1024, 32});
run({1024, 1024, 32}, {1, 1024, 1});
run({4096 * 4, 1024}, {4096 * 4, 1});
run({4096 * 4, 1024}, {1, 1024});
run({1024, 1024, 32}, {1024, 1, 32});
}
#endif