import numpy as np
from megengine.data.transform import *
data_shape = (100, 100, 3)
label_shape = (4,)
ToMode_target_shape = (3, 100, 100)
CenterCrop_size = (90, 70)
CenterCrop_target_shape = CenterCrop_size + (3,)
RandomResizedCrop_size = (50, 50)
RandomResizedCrop_target_shape = RandomResizedCrop_size + (3,)
def generate_data():
return [
(
(np.random.rand(*data_shape) * 255).astype(np.uint8),
np.random.randint(10, size=label_shape),
)
for _ in range(*label_shape)
]
def test_ToMode():
t = ToMode(mode="CHW")
aug_data = t.apply_batch(generate_data())
aug_data_shape = [(a.shape, b.shape) for a, b in aug_data]
target_shape = [(ToMode_target_shape, label_shape)] * 4
assert aug_data_shape == target_shape
def test_CenterCrop():
t = CenterCrop(output_size=CenterCrop_size)
aug_data = t.apply_batch(generate_data())
aug_data_shape = [(a.shape, b.shape) for a, b in aug_data]
target_shape = [(CenterCrop_target_shape, label_shape)] * 4
assert aug_data_shape == target_shape
def test_ColorJitter():
t = ColorJitter()
aug_data = t.apply_batch(generate_data())
aug_data_shape = [(a.shape, b.shape) for a, b in aug_data]
target_shape = [(data_shape, label_shape)] * 4
assert aug_data_shape == target_shape
def test_RandomHorizontalFlip():
t = RandomHorizontalFlip(prob=1)
aug_data = t.apply_batch(generate_data())
aug_data_shape = [(a.shape, b.shape) for a, b in aug_data]
target_shape = [(data_shape, label_shape)] * 4
assert aug_data_shape == target_shape
def test_RandomVerticalFlip():
t = RandomVerticalFlip(prob=1)
aug_data = t.apply_batch(generate_data())
aug_data_shape = [(a.shape, b.shape) for a, b in aug_data]
target_shape = [(data_shape, label_shape)] * 4
assert aug_data_shape == target_shape
def test_RandomResizedCrop():
t = RandomResizedCrop(output_size=RandomResizedCrop_size)
aug_data = t.apply_batch(generate_data())
aug_data_shape = [(a.shape, b.shape) for a, b in aug_data]
target_shape = [(RandomResizedCrop_target_shape, label_shape)] * 4
assert aug_data_shape == target_shape
def test_Normalize():
t = Normalize()
aug_data = t.apply_batch(generate_data())
aug_data_shape = [(a.shape, b.shape) for a, b in aug_data]
target_shape = [(data_shape, label_shape)] * 4
assert aug_data_shape == target_shape
def test_RandomCrop():
t = RandomCrop((150, 120), padding_size=10, padding_value=[1, 2, 3])
aug_data = t.apply_batch(generate_data())
aug_data_shape = [(a.shape, b.shape) for a, b in aug_data]
target_shape = [((150, 120, 3), label_shape)] * 4
assert aug_data_shape == target_shape
def test_Compose():
t = Compose(
[
CenterCrop(output_size=CenterCrop_size),
RandomHorizontalFlip(prob=1),
ToMode(mode="CHW"),
]
)
aug_data = t.apply_batch(generate_data())
aug_data_shape = [(a.shape, b.shape) for a, b in aug_data]
target_shape = [((3, 90, 70), label_shape)] * 4
assert aug_data_shape == target_shape, "aug {}, target {}".format(
aug_data_shape, target_shape
)