chia 0.20.0

A meta-crate that exports all of the Chia crates in the workspace.
Documentation
from __future__ import annotations

import enum
import io
import struct
from dataclasses import dataclass
from decimal import Decimal
from typing import Iterable, Optional

import pytest

# TODO: update after resolution in https://github.com/pytest-dev/pytest/issues/7469
from _pytest.fixtures import SubRequest

# TODO: update after resolution in https://github.com/pytest-dev/pytest/issues/7469
from _pytest.mark.structures import ParameterSet
from typing_extensions import final

from chia_rs.sized_ints import (
    int8,
    int16,
    int32,
    int64,
    int512,
    uint8,
    uint16,
    uint32,
    uint64,
    uint128,
)
from chia_rs.struct_stream import StructStream, parse_metadata_from_name


def dataclass_parameter(instance: object) -> ParameterSet:
    return pytest.param(
        instance, id=repr(instance)[len(type(instance).__name__) + 1 : -1]
    )


def dataclass_parameters(instances: Iterable[object]) -> list[ParameterSet]:
    return [dataclass_parameter(instance) for instance in instances]


class StreamAndBytesMatchMode(enum.Enum):
    minimum = "minimum"
    middle_low = "middle low"
    middle_high = "middle high"
    maximum = "maximum"


@dataclass(frozen=True)
class BadName:
    name: str
    error: str


@final
@dataclass(frozen=True)
class Good:
    name: str
    cls: type[StructStream]
    size: int
    bits: int
    signed: bool
    maximum: int
    minimum: int

    @classmethod
    def create(
        cls,
        name: str,
        size: int,
        signed: bool,
        maximum: int,
        minimum: int,
    ) -> Good:
        raw_class: type[StructStream] = type(name, (StructStream,), {})
        parsed_cls = parse_metadata_from_name(raw_class)
        return cls(
            name=name,
            cls=parsed_cls,
            size=size,
            bits=size * 8,
            signed=signed,
            maximum=maximum,
            minimum=minimum,
        )


good_classes = [
    Good.create(name="uint8", size=1, signed=False, maximum=0xFF, minimum=0),
    Good.create(name="int8", size=1, signed=True, maximum=0x7F, minimum=-0x80),
    Good.create(name="uint16", size=2, signed=False, maximum=0xFFFF, minimum=0),
    Good.create(name="int16", size=2, signed=True, maximum=0x7FFF, minimum=-0x8000),
    Good.create(name="uint24", size=3, signed=False, maximum=0xFFFFFF, minimum=0),
    Good.create(name="int24", size=3, signed=True, maximum=0x7FFFFF, minimum=-0x800000),
    Good.create(
        name="uint128",
        size=16,
        signed=False,
        maximum=0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF,
        minimum=0,
    ),
    Good.create(
        name="int128",
        size=16,
        signed=True,
        maximum=0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF,
        minimum=-0x80000000000000000000000000000000,
    ),
]


@pytest.fixture(
    name="good",
    params=dataclass_parameters(good_classes),
)
def good_fixture(request: SubRequest) -> Good:
    return request.param  # type: ignore[no-any-return]


