import io
from os import PathLike
from _zstd import ZstdCompressor, ZstdDecompressor, ZSTD_DStreamOutSize
from compression._common import _streams
__all__ = ('ZstdFile', 'open')
_MODE_CLOSED = 0
_MODE_READ = 1
_MODE_WRITE = 2
def _nbytes(dat, /):
if isinstance(dat, (bytes, bytearray)):
return len(dat)
with memoryview(dat) as mv:
return mv.nbytes
class ZstdFile(_streams.BaseStream):
FLUSH_BLOCK = ZstdCompressor.FLUSH_BLOCK
FLUSH_FRAME = ZstdCompressor.FLUSH_FRAME
def __init__(self, file, /, mode='r', *,
level=None, options=None, zstd_dict=None):
self._fp = None
self._close_fp = False
self._mode = _MODE_CLOSED
self._buffer = None
if not isinstance(mode, str):
raise ValueError('mode must be a str')
if options is not None and not isinstance(options, dict):
raise TypeError('options must be a dict or None')
mode = mode.removesuffix('b') if mode == 'r':
if level is not None:
raise TypeError('level is illegal in read mode')
self._mode = _MODE_READ
elif mode in {'w', 'a', 'x'}:
if level is not None and not isinstance(level, int):
raise TypeError('level must be int or None')
self._mode = _MODE_WRITE
self._compressor = ZstdCompressor(level=level, options=options,
zstd_dict=zstd_dict)
self._pos = 0
else:
raise ValueError(f'Invalid mode: {mode!r}')
if isinstance(file, (str, bytes, PathLike)):
self._fp = io.open(file, f'{mode}b')
self._close_fp = True
elif ((mode == 'r' and hasattr(file, 'read'))
or (mode != 'r' and hasattr(file, 'write'))):
self._fp = file
else:
raise TypeError('file must be a file-like object '
'or a str, bytes, or PathLike object')
if self._mode == _MODE_READ:
raw = _streams.DecompressReader(
self._fp,
ZstdDecompressor,
zstd_dict=zstd_dict,
options=options,
)
self._buffer = io.BufferedReader(raw)
def close(self):
if self._fp is None:
return
try:
if self._mode == _MODE_READ:
if getattr(self, '_buffer', None):
self._buffer.close()
self._buffer = None
elif self._mode == _MODE_WRITE:
self.flush(self.FLUSH_FRAME)
self._compressor = None
finally:
self._mode = _MODE_CLOSED
try:
if self._close_fp:
self._fp.close()
finally:
self._fp = None
self._close_fp = False
def write(self, data, /):
self._check_can_write()
length = _nbytes(data)
compressed = self._compressor.compress(data)
self._fp.write(compressed)
self._pos += length
return length
def flush(self, mode=FLUSH_BLOCK):
if self._mode == _MODE_READ:
return
self._check_not_closed()
if mode not in {self.FLUSH_BLOCK, self.FLUSH_FRAME}:
raise ValueError('Invalid mode argument, expected either '
'ZstdFile.FLUSH_FRAME or '
'ZstdFile.FLUSH_BLOCK')
if self._compressor.last_mode == mode:
return
data = self._compressor.flush(mode)
self._fp.write(data)
if hasattr(self._fp, 'flush'):
self._fp.flush()
def read(self, size=-1):
if size is None:
size = -1
self._check_can_read()
return self._buffer.read(size)
def read1(self, size=-1):
self._check_can_read()
if size < 0:
size = ZSTD_DStreamOutSize
return self._buffer.read1(size)
def readinto(self, b):
self._check_can_read()
return self._buffer.readinto(b)
def readinto1(self, b):
self._check_can_read()
return self._buffer.readinto1(b)
def readline(self, size=-1):
self._check_can_read()
return self._buffer.readline(size)
def seek(self, offset, whence=io.SEEK_SET):
self._check_can_read()
return self._buffer.seek(offset, whence)
def peek(self, size=-1):
self._check_can_read()
return self._buffer.peek(size)
def __next__(self):
if ret := self._buffer.readline():
return ret
raise StopIteration
def tell(self):
self._check_not_closed()
if self._mode == _MODE_READ:
return self._buffer.tell()
elif self._mode == _MODE_WRITE:
return self._pos
def fileno(self):
self._check_not_closed()
return self._fp.fileno()
@property
def name(self):
self._check_not_closed()
return self._fp.name
@property
def mode(self):
return 'wb' if self._mode == _MODE_WRITE else 'rb'
@property
def closed(self):
return self._mode == _MODE_CLOSED
def seekable(self):
return self.readable() and self._buffer.seekable()
def readable(self):
self._check_not_closed()
return self._mode == _MODE_READ
def writable(self):
self._check_not_closed()
return self._mode == _MODE_WRITE
def open(file, /, mode='rb', *, level=None, options=None, zstd_dict=None,
encoding=None, errors=None, newline=None):
text_mode = 't' in mode
mode = mode.replace('t', '')
if text_mode:
if 'b' in mode:
raise ValueError(f'Invalid mode: {mode!r}')
else:
if encoding is not None:
raise ValueError('Argument "encoding" not supported in binary mode')
if errors is not None:
raise ValueError('Argument "errors" not supported in binary mode')
if newline is not None:
raise ValueError('Argument "newline" not supported in binary mode')
binary_file = ZstdFile(file, mode, level=level, options=options,
zstd_dict=zstd_dict)
if text_mode:
return io.TextIOWrapper(binary_file, encoding, errors, newline)
else:
return binary_file