from __future__ import absolute_import, division, print_function
import codecs
import functools
import inspect
import re
import sys
import py
import _pytest
from _pytest.outcomes import TEST_OUTCOME
from six import text_type
import six
try:
import enum
except ImportError: enum = None
_PY3 = sys.version_info > (3, 0)
_PY2 = not _PY3
if _PY3:
from inspect import signature, Parameter as Parameter
else:
from funcsigs import signature, Parameter as Parameter
NoneType = type(None)
NOTSET = object()
PY35 = sys.version_info[:2] >= (3, 5)
PY36 = sys.version_info[:2] >= (3, 6)
MODULE_NOT_FOUND_ERROR = "ModuleNotFoundError" if PY36 else "ImportError"
if _PY3:
from collections.abc import MutableMapping as MappingMixin from collections.abc import Mapping, Sequence else:
from collections import MutableMapping as MappingMixin from collections import Mapping, Sequence
def _format_args(func):
return str(signature(func))
isfunction = inspect.isfunction
isclass = inspect.isclass
exc_clear = getattr(sys, "exc_clear", lambda: None)
REGEX_TYPE = type(re.compile(""))
def is_generator(func):
genfunc = inspect.isgeneratorfunction(func)
return genfunc and not iscoroutinefunction(func)
def iscoroutinefunction(func):
return (
getattr(func, "_is_coroutine", False)
or (
hasattr(inspect, "iscoroutinefunction")
and inspect.iscoroutinefunction(func)
)
)
def getlocation(function, curdir):
fn = py.path.local(inspect.getfile(function))
lineno = function.__code__.co_firstlineno
if fn.relto(curdir):
fn = fn.relto(curdir)
return "%s:%d" % (fn, lineno + 1)
def num_mock_patch_args(function):
patchings = getattr(function, "patchings", None)
if not patchings:
return 0
mock_modules = [sys.modules.get("mock"), sys.modules.get("unittest.mock")]
if any(mock_modules):
sentinels = [m.DEFAULT for m in mock_modules if m is not None]
return len(
[p for p in patchings if not p.attribute_name and p.new in sentinels]
)
return len(patchings)
def getfuncargnames(function, is_method=False, cls=None):
arg_names = tuple(
p.name
for p in signature(function).parameters.values()
if (
p.kind is Parameter.POSITIONAL_OR_KEYWORD
or p.kind is Parameter.KEYWORD_ONLY
)
and p.default is Parameter.empty
)
if (
is_method
or (
cls
and not isinstance(cls.__dict__.get(function.__name__, None), staticmethod)
)
):
arg_names = arg_names[1:]
if hasattr(function, "__wrapped__"):
arg_names = arg_names[num_mock_patch_args(function):]
return arg_names
def get_default_arg_names(function):
return tuple(
p.name
for p in signature(function).parameters.values()
if p.kind in (Parameter.POSITIONAL_OR_KEYWORD, Parameter.KEYWORD_ONLY)
and p.default is not Parameter.empty
)
if _PY3:
STRING_TYPES = bytes, str
UNICODE_TYPES = six.text_type
if PY35:
def _bytes_to_ascii(val):
return val.decode("ascii", "backslashreplace")
else:
def _bytes_to_ascii(val):
if val:
encoded_bytes, _ = codecs.escape_encode(val)
return encoded_bytes.decode("ascii")
else:
return ""
def ascii_escaped(val):
if isinstance(val, bytes):
return _bytes_to_ascii(val)
else:
return val.encode("unicode_escape").decode("ascii")
else:
STRING_TYPES = six.string_types
UNICODE_TYPES = six.text_type
def ascii_escaped(val):
if isinstance(val, bytes):
try:
return val.encode("ascii")
except UnicodeDecodeError:
return val.encode("string-escape")
else:
return val.encode("unicode-escape")
def get_real_func(obj):
start_obj = obj
for i in range(100):
new_obj = getattr(obj, "__wrapped__", None)
if new_obj is None:
break
obj = new_obj
else:
raise ValueError(
("could not find real function of {start}" "\nstopped at {current}").format(
start=py.io.saferepr(start_obj), current=py.io.saferepr(obj)
)
)
if isinstance(obj, functools.partial):
obj = obj.func
return obj
def getfslineno(obj):
obj = get_real_func(obj)
if hasattr(obj, "place_as"):
obj = obj.place_as
fslineno = _pytest._code.getfslineno(obj)
assert isinstance(fslineno[1], int), obj
return fslineno
def getimfunc(func):
try:
return func.__func__
except AttributeError:
return func
def safe_getattr(object, name, default):
try:
return getattr(object, name, default)
except TEST_OUTCOME:
return default
def _is_unittest_unexpected_success_a_failure():
return sys.version_info >= (3, 4)
if _PY3:
def safe_str(v):
return str(v)
else:
def safe_str(v):
try:
return str(v)
except UnicodeError:
if not isinstance(v, text_type):
v = text_type(v)
errors = "replace"
return v.encode("utf-8", errors)
COLLECT_FAKEMODULE_ATTRIBUTES = (
"Collector",
"Module",
"Generator",
"Function",
"Instance",
"Session",
"Item",
"Class",
"File",
"_fillfuncargs",
)
def _setup_collect_fakemodule():
from types import ModuleType
import pytest
pytest.collect = ModuleType("pytest.collect")
pytest.collect.__all__ = [] for attr in COLLECT_FAKEMODULE_ATTRIBUTES:
setattr(pytest.collect, attr, getattr(pytest, attr))
if _PY2:
from py.io import TextIO
class CaptureIO(TextIO):
@property
def encoding(self):
return getattr(self, "_encoding", "UTF-8")
else:
import io
class CaptureIO(io.TextIOWrapper):
def __init__(self):
super(CaptureIO, self).__init__(
io.BytesIO(), encoding="UTF-8", newline="", write_through=True
)
def getvalue(self):
return self.buffer.getvalue().decode("UTF-8")
class FuncargnamesCompatAttr(object):
@property
def funcargnames(self):
return self.fixturenames