import collections.abc
import math
from abc import ABC, abstractmethod
from typing import Any, Generator, Iterator, List, Union
import numpy as np
from .. import distributed as dist
class Sampler(ABC):
@abstractmethod
def __init__(self):
pass
class MapSampler(Sampler):
def __init__(
self,
dataset,
batch_size=1,
drop_last=False,
num_samples=None,
world_size=None,
rank=None,
seed=None,
):
if (
not isinstance(batch_size, int)
or isinstance(batch_size, bool)
or batch_size <= 0
):
raise ValueError(
"batch_size should be a positive integer value, "
"but got batch_size={}".format(batch_size)
)
if not isinstance(drop_last, bool):
raise ValueError(
"drop_last should be a boolean value, but got "
"drop_last={}".format(drop_last)
)
if num_samples is not None and (
not isinstance(num_samples, int)
or isinstance(num_samples, bool)
or num_samples <= 0
):
raise ValueError(
"num_samples should be a positive integer "
"value, but got num_samples={}".format(num_samples)
)
self.batch_size = batch_size
self.dataset = dataset
self.drop_last = drop_last
if world_size is None:
world_size = dist.get_world_size() if dist.is_distributed() else 1
self.world_size = world_size
if rank is None:
rank = dist.get_rank() if dist.is_distributed() else 0
self.rank = rank
if num_samples is None:
num_samples = len(self.dataset)
self.num_samples = int(math.ceil(num_samples / self.world_size))
if seed is None and self.world_size > 1:
seed = 0
self.rng = np.random.RandomState(seed)
def __iter__(self) -> Union[Generator, Iterator]:
return self.batch()
def __len__(self) -> int:
if self.drop_last:
return self.num_samples // self.batch_size
else:
return int(math.ceil(self.num_samples / self.batch_size))
def sample(self):
raise NotImplementedError
def scatter(self, indices) -> List:
total_size = self.num_samples * self.world_size
indices += indices[: (total_size - len(indices))]
assert len(indices) == total_size
indices = indices[self.rank : total_size : self.world_size]
assert len(indices) == self.num_samples
return indices
def batch(self) -> Iterator[List[Any]]:
indices = list(self.sample())
if self.world_size > 1:
indices = self.scatter(indices)
step, length = self.batch_size, len(indices)
batch_index = [indices[i : i + step] for i in range(0, length, step)]
if self.drop_last and len(batch_index[-1]) < self.batch_size:
batch_index.pop()
return iter(batch_index)
class StreamSampler(Sampler):
def __init__(self, batch_size=1):
self.batch_size = batch_size
def __iter__(self):
return self
def __next__(self):
return iter(range(self.batch_size))
class SequentialSampler(MapSampler):
def __init__(
self,
dataset,
batch_size=1,
drop_last=False,
indices=None,
world_size=None,
rank=None,
):
super().__init__(dataset, batch_size, drop_last, None, world_size, rank)
if indices is not None and not isinstance(indices, collections.abc.Sequence):
raise ValueError(
"indices should be None or a sequence, "
"but got indices={}".format(indices)
)
self.indices = indices
def sample(self) -> Iterator[Any]:
if self.indices is None:
return iter(range(len(self.dataset)))
else:
return self.indices
class RandomSampler(MapSampler):
def __init__(
self,
dataset,
batch_size=1,
drop_last=False,
indices=None,
world_size=None,
rank=None,
seed=None,
):
super().__init__(dataset, batch_size, drop_last, None, world_size, rank, seed)
if indices is not None and not isinstance(indices, collections.abc.Sequence):
raise ValueError(
"indices should be None or a sequence, "
"but got indices={}".format(indices)
)
self.indices = indices
def sample(self) -> List:
if self.indices is None:
return self.rng.permutation(len(self.dataset)).tolist()
else:
return self.rng.permutation(self.indices).tolist()
class ReplacementSampler(MapSampler):
def __init__(
self,
dataset,
batch_size=1,
drop_last=False,
num_samples=None,
weights=None,
world_size=None,
rank=None,
seed=None,
):
super().__init__(
dataset, batch_size, drop_last, num_samples, world_size, rank, seed
)
if weights is not None:
if not isinstance(weights, collections.abc.Sequence):
raise ValueError(
"weights should be None or a sequence, "
"but got weights={}".format(weights)
)
if len(weights) != len(dataset):
raise ValueError(
"len(dataset)={} should be equal to"
"len(weights)={}".format(len(dataset), len(weights))
)
self.weights = weights
if self.weights is not None:
self.weights = np.array(weights) / sum(weights)
def sample(self) -> List:
n = len(self.dataset)
if self.weights is None:
return self.rng.randint(n, size=self.num_samples).tolist()
else:
return self.rng.multinomial(n, self.weights, self.num_samples).tolist()
class Infinite(MapSampler):
def sample(self):
raise NotImplementedError("sample method not supported in Infinite")
def __init__(self, sampler):
self.sampler = sampler
self.sampler_iter = iter(self.sampler)
def __iter__(self):
return self
def __next__(self):
try:
index = next(self.sampler_iter)
except StopIteration:
self.sampler_iter = iter(self.sampler)
index = next(self.sampler_iter)
return index
def __len__(self):
return np.iinfo(np.int64).max