from __future__ import annotations
import enum
import io
import struct
from dataclasses import dataclass
from decimal import Decimal
from typing import Iterable, List, Optional, Type
import pytest
from _pytest.fixtures import SubRequest
from _pytest.mark.structures import ParameterSet
from typing_extensions import final
from chia_rs.sized_ints import (
int8,
int16,
int32,
int64,
int512,
uint8,
uint16,
uint32,
uint64,
uint128,
)
from chia_rs.struct_stream import StructStream, parse_metadata_from_name
def dataclass_parameter(instance: object) -> ParameterSet:
return pytest.param(
instance, id=repr(instance)[len(type(instance).__name__) + 1 : -1]
)
def dataclass_parameters(instances: Iterable[object]) -> List[ParameterSet]:
return [dataclass_parameter(instance) for instance in instances]
class StreamAndBytesMatchMode(enum.Enum):
minimum = "minimum"
middle_low = "middle low"
middle_high = "middle high"
maximum = "maximum"
@dataclass(frozen=True)
class BadName:
name: str
error: str
@final
@dataclass(frozen=True)
class Good:
name: str
cls: Type[StructStream]
size: int
bits: int
signed: bool
maximum: int
minimum: int
@classmethod
def create(
cls,
name: str,
size: int,
signed: bool,
maximum: int,
minimum: int,
) -> Good:
raw_class: Type[StructStream] = type(name, (StructStream,), {})
parsed_cls = parse_metadata_from_name(raw_class)
return cls(
name=name,
cls=parsed_cls,
size=size,
bits=size * 8,
signed=signed,
maximum=maximum,
minimum=minimum,
)
good_classes = [
Good.create(name="uint8", size=1, signed=False, maximum=0xFF, minimum=0),
Good.create(name="int8", size=1, signed=True, maximum=0x7F, minimum=-0x80),
Good.create(name="uint16", size=2, signed=False, maximum=0xFFFF, minimum=0),
Good.create(name="int16", size=2, signed=True, maximum=0x7FFF, minimum=-0x8000),
Good.create(name="uint24", size=3, signed=False, maximum=0xFFFFFF, minimum=0),
Good.create(name="int24", size=3, signed=True, maximum=0x7FFFFF, minimum=-0x800000),
Good.create(
name="uint128",
size=16,
signed=False,
maximum=0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF,
minimum=0,
),
Good.create(
name="int128",
size=16,
signed=True,
maximum=0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF,
minimum=-0x80000000000000000000000000000000,
),
]
@pytest.fixture(
name="good",
params=dataclass_parameters(good_classes),
)
def good_fixture(request: SubRequest) -> Good:
return request.param
class TestStructStream:
def _test_impl(
self,
cls: Type[StructStream],
upper_boundary: int,
lower_boundary: int,
length: int,
struct_format: Optional[str],
) -> None:
with pytest.raises(ValueError):
t = cls(upper_boundary + 1)
with pytest.raises(ValueError):
t = cls(lower_boundary - 1)
t = cls(upper_boundary)
assert t == upper_boundary
t = cls(lower_boundary)
assert t == lower_boundary
t = cls(0)
assert t == 0
with pytest.raises(ValueError):
cls.parse(io.BytesIO(b"\0" * (length - 1)))
with pytest.raises(ValueError):
cls.from_bytes(b"\0" * (length - 1))
with pytest.raises(ValueError):
cls.from_bytes(b"\0" * (length + 1))
if struct_format is not None:
bytes_io = io.BytesIO()
cls(lower_boundary).stream(bytes_io)
assert bytes_io.getvalue() == struct.pack(struct_format, lower_boundary)
bytes_io = io.BytesIO()
cls(upper_boundary).stream(bytes_io)
assert bytes_io.getvalue() == struct.pack(struct_format, upper_boundary)
with pytest.raises(struct.error):
struct.pack(struct_format, lower_boundary - 1)
with pytest.raises(struct.error):
struct.pack(struct_format, upper_boundary + 1)
assert type(cls.MINIMUM) is cls
assert type(cls.MAXIMUM) is cls
def test_int512(self) -> None:
self._test_impl(
int512,
0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, -0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, length=65,
struct_format=None,
)
def test_uint128(self) -> None:
self._test_impl(
uint128,
0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF,
0,
length=16,
struct_format=None,
)
def test_uint64(self) -> None:
self._test_impl(uint64, 0xFFFFFFFFFFFFFFFF, 0, length=8, struct_format="!Q")
def test_int64(self) -> None:
self._test_impl(
int64, 0x7FFFFFFFFFFFFFFF, -0x8000000000000000, length=8, struct_format="!q"
)
def test_uint32(self) -> None:
self._test_impl(uint32, 0xFFFFFFFF, 0, length=4, struct_format="!L")
def test_int32(self) -> None:
self._test_impl(int32, 0x7FFFFFFF, -0x80000000, length=4, struct_format="!l")
def test_uint16(self) -> None:
self._test_impl(uint16, 0xFFFF, 0, length=2, struct_format="!H")
def test_int16(self) -> None:
self._test_impl(int16, 0x7FFF, -0x8000, length=2, struct_format="!h")
def test_uint8(self) -> None:
self._test_impl(uint8, 0xFF, 0, length=1, struct_format="!B")
def test_int8(self) -> None:
self._test_impl(int8, 0x7F, -0x80, length=1, struct_format="!b")
def test_roundtrip(self) -> None:
def roundtrip(v: StructStream) -> None:
s = io.BytesIO()
v.stream(s)
s.seek(0)
cls = type(v)
v2 = cls.parse(s)
assert v2 == v
roundtrip(
int512(
0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF )
)
roundtrip(
int512(
-0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF )
)
roundtrip(uint128(0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF))
roundtrip(uint128(0))
roundtrip(uint64(0xFFFFFFFFFFFFFFFF))
roundtrip(uint64(0))
roundtrip(int64(0x7FFFFFFFFFFFFFFF))
roundtrip(int64(-0x8000000000000000))
roundtrip(uint32(0xFFFFFFFF))
roundtrip(uint32(0))
roundtrip(int32(0x7FFFFFFF))
roundtrip(int32(-0x80000000))
roundtrip(uint16(0xFFFF))
roundtrip(uint16(0))
roundtrip(int16(0x7FFF))
roundtrip(int16(-0x8000))
roundtrip(uint8(0xFF))
roundtrip(uint8(0))
roundtrip(int8(0x7F))
roundtrip(int8(-0x80))
def test_uint32_from_decimal(self) -> None:
assert uint32(Decimal("137")) == 137
def test_uint32_from_float(self) -> None:
assert uint32(4.0) == 4
def test_uint32_from_str(self) -> None:
assert uint32("43") == 43
def test_uint32_from_bytes(self) -> None:
assert uint32(b"273") == 273
def test_struct_stream_cannot_be_instantiated_directly(self) -> None:
with pytest.raises(AttributeError, match="object has no attribute"):
StructStream(0)
@pytest.mark.parametrize(
argnames="bad_name",
argvalues=dataclass_parameters(
instances=[
BadName(name="uint", error="expected integer suffix but got: ''"),
BadName(name="blue", error="expected integer suffix but got"),
BadName(name="blue8", error="expected integer suffix but got: ''"),
BadName(name="sint8", error="expected class name"),
BadName(name="redint8", error="expected class name"),
BadName(name="int7", error="must be a multiple of 8"),
BadName(name="int9", error="must be a multiple of 8"),
BadName(name="int31", error="must be a multiple of 8"),
BadName(name="int0", error="bit size must greater than zero"),
BadName(name="int-1", error="bit size must greater than zero"),
],
),
)
def test_parse_metadata_from_name_raises(self, bad_name: BadName) -> None:
cls = type(bad_name.name, (StructStream,), {})
with pytest.raises(ValueError, match=bad_name.error):
parse_metadata_from_name(cls)
def test_parse_metadata_from_name_correct_size(self, good: Good) -> None:
assert good.cls.SIZE == good.size
def test_parse_metadata_from_name_correct_bits(self, good: Good) -> None:
assert good.cls.BITS == good.bits
def test_parse_metadata_from_name_correct_signedness(self, good: Good) -> None:
assert good.cls.SIGNED == good.signed
def test_parse_metadata_from_name_correct_maximum(self, good: Good) -> None:
assert good.cls.MAXIMUM == good.maximum
def test_parse_metadata_from_name_correct_minimum(self, good: Good) -> None:
assert good.cls.MINIMUM == good.minimum
@pytest.mark.parametrize(
"mode", list(StreamAndBytesMatchMode), ids=lambda mode: mode.value
)
def test_stream_to_bytes_and_bytes_match_minimum(
self, good: Good, mode: StreamAndBytesMatchMode
) -> None:
if mode == StreamAndBytesMatchMode.minimum:
value = good.minimum
elif mode == StreamAndBytesMatchMode.middle_low:
value = int(good.minimum + ((good.maximum - good.minimum) * 0.3))
elif mode == StreamAndBytesMatchMode.middle_high:
value = int(good.minimum + ((good.maximum - good.minimum) * 0.7))
elif mode == StreamAndBytesMatchMode.maximum:
value = good.maximum
else:
raise Exception(f"unhandled parametrization: {mode!r}")
instance = good.cls(value)
assert bytes(instance) == instance.stream_to_bytes()