#include "megdnn/named_tensor.h"
#include "test/common/utils.h"
#include "test/common/fix_gtest_on_platforms_without_exception.inl"
using namespace megdnn;
using megdnn::test::MegDNNError;
TEST(NAMED_TENSOR, NAMED_TENSOR_SHAPE_BASIC) {
ASSERT_EQ(Dimension::NR_NAMES, 10);
Dimension dim0 = {"C"}, dim1 = {"C//32"}, dim2 = {"C//4"}, dim3 = {"C%32"},
dim4 = {"C%4"}, dim5 = {"C//4%8"};
ASSERT_TRUE(dim0 == dim1 * dim3);
ASSERT_TRUE(dim2 == dim1 * dim5);
ASSERT_THROW(dim0 * dim0, MegDNNError);
ASSERT_TRUE(dim1 == dim0 / dim3);
ASSERT_TRUE(dim3 == dim0 / dim1);
ASSERT_TRUE(dim4 == dim3 / dim5);
ASSERT_TRUE(dim5 == dim3 / dim4);
ASSERT_TRUE(dim5 == dim2 / dim1);
ASSERT_THROW(dim5 / dim1, MegDNNError);
ASSERT_TRUE(dim1 < dim4);
ASSERT_TRUE(dim5 < dim4);
ASSERT_FALSE(dim4 < dim5);
ASSERT_TRUE(dim1 < dim2);
ASSERT_FALSE(dim2 < dim1);
auto shape0 =
NamedTensorShape::make_named_tensor_shape(NamedTensorShape::Format::NCHW);
SmallVector<Dimension> dims = {{"N"}, {"C"}, {"H"}, {"W"}};
NamedTensorShape shape1(dims);
NamedTensorShape shape2{{"N"}, {"C"}, {"H"}, {"W"}};
ASSERT_TRUE(shape0.eq_shape(shape1));
ASSERT_TRUE(shape0.eq_shape(shape2));
ASSERT_TRUE(shape1.eq_shape(shape2));
auto shape3 =
NamedTensorShape::make_named_tensor_shape(NamedTensorShape::Format::NCHW4);
ASSERT_FALSE(shape0.eq_shape(shape3));
auto shape4 = NamedTensorShape::make_named_tensor_shape(
NamedTensorShape::Format::NCHW44_DOT);
std::sort(shape4.dims.begin(), shape4.dims.begin() + shape4.ndim);
NamedTensorShape shape5{{"N"}, {"C//32"}, {"H"}, {"W"}, {"C//8%4"}, {"C%8"}};
std::sort(shape5.dims.begin(), shape5.dims.begin() + shape5.ndim);
}