from __future__ import absolute_import, division, print_function
from collections import Mapping, OrderedDict, Sequence
import pytest
from hypothesis import HealthCheck, assume, given, settings
from hypothesis import strategies as st
import attr
from attr import asdict, assoc, astuple, evolve, fields, has
from attr._compat import TYPE
from attr.exceptions import AttrsAttributeNotFoundError
from attr.validators import instance_of
from .strategies import nested_classes, simple_classes
MAPPING_TYPES = (dict, OrderedDict)
SEQUENCE_TYPES = (list, tuple)
class TestAsDict(object):
@given(st.sampled_from(MAPPING_TYPES))
def test_shallow(self, C, dict_factory):
assert {
"x": 1,
"y": 2,
} == asdict(C(x=1, y=2), False, dict_factory=dict_factory)
@given(st.sampled_from(MAPPING_TYPES))
def test_recurse(self, C, dict_class):
assert {
"x": {"x": 1, "y": 2},
"y": {"x": 3, "y": 4},
} == asdict(C(
C(1, 2),
C(3, 4),
), dict_factory=dict_class)
@given(nested_classes, st.sampled_from(MAPPING_TYPES))
@settings(suppress_health_check=[HealthCheck.too_slow])
def test_recurse_property(self, cls, dict_class):
obj = cls()
obj_dict = asdict(obj, dict_factory=dict_class)
def assert_proper_dict_class(obj, obj_dict):
assert isinstance(obj_dict, dict_class)
for field in fields(obj.__class__):
field_val = getattr(obj, field.name)
if has(field_val.__class__):
assert_proper_dict_class(field_val, obj_dict[field.name])
elif isinstance(field_val, Sequence):
dict_val = obj_dict[field.name]
for item, item_dict in zip(field_val, dict_val):
if has(item.__class__):
assert_proper_dict_class(item, item_dict)
elif isinstance(field_val, Mapping):
assert isinstance(obj_dict[field.name], dict_class)
for key, val in field_val.items():
if has(val.__class__):
assert_proper_dict_class(val,
obj_dict[field.name][key])
assert_proper_dict_class(obj, obj_dict)
@given(st.sampled_from(MAPPING_TYPES))
def test_filter(self, C, dict_factory):
assert {
"x": {"x": 1},
} == asdict(C(
C(1, 2),
C(3, 4),
), filter=lambda a, v: a.name != "y", dict_factory=dict_factory)
@given(container=st.sampled_from(SEQUENCE_TYPES))
def test_lists_tuples(self, container, C):
assert {
"x": 1,
"y": [{"x": 2, "y": 3}, {"x": 4, "y": 5}, "a"],
} == asdict(C(1, container([C(2, 3), C(4, 5), "a"])))
@given(container=st.sampled_from(SEQUENCE_TYPES))
def test_lists_tuples_retain_type(self, container, C):
assert {
"x": 1,
"y": container([{"x": 2, "y": 3}, {"x": 4, "y": 5}, "a"]),
} == asdict(C(1, container([C(2, 3), C(4, 5), "a"])),
retain_collection_types=True)
@given(st.sampled_from(MAPPING_TYPES))
def test_dicts(self, C, dict_factory):
res = asdict(C(1, {"a": C(4, 5)}), dict_factory=dict_factory)
assert {
"x": 1,
"y": {"a": {"x": 4, "y": 5}},
} == res
assert isinstance(res, dict_factory)
@given(simple_classes(private_attrs=False), st.sampled_from(MAPPING_TYPES))
def test_roundtrip(self, cls, dict_class):
instance = cls()
dict_instance = asdict(instance, dict_factory=dict_class)
assert isinstance(dict_instance, dict_class)
roundtrip_instance = cls(**dict_instance)
assert instance == roundtrip_instance
@given(simple_classes())
def test_asdict_preserve_order(self, cls):
instance = cls()
dict_instance = asdict(instance, dict_factory=OrderedDict)
assert [a.name for a in fields(cls)] == list(dict_instance.keys())
class TestAsTuple(object):
@given(st.sampled_from(SEQUENCE_TYPES))
def test_shallow(self, C, tuple_factory):
assert (tuple_factory([1, 2]) ==
astuple(C(x=1, y=2), False, tuple_factory=tuple_factory))
@given(st.sampled_from(SEQUENCE_TYPES))
def test_recurse(self, C, tuple_factory):
assert (tuple_factory([tuple_factory([1, 2]),
tuple_factory([3, 4])])
== astuple(C(
C(1, 2),
C(3, 4),
),
tuple_factory=tuple_factory))
@given(nested_classes, st.sampled_from(SEQUENCE_TYPES))
@settings(suppress_health_check=[HealthCheck.too_slow])
def test_recurse_property(self, cls, tuple_class):
obj = cls()
obj_tuple = astuple(obj, tuple_factory=tuple_class)
def assert_proper_tuple_class(obj, obj_tuple):
assert isinstance(obj_tuple, tuple_class)
for index, field in enumerate(fields(obj.__class__)):
field_val = getattr(obj, field.name)
if has(field_val.__class__):
assert_proper_tuple_class(field_val, obj_tuple[index])
assert_proper_tuple_class(obj, obj_tuple)
@given(nested_classes, st.sampled_from(SEQUENCE_TYPES))
@settings(suppress_health_check=[HealthCheck.too_slow])
def test_recurse_retain(self, cls, tuple_class):
obj = cls()
obj_tuple = astuple(obj, tuple_factory=tuple_class,
retain_collection_types=True)
def assert_proper_col_class(obj, obj_tuple):
for index, field in enumerate(fields(obj.__class__)):
field_val = getattr(obj, field.name)
if has(field_val.__class__):
assert_proper_col_class(field_val, obj_tuple[index])
elif isinstance(field_val, (list, tuple)):
expected_type = type(obj_tuple[index])
assert type(field_val) is expected_type for obj_e, obj_tuple_e in zip(field_val, obj_tuple[index]):
if has(obj_e.__class__):
assert_proper_col_class(obj_e, obj_tuple_e)
elif isinstance(field_val, dict):
orig = field_val
tupled = obj_tuple[index]
assert type(orig) is type(tupled) for obj_e, obj_tuple_e in zip(orig.items(),
tupled.items()):
if has(obj_e[0].__class__): assert_proper_col_class(obj_e[0], obj_tuple_e[0])
if has(obj_e[1].__class__): assert_proper_col_class(obj_e[1], obj_tuple_e[1])
assert_proper_col_class(obj, obj_tuple)
@given(st.sampled_from(SEQUENCE_TYPES))
def test_filter(self, C, tuple_factory):
assert tuple_factory([tuple_factory([1, ]), ]) == astuple(C(
C(1, 2),
C(3, 4),
), filter=lambda a, v: a.name != "y", tuple_factory=tuple_factory)
@given(container=st.sampled_from(SEQUENCE_TYPES))
def test_lists_tuples(self, container, C):
assert ((1, [(2, 3), (4, 5), "a"])
== astuple(C(1, container([C(2, 3), C(4, 5), "a"])))
)
@given(st.sampled_from(SEQUENCE_TYPES))
def test_dicts(self, C, tuple_factory):
res = astuple(C(1, {"a": C(4, 5)}), tuple_factory=tuple_factory)
assert tuple_factory([1, {"a": tuple_factory([4, 5])}]) == res
assert isinstance(res, tuple_factory)
@given(container=st.sampled_from(SEQUENCE_TYPES))
def test_lists_tuples_retain_type(self, container, C):
assert (
(1, container([(2, 3), (4, 5), "a"]))
== astuple(C(1, container([C(2, 3), C(4, 5), "a"])),
retain_collection_types=True))
@given(container=st.sampled_from(MAPPING_TYPES))
def test_dicts_retain_type(self, container, C):
assert (
(1, container({"a": (4, 5)}))
== astuple(C(1, container({"a": C(4, 5)})),
retain_collection_types=True))
@given(simple_classes(), st.sampled_from(SEQUENCE_TYPES))
def test_roundtrip(self, cls, tuple_class):
instance = cls()
tuple_instance = astuple(instance, tuple_factory=tuple_class)
assert isinstance(tuple_instance, tuple_class)
roundtrip_instance = cls(*tuple_instance)
assert instance == roundtrip_instance
class TestHas(object):
def test_positive(self, C):
assert has(C)
def test_positive_empty(self):
@attr.s
class D(object):
pass
assert has(D)
def test_negative(self):
assert not has(object)
class TestAssoc(object):
@given(slots=st.booleans(), frozen=st.booleans())
def test_empty(self, slots, frozen):
@attr.s(slots=slots, frozen=frozen)
class C(object):
pass
i1 = C()
with pytest.deprecated_call():
i2 = assoc(i1)
assert i1 is not i2
assert i1 == i2
@given(simple_classes())
def test_no_changes(self, C):
i1 = C()
with pytest.deprecated_call():
i2 = assoc(i1)
assert i1 is not i2
assert i1 == i2
@given(simple_classes(), st.data())
def test_change(self, C, data):
assume(fields(C)) field_names = [a.name for a in fields(C)]
original = C()
chosen_names = data.draw(st.sets(st.sampled_from(field_names)))
change_dict = {name: data.draw(st.integers())
for name in chosen_names}
with pytest.deprecated_call():
changed = assoc(original, **change_dict)
for k, v in change_dict.items():
assert getattr(changed, k) == v
@given(simple_classes())
def test_unknown(self, C):
with pytest.raises(AttrsAttributeNotFoundError) as e, \
pytest.deprecated_call():
assoc(C(), aaaa=2)
assert (
"aaaa is not an attrs attribute on {cls!r}.".format(cls=C),
) == e.value.args
def test_frozen(self):
@attr.s(frozen=True)
class C(object):
x = attr.ib()
y = attr.ib()
with pytest.deprecated_call():
assert C(3, 2) == assoc(C(1, 2), x=3)
def test_warning(self):
@attr.s
class C(object):
x = attr.ib()
with pytest.warns(DeprecationWarning) as wi:
assert C(2) == assoc(C(1), x=2)
assert __file__ == wi.list[0].filename
class TestEvolve(object):
@given(slots=st.booleans(), frozen=st.booleans())
def test_empty(self, slots, frozen):
@attr.s(slots=slots, frozen=frozen)
class C(object):
pass
i1 = C()
i2 = evolve(i1)
assert i1 is not i2
assert i1 == i2
@given(simple_classes())
def test_no_changes(self, C):
i1 = C()
i2 = evolve(i1)
assert i1 is not i2
assert i1 == i2
@given(simple_classes(), st.data())
def test_change(self, C, data):
assume(fields(C)) field_names = [a.name for a in fields(C)]
original = C()
chosen_names = data.draw(st.sets(st.sampled_from(field_names)))
change_dict = {name.replace('_', ''): data.draw(st.integers())
for name in chosen_names}
changed = evolve(original, **change_dict)
for name in chosen_names:
assert getattr(changed, name) == change_dict[name.replace('_', '')]
@given(simple_classes())
def test_unknown(self, C):
with pytest.raises(TypeError) as e:
evolve(C(), aaaa=2)
expected = "__init__() got an unexpected keyword argument 'aaaa'"
assert (expected,) == e.value.args
def test_validator_failure(self):
@attr.s
class C(object):
a = attr.ib(validator=instance_of(int))
with pytest.raises(TypeError) as e:
evolve(C(a=1), a="some string")
m = e.value.args[0]
assert m.startswith("'a' must be <{type} 'int'>".format(type=TYPE))
def test_private(self):
@attr.s
class C(object):
_a = attr.ib()
assert evolve(C(1), a=2)._a == 2
with pytest.raises(TypeError):
evolve(C(1), _a=2)
with pytest.raises(TypeError):
evolve(C(1), a=3, _a=2)
def test_non_init_attrs(self):
@attr.s
class C(object):
a = attr.ib()
b = attr.ib(init=False, default=0)
assert evolve(C(1), a=2).a == 2