import logging
import multiprocessing
import os
import re
from dataclasses import dataclass
from io import BytesIO
from itertools import repeat, takewhile
from pathlib import Path
from typing import TYPE_CHECKING, Any, List, Mapping, Optional, Union
from urllib.parse import ParseResult as UrlParseResult
from urllib.parse import urlparse
import numpy as np
if TYPE_CHECKING:
from scipy import sparse
else:
try:
from scipy import sparse
except ImportError:
sparse = None
from .bed_reader import (
check_file_cloud,
read_cloud_f32,
read_cloud_f64,
read_cloud_i8,
read_f32,
read_f64,
read_i8,
url_to_bytes,
)
def _rawincount(f):
f.seek(0)
bufgen = takewhile(lambda x: x, (f.read(1024 * 1024) for _ in repeat(None)))
return sum(buf.count(b"\n") for buf in bufgen)
@dataclass
class _MetaMeta:
suffix: str
column: int
dtype: type
missing_value: object
fill_sequence: object
def _all_same(key, length, missing, dtype):
if np.issubdtype(dtype, np.str_):
dtype = f"<U{len(missing)}"
return np.full(length, missing, dtype=dtype)
def _sequence(key, length, missing, dtype):
if np.issubdtype(dtype, np.str_):
longest = len(f"{key}{length}")
dtype = f"<U{longest}"
return np.fromiter(
(f"{key}{i + 1}" for i in range(length)), dtype=dtype, count=length,
)
_delimiters = {"fam": r"\s+", "bim": "\t"}
_count_name = {"fam": "iid_count", "bim": "sid_count"}
_meta_meta = {
"fid": _MetaMeta("fam", 0, np.str_, "0", _all_same),
"iid": _MetaMeta("fam", 1, np.str_, None, _sequence),
"father": _MetaMeta("fam", 2, np.str_, "0", _all_same),
"mother": _MetaMeta("fam", 3, np.str_, "0", _all_same),
"sex": _MetaMeta("fam", 4, np.int32, 0, _all_same),
"pheno": _MetaMeta("fam", 5, np.str_, "0", _all_same),
"chromosome": _MetaMeta("bim", 0, np.str_, "0", _all_same),
"sid": _MetaMeta("bim", 1, np.str_, None, _sequence),
"cm_position": _MetaMeta("bim", 2, np.float32, 0, _all_same),
"bp_position": _MetaMeta("bim", 3, np.int32, 0, _all_same),
"allele_1": _MetaMeta("bim", 4, np.str_, "A1", _all_same),
"allele_2": _MetaMeta("bim", 5, np.str_, "A2", _all_same),
}
def get_num_threads(num_threads=None):
if num_threads is not None:
return num_threads
if "PST_NUM_THREADS" in os.environ:
return int(os.environ["PST_NUM_THREADS"])
if "NUM_THREADS" in os.environ:
return int(os.environ["NUM_THREADS"])
if "MKL_NUM_THREADS" in os.environ:
return int(os.environ["MKL_NUM_THREADS"])
return multiprocessing.cpu_count()
def get_max_concurrent_requests(max_concurrent_requests=None):
if max_concurrent_requests is not None:
return max_concurrent_requests
return 10
def get_max_chunk_bytes(max_chunk_bytes=None):
if max_chunk_bytes is not None:
return max_chunk_bytes
return 8_000_000
class open_bed:
def __init__(
self,
location: Union[str, Path, UrlParseResult],
iid_count: Optional[int] = None,
sid_count: Optional[int] = None,
properties: Mapping[str, List[Any]] = {},
count_A1: bool = True,
num_threads: Optional[int] = None,
skip_format_check: bool = False,
fam_location: Optional[Union[str, Path, UrlParseResult]] = None,
bim_location: Optional[Union[str, Path, UrlParseResult]] = None,
cloud_options: Mapping[str, str] = {},
max_concurrent_requests: Optional[int] = None,
max_chunk_bytes: Optional[int] = None,
filepath: Optional[Union[str, Path]] = None,
fam_filepath: Optional[Union[str, Path]] = None,
bim_filepath: Optional[Union[str, Path]] = None,
) -> None:
location = self._combined(location, filepath, "location", "filepath")
fam_location = self._combined(
fam_location, fam_filepath, "fam_location", "fam_filepath",
)
bim_location = self._combined(
bim_location, bim_filepath, "bim_location", "bim_filepath",
)
self.location = self._path_or_url(location)
self.cloud_options = cloud_options
self.count_A1 = count_A1
self._num_threads = num_threads
self._max_concurrent_requests = max_concurrent_requests
self._max_chunk_bytes = max_chunk_bytes
self.skip_format_check = skip_format_check
self._fam_location = (
self._path_or_url(fam_location)
if fam_location is not None
else self._replace_extension(self.location, "fam")
)
self._bim_location = (
self._path_or_url(bim_location)
if bim_location is not None
else self._replace_extension(self.location, "bim")
)
self.properties_dict, self._counts = open_bed._fix_up_properties(
properties, iid_count, sid_count, use_fill_sequence=False,
)
self._iid_range = None
self._sid_range = None
if not self.skip_format_check:
if self._is_url(self.location):
check_file_cloud(self.location.geturl(), self.cloud_options)
else:
with open(self.location, "rb") as filepointer:
self._mode = self._check_file(filepointer)
@staticmethod
def _combined(location, filepath, location_name, filepath_name):
if location is not None and filepath is not None:
msg = f"Cannot set both {location_name} and {filepath_name}"
raise ValueError(msg)
return location if location is not None else filepath
@staticmethod
def _replace_extension(location, extension):
if open_bed._is_url(location):
path, _ = os.path.splitext(location.path)
new_path = f"{path}.{extension}"
return UrlParseResult(
scheme=location.scheme,
netloc=location.netloc,
path=new_path,
params=location.params,
query=location.query,
fragment=location.fragment,
)
assert isinstance(location, Path) return location.parent / (location.stem + "." + extension)
@staticmethod
def _is_url(location):
return isinstance(location, UrlParseResult)
@staticmethod
def _path_or_url(input):
if isinstance(input, Path):
return input
if isinstance(input, UrlParseResult):
return input
assert isinstance(
input, str,
), "Expected a string or Path object or UrlParseResult"
parsed = urlparse(input)
if parsed.scheme and "://" in input:
return parsed
return Path(input)
def read(
self,
index: Optional[Any] = None,
dtype: Optional[Union[type, str]] = "float32",
order: Optional[str] = "F",
force_python_only: Optional[bool] = False,
num_threads=None,
max_concurrent_requests=None,
max_chunk_bytes=None,
) -> np.ndarray:
iid_index_or_slice_etc, sid_index_or_slice_etc = self._split_index(index)
dtype = np.dtype(dtype)
if order not in {"F", "C"}:
msg = f"order '{order}' not known, only 'F', 'C'"
raise ValueError(msg)
if self._iid_range is None:
self._iid_range = np.arange(self.iid_count, dtype="intp")
if self._sid_range is None:
self._sid_range = np.arange(self.sid_count, dtype="intp")
iid_index = np.ascontiguousarray(
self._iid_range[iid_index_or_slice_etc],
dtype="intp",
)
sid_index = np.ascontiguousarray(
self._sid_range[sid_index_or_slice_etc], dtype="intp",
)
if not force_python_only or open_bed._is_url(self.location):
num_threads = get_num_threads(
self._num_threads if num_threads is None else num_threads,
)
max_concurrent_requests = get_max_concurrent_requests(
self._max_concurrent_requests
if max_concurrent_requests is None
else max_concurrent_requests,
)
max_chunk_bytes = get_max_chunk_bytes(
self._max_chunk_bytes if max_chunk_bytes is None else max_chunk_bytes,
)
val = np.zeros((len(iid_index), len(sid_index)), order=order, dtype=dtype)
if self.iid_count > 0 and self.sid_count > 0:
reader, location_str, is_cloud = self._pick_reader(dtype)
if not is_cloud:
reader(
location_str,
self.cloud_options,
iid_count=self.iid_count,
sid_count=self.sid_count,
is_a1_counted=self.count_A1,
iid_index=iid_index,
sid_index=sid_index,
val=val,
num_threads=num_threads,
)
else:
reader(
location_str,
self.cloud_options,
iid_count=self.iid_count,
sid_count=self.sid_count,
is_a1_counted=self.count_A1,
iid_index=iid_index,
sid_index=sid_index,
val=val,
num_threads=num_threads,
max_concurrent_requests=max_concurrent_requests,
max_chunk_bytes=max_chunk_bytes,
)
else:
if not self.count_A1:
byteZero = 0
byteThree = 2
else:
byteZero = 2
byteThree = 0
missing = -127 if dtype == np.int8 else np.nan
if self.major == "SNP":
minor_count = self.iid_count
minor_index = iid_index
major_index = sid_index
else:
minor_count = self.sid_count
minor_index = sid_index
major_index = iid_index
val = np.zeros(
((int(np.ceil(0.25 * minor_count)) * 4), len(major_index)),
order=order,
dtype=dtype,
)
nbyte = int(np.ceil(0.25 * minor_count))
with open(self.location, "rb") as filepointer:
for major_index_value, major_index_index in enumerate(major_index):
startbit = int(np.ceil(0.25 * minor_count) * major_index_index + 3)
filepointer.seek(startbit)
bytes = np.array(bytearray(filepointer.read(nbyte))).reshape(
(int(np.ceil(0.25 * minor_count)), 1), order="F",
)
val[3::4, major_index_value : major_index_value + 1] = byteZero
val[3::4, major_index_value : major_index_value + 1][
bytes >= 64
] = missing
val[3::4, major_index_value : major_index_value + 1][
bytes >= 128
] = 1
val[3::4, major_index_value : major_index_value + 1][
bytes >= 192
] = byteThree
bytes = np.mod(bytes, 64)
val[2::4, major_index_value : major_index_value + 1] = byteZero
val[2::4, major_index_value : major_index_value + 1][
bytes >= 16
] = missing
val[2::4, major_index_value : major_index_value + 1][
bytes >= 32
] = 1
val[2::4, major_index_value : major_index_value + 1][
bytes >= 48
] = byteThree
bytes = np.mod(bytes, 16)
val[1::4, major_index_value : major_index_value + 1] = byteZero
val[1::4, major_index_value : major_index_value + 1][
bytes >= 4
] = missing
val[1::4, major_index_value : major_index_value + 1][bytes >= 8] = 1
val[1::4, major_index_value : major_index_value + 1][
bytes >= 12
] = byteThree
bytes = np.mod(bytes, 4)
val[0::4, major_index_value : major_index_value + 1] = byteZero
val[0::4, major_index_value : major_index_value + 1][
bytes >= 1
] = missing
val[0::4, major_index_value : major_index_value + 1][bytes >= 2] = 1
val[0::4, major_index_value : major_index_value + 1][
bytes >= 3
] = byteThree
val = val[minor_index, :] assert val.dtype == np.dtype(dtype) if not open_bed._array_properties_are_ok(val, order):
val = val.copy(order=order)
if self.major == "individual":
val = val.T
return val
def _pick_reader(self, dtype):
if dtype == np.int8:
file_reader = read_i8
cloud_reader = read_cloud_i8
elif dtype == np.float64:
file_reader = read_f64
cloud_reader = read_cloud_f64
elif dtype == np.float32:
file_reader = read_f32
cloud_reader = read_cloud_f32
else:
raise ValueError(
f"dtype '{dtype}' not known, only "
+ "'int8', 'float32', and 'float64' are allowed.",
)
if open_bed._is_url(self.location):
reader = cloud_reader
location_str = self.location.geturl()
is_cloud = True
else:
reader = file_reader
location_str = str(self.location.as_posix())
is_cloud = False
return reader, location_str, is_cloud
def __str__(self) -> str:
return f"{self.__class__.__name__}('{self.location}',...)"
@property
def major(self) -> str:
if self._is_url(self.location):
msg = "Cannot determine major mode for cloud files"
raise ValueError(msg)
if not hasattr(self, "mode"):
with open(self.location, "rb") as filepointer:
self._mode = self._check_file(filepointer)
return "individual" if self._mode == b"\x00" else "SNP"
@property
def fid(self) -> np.ndarray:
return self.property_item("fid")
@property
def iid(self) -> np.ndarray:
return self.property_item("iid")
@property
def father(self) -> np.ndarray:
return self.property_item("father")
@property
def mother(self) -> np.ndarray:
return self.property_item("mother")
@property
def sex(self) -> np.ndarray:
return self.property_item("sex")
@property
def pheno(self) -> np.ndarray:
return self.property_item("pheno")
@property
def properties(self) -> Mapping[str, np.array]:
for key in _meta_meta:
self.property_item(key)
return self.properties_dict
def property_item(self, name: str) -> np.ndarray:
if name not in self.properties_dict:
mm = _meta_meta[name]
self._read_fam_or_bim(suffix=mm.suffix)
return self.properties_dict[name]
@property
def chromosome(self) -> np.ndarray:
return self.property_item("chromosome")
@property
def sid(self) -> np.ndarray:
return self.property_item("sid")
@property
def cm_position(self) -> np.ndarray:
return self.property_item("cm_position")
@property
def bp_position(self) -> np.ndarray:
return self.property_item("bp_position")
@property
def allele_1(self) -> np.ndarray:
return self.property_item("allele_1")
@property
def allele_2(self) -> np.ndarray:
return self.property_item("allele_2")
@property
def iid_count(self) -> np.ndarray:
return self._count("fam")
@property
def sid_count(self) -> np.ndarray:
return self._count("bim")
def _property_location(self, suffix):
if suffix == "fam":
return self._fam_location
assert suffix == "bim" return self._bim_location
def _count(self, suffix):
count = self._counts[suffix]
if count is None:
location = self._property_location(suffix)
if open_bed._is_url(location):
if suffix == "fam":
if self.property_item("iid") is None:
file_bytes = bytes(
url_to_bytes(location.geturl(), self.cloud_options),
)
count = _rawincount(BytesIO(file_bytes))
else:
count = len(self.iid)
elif suffix == "bim":
if self.property_item("sid") is None:
file_bytes = bytes(
url_to_bytes(location.geturl(), self.cloud_options),
)
count = _rawincount(BytesIO(file_bytes))
else:
count = len(self.sid)
else:
msg = "real assert"
raise ValueError(msg)
else:
with open(location, "rb") as f:
count = _rawincount(f)
self._counts[suffix] = count
return count
@staticmethod
def _check_file(filepointer):
magic_number = filepointer.read(2)
if magic_number != b"l\x1b":
msg = "Not a valid .bed file"
raise ValueError(msg)
mode = filepointer.read(1)
if mode not in (b"\x00", b"\x01"):
msg = "Not a valid .bed file"
raise ValueError(msg)
return mode
def __del__(self) -> None:
self.__exit__()
def __enter__(self):
return self
def __exit__(self, *_):
pass
@staticmethod
def _array_properties_are_ok(val, order):
if order == "F":
return val.flags["F_CONTIGUOUS"]
assert order == "C" return val.flags["C_CONTIGUOUS"]
@property
def shape(self):
return (self.iid_count, self.sid_count)
@staticmethod
def _split_index(index):
if not isinstance(index, tuple):
index = (None, index)
iid_index = open_bed._fix_up_index(index[0])
sid_index = open_bed._fix_up_index(index[1])
return iid_index, sid_index
@staticmethod
def _fix_up_index(index):
if index is None: return slice(None)
try: index = index.__index__() return [index]
except Exception:
pass
return index
@staticmethod
def _write_fam_or_bim(base_filepath, properties, suffix, property_filepath) -> None:
assert suffix in {"fam", "bim"}, "real assert"
filepath = (
Path(property_filepath)
if property_filepath is not None
else base_filepath.parent / (base_filepath.stem + "." + suffix)
)
fam_bim_list = []
for key, mm in _meta_meta.items():
if mm.suffix == suffix:
assert len(fam_bim_list) == mm.column, "real assert"
fam_bim_list.append(properties[key])
sep = " " if suffix == "fam" else "\t"
with open(filepath, "w") as filepointer:
for index in range(len(fam_bim_list[0])):
filepointer.write(
sep.join(str(seq[index]) for seq in fam_bim_list) + "\n",
)
@staticmethod
def _fix_up_properties_array(input, dtype, missing_value, key):
if input is None:
return None
if len(input) == 0:
return np.zeros([0], dtype=dtype)
if not isinstance(input, np.ndarray):
return open_bed._fix_up_properties_array(
np.array(input), dtype, missing_value, key,
)
if len(input.shape) != 1:
msg = f"{key} should be one dimensional"
raise ValueError(msg)
do_missing_values = True
if np.issubdtype(input.dtype, np.floating) and np.issubdtype(dtype, int):
input[input != input] = missing_value
old_settings = np.seterr(invalid="warn")
try:
output = np.array(input, dtype=dtype)
finally:
np.seterr(**old_settings)
elif not np.issubdtype(input.dtype, dtype):
old_settings = np.seterr(invalid="warn")
try:
output = np.array(input, dtype=dtype)
finally:
np.seterr(**old_settings)
else:
output = input
if do_missing_values and np.issubdtype(input.dtype, np.floating):
output[input != input] = missing_value
return output
@staticmethod
def _fix_up_properties(properties, iid_count, sid_count, use_fill_sequence):
for key in properties:
if key not in _meta_meta:
msg = f"properties key '{key}' not known"
raise KeyError(msg)
count_dict = {"fam": iid_count, "bim": sid_count}
properties_dict = {}
for key, mm in _meta_meta.items():
count = count_dict[mm.suffix]
if key not in properties or (use_fill_sequence and properties[key] is None):
if use_fill_sequence:
output = mm.fill_sequence(key, count, mm.missing_value, mm.dtype)
else:
continue else:
output = open_bed._fix_up_properties_array(
properties[key], mm.dtype, mm.missing_value, key,
)
if output is not None:
if count is None:
count_dict[mm.suffix] = len(output)
elif count != len(output):
raise ValueError(
f"The length of override {key}, {len(output)}, should not "
+ "be different from the current "
+ f"{_count_name[mm.suffix]}, {count}",
)
properties_dict[key] = output
return properties_dict, count_dict
def _read_fam_or_bim(self, suffix) -> None:
property_location = self._property_location(suffix)
logging.info(f"Loading {suffix} file {property_location}")
count = self._counts[suffix]
delimiter = _delimiters[suffix]
if delimiter in {r"\s+"}:
delimiter = None
usecolsdict = {}
dtype_dict = {}
for key, mm in _meta_meta.items():
if mm.suffix is suffix and key not in self.properties_dict:
usecolsdict[key] = mm.column
dtype_dict[mm.column] = mm.dtype
assert list(usecolsdict.values()) == sorted(usecolsdict.values()) assert len(usecolsdict) > 0
if self._is_url(property_location):
file_bytes = bytes(
url_to_bytes(property_location.geturl(), self.cloud_options),
)
if len(file_bytes) == 0:
columns, row_count = [], 0
else: columns, row_count = _read_csv(
BytesIO(file_bytes),
delimiter=delimiter,
dtype=dtype_dict,
usecols=usecolsdict.values(),
)
elif os.path.getsize(property_location) == 0:
columns, row_count = [], 0
else:
columns, row_count = _read_csv(
property_location,
delimiter=delimiter,
dtype=dtype_dict,
usecols=usecolsdict.values(),
)
if count is None:
self._counts[suffix] = row_count
elif count != row_count:
raise ValueError(
f"The number of lines in the *.{suffix} file, {row_count}, "
+ "should not be different from the current "
+ "f{_count_name[suffix]}, {count}",
)
for i, key in enumerate(usecolsdict.keys()):
mm = _meta_meta[key]
if row_count == 0:
output = np.array([], dtype=mm.dtype)
else:
output = columns[i]
if not np.issubdtype(output.dtype, mm.dtype):
output = np.array(output, dtype=mm.dtype)
self.properties_dict[key] = output
def read_sparse(
self,
index: Optional[Any] = None,
dtype: Optional[Union[type, str]] = "float32",
batch_size: Optional[int] = None,
format: Optional[str] = "csc",
num_threads=None,
max_concurrent_requests=None,
max_chunk_bytes=None,
) -> (Union[sparse.csc_matrix, sparse.csr_matrix]) if sparse is not None else None:
if sparse is None:
raise ImportError(
"The function read_sparse() requires scipy. "
+ "Install it with 'pip install --upgrade bed-reader[sparse]'.",
)
iid_index_or_slice_etc, sid_index_or_slice_etc = self._split_index(index)
dtype = np.dtype(dtype)
if self._iid_range is None:
self._iid_range = np.arange(self.iid_count, dtype="intp")
if self._sid_range is None:
self._sid_range = np.arange(self.sid_count, dtype="intp")
iid_index = np.ascontiguousarray(
self._iid_range[iid_index_or_slice_etc],
dtype="intp",
)
sid_index = np.ascontiguousarray(
self._sid_range[sid_index_or_slice_etc], dtype="intp",
)
if (
len(iid_index) > np.iinfo(np.int32).max
or len(sid_index) > np.iinfo(np.int32).max
):
msg = (
"Too many Individuals or SNPs (variants) requested. "
"Maximum is {np.iinfo(np.int32).max}."
)
raise ValueError(
msg,
)
if batch_size is None:
batch_size = round(np.sqrt(len(sid_index)))
num_threads = get_num_threads(
self._num_threads if num_threads is None else num_threads,
)
max_concurrent_requests = get_max_concurrent_requests(
self._max_concurrent_requests
if max_concurrent_requests is None
else max_concurrent_requests,
)
max_chunk_bytes = get_max_chunk_bytes(
self._max_chunk_bytes if max_chunk_bytes is None else max_chunk_bytes,
)
if format == "csc":
order = "F"
indptr = np.zeros(len(sid_index) + 1, dtype=np.int32)
elif format == "csr":
order = "C"
indptr = np.zeros(len(iid_index) + 1, dtype=np.int32)
else:
msg = f"format '{format}' not known. Expected 'csc' or 'csr'."
raise ValueError(msg)
data = [np.empty(0, dtype=dtype)]
indices = [np.empty(0, dtype=np.int32)]
if self.iid_count > 0 and self.sid_count > 0:
reader, location_str, is_cloud = self._pick_reader(dtype)
if format == "csc":
val = np.zeros((len(iid_index), batch_size), order=order, dtype=dtype)
for batch_start in range(0, len(sid_index), batch_size):
batch_end = batch_start + batch_size
if batch_end > len(sid_index):
batch_end = len(sid_index)
del val
val = np.zeros(
(len(iid_index), batch_end - batch_start),
order=order,
dtype=dtype,
)
batch_slice = np.s_[batch_start:batch_end]
batch_index = sid_index[batch_slice]
if not is_cloud:
reader(
location_str,
self.cloud_options,
iid_count=self.iid_count,
sid_count=self.sid_count,
is_a1_counted=self.count_A1,
iid_index=iid_index,
sid_index=batch_index,
val=val,
num_threads=num_threads,
)
else:
reader(
location_str,
self.cloud_options,
iid_count=self.iid_count,
sid_count=self.sid_count,
is_a1_counted=self.count_A1,
iid_index=iid_index,
sid_index=batch_index,
val=val,
num_threads=num_threads,
max_concurrent_requests=max_concurrent_requests,
max_chunk_bytes=max_chunk_bytes,
)
self.sparsify(
val, order, iid_index, batch_slice, data, indices, indptr,
)
else:
assert format == "csr" val = np.zeros((batch_size, len(sid_index)), order=order, dtype=dtype)
for batch_start in range(0, len(iid_index), batch_size):
batch_end = batch_start + batch_size
if batch_end > len(iid_index):
batch_end = len(iid_index)
del val
val = np.zeros(
(batch_end - batch_start, len(sid_index)),
order=order,
dtype=dtype,
)
batch_slice = np.s_[batch_start:batch_end]
batch_index = iid_index[batch_slice]
if not is_cloud:
reader(
location_str,
self.cloud_options,
iid_count=self.iid_count,
sid_count=self.sid_count,
is_a1_counted=self.count_A1,
iid_index=batch_index,
sid_index=sid_index,
val=val,
num_threads=num_threads,
)
else:
reader(
location_str,
self.cloud_options,
iid_count=self.iid_count,
sid_count=self.sid_count,
is_a1_counted=self.count_A1,
iid_index=batch_index,
sid_index=sid_index,
val=val,
num_threads=num_threads,
max_concurrent_requests=max_concurrent_requests,
max_chunk_bytes=max_chunk_bytes,
)
self.sparsify(
val, order, sid_index, batch_slice, data, indices, indptr,
)
data = np.concatenate(data)
indices = np.concatenate(indices)
if format == "csc":
return sparse.csc_matrix(
(data, indices, indptr), (len(iid_index), len(sid_index)),
)
assert format == "csr" return sparse.csr_matrix(
(data, indices, indptr), (len(iid_index), len(sid_index)),
)
def sparsify(self, val, order, minor_index, batch_slice, data, indices, indptr) -> None:
flatten = np.ravel(val, order=order)
nz_indices = np.flatnonzero(flatten).astype(np.int32)
column_indexes = nz_indices // len(minor_index)
counts = np.bincount(
column_indexes, minlength=batch_slice.stop - batch_slice.start,
).astype(np.int32)
counts_with_initial = np.r_[
indptr[batch_slice.start : batch_slice.start + 1], counts,
]
data.append(flatten[nz_indices])
indices.append(np.mod(nz_indices, len(minor_index)))
indptr[1:][batch_slice] = np.cumsum(counts_with_initial)[1:]
def _read_csv(filepath, delimiter=None, dtype=None, usecols=None):
pattern = re.compile(r"^np\.\w+\((.+?)\)$")
usecols_indices = list(usecols)
transposed = np.loadtxt(
filepath,
dtype=np.str_,
delimiter=delimiter,
usecols=usecols_indices,
unpack=True,
)
if transposed.ndim == 1:
transposed = transposed.reshape(-1, 1)
row_count = transposed.shape[1]
columns = []
for output_index, input_index in enumerate(usecols_indices):
col = transposed[output_index]
if len(col) > 0 and pattern.fullmatch(col[0]):
col = np.array([pattern.fullmatch(x).group(1) for x in col])
col_dtype = dtype.get(input_index, np.str_)
columns.append(_convert_to_dtype(col, col_dtype))
return columns, row_count
def _convert_to_dtype(str_arr, dtype):
assert dtype in [np.str_, np.float32, np.int32]
if dtype == np.str_:
return str_arr
try:
new_arr = str_arr.astype(dtype)
except ValueError as e:
if dtype == np.float32:
raise
try:
assert dtype == np.int32 float_arr = str_arr.astype(np.float32)
except ValueError:
raise e
new_arr = float_arr.astype(np.int32)
if not np.array_equal(new_arr, float_arr):
msg = f"invalid literal for int: '{str_arr[np.where(new_arr != float_arr)][:1]}')"
raise ValueError(
msg,
)
return new_arr
if __name__ == "__main__":
import pytest
logging.basicConfig(level=logging.INFO)
pytest.main(["--doctest-modules", __file__])