use anyhow::Result;
use crate::{file_manager::save_file_with_content, project_info::ProjectInfo};
fn create_config_test_file(project_info: &ProjectInfo) -> String {
let module = &project_info.module_name();
format!(
r#"import pytest
from pydantic import AnyUrl, SecretStr
from {module}.core.config import Settings
def test_check_default_secret_production():
with pytest.raises(ValueError):
Settings(
FIRST_SUPERUSER_EMAIL="user@email.com",
FIRST_SUPERUSER_PASSWORD=SecretStr("Abc$123be"),
FIRST_SUPERUSER_NAME="Some Name",
POSTGRES_HOST="http://localhost",
POSTGRES_USER="postgres",
POSTGRES_PASSWORD=SecretStr("Somepassword!"),
POSTGRES_DB="test_db",
VALKEY_HOST="http://localhost",
VALKEY_PASSWORD=SecretStr("Somepassword!"),
ENVIRONMENT="production",
SECRET_KEY=SecretStr("changethis"),
)
def test_check_default_secret_testing():
with pytest.raises(ValueError):
Settings(
FIRST_SUPERUSER_EMAIL="user@email.com",
FIRST_SUPERUSER_PASSWORD=SecretStr("Abc$123be"),
FIRST_SUPERUSER_NAME="Some Name",
POSTGRES_HOST="http://localhost",
POSTGRES_USER="postgres",
POSTGRES_PASSWORD=SecretStr("Somepassword!"),
POSTGRES_DB="test_db",
VALKEY_HOST="http://localhost",
VALKEY_PASSWORD=SecretStr("Somepassword!"),
ENVIRONMENT="testing",
SECRET_KEY=SecretStr("changethis"),
)
def test_check_default_secret_local():
with pytest.warns(
UserWarning,
match='The value of SECRET_KEY is "changethis", for security, please change it, at least for deployments.',
):
Settings(
FIRST_SUPERUSER_EMAIL="user@email.com",
FIRST_SUPERUSER_PASSWORD=SecretStr("Abc$123be"),
FIRST_SUPERUSER_NAME="Some Name",
POSTGRES_HOST="http://localhost",
POSTGRES_USER="postgres",
POSTGRES_PASSWORD=SecretStr("Somepassword!"),
POSTGRES_DB="test_db",
VALKEY_HOST="http://localhost",
VALKEY_PASSWORD=SecretStr("Somepassword!"),
ENVIRONMENT="local",
SECRET_KEY=SecretStr("changethis"),
)
def test_serer_host_production():
settings = Settings(
FIRST_SUPERUSER_EMAIL="user@email.com",
FIRST_SUPERUSER_PASSWORD=SecretStr("Abc$123be"),
FIRST_SUPERUSER_NAME="Some Name",
POSTGRES_HOST="http://localhost",
POSTGRES_USER="postgres",
POSTGRES_PASSWORD=SecretStr("Somepassword!"),
POSTGRES_DB="test_db",
VALKEY_HOST="http://localhost",
VALKEY_PASSWORD=SecretStr("Somepassword!"),
SECRET_KEY=SecretStr("Somesecretkey"),
ENVIRONMENT="production",
)
assert settings.server_host == f"https://{{settings.DOMAIN}}"
def test_serer_host_testing():
settings = Settings(
FIRST_SUPERUSER_EMAIL="user@email.com",
FIRST_SUPERUSER_PASSWORD=SecretStr("Abc$123be"),
FIRST_SUPERUSER_NAME="Some Name",
POSTGRES_HOST="http://localhost",
POSTGRES_USER="postgres",
POSTGRES_PASSWORD=SecretStr("Somepassword!"),
POSTGRES_DB="test_db",
VALKEY_HOST="http://localhost",
VALKEY_PASSWORD=SecretStr("Somepassword!"),
SECRET_KEY=SecretStr("Somesecretkey"),
ENVIRONMENT="testing",
)
assert settings.server_host == f"https://{{settings.DOMAIN}}"
def test_serer_host_local():
settings = Settings(
FIRST_SUPERUSER_EMAIL="user@email.com",
FIRST_SUPERUSER_PASSWORD=SecretStr("Abc$123be"),
FIRST_SUPERUSER_NAME="Some Name",
POSTGRES_HOST="http://localhost",
POSTGRES_USER="postgres",
POSTGRES_PASSWORD=SecretStr("Somepassword!"),
POSTGRES_DB="test_db",
VALKEY_HOST="http://localhost",
VALKEY_PASSWORD=SecretStr("Somepassword!"),
SECRET_KEY=SecretStr("Somesecretkey"),
ENVIRONMENT="local",
)
assert settings.server_host == f"http://{{settings.DOMAIN}}"
def test_parse_cors_error():
with pytest.raises(ValueError):
Settings(
FIRST_SUPERUSER_EMAIL="user@email.com",
FIRST_SUPERUSER_PASSWORD=SecretStr("Abc$123be"),
FIRST_SUPERUSER_NAME="Some Name",
POSTGRES_HOST="http://localhost",
POSTGRES_USER="postgres",
POSTGRES_PASSWORD=SecretStr("Somepassword!"),
POSTGRES_DB="test_db",
VALKEY_HOST="http://localhost",
VALKEY_PASSWORD=SecretStr("Somepassword!"),
SECRET_KEY=SecretStr("Somesecretkey"),
BACKEND_CORS_ORIGINS=1, # type: ignore
)
def test_parse_cors_string():
settings = Settings(
FIRST_SUPERUSER_EMAIL="user@email.com",
FIRST_SUPERUSER_PASSWORD=SecretStr("Abc$123be"),
FIRST_SUPERUSER_NAME="Some Name",
POSTGRES_HOST="http://localhost",
POSTGRES_USER="postgres",
POSTGRES_PASSWORD=SecretStr("Somepassword!"),
POSTGRES_DB="test_db",
VALKEY_HOST="http://localhost",
VALKEY_PASSWORD=SecretStr("Somepassword!"),
SECRET_KEY=SecretStr("Somesecretkey"),
BACKEND_CORS_ORIGINS="http://localhost, http://127.0.0.1",
)
assert settings.BACKEND_CORS_ORIGINS == [AnyUrl("http://localhost"), AnyUrl("http://127.0.0.1")]
"#
)
}
pub fn save_config_test_file(project_info: &ProjectInfo) -> Result<()> {
let base = &project_info.base_dir();
let file_path = base.join("tests/core/test_config.py");
let file_content = create_config_test_file(project_info);
save_file_with_content(&file_path, &file_content)?;
Ok(())
}
fn create_conftest_file(project_info: &ProjectInfo) -> String {
let module = &project_info.module_name();
format!(
r#"from __future__ import annotations
import itertools
import os
import subprocess
from pathlib import Path
from unittest.mock import patch
from uuid import uuid4
import asyncpg
import pytest
from httpx import ASGITransport, AsyncClient
from {module}.api.deps import get_cache_client, get_db_pool
from {module}.core.cache import cache
from {module}.core.config import settings
from {module}.core.db import Database
from {module}.main import app
from {module}.models.users import UserCreate
from {module}.services.db import user_services
from tests.utils import (
get_superuser_token_headers,
random_email,
random_lower_string,
random_password,
)
ROOT_PATH = Path().absolute()
ASSETS_DIR = ROOT_PATH / "tests" / "assets"
async def user_authentication_headers(test_client, email, password):
data = {{"username": email, "password": password}}
result = await test_client.post("/login/access-token", data=data)
response = result.json()
auth_token = response["access_token"]
return {{"Authorization": f"Bearer {{auth_token}}"}}
@pytest.fixture(scope="session")
def valkey_db_index(worker_id):
if worker_id == "master":
return 0
else:
return int(worker_id.lstrip("gw")) + 1
DBS_PER_WORKER = 5
MAX_DB_INDEX = 99
MAX_WORKERS = MAX_DB_INDEX // DBS_PER_WORKER
_db_counters: dict[int, itertools.count[int]] = {{}}
@pytest.fixture
def next_db(worker_id):
"""Calculate db number per worker so data doesn't clash in parallel tests."""
if worker_id == "master":
return 1
worker_num = int(worker_id.lstrip("gw") or "0")
if worker_num >= MAX_WORKERS:
raise RuntimeError(
f"Worker {{worker_id}} exceeds DB allocation limit (max {{MAX_WORKERS}} workers). "
f"Either reduce number of workers or decrease DBS_PER_WORKER."
)
base = 1 + (worker_num * DBS_PER_WORKER) # skip db=0
if base + DBS_PER_WORKER - 1 > MAX_DB_INDEX:
raise RuntimeError(f"Worker {{worker_id}} would exceed MAX_DB_INDEX with base {{base}}")
if worker_id not in _db_counters:
_db_counters[worker_id] = itertools.count(0)
offset = next(_db_counters[worker_id]) % DBS_PER_WORKER
db_index = base + offset
return db_index
@pytest.fixture
def db_name(worker_id):
base_name = "ae_reporter_test"
unique_suffix = str(uuid4()).replace("-", "")[:8]
if worker_id == "master":
return f"{{base_name}}_{{unique_suffix}}"
return f"{{base_name}}_{{worker_id}}_{{unique_suffix}}"
@pytest.fixture(autouse=True)
async def test_cache(next_db):
await cache.create_client(db=next_db)
yield cache
await cache.client.flushdb() # type: ignore
await cache.close_client()
@pytest.fixture
def apply_migrations(db_name):
test_db_url = f"postgresql://{{settings.POSTGRES_USER}}:{{settings.POSTGRES_PASSWORD.get_secret_value()}}@{{settings.POSTGRES_HOST}}:5432/{{db_name}}"
migration_dir = ROOT_PATH
with patch.dict(os.environ, {{"DATABASE_URL": test_db_url}}):
subprocess.run(["sqlx", "database", "create"], cwd=migration_dir)
subprocess.run(["sqlx", "migrate", "run"], cwd=migration_dir)
yield
@pytest.fixture
async def test_db(db_name, apply_migrations):
test_db = Database(db_name=db_name)
await test_db.create_pool(min_size=1, max_size=2)
await test_db.create_first_superuser()
yield test_db
await test_db.close_pool()
# Need to connect to "postgres" db instead of the db being dropped
conn = await asyncpg.connect(
database="postgres",
user=settings.POSTGRES_USER,
password=settings.POSTGRES_PASSWORD.get_secret_value(),
host=settings.POSTGRES_HOST,
)
# Terminate any remaining connections to the test database
await conn.execute(
"""
SELECT pg_terminate_backend(pid)
FROM pg_stat_activity
WHERE datname = $1 AND pid <> pg_backend_pid()
""",
db_name,
)
await conn.execute(f'DROP DATABASE "{{db_name}}"')
await conn.close()
@pytest.fixture
async def test_client(test_db, test_cache):
app.dependency_overrides[get_cache_client] = lambda: test_cache.client
app.dependency_overrides[get_db_pool] = lambda: test_db.db_pool
async with AsyncClient(
transport=ASGITransport(app=app), base_url=f"http://127.0.0.1{{settings.API_V1_PREFIX}}"
) as client:
yield client
app.dependency_overrides.clear()
@pytest.fixture
async def superuser_token_headers(test_client):
return await get_superuser_token_headers(test_client)
@pytest.fixture
def normal_user_credentials():
return {{
"password": random_password(),
"full_name": random_lower_string(),
"email": random_email(),
}}
@pytest.fixture
async def normal_user_token_headers(test_db, test_client, test_cache, normal_user_credentials):
user = await user_services.get_user_by_email(
pool=test_db.db_pool, email=normal_user_credentials["email"]
)
if not user:
user = await user_services.create_user(
pool=test_db.db_pool,
cache_client=test_cache.client,
user=UserCreate(
email=normal_user_credentials["email"],
password=normal_user_credentials["password"],
full_name=normal_user_credentials["full_name"],
),
)
return await user_authentication_headers(
test_client=test_client,
email=normal_user_credentials["email"],
password=normal_user_credentials["password"],
)
@pytest.fixture
async def test_user(test_db, test_cache):
email = random_email()
password = random_password()
full_name = random_lower_string()
user = await user_services.create_user(
pool=test_db.db_pool,
cache_client=test_cache.client,
user=UserCreate(
email=email,
password=password,
full_name=full_name,
),
)
return user
"#
)
}
pub fn save_conftest_file(project_info: &ProjectInfo) -> Result<()> {
let base = &project_info.base_dir();
let file_path = base.join("tests/conftest.py");
let file_content = create_conftest_file(project_info);
save_file_with_content(&file_path, &file_content)?;
Ok(())
}
fn create_health_route_test_file(project_info: &ProjectInfo) -> String {
let module = &project_info.module_name();
format!(
r#"from unittest.mock import AsyncMock, MagicMock
import pytest
from httpx import ASGITransport, AsyncClient
from {module}.api.deps import get_cache_client, get_db_pool
from {module}.core.config import settings
from {module}.main import app
@pytest.fixture
def failing_db_pool():
mock_pool = MagicMock()
mock_acquire = AsyncMock()
mock_acquire.__aenter__.side_effect = Exception("DB down")
mock_pool.acquire.return_value = mock_acquire
return mock_pool
@pytest.fixture
async def test_client_bad_db(failing_db_pool, test_cache):
app.dependency_overrides[get_cache_client] = lambda: test_cache.client
app.dependency_overrides[get_db_pool] = lambda: failing_db_pool
async with AsyncClient(
transport=ASGITransport(app=app), base_url=f"http://127.0.0.1{{settings.API_V1_PREFIX}}"
) as client:
yield client
app.dependency_overrides.clear()
@pytest.fixture
def failing_cache_client():
mock_client = MagicMock()
mock_acquire = AsyncMock()
mock_acquire.__aenter__.side_effect = Exception("Cache down")
mock_client.acquire.return_value = mock_acquire
return mock_client
@pytest.fixture
async def test_client_bad_cache(failing_cache_client, test_db):
app.dependency_overrides[get_cache_client] = lambda: failing_cache_client
app.dependency_overrides[get_db_pool] = lambda: test_db.db_pool
async with AsyncClient(
transport=ASGITransport(app=app), base_url=f"http://127.0.0.1{{settings.API_V1_PREFIX}}"
) as client:
yield client
app.dependency_overrides.clear()
async def test_health(test_client):
result = await test_client.get("health")
assert result.status_code == 200
assert result.json()["server"] == "healthy"
assert result.json()["db"] == "healthy"
assert result.json()["cache"] == "healthy"
async def test_health_no_db(test_client_bad_db):
result = await test_client_bad_db.get("health")
assert result.status_code == 200
assert result.json()["server"] == "healthy"
assert result.json()["db"] == "unhealthy"
assert result.json()["cache"] == "healthy"
async def test_health_no_cache(test_client_bad_cache):
result = await test_client_bad_cache.get("health")
assert result.status_code == 200
assert result.json()["server"] == "healthy"
assert result.json()["db"] == "healthy"
assert result.json()["cache"] == "unhealthy"
"#
)
}
pub fn save_health_route_test_file(project_info: &ProjectInfo) -> Result<()> {
let base = &project_info.base_dir();
let file_path = base.join("tests/api/routes/test_health_route.py");
let file_content = create_health_route_test_file(project_info);
save_file_with_content(&file_path, &file_content)?;
Ok(())
}
fn create_login_route_test_file(project_info: &ProjectInfo) -> String {
let module = &project_info.module_name();
format!(
r#"from unittest.mock import Mock
from fastapi import Request
from {module}.api.deps import get_current_user
from {module}.core.config import settings
from tests.utils import random_password
async def test_get_access_token(test_client):
login_data = {{
"username": settings.FIRST_SUPERUSER_EMAIL,
"password": settings.FIRST_SUPERUSER_PASSWORD.get_secret_value(),
}}
response = await test_client.post("/login/access-token", data=login_data)
tokens = response.json()
assert response.status_code == 200
assert "access_token" in tokens
assert tokens["access_token"]
async def test_get_access_token_incorrect_password(test_client):
login_data = {{
"username": settings.FIRST_SUPERUSER_EMAIL,
"password": random_password(),
}}
response = await test_client.post("/login/access-token", data=login_data)
assert response.status_code == 400
async def test_use_access_token(test_client, superuser_token_headers):
response = await test_client.post(
"/login/test-token",
headers=superuser_token_headers,
)
result = response.json()
assert response.status_code == 200
assert "email" in result
async def test_access_token_inactive_user(
test_client,
superuser_token_headers,
normal_user_token_headers,
normal_user_credentials,
test_db,
test_cache,
):
mock_request = Mock(spec=Request)
mock_request.url.path = "/api/v1/users/me"
user = await get_current_user(
test_db.db_pool,
test_cache.client,
normal_user_token_headers["Authorization"].split(" ", 1)[1],
)
test_client.cookies.clear()
response = await test_client.patch(
f"/users/{{user.id}}",
headers=superuser_token_headers,
json={{"fullName": user.full_name, "isActive": False}},
)
assert response.status_code == 200
login_data = {{
"username": user.email,
"password": normal_user_credentials["password"],
}}
test_client.cookies.clear()
response = await test_client.post("/login/access-token", data=login_data)
assert response.status_code == 401
"#
)
}
pub fn save_login_route_test_file(project_info: &ProjectInfo) -> Result<()> {
let base = &project_info.base_dir();
let file_path = base.join("tests/api/routes/test_login_routes.py");
let file_content = create_login_route_test_file(project_info);
save_file_with_content(&file_path, &file_content)?;
Ok(())
}
fn create_test_utils_file(project_info: &ProjectInfo) -> String {
let module = &project_info.module_name();
format!(
r#"import random
import string
from {module}.core.config import settings
def random_email() -> str:
return f"{{random_lower_string()}}@{{random_lower_string()}}.com"
def random_lower_string() -> str:
return "".join(random.choices(string.ascii_lowercase, k=32))
def random_password() -> str:
password = "".join(random.choices(string.ascii_lowercase, k=32))
return f"A{{password}}1_"
async def get_superuser_token_headers(test_client):
login_data = {{
"username": settings.FIRST_SUPERUSER_EMAIL,
"password": settings.FIRST_SUPERUSER_PASSWORD.get_secret_value(),
}}
response = await test_client.post("/login/access-token", data=login_data)
tokens = response.json()
a_token = tokens["access_token"]
headers = {{"Authorization": f"Bearer {{a_token}}"}}
return headers
"#
)
}
pub fn save_test_utils_file(project_info: &ProjectInfo) -> Result<()> {
let base = &project_info.base_dir();
let file_path = base.join("tests/utils.py");
let file_content = create_test_utils_file(project_info);
save_file_with_content(&file_path, &file_content)?;
Ok(())
}
fn create_test_deps_file(project_info: &ProjectInfo) -> String {
let module = &project_info.module_name();
format!(
r#"from unittest.mock import Mock
import pytest
from fastapi import HTTPException, Request
from {module}.api.deps import get_cache_client, get_current_user, get_db_pool
from {module}.core.cache import cache
from {module}.core.db import db
async def test_auth_no_authorization_in_header(test_client, normal_user_token_headers):
del normal_user_token_headers["Authorization"]
test_client.cookies.clear()
response = await test_client.get(
"/users/me",
headers=normal_user_token_headers,
)
assert response.status_code == 401
async def test_auth_no_bearer(test_client, normal_user_token_headers):
normal_user_token_headers["Authorization"] = normal_user_token_headers[
"Authorization"
].removeprefix("Bearer ")
test_client.cookies.clear()
response = await test_client.get(
"/users/me",
headers=normal_user_token_headers,
)
assert response.status_code == 401
async def test_get_current_user_invalid_token(test_db, test_cache):
mock_request = Mock(spec=Request)
mock_request.url.path = "/api/v1/users/me"
with pytest.raises(HTTPException) as ex:
await get_current_user(
test_db.db_pool,
test_cache.client,
"e",
)
assert ex.value.status_code == 403
async def test_get_current_user_inactive(
test_client, test_cache, normal_user_token_headers, superuser_token_headers, test_db
):
mock_request = Mock(spec=Request)
mock_request.url.path = "/api/v1/users/me"
user = await get_current_user(
test_db.db_pool,
test_cache.client,
normal_user_token_headers["Authorization"].split(" ", 1)[1],
)
test_client.cookies.clear()
response = await test_client.patch(
f"/users/{{user.id}}",
headers=superuser_token_headers,
json={{"fullName": user.full_name, "isActive": False}},
)
assert response.status_code == 200
with pytest.raises(HTTPException) as ex:
await get_current_user(
test_db.db_pool,
test_cache.client,
normal_user_token_headers["Authorization"].split(" ", 1)[1],
)
assert ex.value.status_code == 403
@pytest.fixture
async def temp_db_pool():
await db.create_pool()
yield
await db.close_pool()
@pytest.mark.usefixtures("temp_db_pool")
async def test_get_db_pool_success():
async for pool in get_db_pool():
assert pool is not None
@pytest.fixture
async def temp_cache_client():
await cache.create_client()
yield
await cache.close_client()
@pytest.mark.usefixtures("temp_cache_client")
async def test_get_cache_client_success():
async for client in get_cache_client():
assert client is not None
"#
)
}
pub fn save_test_deps_file(project_info: &ProjectInfo) -> Result<()> {
let base = &project_info.base_dir();
let file_path = base.join("tests/api/test_deps.py");
let file_content = create_test_deps_file(project_info);
save_file_with_content(&file_path, &file_content)?;
Ok(())
}
fn create_user_model_test_file(project_info: &ProjectInfo) -> String {
let module = &project_info.module_name();
format!(
r#"import pytest
from {module}.models.users import UserCreate, UserUpdate
from tests.utils import random_email, random_lower_string
@pytest.mark.parametrize("password", ("loweronly1.", "UPPER1*ONLY", "no@Number", "nospEcial4"))
def test_user_create_invalid_password(password):
with pytest.raises(ValueError) as e:
UserCreate(
email=random_email(),
full_name=random_lower_string(),
password=password,
)
assert (
"Password must contain at least one uppercase letter, one lowercase letter, one number, and one special character"
in (str(e.value))
)
def test_user_create_short_password():
with pytest.raises(ValueError) as e:
UserUpdate(
email=random_email(),
full_name=random_lower_string(),
password="Short1_",
)
assert "at least 8 characters" in (str(e.value))
@pytest.mark.parametrize("password", ("loweronly1.", "UPPER1*ONLY", "no@Number", "nospEcial4"))
def test_user_update_invalid_password(password):
with pytest.raises(ValueError) as e:
UserUpdate(
email=random_email(),
full_name=random_lower_string(),
password=password,
)
assert (
"Password must contain at least one uppercase letter, one lowercase letter, one number, and one special character"
in (str(e.value))
)
def test_user_update_short_password():
with pytest.raises(ValueError) as e:
UserCreate(
email=random_email(),
full_name=random_lower_string(),
password="Short1_",
)
assert "at least 8 characters" in (str(e.value))
"#
)
}
pub fn save_user_model_test_file(project_info: &ProjectInfo) -> Result<()> {
let base = &project_info.base_dir();
let file_path = base.join("tests/models/test_users.py");
let file_content = create_user_model_test_file(project_info);
save_file_with_content(&file_path, &file_content)?;
Ok(())
}
fn create_user_routes_test_file(project_info: &ProjectInfo) -> String {
let module = &project_info.module_name();
format!(
r#"from uuid import uuid4
from {module}.core.config import settings
from {module}.core.security import verify_password
from {module}.models.users import UserCreate
from {module}.services.db import user_services
from tests.utils import random_email, random_lower_string, random_password
async def test_get_users_superuser_me(test_client, superuser_token_headers):
response = await test_client.get("/users/me", headers=superuser_token_headers)
current_user = response.json()
assert current_user is not None
assert current_user["isActive"] is True
assert current_user["isSuperuser"]
assert current_user["email"] == settings.FIRST_SUPERUSER_EMAIL
assert current_user["fullName"] == settings.FIRST_SUPERUSER_NAME
async def test_get_users_normal_user_me(test_client, normal_user_token_headers):
response = await test_client.get("/users/me", headers=normal_user_token_headers)
current_user = response.json()
assert current_user is not None
assert current_user["isActive"] is True
assert current_user["isSuperuser"] is False
assert current_user["email"] is not None
async def test_get_existing_user(test_db, test_client, superuser_token_headers, test_user):
user_id = test_user.id
response = await test_client.get(
f"/users/{{user_id}}",
headers=superuser_token_headers,
)
assert 200 <= response.status_code < 300
api_user = response.json()
existing_user = await user_services.get_user_by_email(
pool=test_db.db_pool, email=test_user.email
)
assert existing_user
assert existing_user.email == api_user["email"]
async def test_get_user_not_found(test_client, superuser_token_headers):
response = await test_client.get(
"/users/bad",
headers=superuser_token_headers,
)
assert response.status_code == 404
async def test_get_existing_user_current_user(test_client, test_db, test_cache):
email = random_email()
password = random_password()
full_name = random_lower_string()
user = await user_services.create_user(
pool=test_db.db_pool,
cache_client=test_cache.client,
user=UserCreate(
email=email,
password=password,
full_name=full_name,
),
)
user_id = user.id
login_data = {{
"username": email,
"password": password,
}}
response = await test_client.post("/login/access-token", data=login_data)
tokens = response.json()
access_token = tokens["access_token"]
headers = {{"Authorization": f"Bearer {{access_token}}"}}
response = await test_client.get(
f"/users/{{user_id}}",
headers=headers,
)
assert 200 <= response.status_code < 300
api_user = response.json()
existing_user = await user_services.get_user_by_email(pool=test_db.db_pool, email=email)
assert existing_user
assert existing_user.email == api_user["email"]
async def test_get_existing_user_permissions_error(
test_client, normal_user_token_headers, test_user
):
response = await test_client.get(
f"/users/{{test_user.id}}",
headers=normal_user_token_headers,
)
assert response.status_code == 403
assert response.json() == {{"detail": "The user doesn't have enough privileges"}}
async def test_create_user(test_client):
username = random_email()
password = random_password()
full_name = random_lower_string()
data = {{
"email": username,
"password": password,
"fullName": full_name,
}}
response = await test_client.post(
"/users/",
json=data,
)
assert response.status_code == 200
async def test_create_user_existing_username(test_client, test_db, test_cache):
username = random_email()
password = random_password()
full_name = random_lower_string()
await user_services.create_user(
pool=test_db.db_pool,
cache_client=test_cache.client,
user=UserCreate(
email=username,
password=password,
full_name=full_name,
),
)
data = {{
"email": username,
"password": password,
"fullName": full_name,
}}
response = await test_client.post(
"/users/",
json=data,
)
created_user = response.json()
assert response.status_code == 400
assert "A user with this email address already exists" in created_user["detail"]
async def test_read_users(test_client, superuser_token_headers, test_db, test_cache):
username = random_email()
password = random_password()
full_name = random_lower_string()
username2 = random_email()
password2 = random_password()
full_name2 = random_lower_string()
await user_services.create_user(
pool=test_db.db_pool,
cache_client=test_cache.client,
user=UserCreate(
email=username,
password=password,
full_name=full_name,
),
)
await user_services.create_user(
pool=test_db.db_pool,
cache_client=test_cache.client,
user=UserCreate(
email=username2,
password=password2,
full_name=full_name2,
),
)
response = await test_client.get("/users/", headers=superuser_token_headers)
all_users = response.json()
assert len(all_users["data"]) > 1
assert "count" in all_users
for item in all_users["data"]:
assert "email" in item
assert all_users["totalUsers"] >= 2
async def test_update_user_me(test_client, normal_user_token_headers, test_db):
full_name = "Updated"
email = random_email()
data = {{"fullName": full_name, "email": email}}
response = await test_client.patch(
"/users/me",
headers=normal_user_token_headers,
json=data,
)
assert response.status_code == 200
updated_user = response.json()
assert updated_user["email"] == email
assert updated_user["fullName"] == full_name
user_db = await user_services.get_user_by_email(pool=test_db.db_pool, email=email)
assert user_db
assert user_db.email == email
assert user_db.full_name == full_name
async def test_update_password_me(test_client, superuser_token_headers, test_db):
new_password = random_password()
data = {{
"current_password": settings.FIRST_SUPERUSER_PASSWORD.get_secret_value(),
"new_password": new_password,
}}
response = await test_client.patch(
"/users/me/password",
headers=superuser_token_headers,
json=data,
)
assert response.status_code == 204
user_db = await user_services.get_user_by_email(
pool=test_db.db_pool, email=settings.FIRST_SUPERUSER_EMAIL
)
assert user_db
assert user_db.email == settings.FIRST_SUPERUSER_EMAIL
assert verify_password(new_password, user_db.hashed_password)
async def test_update_password_me_incorrect_password(test_client, superuser_token_headers):
bad_password = random_password()
new_password = random_password()
data = {{"current_password": bad_password, "new_password": new_password}}
response = await test_client.patch(
"/users/me/password",
headers=superuser_token_headers,
json=data,
)
assert response.status_code == 400
updated_user = response.json()
assert updated_user["detail"] == "Incorrect password"
async def test_update_user_me_email_exists(
test_client, test_db, normal_user_token_headers, test_cache
):
email = random_email()
password = random_password()
full_name = random_lower_string()
await user_services.create_user(
pool=test_db.db_pool,
cache_client=test_cache.client,
user=UserCreate(
email=email,
password=password,
full_name=full_name,
),
)
data = {{"email": email}}
response = await test_client.patch(
"/users/me",
headers=normal_user_token_headers,
json=data,
)
assert response.status_code == 409
assert response.json()["detail"] == "A user with this email address already exists"
async def test_update_password_me_same_password_error(test_client, superuser_token_headers):
data = {{
"currentPassword": settings.FIRST_SUPERUSER_PASSWORD.get_secret_value(),
"newPassword": settings.FIRST_SUPERUSER_PASSWORD.get_secret_value(),
}}
response = await test_client.patch(
"/users/me/password",
headers=superuser_token_headers,
json=data,
)
assert response.status_code == 400
updated_user = response.json()
assert updated_user["detail"] == "New password cannot be the same as the current one"
async def test_update_user(test_client, superuser_token_headers, test_db, test_user):
data = {{"fullName": "Updated_full_name"}}
response = await test_client.patch(
f"/users/{{test_user.id}}",
headers=superuser_token_headers,
json=data,
)
assert response.status_code == 200
updated_user = response.json()
assert updated_user["fullName"] == "Updated_full_name"
user_db = await user_services.get_user_by_email(pool=test_db.db_pool, email=test_user.email)
assert user_db
assert user_db.full_name == "Updated_full_name"
async def test_update_user_password(test_client, superuser_token_headers, test_user):
data = {{"password": "Test_password1"}}
response = await test_client.patch(
f"/users/{{test_user.id}}",
headers=superuser_token_headers,
json=data,
)
assert response.status_code == 200
async def test_update_user_not_exists(test_client, superuser_token_headers):
data = {{"fullName": "Updated_full_name"}}
response = await test_client.patch(
f"/users/{{str(uuid4())}}",
headers=superuser_token_headers,
json=data,
)
assert response.status_code == 404
assert response.json()["detail"] == "The user with this id does not exist in the system"
async def test_update_user_email_exists(test_client, superuser_token_headers, test_db, test_cache):
username = random_email()
password = random_password()
full_name = random_lower_string()
username2 = random_email()
password2 = random_password()
full_name_2 = random_lower_string()
user = await user_services.create_user(
pool=test_db.db_pool,
cache_client=test_cache.client,
user=UserCreate(
email=username,
password=password,
full_name=full_name,
),
)
user2 = await user_services.create_user(
pool=test_db.db_pool,
cache_client=test_cache.client,
user=UserCreate(
email=username2,
password=password2,
full_name=full_name_2,
),
)
data = {{"email": user2.email}}
response = await test_client.patch(
f"/users/{{user.id}}",
headers=superuser_token_headers,
json=data,
)
assert response.status_code == 409
assert response.json()["detail"] == "User with this email already exists"
async def test_delete_user_me(test_client, test_db, test_cache):
username = random_email()
password = random_password()
full_name = random_lower_string()
user = await user_services.create_user(
pool=test_db.db_pool,
cache_client=test_cache.client,
user=UserCreate(
email=username,
password=password,
full_name=full_name,
),
)
user_id = user.id
login_data = {{
"username": username,
"password": password,
}}
response = await test_client.post("/login/access-token", data=login_data)
tokens = response.json()
access_token = tokens["access_token"]
headers = {{"Authorization": f"Bearer {{access_token}}"}}
response = await test_client.delete(
"/users/me",
headers=headers,
)
assert response.status_code == 204
result = await user_services.get_user_by_id(
pool=test_db.db_pool, cache_client=test_cache.client, user_id=user_id
)
assert result is None
async def test_delete_user_me_as_superuser(test_client, superuser_token_headers):
response = await test_client.delete(
"/users/me",
headers=superuser_token_headers,
)
assert response.status_code == 400
response = response.json()
assert response["detail"] == "Super users are not allowed to delete themselves"
async def test_delete_user_super_user(
test_client, superuser_token_headers, test_db, test_user, test_cache
):
user_id = test_user.id
response = await test_client.delete(
f"/users/{{user_id}}",
headers=superuser_token_headers,
)
assert response.status_code == 200
deleted_user = response.json()
assert deleted_user["message"] == "User deleted successfully"
result = await user_services.get_user_by_id(
pool=test_db.db_pool, cache_client=test_cache.client, user_id=user_id
)
assert result is None
async def test_delete_user_not_found(test_client, superuser_token_headers):
response = await test_client.delete(
f"/users/{{str(uuid4())}}",
headers=superuser_token_headers,
)
assert response.status_code == 404
assert response.json()["detail"] == "User not found"
async def test_delete_user_current_super_user_error(test_client, superuser_token_headers, test_db):
super_user = await user_services.get_user_by_email(
pool=test_db.db_pool, email=settings.FIRST_SUPERUSER_EMAIL
)
assert super_user
user_id = super_user.id
response = await test_client.delete(
f"/users/{{user_id}}",
headers=superuser_token_headers,
)
assert response.status_code == 403
assert response.json()["detail"] == "Super users are not allowed to delete themselves"
async def test_delete_user_without_privileges(test_client, normal_user_token_headers, test_user):
response = await test_client.delete(
f"/users/{{test_user.id}}",
headers=normal_user_token_headers,
)
assert response.status_code == 403
assert response.json()["detail"] == "The user doesn't have enough privileges"
"#
)
}
pub fn save_user_routes_test_file(project_info: &ProjectInfo) -> Result<()> {
let base = &project_info.base_dir();
let file_path = base.join("tests/api/routes/test_users.py");
let file_content = create_user_routes_test_file(project_info);
save_file_with_content(&file_path, &file_content)?;
Ok(())
}
fn create_version_route_test_file(project_info: &ProjectInfo) -> String {
let module = &project_info.module_name();
format!(
r#"from {module} import __version__
async def test_read_version(test_client):
response = await test_client.get("version")
assert response.status_code == 200
assert response.json()["version"] == __version__
"#
)
}
pub fn save_version_route_test_file(project_info: &ProjectInfo) -> Result<()> {
let base = &project_info.base_dir();
let file_path = base.join("tests/api/routes/test_version.py");
let file_content = create_version_route_test_file(project_info);
save_file_with_content(&file_path, &file_content)?;
Ok(())
}
fn create_user_services_cache_test_file(project_info: &ProjectInfo) -> String {
let module = &project_info.module_name();
format!(
r#"import pytest
from {module}.services.cache.user_cache_services import delete_all_users_public
from {module}.services.db import user_services
@pytest.mark.usefixtures("test_user")
async def test_delete_all_users_public(test_db, test_cache):
await user_services.get_users_public(pool=test_db.db_pool, cache_client=test_cache.client)
keys_before = [key async for key in test_cache.client.scan_iter("users:public:*")]
assert len(keys_before) > 0
await delete_all_users_public(cache_client=test_cache.client)
keys_after = [key async for key in test_cache.client.scan_iter("users:public:*")]
assert len(keys_after) == 0
"#
)
}
pub fn save_user_services_cache_test_file(project_info: &ProjectInfo) -> Result<()> {
let base = &project_info.base_dir();
let file_path = base.join("tests/services/cache/test_user_services.py");
let file_content = create_user_services_cache_test_file(project_info);
save_file_with_content(&file_path, &file_content)?;
Ok(())
}
fn create_user_services_db_test_file(project_info: &ProjectInfo) -> String {
let module = &project_info.module_name();
format!(
r#"import pytest
from {module}.services.db.user_services import get_user_public_by_email, get_users_public
from tests.utils import random_email
@pytest.mark.usefixtures("test_user")
async def test_get_users_public_cache(test_db, test_cache):
result = await get_users_public(pool=test_db.db_pool, cache_client=test_cache.client)
# retrieve again to hit cache
result_cache = await get_users_public(pool=test_db.db_pool, cache_client=test_cache.client)
assert result == result_cache
async def test_get_user_public_by_email(test_db, test_user):
result = await get_user_public_by_email(pool=test_db.db_pool, email=test_user.email)
assert result is not None
assert result.email == test_user.email
async def test_get_user_public_by_email_not_found(test_db):
result = await get_user_public_by_email(pool=test_db.db_pool, email=random_email())
assert result is None
"#
)
}
pub fn save_user_services_db_test_file(project_info: &ProjectInfo) -> Result<()> {
let base = &project_info.base_dir();
let file_path = base.join("tests/services/db/test_user_services.py");
let file_content = create_user_services_db_test_file(project_info);
save_file_with_content(&file_path, &file_content)?;
Ok(())
}
fn create_main_test_file(project_info: &ProjectInfo) -> String {
let module = &project_info.module_name();
format!(
r#"import importlib
from unittest.mock import patch
from fastapi.testclient import TestClient
from loguru import logger
from {module} import main
from {module}.core.config import settings
async def test_http_exception_handler(test_client, normal_user_token_headers, caplog):
logger.add(caplog.handler, level="ERROR", format="{{message}}")
with patch(
"{module}.services.db.user_services.get_user_by_id",
side_effect=Exception("Server crashed"),
):
response = await test_client.get("users/me", headers=normal_user_token_headers)
assert response.status_code == 500
assert "Server crashed" in caplog.text
def test_cors_middleware_added(test_client):
with patch.object(
type(settings),
"all_cors_origins",
new=property(lambda _: ["https://example.com"]),
):
importlib.reload(main)
app = main.app
client = TestClient(app)
resp = client.options(
"/",
headers={{
"Origin": "https://example.com",
"Access-Control-Request-Method": "GET",
"Access-Control-Request-Headers": "Authorization, Content-Type",
}},
)
assert resp.status_code == 200
assert resp.headers.get("access-control-allow-origin") == "https://example.com"
"#
)
}
pub fn save_main_test_file(project_info: &ProjectInfo) -> Result<()> {
let base = &project_info.base_dir();
let file_path = base.join("tests/test_main.py");
let file_content = create_main_test_file(project_info);
save_file_with_content(&file_path, &file_content)?;
Ok(())
}