megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import copy
import os
import sys

import numpy as np
import pytest

from megengine.data.dataset import ArrayDataset
from megengine.data.sampler import RandomSampler, ReplacementSampler, SequentialSampler


def test_sequential_sampler():
    indices = list(range(100))
    sampler = SequentialSampler(ArrayDataset(indices))
    assert indices == list(each[0] for each in sampler)


def test_RandomSampler():
    indices = list(range(20))
    indices_copy = copy.deepcopy(indices)
    sampler = RandomSampler(ArrayDataset(indices_copy))
    sample_indices = sampler
    assert indices != list(each[0] for each in sample_indices)
    assert indices == sorted(list(each[0] for each in sample_indices))


def test_random_sampler_seed():
    seed = [0, 1]
    indices = list(range(20))
    indices_copy1 = copy.deepcopy(indices)
    indices_copy2 = copy.deepcopy(indices)
    indices_copy3 = copy.deepcopy(indices)
    sampler1 = RandomSampler(ArrayDataset(indices_copy1), seed=seed[0])
    sampler2 = RandomSampler(ArrayDataset(indices_copy2), seed=seed[0])
    sampler3 = RandomSampler(ArrayDataset(indices_copy3), seed=seed[1])
    assert indices != list(each[0] for each in sampler1)
    assert indices != list(each[0] for each in sampler2)
    assert indices != list(each[0] for each in sampler3)
    assert indices == sorted(list(each[0] for each in sampler1))
    assert indices == sorted(list(each[0] for each in sampler2))
    assert indices == sorted(list(each[0] for each in sampler3))
    assert list(each[0] for each in sampler1) == list(each[0] for each in sampler2)
    assert list(each[0] for each in sampler1) != list(each[0] for each in sampler3)


def test_ReplacementSampler():
    num_samples = 30
    indices = list(range(20))
    weights = list(range(20))
    sampler = ReplacementSampler(
        ArrayDataset(indices), num_samples=num_samples, weights=weights
    )
    assert len(list(each[0] for each in sampler)) == num_samples


def test_sampler_drop_last_false():
    batch_size = 5
    drop_last = False
    indices = list(range(24))
    sampler = SequentialSampler(
        ArrayDataset(indices), batch_size=batch_size, drop_last=drop_last
    )
    assert len([each for each in sampler]) == len(sampler)


def test_sampler_drop_last_true():
    batch_size = 5
    drop_last = True
    indices = list(range(24))
    sampler = SequentialSampler(
        ArrayDataset(indices), batch_size=batch_size, drop_last=drop_last
    )
    assert len([each for each in sampler]) == len(sampler)