import collections
import gc
import math
import multiprocessing
import os
import platform
import queue
import random
import threading
import time
from typing import Callable, Union
import numpy as np
from ..device import _sh, get_default_device
from ..functional.tensor import copy
from ..logger import get_logger
from ..random.rng import _random_seed_generator
from ..tensor import Tensor
from .collator import Collator
from .dataset import Dataset, StreamDataset
from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler
from .transform import PseudoTransform, Transform
try:
import thread
except:
import _thread as thread
logger = get_logger(__name__)
GLOBAL_TIMEOUT = 5
def raise_timeout_error():
raise RuntimeError("dataloader timeout")
class DataLoader:
__initialized = False
def __init__(
self,
dataset: Dataset,
sampler: Sampler = None,
transform: Transform = None,
collator: Collator = None,
num_workers: int = 0,
timeout: int = 0,
timeout_event: Callable = raise_timeout_error,
divide: bool = False,
preload: bool = False,
):
if num_workers < 0:
raise ValueError("num_workers should not be negative")
if timeout < 0:
raise ValueError("timeout should not be negative")
if divide and num_workers <= 1:
raise ValueError("divide should not be set to True when num_workers <= 1")
self.dataset = dataset
self.num_workers = num_workers
self.timeout = timeout
self.timeout_event = timeout_event
self.divide = divide
self.preload = preload
if isinstance(dataset, StreamDataset):
self.sampler = sampler if sampler else StreamSampler(batch_size=1)
assert isinstance(
self.sampler, StreamSampler
), "types of dataset and sampler do not match"
else:
assert isinstance(
dataset, Dataset
), "Can not recognize this kind of dataset: %s" % type(dataset)
self.sampler = (
sampler
if sampler
else SequentialSampler(dataset, batch_size=1, drop_last=False)
)
assert isinstance(
self.sampler, MapSampler
), "types of dataset and sampler do not match"
if divide:
if self.sampler.batch_size <= self.num_workers:
raise ValueError(
"batch size must not smaller than num_workers in divide mode."
)
elif self.sampler.batch_size % self.num_workers:
logger.warning(
"batch size is not divisible by num_workers, may lose performance in divide mode."
)
if transform is None:
self.transform = PseudoTransform()
else:
self.transform = transform
if collator is None:
self.collator = Collator()
else:
self.collator = collator
self.__initialized = True
def __iter__(self):
if platform.system() == "Windows" and self.num_workers > 0:
print(
"pyarrow.plasma does not support ParallelDataLoader on windows, changing num_workers to be zero"
)
self.num_workers = 0
if os.getenv("TERMUX_VERSION"):
print(
"pyarrow do not support on termux env now, changing num_workers to be zero"
)
self.num_workers = 0
if isinstance(self.dataset, StreamDataset):
if not self.num_workers:
return _SerialStreamDataLoaderIter(self, self.preload)
else:
return _ParallelStreamDataLoaderIter(self, self.preload)
else:
assert isinstance(
self.dataset, Dataset
), "Can not recognize this kind of dataset: %s" % type(self.dataset)
if not self.num_workers:
return _SerialMapDataLoaderIter(self, self.preload)
else:
return _ParallelMapDataLoaderIter(self, self.preload)
def __len__(self):
return len(self.sampler)
class PreLoader:
def __init__(self, preload):
if preload:
self.default_device = get_default_device()
self.pre_load_device = self.default_device + ":" + str(_sh.get_next())
self.pre_load_device_cache = None
self.preload = preload
def _load_tensor(self, batch, cached=True):
if isinstance(batch, np.ndarray):
device = self.pre_load_device if cached else self.default_device
return Tensor(batch, device=device)
elif isinstance(batch, collections.abc.Mapping):
return {k: self._load_tensor(v, cached) for k, v in batch.items()}
elif isinstance(batch, tuple) and hasattr(batch, "_fields"): return type(batch)(*(self._load_tensor(value, cached) for value in batch))
elif isinstance(batch, collections.abc.Sequence):
return [self._load_tensor(value, cached) for value in batch]
else:
return batch
def _load_cache(self, data):
if isinstance(data, Tensor):
if data.device == self.default_device:
return data
return copy(data, device=self.default_device)
elif isinstance(data, collections.abc.Mapping):
return {k: self._load_cache(v) for k, v in data.items()}
elif isinstance(data, tuple) and hasattr(data, "_fields"): return type(data)(*(self._load_cache(value) for value in data))
elif isinstance(data, collections.abc.Sequence):
return [self._load_cache(value) for value in data]
else:
return data
def _swap_out_cache(self):
out = self._load_cache(self.pre_load_device_cache)
self.pre_load_device_cache = None return out
class _BaseMapDataLoaderIter(PreLoader):
def __init__(self, loader, preload):
super().__init__(preload)
self.dataset = loader.dataset
self.sampler = loader.sampler
self.seed = _random_seed_generator().__next__()
self.transform = loader.transform
self.collator = loader.collator
self.num_workers = loader.num_workers
self.timeout = loader.timeout
self.timeout_event = loader.timeout_event
self.divide = loader.divide
self.num_processed = 0
def _get_next_batch(self):
raise NotImplementedError
def __len__(self):
return len(self.sampler)
def __iter__(self):
return self
def __next__(self):
if self.preload:
cached = self.pre_load_device_cache
if cached is None: if self.num_processed >= len(self): raise StopIteration
elif self.num_processed == 0: self._try_load_tensor(cached=False) out = self._swap_out_cache()
self._try_load_tensor()
return out
else:
if self.num_processed >= len(self):
raise StopIteration
minibatch = self._get_next_batch()
self.num_processed += 1
return minibatch
def _try_load_tensor(self, cached=True):
if self.num_processed >= len(self):
return
else:
self.num_processed += 1
batch = self._get_next_batch()
self.pre_load_device_cache = self._load_tensor(batch, cached)
class _SerialMapDataLoaderIter(_BaseMapDataLoaderIter):
def __init__(self, loader, preload):
super(_SerialMapDataLoaderIter, self).__init__(loader, preload)
self.indices_iter = iter(self.sampler)
def _get_next_batch(self):
indices = next(self.indices_iter)
items = [self.dataset[idx] for idx in indices]
trans_items = self.transform.apply_batch(items)
return self.collator.apply(trans_items)
class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter):
__initialized = False
def __init__(self, loader, preload):
super(_ParallelMapDataLoaderIter, self).__init__(loader, preload)
self.task_queues = [
multiprocessing.Queue(maxsize=2) for _ in range(self.num_workers)
]
self.feed_batch_idx = multiprocessing.Value("i", 0)
self.target_batch_idx = multiprocessing.Value("i", 0)
self.shutdown_flag = multiprocessing.Value("i", 0)
self.trans_data_queues = [
multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
]
from .tools._queue import PlasmaShmQueue
self.batch_queue = PlasmaShmQueue(maxsize=2)
self.task_feeding_worker = multiprocessing.Process(
target=_task_feeding_loop,
args=(
iter(self.sampler),
self.task_queues,
self.num_workers,
self.divide,
self.shutdown_flag,
self.feed_batch_idx,
),
daemon=True,
)
gc.collect()
self.task_feeding_worker.start()
self.workers = []
for worker_id in range(self.num_workers):
worker = multiprocessing.Process(
target=_worker_loop,
args=(
self.dataset,
self.task_queues[worker_id],
self.trans_data_queues[worker_id],
self.transform,
self.seed + worker_id + 1,
self.shutdown_flag,
),
daemon=True,
)
gc.collect()
worker.start()
self.workers.append(worker)
if self.divide:
self.data_collecting_worker = multiprocessing.Process(
target=_data_gathering_loop,
args=(
self.trans_data_queues,
self.batch_queue,
self.collator,
len(self),
self.num_workers,
self.shutdown_flag,
self.target_batch_idx,
),
daemon=True,
)
else:
self.data_collecting_worker = multiprocessing.Process(
target=_data_selecting_loop,
args=(
self.trans_data_queues,
self.batch_queue,
self.collator,
len(self),
self.num_workers,
self.shutdown_flag,
self.target_batch_idx,
),
daemon=True,
)
gc.collect()
self.data_collecting_worker.start()
self.__initialized = True
def _check_workers(self):
if not self.data_collecting_worker.is_alive():
exitcode = self.data_collecting_worker.exitcode
if exitcode != 0:
raise RuntimeError("data collecting worker died. {}".format(exitcode))
if not self.task_feeding_worker.is_alive():
exitcode = self.task_feeding_worker.exitcode
if exitcode != 0:
raise RuntimeError("task feeding worker died. {}".format(exitcode))
for worker_id, worker in enumerate(self.workers):
if not worker.is_alive():
exitcode = worker.exitcode
if exitcode != 0:
raise RuntimeError("worker:{} died. {}".format(worker_id, exitcode))
logger.debug("all workers are alive.")
def _get_next_batch(self):
start_time = time.time()
while True:
self._check_workers()
try:
return self.batch_queue.get(timeout=1)
except queue.Empty:
logger.debug("batch queue empty!")
waited_time = time.time() - start_time
if self.timeout > 0:
if waited_time > self.timeout:
raise RuntimeError("get_next_batch timeout!")
def _shutdown(self):
with self.shutdown_flag.get_lock():
self.shutdown_flag.value = 1
if self.task_feeding_worker.is_alive():
self.task_feeding_worker.terminate()
self.task_feeding_worker.join()
if self.data_collecting_worker.is_alive():
self.data_collecting_worker.terminate()
self.data_collecting_worker.join()
for worker in self.workers:
if worker.is_alive():
worker.terminate()
worker.join()
for q in self.trans_data_queues:
q.cancel_join_thread()
q.close()
for q in self.task_queues:
q.cancel_join_thread()
q.close()
self.batch_queue.cancel_join_thread()
self.batch_queue.close()
def __del__(self):
if self.__initialized:
self._shutdown()
class _BaseStreamDataLoaderIter(PreLoader):
def __init__(self, loader, preload):
super().__init__(preload)
self.dataset = loader.dataset
self.sampler = loader.sampler
self.transform = loader.transform
self.collator = loader.collator
self.num_workers = loader.num_workers
self.timeout = loader.timeout
self.timeout_event = loader.timeout_event
def _get_next_batch(self):
raise NotImplementedError
def _process_raw_data(self, raw_data):
assert len(raw_data) == 2 and isinstance(
raw_data[0], bool
), "StreamDataset should provide a binary tuple, the first item indicates whether the data was batched."
if not raw_data[0]:
data = list((x,) for x in raw_data[1])
else:
data = raw_data[1]
ret = []
for idx in range(len(data[0])):
ret.append(tuple(e[idx] for e in data))
return ret
def __iter__(self):
return self
def __next__(self):
if self.preload:
if self.pre_load_device_cache is None:
self._try_load_tensor(cached=False) out = self._swap_out_cache()
self._try_load_tensor() return out
else:
return self._get_next_batch()
def _try_load_tensor(self, cached=True):
batch = self._get_next_batch()
self.pre_load_device_cache = self._load_tensor(batch, cached)
class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter):
def __init__(self, loader, preload):
super().__init__(loader, preload)
self.dataset_iter = iter(self.dataset)
self.idx = 0
self.unused = []
def _try_get_raw_data(self, start_time):
raw_data = None
while not raw_data:
try:
if self.timeout > 0:
timer = threading.Timer(self.timeout, thread.interrupt_main)
timer.start()
raw_data = next(self.dataset_iter)
if self.timeout > 0:
timer.cancel()
except KeyboardInterrupt:
raw_data = self.timeout_event()
except:
if self.timeout > 0:
timer.cancel()
waited_time = time.time() - start_time
if waited_time > self.timeout:
raw_data = self.timeout_event()
return raw_data
def _get_next_batch(self):
ret = []
start_time = time.time()
while len(ret) < self.sampler.batch_size:
if len(self.unused) != 0:
batch_data = self.unused
else:
raw_data = self._try_get_raw_data(start_time)
batch_data = self._process_raw_data(raw_data)
while len(batch_data) != 0 and len(ret) < self.sampler.batch_size:
data = batch_data.pop()
ret.append(self.transform.apply(data))
self.unused = batch_data
return self.collator.apply(ret)
class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
__initialized = False
def __init__(self, loader, preload):
super().__init__(loader, preload)
self.shutdown_flag = multiprocessing.Value("i", 0)
self.raw_data_queues = [
multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
]
self.trans_data_queues = [
multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
]
from .tools._queue import PlasmaShmQueue
self.batch_queue = PlasmaShmQueue(maxsize=2)
self.recieve_worker = multiprocessing.Process(
target=self._worker_to_raw_data_queues, daemon=True
)
gc.collect()
self.recieve_worker.start()
self.transform_workers = []
for worker_id in range(self.num_workers):
worker = multiprocessing.Process(
target=self._worker_to_trans_data_queues, args=(worker_id,), daemon=True
)
gc.collect()
worker.start()
self.transform_workers.append(worker)
self.collect_worker = multiprocessing.Process(
target=self._worker_to_batch_queue, daemon=True
)
gc.collect()
self.collect_worker.start()
self.__initialized = True
def _put_raw_data_queues(self, raw_data, qidx):
batch_data = self._process_raw_data(raw_data)
for data in batch_data:
while True:
qidx = qidx % self.num_workers
try:
self.raw_data_queues[qidx].put(data)
break
except queue.Full:
if self.shutdown_flag.value == 1:
break
logger.debug("raw data queue %d is full" % qidx)
finally:
qidx += 1
return qidx
def _worker_to_raw_data_queues(self):
dataset_iter = iter(self.dataset)
qidx = 0
while True:
if self.shutdown_flag.value == 1:
break
raw_data = next(dataset_iter)
qidx = self._put_raw_data_queues(raw_data, qidx)
def _worker_to_trans_data_queues(self, worker_id):
while True:
if self.shutdown_flag.value == 1:
break
try:
data = self.raw_data_queues[worker_id].get(timeout=GLOBAL_TIMEOUT)
except queue.Empty:
continue
trans_data = self.transform.apply(data)
while True:
try:
self.trans_data_queues[worker_id].put(trans_data)
break
except queue.Full:
if self.shutdown_flag.value == 1:
break
logger.debug("batch queue if full")
def _worker_to_batch_queue(self):
cnt = -1
trans_items = []
while True:
if self.shutdown_flag.value == 1:
break
cnt += 1
queue_id = cnt % self.num_workers
try:
trans_item = self.trans_data_queues[queue_id].get(
timeout=GLOBAL_TIMEOUT
)
except queue.Empty:
continue
trans_items.append(trans_item)
if len(trans_items) == self.sampler.batch_size:
batch_data = self.collator.apply(trans_items)
while True:
try:
self.batch_queue.put(batch_data, timeout=1)
break
except queue.Full:
if self.shutdown_flag.value == 1:
break
logger.debug("batch queue is full")
trans_items = []
def _check_workers(self):
if not self.collect_worker.is_alive():
exitcode = self.collect_worker.exitcode
if exitcode != 0:
raise RuntimeError("collator worker died. {}".format(exitcode))
for worker_id, worker in enumerate(self.transform_workers):
if not worker.is_alive():
exitcode = worker.exitcode
if exitcode != 0:
raise RuntimeError(
"worker: {} died. {}".format(worker_id, exitcode)
)
def _get_next_batch(self):
start_time = time.time()
while True:
self._check_workers()
try:
return self.batch_queue.get(timeout=1)
except queue.Empty:
logger.debug("batch queue empty!")
waited_time = time.time() - start_time
if self.timeout > 0 and waited_time > self.timeout:
self._put_raw_data_queues(self.timeout_event(), 0)
def _shutdown(self):
with self.shutdown_flag.get_lock():
self.shutdown_flag.value = 1
if self.recieve_worker.is_alive():
self.recieve_worker.terminate()
self.recieve_worker.join()
if self.collect_worker.is_alive():
self.collect_worker.terminate()
self.collect_worker.join()
for worker in self.transform_workers:
if worker.is_alive():
worker.terminate()
worker.join()
for q in self.raw_data_queues:
q.cancel_join_thread()
q.close()
for q in self.trans_data_queues:
q.cancel_join_thread()
q.close()
self.batch_queue.cancel_join_thread()
self.batch_queue.close()
def __del__(self):
if self.__initialized:
self._shutdown()
def _task_feeding_loop(
indices_iter, task_queues, num_workers, divide, shutdown_flag, feed_batch_idx
):
while True:
if shutdown_flag.value == 1:
break
batch_idx = feed_batch_idx.value
try:
indices = next(indices_iter)
except StopIteration:
break
if divide:
while any([q.full() for q in task_queues]):
if shutdown_flag.value == 1:
return
sub_num = math.ceil(len(indices) / num_workers)
for worker_id in range(num_workers):
sub_indices = indices[worker_id * sub_num : (worker_id + 1) * sub_num]
task_queues[worker_id].put((batch_idx, sub_indices))
else:
target_id = batch_idx % num_workers
while task_queues[target_id].full():
if shutdown_flag.value == 1:
return
task_queues[target_id].put((batch_idx, indices))
with feed_batch_idx.get_lock():
feed_batch_idx.value += 1
def _worker_loop(dataset, task_queue, trans_data_queue, transform, seed, shutdown_flag):
random.seed(seed)
np.random.seed(seed)
while True:
if shutdown_flag.value == 1:
break
try:
batch_idx, indices = task_queue.get(timeout=GLOBAL_TIMEOUT)
except queue.Empty:
continue
if len(indices) > 0:
items = [dataset[idx] for idx in indices]
trans_items = transform.apply_batch(items)
else:
trans_items = ()
while True:
try:
trans_data_queue.put((batch_idx, trans_items), timeout=1)
break
except queue.Full:
if shutdown_flag.value == 1:
break
logger.debug("batch part queue is full!")
def _data_gathering_loop(
trans_data_queues,
batch_queue,
collator,
length,
num_workers,
shutdown_flag,
target_idx,
):
while True:
if shutdown_flag.value == 1:
break
target_batch_idx = target_idx.value
if target_batch_idx >= length:
break
full_trans_items = []
for worker_id in range(num_workers):
while True:
try:
batch_idx, trans_items = trans_data_queues[worker_id].get(
timeout=GLOBAL_TIMEOUT
)
break
except queue.Empty:
if shutdown_flag.value == 1:
break
logger.debug(
"worker:{} data queue get timeout! target batch idx:{}".format(
worker_id, target_batch_idx
)
)
if batch_idx != target_batch_idx:
raise RuntimeError(
"Unexperted batch_idx in data gathering loop. worker_id:{}.".format(
worker_id
)
)
else:
full_trans_items.extend(trans_items)
full_batch = collator.apply(full_trans_items)
while True:
try:
batch_queue.put(full_batch, timeout=1)
break
except queue.Full:
if shutdown_flag.value == 1:
break
logger.debug("batch queue is full!")
with target_idx.get_lock():
target_idx.value += 1
batch_queue.disconnect_client()
def _data_selecting_loop(
trans_data_queues,
batch_queue,
collator,
length,
num_workers,
shutdown_flag,
target_idx,
):
while True:
if shutdown_flag.value == 1:
break
target_batch_idx = target_idx.value
if target_batch_idx >= length:
break
target_worker_id = target_batch_idx % num_workers
while True:
try:
batch_idx, trans_items = trans_data_queues[target_worker_id].get(
timeout=GLOBAL_TIMEOUT
)
batch_data = collator.apply(trans_items)
break
except queue.Empty:
if shutdown_flag.value == 1:
break
logger.debug(
"worker:{} data queue get timeout! target batch idx:{}".format(
target_worker_id, target_batch_idx
)
)
if batch_idx != target_batch_idx:
raise RuntimeError(
"batch_idx {} mismatch the target_batch_idx {}".format(
batch_idx, target_batch_idx
)
)
while True:
try:
batch_queue.put(batch_data, timeout=1)
break
except queue.Full:
if shutdown_flag.value == 1:
break
logger.debug("batch queue is full!")
with target_idx.get_lock():
target_idx.value += 1
batch_queue.disconnect_client()