serde-generate 0.33.0

Library to generate (de)serialization code in multiple languages
Documentation
# Copyright (c) Facebook, Inc. and its affiliates
# SPDX-License-Identifier: MIT OR Apache-2.0

import unittest
import serde_types as st
import bcs
import typing
import sys
from dataclasses import dataclass
from collections import OrderedDict


def encode_u32_as_uleb128(value: int) -> bytes:
    serializer = bcs.BcsSerializer()
    serializer.serialize_u32_as_uleb128(value)
    return serializer.get_buffer()


def decode_uleb128_as_u32(content: bytes) -> int:
    deserializer = bcs.BcsDeserializer(content)
    value = deserializer.deserialize_uleb128_as_u32()
    assert deserializer.get_remaining_buffer() == b""
    return value


def encode_length(value: int) -> bytes:
    serializer = bcs.BcsSerializer()
    serializer.serialize_len(value)
    return serializer.get_buffer()


def decode_length(content: bytes) -> int:
    deserializer = bcs.BcsDeserializer(content)
    value = deserializer.deserialize_len()
    assert deserializer.get_remaining_buffer() == b""
    return value


class BcsTestCase(unittest.TestCase):
    def test_bcs_bool(self):
        self.assertEqual(bcs.serialize(False, bool), b"\x00")
        self.assertEqual(bcs.serialize(True, bool), b"\x01")
        self.assertEqual(bcs.deserialize(b"\x00", bool), (False, b""))
        self.assertEqual(bcs.deserialize(b"\x01", bool), (True, b""))
        with self.assertRaises(st.DeserializationError):
            bcs.deserialize(b"\x02", bool)
        with self.assertRaises(st.DeserializationError):
            bcs.deserialize(b"", bool)

    def test_bcs_u8(self):
        self.assertEqual(bcs.serialize(0x1, st.uint8), b"\x01")
        self.assertEqual(bcs.deserialize(b"\xff", st.uint8), (255, b""))

    def test_bcs_u16(self):
        self.assertEqual(bcs.serialize(0x0102, st.uint16), b"\x02\x01")
        self.assertEqual(bcs.deserialize(b"\xff\xff", st.uint16), (65535, b""))

    def test_bcs_u32(self):
        self.assertEqual(bcs.serialize(0x01020304, st.uint32), b"\x04\x03\x02\x01")
        self.assertEqual(
            bcs.deserialize(b"\xff\xff\xff\xff", st.uint32), (4294967295, b"")
        )

    def test_bcs_u64(self):
        self.assertEqual(
            bcs.serialize(0x0102030405060708, st.uint64),
            b"\x08\x07\x06\x05\x04\x03\x02\x01",
        )
        self.assertEqual(bcs.deserialize(b"\xff" * 8, st.uint64), ((1 << 64) - 1, b""))

    def test_bcs_u128(self):
        self.assertEqual(
            bcs.serialize(st.uint128(0x0102030405060708090A0B0C0D0E0F10), st.uint128),
            b"\x10\x0f\x0e\r\x0c\x0b\n\t\x08\x07\x06\x05\x04\x03\x02\x01",
        )
        self.assertEqual(
            bcs.deserialize(b"\xff" * 16, st.uint128), (st.uint128((1 << 128) - 1), b"")
        )

    def test_bcs_i8(self):
        self.assertEqual(bcs.serialize(4, st.int8), b"\x04")
        self.assertEqual(bcs.serialize(-2, st.int8), b"\xfe")
        self.assertEqual(bcs.deserialize(b"\xff", st.int8), (-1, b""))

    def test_bcs_i16(self):
        self.assertEqual(bcs.serialize(0x0102, st.int16), b"\x02\x01")
        self.assertEqual(bcs.deserialize(b"\xff\xff", st.int16), (-1, b""))

    def test_bcs_i32(self):
        self.assertEqual(bcs.serialize(0x01020304, st.int32), b"\x04\x03\x02\x01")
        self.assertEqual(bcs.deserialize(b"\xff\xff\xff\xff", st.int32), (-1, b""))

    def test_bcs_i64(self):
        self.assertEqual(
            bcs.serialize(0x0102030405060708, st.int64),
            b"\x08\x07\x06\x05\x04\x03\x02\x01",
        )
        self.assertEqual(bcs.deserialize(b"\xff" * 8, st.int64), (-1, b""))

    def test_bcs_i128(self):
        self.assertEqual(
            bcs.serialize(st.int128(0x0102030405060708090A0B0C0D0E0F10), st.int128),
            b"\x10\x0f\x0e\r\x0c\x0b\n\t\x08\x07\x06\x05\x04\x03\x02\x01",
        )
        self.assertEqual(bcs.deserialize(b"\xff" * 16, st.int128), (st.int128(-1), b""))

    def test_encode_u32_as_uleb128(self):
        self.assertEqual(encode_u32_as_uleb128(0), b"\x00")
        self.assertEqual(encode_u32_as_uleb128(3), b"\x03")
        self.assertEqual(encode_u32_as_uleb128(0x7F), b"\x7f")
        self.assertEqual(encode_u32_as_uleb128(0x3F01), b"\x81\x7e")
        self.assertEqual(encode_u32_as_uleb128(0x8001), b"\x81\x80\x02")
        self.assertEqual(encode_u32_as_uleb128(bcs.MAX_U32), b"\xff\xff\xff\xff\x0f")

        self.assertEqual(decode_uleb128_as_u32(b"\x00"), 0)
        self.assertEqual(decode_uleb128_as_u32(b"\x03"), 3)
        self.assertEqual(decode_uleb128_as_u32(b"\x7f"), 0x7F)
        self.assertEqual(decode_uleb128_as_u32(b"\x81\x7e"), 0x3F01)
        self.assertEqual(decode_uleb128_as_u32(b"\x80\x80\x01"), 0x4000)
        self.assertEqual(decode_uleb128_as_u32(b"\x80\x80\x01"), 0x4000)
        self.assertEqual(decode_uleb128_as_u32(b"\x81\x80\x02"), 0x8001)
        self.assertEqual(decode_uleb128_as_u32(b"\xff\xff\xff\xff\x0f"), bcs.MAX_U32)
        with self.assertRaises(st.DeserializationError):
            decode_uleb128_as_u32(b"\x80\x00")
        with self.assertRaises(st.DeserializationError):
            decode_uleb128_as_u32(b"\xff\xff\xff\xff\x10")

    def test_encode_length(self):
        self.assertEqual(encode_length(bcs.MAX_LENGTH), b"\xff\xff\xff\xff\x07")
        with self.assertRaises(st.SerializationError):
            encode_length(bcs.MAX_LENGTH + 1)
        with self.assertRaises(st.DeserializationError):
            decode_length(b"\xff\xff\xff\xff\x08")

    def test_serialize_bytes(self):
        self.assertEqual(bcs.serialize(b"", bytes), b"\x00")
        self.assertEqual(bcs.serialize(b"\x00\x00", bytes), b"\x02\x00\x00")
        self.assertEqual(
            bcs.serialize(b"\x00" * 128, bytes), b"\x80\x01" + b"\x00" * 128
        )

        self.assertEqual(bcs.deserialize(b"\x00", bytes), (b"", b""))

    def test_serialize_tuple(self):
        T = typing.Tuple[st.uint8, st.uint16]
        self.assertEqual(bcs.serialize((0, 1), T), b"\x00\x01\x00")
        self.assertEqual(bcs.deserialize(b"\x02\x01\x00", T), ((2, 1), b""))

    def test_serialize_option(self):
        T = typing.Optional[st.uint16]
        self.assertEqual(bcs.serialize(None, T), b"\x00")
        self.assertEqual(bcs.serialize(6, T), b"\x01\x06\x00")
        self.assertEqual(bcs.deserialize(b"\x00", T), (None, b""))
        self.assertEqual(bcs.deserialize(b"\x01\x02\x00", T), (2, b""))
        with self.assertRaisesRegex(st.DeserializationError, "Wrong tag.*"):
            # Must enforce canonical encoding.
            bcs.deserialize(b"\x02\x06\x00", T)

    def test_serialize_sequence(self):
        Seq = typing.Sequence[st.uint16]
        self.assertEqual(bcs.serialize([], Seq), b"\x00")
        self.assertEqual(bcs.serialize([0, 1], Seq), b"\x02\x00\x00\x01\x00")
        self.assertEqual(
            bcs.serialize([256] * 128, Seq), b"\x80\x01" + b"\x00\x01" * 128
        )
        self.assertEqual(bcs.deserialize(b"\x01\x03\x00", Seq), ([3], b""))

    def test_serialize_str(self):
        self.assertEqual(bcs.serialize("ABC\u0394", str), b"\x05ABC\xce\x94")
        self.assertEqual(bcs.deserialize(b"\x05ABC\xce\x94A", str), ("ABC\u0394", b"A"))
        with self.assertRaises(st.DeserializationError):
            bcs.deserialize(b"\x03AB", str)
        with self.assertRaises(st.DeserializationError):
            bcs.deserialize(b"\x03\x80ab", str)

    def test_deserialize_long_sequence(self):
        Seq = typing.Sequence[st.uint16]
        five = st.uint16(5)
        v = [five] * 1000000
        b = bcs.serialize(v, Seq)
        self.assertEqual(bcs.deserialize(b, Seq), (v, b""))

    def test_serialize_map(self):
        Map = typing.Dict[st.uint16, st.uint8]
        m = OrderedDict([(1, 5), (256, 3)])
        e = bcs.serialize(m, Map)
        self.assertEqual(e, b"\x02\x00\x01\x03\x01\x00\x05")
        self.assertEqual(
            (m, b""), bcs.deserialize(b"\x02\x00\x01\x03\x01\x00\x05", Map)
        )
        m2 = OrderedDict([(256, 3), (1, 5)])
        e2 = bcs.serialize(m2, Map)
        self.assertEqual(e2, e)
        with self.assertRaises(st.DeserializationError):
            # Must enforce canonical encoding.
            bcs.deserialize(b"\x02\x01\x00\x05\x00\x01\x03", Map)

    def test_serialize_set(self):
        Set = typing.Dict[st.uint16, st.unit]
        m = {256: None, 1: None}
        e = bcs.serialize(m, Set)
        self.assertEqual(e, b"\x02\x00\x01\x01\x00")
        self.assertEqual((m, b""), bcs.deserialize(b"\x02\x00\x01\x01\x00", Set))
        with self.assertRaises(st.DeserializationError):
            # Must enforce canonical encoding.
            bcs.deserialize(b"\x02\x01\x00\x00\x01", Set)

    @dataclass
    class Foo:
        x: st.uint8
        y: st.uint16

    def test_struct(self):
        self.assertEqual(
            bcs.serialize(BcsTestCase.Foo(x=0, y=1), BcsTestCase.Foo), b"\x00\x01\x00"
        )
        self.assertEqual(
            bcs.deserialize(b"\x02\x01\x00", BcsTestCase.Foo),
            (BcsTestCase.Foo(x=2, y=1), b""),
        )

    class Bar:
        VARIANTS = []  # type: typing.Sequence[typing.Type['Bar']]

    @dataclass
    class Bar1(Bar):
        INDEX = 1
        x: st.uint8
        y: st.uint16

    Bar.VARIANTS = [None, Bar1, None]

    def test_enum(self):
        self.assertEqual(
            bcs.serialize(BcsTestCase.Bar1(x=0, y=1), BcsTestCase.Bar),
            b"\x01\x00\x01\x00",
        )
        self.assertEqual(
            bcs.deserialize(b"\x01\x02\x01\x00", BcsTestCase.Bar),
            (BcsTestCase.Bar1(x=2, y=1), b""),
        )

    @dataclass
    class List:
        next: typing.Optional[typing.Tuple[st.uint64, "BcsTestCase.List"]]

        @staticmethod
        def empty() -> "BcsTestCase.List":
            return BcsTestCase.List(next=None)

        @staticmethod
        def cons(value: st.uint64, tail: "BcsTestCase.List") -> "BcsTestCase.List":
            return BcsTestCase.List(next=(value, tail))

        @staticmethod
        def integers(size: int) -> "BcsTestCase.List":
            result = BcsTestCase.List.empty()
            for i in range(size):
                result = BcsTestCase.List.cons(st.uint64(i), result)
            return result

    def test_max_container_depth(self):
        # Required to avoid RecursionError's in python.
        sys.setrecursionlimit(bcs.MAX_CONTAINER_DEPTH * 5)

        l1 = BcsTestCase.List.integers(4)
        b1 = bcs.serialize(l1, BcsTestCase.List)
        self.assertEqual(
            b1,
            bytes(
                [
                    # fmt: off
                    1, 3, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0
                ]
            ),
        )
        self.assertEqual(bcs.deserialize(b1, BcsTestCase.List), (l1, b""))

        l2 = BcsTestCase.List.integers(bcs.MAX_CONTAINER_DEPTH - 1)
        b2 = bcs.serialize(l2, BcsTestCase.List)
        self.assertEqual(bcs.deserialize(b2, BcsTestCase.List), (l2, b""))

        l3 = BcsTestCase.List.integers(bcs.MAX_CONTAINER_DEPTH)
        with self.assertRaises(st.SerializationError):
            bcs.serialize(l3, BcsTestCase.List)

        b3 = bytes([1, 243, 1, 0, 0, 0, 0, 0, 0]) + b2
        with self.assertRaisesRegex(
            st.DeserializationError, "Exceeded maximum container depth.*"
        ):
            self.assertEqual(bcs.deserialize(b3, BcsTestCase.List))

        # Pairs don't count in "container depth".
        P = typing.Tuple[BcsTestCase.List, BcsTestCase.List]
        self.assertEqual(bcs.deserialize(b2 + b2, P), ((l2, l2), b""))
        with self.assertRaisesRegex(
            st.DeserializationError, "Exceeded maximum container depth.*"
        ):
            bcs.deserialize(b2 + b3, P)