class TestStructStream:
    def _test_impl(
        self,
        cls: type[StructStream],
        upper_boundary: int,
        lower_boundary: int,
        length: int,
        struct_format: Optional[str],
    ) -> None:
        with pytest.raises(ValueError):
            t = cls(upper_boundary + 1)

        with pytest.raises(ValueError):
            t = cls(lower_boundary - 1)

        t = cls(upper_boundary)
        assert t == upper_boundary

        t = cls(lower_boundary)
        assert t == lower_boundary

        t = cls(0)
        assert t == 0

        with pytest.raises(ValueError):
            cls.parse(io.BytesIO(b"\0" * (length - 1)))

        with pytest.raises(ValueError):
            cls.from_bytes(b"\0" * (length - 1))

        with pytest.raises(ValueError):
            cls.from_bytes(b"\0" * (length + 1))

        if struct_format is not None:
            bytes_io = io.BytesIO()
            cls(lower_boundary).stream(bytes_io)
            assert bytes_io.getvalue() == struct.pack(struct_format, lower_boundary)

            bytes_io = io.BytesIO()
            cls(upper_boundary).stream(bytes_io)
            assert bytes_io.getvalue() == struct.pack(struct_format, upper_boundary)

            with pytest.raises(struct.error):
                struct.pack(struct_format, lower_boundary - 1)
            with pytest.raises(struct.error):
                struct.pack(struct_format, upper_boundary + 1)

        assert type(cls.MINIMUM) is cls
        assert type(cls.MAXIMUM) is cls

    def test_int512(self) -> None:
        # int512 is special. it uses 65 bytes to allow positive and negative
        # "uint512"
        self._test_impl(
            int512,
            0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF,  # noqa: E501
            -0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF,  # noqa: E501
            length=65,
            struct_format=None,
        )

    def test_uint128(self) -> None:
        self._test_impl(
            uint128,
            0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF,
            0,
            length=16,
            struct_format=None,
        )

    def test_uint64(self) -> None:
        self._test_impl(uint64, 0xFFFFFFFFFFFFFFFF, 0, length=8, struct_format="!Q")

    def test_int64(self) -> None:
        self._test_impl(
            int64, 0x7FFFFFFFFFFFFFFF, -0x8000000000000000, length=8, struct_format="!q"
        )

    def test_uint32(self) -> None:
        self._test_impl(uint32, 0xFFFFFFFF, 0, length=4, struct_format="!L")

    def test_int32(self) -> None:
        self._test_impl(int32, 0x7FFFFFFF, -0x80000000, length=4, struct_format="!l")

    def test_uint16(self) -> None:
        self._test_impl(uint16, 0xFFFF, 0, length=2, struct_format="!H")

    def test_int16(self) -> None:
        self._test_impl(int16, 0x7FFF, -0x8000, length=2, struct_format="!h")

    def test_uint8(self) -> None:
        self._test_impl(uint8, 0xFF, 0, length=1, struct_format="!B")

    def test_int8(self) -> None:
        self._test_impl(int8, 0x7F, -0x80, length=1, struct_format="!b")

    def test_roundtrip(self) -> None:
        def roundtrip(v: StructStream) -> None:
            s = io.BytesIO()
            v.stream(s)
            s.seek(0)
            cls = type(v)
            v2 = cls.parse(s)
            assert v2 == v

        # int512 is special. it uses 65 bytes to allow positive and negative
        # "uint512"
        roundtrip(
            int512(
                0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF  # noqa: E501
            )
        )
        roundtrip(
            int512(
                -0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF  # noqa: E501
            )
        )

        roundtrip(uint128(0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF))
        roundtrip(uint128(0))

        roundtrip(uint64(0xFFFFFFFFFFFFFFFF))
        roundtrip(uint64(0))

        roundtrip(int64(0x7FFFFFFFFFFFFFFF))
        roundtrip(int64(-0x8000000000000000))

        roundtrip(uint32(0xFFFFFFFF))
        roundtrip(uint32(0))

        roundtrip(int32(0x7FFFFFFF))
        roundtrip(int32(-0x80000000))

        roundtrip(uint16(0xFFFF))
        roundtrip(uint16(0))

        roundtrip(int16(0x7FFF))
        roundtrip(int16(-0x8000))

        roundtrip(uint8(0xFF))
        roundtrip(uint8(0))

        roundtrip(int8(0x7F))
        roundtrip(int8(-0x80))

    def test_uint32_from_decimal(self) -> None:
        assert uint32(Decimal("137")) == 137

    def test_uint32_from_float(self) -> None:
        assert uint32(4.0) == 4

    def test_uint32_from_str(self) -> None:
        assert uint32("43") == 43

    def test_uint32_from_bytes(self) -> None:
        assert uint32(b"273") == 273

    def test_struct_stream_cannot_be_instantiated_directly(self) -> None:
        with pytest.raises(AttributeError, match="object has no attribute"):
            StructStream(0)

    @pytest.mark.parametrize(
        argnames="bad_name",
        argvalues=dataclass_parameters(
            instances=[
                BadName(name="uint", error="expected integer suffix but got: ''"),
                BadName(name="blue", error="expected integer suffix but got"),
                BadName(name="blue8", error="expected integer suffix but got: ''"),
                BadName(name="sint8", error="expected class name"),
                BadName(name="redint8", error="expected class name"),
                BadName(name="int7", error="must be a multiple of 8"),
                BadName(name="int9", error="must be a multiple of 8"),
                BadName(name="int31", error="must be a multiple of 8"),
                BadName(name="int0", error="bit size must greater than zero"),
                # below could not happen in a hard coded class name, but testing for good measure
                BadName(name="int-1", error="bit size must greater than zero"),
            ],
        ),
    )
    def test_parse_metadata_from_name_raises(self, bad_name: BadName) -> None:
        cls = type(bad_name.name, (StructStream,), {})
        with pytest.raises(ValueError, match=bad_name.error):
            parse_metadata_from_name(cls)

    def test_parse_metadata_from_name_correct_size(self, good: Good) -> None:
        assert good.cls.SIZE == good.size

    def test_parse_metadata_from_name_correct_bits(self, good: Good) -> None:
        assert good.cls.BITS == good.bits

    def test_parse_metadata_from_name_correct_signedness(self, good: Good) -> None:
        assert good.cls.SIGNED == good.signed

    def test_parse_metadata_from_name_correct_maximum(self, good: Good) -> None:
        assert good.cls.MAXIMUM == good.maximum

    def test_parse_metadata_from_name_correct_minimum(self, good: Good) -> None:
        assert good.cls.MINIMUM == good.minimum

    @pytest.mark.parametrize(
        "mode", list(StreamAndBytesMatchMode), ids=lambda mode: mode.value
    )
    def test_stream_to_bytes_and_bytes_match_minimum(
        self, good: Good, mode: StreamAndBytesMatchMode
    ) -> None:
        if mode == StreamAndBytesMatchMode.minimum:
            value = good.minimum
        elif mode == StreamAndBytesMatchMode.middle_low:
            value = int(good.minimum + ((good.maximum - good.minimum) * 0.3))
        elif mode == StreamAndBytesMatchMode.middle_high:
            value = int(good.minimum + ((good.maximum - good.minimum) * 0.7))
        elif mode == StreamAndBytesMatchMode.maximum:
            value = good.maximum
        else:
            raise Exception(f"unhandled parametrization: {mode!r}")

        instance = good.cls(value)
        assert bytes(instance) == instance.stream_to_bytes()