import asyncio
import contextvars
import inspect
import warnings
from .case import TestCase
__unittest = True
class IsolatedAsyncioTestCase(TestCase):
loop_factory = None
def __init__(self, methodName='runTest'):
super().__init__(methodName)
self._asyncioRunner = None
self._asyncioTestContext = contextvars.copy_context()
async def asyncSetUp(self):
pass
async def asyncTearDown(self):
pass
def addAsyncCleanup(self, func, /, *args, **kwargs):
self.addCleanup(*(func, *args), **kwargs)
async def enterAsyncContext(self, cm):
cls = type(cm)
try:
enter = cls.__aenter__
exit = cls.__aexit__
except AttributeError:
msg = (f"'{cls.__module__}.{cls.__qualname__}' object does "
"not support the asynchronous context manager protocol")
try:
cls.__enter__
cls.__exit__
except AttributeError:
pass
else:
msg += (" but it supports the context manager protocol. "
"Did you mean to use enterContext()?")
raise TypeError(msg) from None
result = await enter(cm)
self.addAsyncCleanup(exit, cm, None, None, None)
return result
def _callSetUp(self):
self._asyncioRunner.get_loop()
self._asyncioTestContext.run(self.setUp)
self._callAsync(self.asyncSetUp)
def _callTestMethod(self, method):
result = self._callMaybeAsync(method)
if result is not None:
msg = (
f'It is deprecated to return a value that is not None '
f'from a test case ({method} returned {type(result).__name__!r})',
)
warnings.warn(msg, DeprecationWarning, stacklevel=4)
def _callTearDown(self):
self._callAsync(self.asyncTearDown)
self._asyncioTestContext.run(self.tearDown)
def _callCleanup(self, function, *args, **kwargs):
self._callMaybeAsync(function, *args, **kwargs)
def _callAsync(self, func, /, *args, **kwargs):
assert self._asyncioRunner is not None, 'asyncio runner is not initialized'
assert inspect.iscoroutinefunction(func), f'{func!r} is not an async function'
return self._asyncioRunner.run(
func(*args, **kwargs),
context=self._asyncioTestContext
)
def _callMaybeAsync(self, func, /, *args, **kwargs):
assert self._asyncioRunner is not None, 'asyncio runner is not initialized'
if inspect.iscoroutinefunction(func):
return self._asyncioRunner.run(
func(*args, **kwargs),
context=self._asyncioTestContext,
)
else:
return self._asyncioTestContext.run(func, *args, **kwargs)
def _setupAsyncioRunner(self):
assert self._asyncioRunner is None, 'asyncio runner is already initialized'
runner = asyncio.Runner(debug=True, loop_factory=self.loop_factory)
self._asyncioRunner = runner
def _tearDownAsyncioRunner(self):
runner = self._asyncioRunner
runner.close()
def run(self, result=None):
self._setupAsyncioRunner()
try:
return super().run(result)
finally:
self._tearDownAsyncioRunner()
def debug(self):
self._setupAsyncioRunner()
super().debug()
self._tearDownAsyncioRunner()
def __del__(self):
if self._asyncioRunner is not None:
self._tearDownAsyncioRunner()