from __future__ import annotations
import logging
from dataclasses import dataclass, field
from typing import Iterable, Iterator, Optional, Protocol
log = logging.getLogger(__name__)
@dataclass(frozen=True)
class User:
id: str
email: str
name: str
tags: list[str] = field(default_factory=list)
class UserNotFoundError(Exception):
def __init__(self, user_id: str) -> None:
super().__init__(f"user {user_id} not found")
self.user_id = user_id
class ConflictError(Exception):
def __init__(self, field_name: str, value: str) -> None:
super().__init__(f"conflict on {field_name}={value!r}")
self.field_name = field_name
self.value = value
class UserRepository(Protocol):
def find_by_id(self, user_id: str) -> Optional[User]: ...
def find_by_email(self, email: str) -> Optional[User]: ...
def insert(self, user: User) -> User: ...
def scan(self) -> Iterator[User]: ...
class UserService:
def __init__(self, repo: UserRepository) -> None:
self._repo = repo
def get(self, user_id: str) -> User:
user = self._repo.find_by_id(user_id)
if user is None:
raise UserNotFoundError(user_id)
return user
def create(self, email: str, name: str, tags: Iterable[str] = ()) -> User:
if self._repo.find_by_email(email) is not None:
raise ConflictError("email", email)
new_user = User(id=make_id(email), email=email, name=name, tags=list(tags))
log.info("creating user", extra={"email": email})
return self._repo.insert(new_user)
def with_tag(self, tag: str) -> list[User]:
return [u for u in self._repo.scan() if tag in u.tags]
class InMemoryRepository:
def __init__(self) -> None:
self._by_id: dict[str, User] = {}
def find_by_id(self, user_id: str) -> Optional[User]:
return self._by_id.get(user_id)
def find_by_email(self, email: str) -> Optional[User]:
for user in self._by_id.values():
if user.email == email:
return user
return None
def insert(self, user: User) -> User:
if self.find_by_email(user.email) is not None:
raise ConflictError("email", user.email)
self._by_id[user.id] = user
return user
def scan(self) -> Iterator[User]:
yield from self._by_id.values()
def make_id(email: str) -> str:
head, _, _ = email.partition("@")
return head.lower() if head else email.lower()