import re
import struct
from collections import OrderedDict, defaultdict
from contextlib import contextmanager
from functools import wraps
from datetime import datetime, date, time
from io import BytesIO
from cbor2.compat import iteritems, timezone, long, unicode, as_unicode, bytes_from_list
from cbor2.types import CBORTag, undefined, CBORSimpleValue
class CBOREncodeError(Exception):
def shareable_encoder(func):
@wraps(func)
def wrapper(encoder, value, *args, **kwargs):
value_id = id(value)
container, container_index = encoder._shared_containers.get(value_id, (None, None))
if encoder.value_sharing:
if container is value:
encoder.write(encode_length(0xd8, 0x1d))
encode_int(encoder, container_index)
else:
encoder._shared_containers[value_id] = (value, len(encoder._shared_containers))
encoder.write(encode_length(0xd8, 0x1c))
func(encoder, value, *args, **kwargs)
else:
if container is value:
raise CBOREncodeError('cyclic data structure detected but value sharing is '
'disabled')
else:
encoder._shared_containers[value_id] = (value, None)
func(encoder, value, *args, **kwargs)
del encoder._shared_containers[value_id]
return wrapper
def encode_length(major_tag, length):
if length < 24:
return struct.pack('>B', major_tag | length)
elif length < 256:
return struct.pack('>BB', major_tag | 24, length)
elif length < 65536:
return struct.pack('>BH', major_tag | 25, length)
elif length < 4294967296:
return struct.pack('>BL', major_tag | 26, length)
else:
return struct.pack('>BQ', major_tag | 27, length)
def encode_int(encoder, value):
if value >= 18446744073709551616 or value < -18446744073709551616:
if value >= 0:
major_type = 0x02
else:
major_type = 0x03
value = -value - 1
values = []
while value > 0:
value, remainder = divmod(value, 256)
values.insert(0, remainder)
payload = bytes_from_list(values)
encode_semantic(encoder, CBORTag(major_type, payload))
elif value >= 0:
encoder.write(encode_length(0, value))
else:
encoder.write(encode_length(0x20, abs(value) - 1))
def encode_bytestring(encoder, value):
encoder.write(encode_length(0x40, len(value)) + value)
def encode_bytearray(encoder, value):
encode_bytestring(encoder, bytes(value))
def encode_string(encoder, value):
encoded = value.encode('utf-8')
encoder.write(encode_length(0x60, len(encoded)) + encoded)
@shareable_encoder
def encode_array(encoder, value):
encoder.write(encode_length(0x80, len(value)))
for item in value:
encoder.encode(item)
@shareable_encoder
def encode_map(encoder, value):
encoder.write(encode_length(0xa0, len(value)))
for key, val in iteritems(value):
encoder.encode(key)
encoder.encode(val)
def encode_semantic(encoder, value):
encoder.write(encode_length(0xc0, value.tag))
encoder.encode(value.value)
def encode_datetime(encoder, value):
if not value.tzinfo:
if encoder.timezone:
value = value.replace(tzinfo=encoder.timezone)
else:
raise CBOREncodeError(
'naive datetime encountered and no default timezone has been set')
if encoder.datetime_as_timestamp:
from calendar import timegm
timestamp = timegm(value.utctimetuple()) + value.microsecond // 1000000
encode_semantic(encoder, CBORTag(1, timestamp))
else:
datestring = as_unicode(value.isoformat().replace('+00:00', 'Z'))
encode_semantic(encoder, CBORTag(0, datestring))
def encode_date(encoder, value):
value = datetime.combine(value, time()).replace(tzinfo=timezone.utc)
encode_datetime(encoder, value)
def encode_decimal(encoder, value):
if value.is_nan():
encoder.write(b'\xf9\x7e\x00')
elif value.is_infinite():
encoder.write(b'\xf9\x7c\x00' if value > 0 else b'\xf9\xfc\x00')
else:
dt = value.as_tuple()
mantissa = sum(d * 10 ** i for i, d in enumerate(reversed(dt.digits)))
with encoder.disable_value_sharing():
encode_semantic(encoder, CBORTag(4, [dt.exponent, mantissa]))
def encode_rational(encoder, value):
with encoder.disable_value_sharing():
encode_semantic(encoder, CBORTag(30, [value.numerator, value.denominator]))
def encode_regexp(encoder, value):
encode_semantic(encoder, CBORTag(35, as_unicode(value.pattern)))
def encode_mime(encoder, value):
encode_semantic(encoder, CBORTag(36, as_unicode(value.as_string())))
def encode_uuid(encoder, value):
encode_semantic(encoder, CBORTag(37, value.bytes))
def encode_simple_value(encoder, value):
if value.value < 20:
encoder.write(struct.pack('>B', 0xe0 | value.value))
else:
encoder.write(struct.pack('>BB', 0xf8, value.value))
def encode_float(encoder, value):
import math
if math.isnan(value):
encoder.write(b'\xf9\x7e\x00')
elif math.isinf(value):
encoder.write(b'\xf9\x7c\x00' if value > 0 else b'\xf9\xfc\x00')
else:
encoder.write(struct.pack('>Bd', 0xfb, value))
def encode_boolean(encoder, value):
encoder.write(b'\xf5' if value else b'\xf4')
def encode_none(encoder, value):
encoder.write(b'\xf6')
def encode_undefined(encoder, value):
encoder.write(b'\xf7')
default_encoders = OrderedDict([
(bytes, encode_bytestring),
(bytearray, encode_bytearray),
(unicode, encode_string),
(int, encode_int),
(long, encode_int),
(float, encode_float),
(('decimal', 'Decimal'), encode_decimal),
(bool, encode_boolean),
(type(None), encode_none),
(tuple, encode_array),
(list, encode_array),
(dict, encode_map),
(defaultdict, encode_map),
(OrderedDict, encode_map),
(type(undefined), encode_undefined),
(datetime, encode_datetime),
(date, encode_date),
(type(re.compile('')), encode_regexp),
(('fractions', 'Fraction'), encode_rational),
(('email.message', 'Message'), encode_mime),
(('uuid', 'UUID'), encode_uuid),
(CBORSimpleValue, encode_simple_value),
(CBORTag, encode_semantic)
])
class CBOREncoder(object):
__slots__ = ('fp', 'datetime_as_timestamp', 'timezone', 'default', 'value_sharing',
'json_compatible', '_shared_containers', '_encoders')
def __init__(self, fp, datetime_as_timestamp=False, timezone=None, value_sharing=False,
default=None):
self.fp = fp
self.datetime_as_timestamp = datetime_as_timestamp
self.timezone = timezone
self.value_sharing = value_sharing
self.default = default
self._shared_containers = {} self._encoders = default_encoders.copy()
def _find_encoder(self, obj_type):
from sys import modules
for type_, enc in list(iteritems(self._encoders)):
if type(type_) is tuple:
modname, typename = type_
imported_type = getattr(modules.get(modname), typename, None)
if imported_type is not None:
del self._encoders[type_]
self._encoders[imported_type] = enc
type_ = imported_type
else: continue
if issubclass(obj_type, type_):
self._encoders[obj_type] = enc
return enc
return None
@contextmanager
def disable_value_sharing(self):
old_value_sharing = self.value_sharing
self.value_sharing = False
yield
self.value_sharing = old_value_sharing
def write(self, data):
self.fp.write(data)
def encode(self, obj):
obj_type = obj.__class__
encoder = self._encoders.get(obj_type) or self._find_encoder(obj_type) or self.default
if not encoder:
raise CBOREncodeError('cannot serialize type %s' % obj_type.__name__)
encoder(self, obj)
def encode_to_bytes(self, obj):
old_fp = self.fp
self.fp = fp = BytesIO()
self.encode(obj)
self.fp = old_fp
return fp.getvalue()
def dumps(obj, **kwargs):
fp = BytesIO()
dump(obj, fp, **kwargs)
return fp.getvalue()
def dump(obj, fp, **kwargs):
CBOREncoder(fp, **kwargs).encode(obj)