from __future__ import annotations
import asyncio
from typing import Any, NamedTuple, Self
from urllib.parse import urljoin
import httpx
from .exceptions import (
AuthFrameworkError,
NetworkError,
TimeoutError as AuthTimeoutError,
create_error_from_response,
is_retryable_error,
)
HTTP_SUCCESS_THRESHOLD = 400
class RequestConfig(NamedTuple):
json_data: dict[str, Any] | None = None
form_data: dict[str, str | None] | None = None
params: dict[str, Any] | None = None
timeout: float | None = None
retries: int | None = None
class BaseClient:
def __init__(
self,
base_url: str,
timeout: float = 30.0,
retries: int = 3,
api_key: str | None = None,
) -> None:
self.base_url = base_url.rstrip("/")
self.timeout = timeout
self.retries = retries
self.api_key = api_key
self._access_token: str | None = None
headers = {"User-Agent": "AuthFramework-Python-SDK/1.0.0"}
if api_key:
headers["X-API-Key"] = api_key
self._client = httpx.AsyncClient(
timeout=timeout,
headers=headers,
)
async def __aenter__(self) -> Self:
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: object,
) -> None:
await self._client.aclose()
async def close(self) -> None:
await self._client.aclose()
def set_access_token(self, token: str) -> None:
self._access_token = token
def clear_access_token(self) -> None:
self._access_token = None
def get_access_token(self) -> str | None:
return self._access_token
async def make_request(
self,
method: str,
endpoint: str,
*,
config: RequestConfig | None = None,
) -> dict[str, Any]:
if config is None:
config = RequestConfig()
url = urljoin(self.base_url, endpoint.lstrip("/"))
request_timeout = config.timeout or self.timeout
request_retries = config.retries if config.retries is not None else self.retries
headers: dict[str, str] = {}
if self._access_token:
headers["Authorization"] = f"Bearer {self._access_token}"
for attempt in range(request_retries + 1):
response = await self._attempt_request(
method,
url,
headers,
config,
request_timeout,
)
if response:
return response
if attempt < request_retries:
await asyncio.sleep(min(2**attempt, 10))
retries_msg = "Max retries exceeded"
raise AuthFrameworkError(retries_msg)
async def _attempt_request(
self,
method: str,
url: str,
headers: dict[str, str],
config: RequestConfig,
request_timeout: float,
) -> dict[str, Any] | None:
try:
response = await self._execute_request(
method,
url,
headers,
config,
request_timeout,
)
if response.status_code < HTTP_SUCCESS_THRESHOLD:
return response.json()
error_info = self._parse_error_response(response)
self._raise_api_error(response.status_code, error_info)
except httpx.TimeoutException as e:
timeout_msg = "Request timeout"
raise AuthTimeoutError(timeout_msg) from e
except httpx.NetworkError as e:
network_msg = "Network error"
raise NetworkError(network_msg) from e
except AuthFrameworkError:
raise
except Exception as e:
if not is_retryable_error(e):
failed_msg = "Request failed"
raise AuthFrameworkError(failed_msg) from e
return None
return None
async def _execute_request(
self,
method: str,
url: str,
headers: dict[str, str],
config: RequestConfig,
timeout: float,
) -> httpx.Response:
if config.form_data:
headers["Content-Type"] = "application/x-www-form-urlencoded"
return await self._client.request(
method,
url,
data=config.form_data,
params=config.params,
headers=headers,
timeout=timeout,
)
return await self._client.request(
method,
url,
json=config.json_data,
params=config.params,
headers=headers,
timeout=timeout,
)
@staticmethod
def _parse_error_response(response: httpx.Response) -> dict[str, Any]:
try:
error_data = response.json()
return error_data.get("error", {})
except (ValueError, KeyError):
return {"message": response.text, "code": "UNKNOWN_ERROR"}
@staticmethod
def _raise_api_error(status_code: int, error_info: dict[str, Any]) -> None:
raise create_error_from_response(status_code, error_info)