use std::fmt::Write;
#[derive(Clone, Debug)]
pub struct Field {
pub name: String,
pub field_type: String,
pub optional: bool,
pub default: Option<String>,
}
#[derive(Clone, Debug)]
pub struct Endpoint {
pub method: String,
pub path: String,
pub handler: String,
pub status_code: u16,
}
pub struct PythonGenerator;
impl PythonGenerator {
pub fn generate_fastapi_service(
service_name: &str, port: u16, description: &str,
) -> Result<String, String> {
let mut out = String::new();
let snake = to_snake(service_name);
writeln!(out, "\"\"\"").map_err(|e| e.to_string())?;
writeln!(out, "{} — {}", service_name, description).map_err(|e| e.to_string())?;
writeln!(out, "\nAuto-generated by ggen. Edit with care.").map_err(|e| e.to_string())?;
writeln!(out, "\"\"\"\n").map_err(|e| e.to_string())?;
writeln!(out, "from contextlib import asynccontextmanager").map_err(|e| e.to_string())?;
writeln!(out, "import logging\n").map_err(|e| e.to_string())?;
writeln!(out, "from fastapi import FastAPI, HTTPException, status")
.map_err(|e| e.to_string())?;
writeln!(out, "from fastapi.responses import JSONResponse").map_err(|e| e.to_string())?;
writeln!(out, "from {snake}_router import router as {snake}_router\n")
.map_err(|e| e.to_string())?;
writeln!(out, "logger = logging.getLogger(__name__)\n").map_err(|e| e.to_string())?;
writeln!(out, "\n@asynccontextmanager").map_err(|e| e.to_string())?;
writeln!(out, "async def lifespan(app: FastAPI):").map_err(|e| e.to_string())?;
writeln!(
out,
" logger.info(\"{service_name} starting on port {port}\")"
)
.map_err(|e| e.to_string())?;
writeln!(out, " # Initialise DB pool, cache, OTEL exporter here")
.map_err(|e| e.to_string())?;
writeln!(out, " yield").map_err(|e| e.to_string())?;
writeln!(
out,
" # Armstrong: cleanup — let downstream errors propagate"
)
.map_err(|e| e.to_string())?;
writeln!(
out,
" logger.info(\"{service_name} shutdown complete\")\n"
)
.map_err(|e| e.to_string())?;
writeln!(out, "\napp = FastAPI(").map_err(|e| e.to_string())?;
writeln!(out, " title=\"{service_name}\",").map_err(|e| e.to_string())?;
writeln!(out, " description=\"{description}\",").map_err(|e| e.to_string())?;
writeln!(out, " version=\"1.0.0\",").map_err(|e| e.to_string())?;
writeln!(out, " lifespan=lifespan,").map_err(|e| e.to_string())?;
writeln!(out, ")\n").map_err(|e| e.to_string())?;
writeln!(
out,
"app.include_router({snake}_router, prefix=\"/api/v1\")\n"
)
.map_err(|e| e.to_string())?;
writeln!(out, "\n@app.get(\"/health\", tags=[\"ops\"])").map_err(|e| e.to_string())?;
writeln!(out, "async def health():").map_err(|e| e.to_string())?;
writeln!(
out,
" \"\"\"Liveness probe — returns 200 when the process is running.\"\"\""
)
.map_err(|e| e.to_string())?;
writeln!(
out,
" return {{\"status\": \"ok\", \"service\": \"{service_name}\"}}\n"
)
.map_err(|e| e.to_string())?;
writeln!(out, "\n@app.exception_handler(Exception)").map_err(|e| e.to_string())?;
writeln!(
out,
"async def unhandled_exception_handler(request, exc: Exception):"
)
.map_err(|e| e.to_string())?;
writeln!(
out,
" logger.error(\"Unhandled exception: %s\", exc, exc_info=True)"
)
.map_err(|e| e.to_string())?;
writeln!(
out,
" return JSONResponse(\n status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,"
)
.map_err(|e| e.to_string())?;
writeln!(
out,
" content={{\"detail\": \"Internal server error\"}},\n )"
)
.map_err(|e| e.to_string())?;
Ok(out)
}
pub fn generate_pydantic_model(model_name: &str, fields: &[Field]) -> Result<String, String> {
let mut out = String::new();
writeln!(out, "\"\"\"").map_err(|e| e.to_string())?;
writeln!(out, "Pydantic v2 model for {}.", model_name).map_err(|e| e.to_string())?;
writeln!(out, "\nAuto-generated by ggen. Edit with care.").map_err(|e| e.to_string())?;
writeln!(out, "\"\"\"\n").map_err(|e| e.to_string())?;
writeln!(out, "from __future__ import annotations\n").map_err(|e| e.to_string())?;
writeln!(out, "from datetime import datetime").map_err(|e| e.to_string())?;
writeln!(out, "from typing import Optional").map_err(|e| e.to_string())?;
writeln!(out, "from uuid import UUID\n").map_err(|e| e.to_string())?;
writeln!(
out,
"from pydantic import BaseModel, Field, model_validator\n"
)
.map_err(|e| e.to_string())?;
writeln!(out, "\nclass {model_name}(BaseModel):").map_err(|e| e.to_string())?;
writeln!(out, " model_config = {{\"from_attributes\": True}}\n")
.map_err(|e| e.to_string())?;
for field in fields {
let type_annotation = if field.optional {
format!("Optional[{}]", field.field_type)
} else {
field.field_type.clone()
};
let default_part = if field.optional {
let default_val = field.default.as_deref().unwrap_or("None");
format!(" = Field(default={default_val})")
} else if let Some(ref d) = field.default {
format!(" = Field(default={d})")
} else {
String::new()
};
writeln!(
out,
" {}: {}{}",
field.name, type_annotation, default_part
)
.map_err(|e| e.to_string())?;
}
writeln!(out, "\n model_config = {{").map_err(|e| e.to_string())?;
writeln!(out, " \"from_attributes\": True,").map_err(|e| e.to_string())?;
writeln!(out, " \"json_schema_extra\": {{").map_err(|e| e.to_string())?;
writeln!(out, " \"example\": {{").map_err(|e| e.to_string())?;
for field in fields {
let example_val = example_value_for_type(&field.field_type, field.optional);
writeln!(out, " \"{}\": {},", field.name, example_val)
.map_err(|e| e.to_string())?;
}
writeln!(out, " }}").map_err(|e| e.to_string())?;
writeln!(out, " }}").map_err(|e| e.to_string())?;
writeln!(out, " }}\n").map_err(|e| e.to_string())?;
writeln!(out, " @model_validator(mode=\"after\")").map_err(|e| e.to_string())?;
writeln!(out, " def validate_model(self) -> \"{model_name}\":")
.map_err(|e| e.to_string())?;
writeln!(
out,
" \"\"\"Cross-field validation hook — add domain invariants here.\"\"\""
)
.map_err(|e| e.to_string())?;
writeln!(out, " return self").map_err(|e| e.to_string())?;
Ok(out)
}
pub fn generate_sqlalchemy_repository(
model_name: &str, fields: &[Field],
) -> Result<String, String> {
let mut out = String::new();
let table_name = format!("{}s", to_snake(model_name));
let snake_model = to_snake(model_name);
writeln!(out, "\"\"\"").map_err(|e| e.to_string())?;
writeln!(out, "SQLAlchemy 2.0 async repository for {}.", model_name)
.map_err(|e| e.to_string())?;
writeln!(
out,
"\nWvdA: all queries use execution_options(timeout=30)."
)
.map_err(|e| e.to_string())?;
writeln!(out, "\nAuto-generated by ggen. Edit with care.").map_err(|e| e.to_string())?;
writeln!(out, "\"\"\"\n").map_err(|e| e.to_string())?;
writeln!(out, "from __future__ import annotations\n").map_err(|e| e.to_string())?;
writeln!(out, "from typing import List, Optional").map_err(|e| e.to_string())?;
writeln!(out, "from uuid import UUID\n").map_err(|e| e.to_string())?;
writeln!(
out,
"from sqlalchemy import Column, String, Integer, Boolean, DateTime, select, update, delete"
)
.map_err(|e| e.to_string())?;
writeln!(
out,
"from sqlalchemy.dialects.postgresql import UUID as PG_UUID"
)
.map_err(|e| e.to_string())?;
writeln!(out, "from sqlalchemy.ext.asyncio import AsyncSession")
.map_err(|e| e.to_string())?;
writeln!(
out,
"from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column\n"
)
.map_err(|e| e.to_string())?;
writeln!(out, "\nclass Base(DeclarativeBase):").map_err(|e| e.to_string())?;
writeln!(out, " pass\n").map_err(|e| e.to_string())?;
writeln!(out, "\nclass {model_name}ORM(Base):").map_err(|e| e.to_string())?;
writeln!(
out,
" \"\"\"SQLAlchemy ORM mapping for {model_name}.\"\"\""
)
.map_err(|e| e.to_string())?;
writeln!(out, "\n __tablename__ = \"{table_name}\"\n").map_err(|e| e.to_string())?;
for field in fields {
let sa_type = sqlalchemy_type_for(&field.field_type);
let nullable_part = if field.optional {
", nullable=True"
} else {
""
};
writeln!(
out,
" {}: Mapped[{}] = mapped_column({}{nullable_part})",
field.name, field.field_type, sa_type
)
.map_err(|e| e.to_string())?;
}
writeln!(out).map_err(|e| e.to_string())?;
writeln!(out, "\nclass {model_name}Repository:").map_err(|e| e.to_string())?;
writeln!(
out,
" \"\"\"Async CRUD repository for {model_name}.\"\"\"\n"
)
.map_err(|e| e.to_string())?;
writeln!(
out,
" def __init__(self, session: AsyncSession) -> None:"
)
.map_err(|e| e.to_string())?;
writeln!(out, " self._session = session\n").map_err(|e| e.to_string())?;
writeln!(
out,
" async def create(self, {snake_model}: {model_name}ORM) -> {model_name}ORM:"
)
.map_err(|e| e.to_string())?;
writeln!(
out,
" \"\"\"Persist a new {model_name} and return the managed instance.\"\"\""
)
.map_err(|e| e.to_string())?;
writeln!(out, " self._session.add({snake_model})").map_err(|e| e.to_string())?;
writeln!(out, " await self._session.flush()").map_err(|e| e.to_string())?;
writeln!(out, " await self._session.refresh({snake_model})")
.map_err(|e| e.to_string())?;
writeln!(out, " return {snake_model}\n").map_err(|e| e.to_string())?;
writeln!(
out,
" async def get_by_id(self, record_id: UUID) -> Optional[{model_name}ORM]:"
)
.map_err(|e| e.to_string())?;
writeln!(
out,
" \"\"\"Retrieve a {model_name} by primary key. Returns None if not found.\"\"\""
)
.map_err(|e| e.to_string())?;
writeln!(
out,
" # WvdA: execution_options(timeout=30) prevents query deadlock"
)
.map_err(|e| e.to_string())?;
writeln!(
out,
" stmt = (\n select({model_name}ORM)\n .where({model_name}ORM.id == record_id)\n .execution_options(timeout=30)\n )"
)
.map_err(|e| e.to_string())?;
writeln!(out, " result = await self._session.execute(stmt)")
.map_err(|e| e.to_string())?;
writeln!(out, " return result.scalars().first()\n").map_err(|e| e.to_string())?;
writeln!(
out,
" async def list_all(self, limit: int = 100) -> List[{model_name}ORM]:"
)
.map_err(|e| e.to_string())?;
writeln!(
out,
" \"\"\"Return up to *limit* records (WvdA boundedness: default 100).\"\"\""
)
.map_err(|e| e.to_string())?;
writeln!(
out,
" stmt = (\n select({model_name}ORM)\n .limit(limit)\n .execution_options(timeout=30)\n )"
)
.map_err(|e| e.to_string())?;
writeln!(out, " result = await self._session.execute(stmt)")
.map_err(|e| e.to_string())?;
writeln!(out, " return list(result.scalars().all())\n")
.map_err(|e| e.to_string())?;
writeln!(
out,
" async def update(self, record_id: UUID, values: dict) -> Optional[{model_name}ORM]:"
)
.map_err(|e| e.to_string())?;
writeln!(
out,
" \"\"\"Update fields on a {model_name} by primary key.\"\"\""
)
.map_err(|e| e.to_string())?;
writeln!(
out,
" stmt = (\n update({model_name}ORM)\n .where({model_name}ORM.id == record_id)\n .values(**values)\n .execution_options(timeout=30)\n )"
)
.map_err(|e| e.to_string())?;
writeln!(out, " await self._session.execute(stmt)").map_err(|e| e.to_string())?;
writeln!(out, " return await self.get_by_id(record_id)\n")
.map_err(|e| e.to_string())?;
writeln!(out, " async def delete(self, record_id: UUID) -> bool:")
.map_err(|e| e.to_string())?;
writeln!(
out,
" \"\"\"Delete a {model_name} by primary key. Returns True if a row was deleted.\"\"\""
)
.map_err(|e| e.to_string())?;
writeln!(
out,
" stmt = (\n delete({model_name}ORM)\n .where({model_name}ORM.id == record_id)\n .execution_options(timeout=30)\n )"
)
.map_err(|e| e.to_string())?;
writeln!(out, " result = await self._session.execute(stmt)")
.map_err(|e| e.to_string())?;
writeln!(out, " return result.rowcount > 0").map_err(|e| e.to_string())?;
Ok(out)
}
pub fn generate_pytest_tests(
service_name: &str, endpoints: &[Endpoint],
) -> Result<String, String> {
let mut out = String::new();
let snake = to_snake(service_name);
writeln!(out, "\"\"\"").map_err(|e| e.to_string())?;
writeln!(out, "pytest test suite for {service_name}.").map_err(|e| e.to_string())?;
writeln!(
out,
"\nChicago TDD: behaviour verification via real FastAPI TestClient."
)
.map_err(|e| e.to_string())?;
writeln!(out, "\nAuto-generated by ggen. Edit with care.").map_err(|e| e.to_string())?;
writeln!(out, "\"\"\"\n").map_err(|e| e.to_string())?;
writeln!(out, "from __future__ import annotations\n").map_err(|e| e.to_string())?;
writeln!(out, "import pytest").map_err(|e| e.to_string())?;
writeln!(out, "import pytest_asyncio").map_err(|e| e.to_string())?;
writeln!(out, "from httpx import AsyncClient, ASGITransport\n")
.map_err(|e| e.to_string())?;
writeln!(
out,
"from {snake}_app import app # FastAPI application instance\n"
)
.map_err(|e| e.to_string())?;
writeln!(out, "\n@pytest_asyncio.fixture").map_err(|e| e.to_string())?;
writeln!(
out,
"async def async_client() -> AsyncClient: # type: ignore[override]"
)
.map_err(|e| e.to_string())?;
writeln!(
out,
" \"\"\"Fixture — yields an HTTPX async client wired to the FastAPI app.\"\"\""
)
.map_err(|e| e.to_string())?;
writeln!(
out,
" async with AsyncClient(\n transport=ASGITransport(app=app), base_url=\"http://test\"\n ) as client:"
)
.map_err(|e| e.to_string())?;
writeln!(out, " yield client\n").map_err(|e| e.to_string())?;
writeln!(out, "\n@pytest.mark.asyncio").map_err(|e| e.to_string())?;
writeln!(
out,
"async def test_health_check(async_client: AsyncClient) -> None:"
)
.map_err(|e| e.to_string())?;
writeln!(
out,
" \"\"\"Liveness probe returns 200 with service name.\"\"\""
)
.map_err(|e| e.to_string())?;
writeln!(out, " response = await async_client.get(\"/health\")")
.map_err(|e| e.to_string())?;
writeln!(out, " assert response.status_code == 200").map_err(|e| e.to_string())?;
writeln!(out, " assert response.json()[\"status\"] == \"ok\"")
.map_err(|e| e.to_string())?;
writeln!(
out,
" assert response.json()[\"service\"] == \"{service_name}\"\n"
)
.map_err(|e| e.to_string())?;
for ep in endpoints {
let method_lower = ep.method.to_lowercase();
let handler_snake = to_snake(&ep.handler);
let path_escaped = ep.path.replace('{', "{{").replace('}', "}}");
writeln!(out, "\n@pytest.mark.asyncio").map_err(|e| e.to_string())?;
writeln!(
out,
"async def test_{handler_snake}_happy_path(async_client: AsyncClient) -> None:"
)
.map_err(|e| e.to_string())?;
writeln!(
out,
" \"\"\"Happy path: {} {} returns {}.\"\"\"\n # Provide valid payload / path params",
ep.method, ep.path, ep.status_code
)
.map_err(|e| e.to_string())?;
writeln!(
out,
" response = await async_client.{method_lower}(\"{path_escaped}\")"
)
.map_err(|e| e.to_string())?;
writeln!(
out,
" assert response.status_code == {}\n",
ep.status_code
)
.map_err(|e| e.to_string())?;
if ep.path.contains('{') {
writeln!(out, "\n@pytest.mark.asyncio").map_err(|e| e.to_string())?;
writeln!(
out,
"async def test_{handler_snake}_not_found(async_client: AsyncClient) -> None:"
)
.map_err(|e| e.to_string())?;
writeln!(
out,
" \"\"\"Error path: non-existent resource returns 404.\"\"\""
)
.map_err(|e| e.to_string())?;
writeln!(
out,
" response = await async_client.{method_lower}(\"/api/v1/00000000-0000-0000-0000-000000000000\")"
)
.map_err(|e| e.to_string())?;
writeln!(out, " assert response.status_code == 404\n")
.map_err(|e| e.to_string())?;
}
}
Ok(out)
}
pub fn generate_requirements_txt(
service_name: &str, features: &[&str],
) -> Result<String, String> {
let mut out = String::new();
writeln!(out, "# requirements.txt for {service_name}").map_err(|e| e.to_string())?;
writeln!(
out,
"# Auto-generated by ggen. Pin versions before production deploy."
)
.map_err(|e| e.to_string())?;
writeln!(
out,
"# WvdA: explicit versions prevent unbounded dependency drift.\n"
)
.map_err(|e| e.to_string())?;
writeln!(out, "# Core framework").map_err(|e| e.to_string())?;
writeln!(out, "fastapi==0.115.0").map_err(|e| e.to_string())?;
writeln!(out, "uvicorn[standard]==0.30.6").map_err(|e| e.to_string())?;
writeln!(out, "pydantic==2.9.2").map_err(|e| e.to_string())?;
writeln!(out, "pydantic-settings==2.5.2").map_err(|e| e.to_string())?;
writeln!(out, "httpx==0.27.2\n").map_err(|e| e.to_string())?;
if features.contains(&"postgres") {
writeln!(out, "# PostgreSQL / SQLAlchemy").map_err(|e| e.to_string())?;
writeln!(out, "sqlalchemy==2.0.36").map_err(|e| e.to_string())?;
writeln!(out, "asyncpg==0.30.0").map_err(|e| e.to_string())?;
writeln!(out, "alembic==1.13.3\n").map_err(|e| e.to_string())?;
}
if features.contains(&"redis") {
writeln!(out, "# Redis cache").map_err(|e| e.to_string())?;
writeln!(out, "redis==5.2.0").map_err(|e| e.to_string())?;
writeln!(out, "hiredis==3.0.0\n").map_err(|e| e.to_string())?;
}
if features.contains(&"otel") {
writeln!(out, "# OpenTelemetry").map_err(|e| e.to_string())?;
writeln!(out, "opentelemetry-api==1.27.0").map_err(|e| e.to_string())?;
writeln!(out, "opentelemetry-sdk==1.27.0").map_err(|e| e.to_string())?;
writeln!(out, "opentelemetry-exporter-otlp==1.27.0").map_err(|e| e.to_string())?;
writeln!(out, "opentelemetry-instrumentation-fastapi==0.48b0\n")
.map_err(|e| e.to_string())?;
}
if features.contains(&"test") {
writeln!(out, "# Testing").map_err(|e| e.to_string())?;
writeln!(out, "pytest==8.3.3").map_err(|e| e.to_string())?;
writeln!(out, "pytest-asyncio==0.24.0").map_err(|e| e.to_string())?;
writeln!(out, "anyio==4.6.2\n").map_err(|e| e.to_string())?;
}
Ok(out)
}
pub fn generate_main(service_name: &str, port: u16) -> Result<String, String> {
let mut out = String::new();
let snake = to_snake(service_name);
writeln!(out, "\"\"\"").map_err(|e| e.to_string())?;
writeln!(out, "main.py — entry point for {service_name}.").map_err(|e| e.to_string())?;
writeln!(
out,
"\nArmstrong: lifespan context manager handles startup/shutdown."
)
.map_err(|e| e.to_string())?;
writeln!(out, "Let-it-crash: any startup failure exits with code 1.")
.map_err(|e| e.to_string())?;
writeln!(out, "\nAuto-generated by ggen. Edit with care.").map_err(|e| e.to_string())?;
writeln!(out, "\"\"\"\n").map_err(|e| e.to_string())?;
writeln!(out, "import logging").map_err(|e| e.to_string())?;
writeln!(out, "import signal").map_err(|e| e.to_string())?;
writeln!(out, "import sys\n").map_err(|e| e.to_string())?;
writeln!(out, "import uvicorn\n").map_err(|e| e.to_string())?;
writeln!(out, "from {snake}_app import app # noqa: F401\n").map_err(|e| e.to_string())?;
writeln!(
out,
"logging.basicConfig(\n level=logging.INFO,\n format=\"%(asctime)s [%(levelname)s] %(name)s: %(message)s\",\n)\nlogger = logging.getLogger(\"{snake}\")\n"
)
.map_err(|e| e.to_string())?;
writeln!(out, "# Armstrong: graceful shutdown on SIGINT / SIGTERM")
.map_err(|e| e.to_string())?;
writeln!(out, "_server: uvicorn.Server | None = None\n").map_err(|e| e.to_string())?;
writeln!(out, "\ndef _handle_signal(signum: int, _frame) -> None:")
.map_err(|e| e.to_string())?;
writeln!(
out,
" logger.info(\"{service_name} received signal %d, shutting down\", signum)"
)
.map_err(|e| e.to_string())?;
writeln!(out, " if _server:").map_err(|e| e.to_string())?;
writeln!(out, " _server.should_exit = True\n").map_err(|e| e.to_string())?;
writeln!(out, "\ndef main() -> None:").map_err(|e| e.to_string())?;
writeln!(out, " global _server").map_err(|e| e.to_string())?;
writeln!(out, " signal.signal(signal.SIGINT, _handle_signal)")
.map_err(|e| e.to_string())?;
writeln!(out, " signal.signal(signal.SIGTERM, _handle_signal)\n")
.map_err(|e| e.to_string())?;
writeln!(
out,
" config = uvicorn.Config(\n app=\"{snake}_app:app\","
)
.map_err(|e| e.to_string())?;
writeln!(out, " host=\"0.0.0.0\",").map_err(|e| e.to_string())?;
writeln!(out, " port={port},").map_err(|e| e.to_string())?;
writeln!(out, " log_level=\"info\",").map_err(|e| e.to_string())?;
writeln!(out, " access_log=True,").map_err(|e| e.to_string())?;
writeln!(out, " )").map_err(|e| e.to_string())?;
writeln!(out, " _server = uvicorn.Server(config)").map_err(|e| e.to_string())?;
writeln!(
out,
" logger.info(\"{service_name} starting on port {port}\")"
)
.map_err(|e| e.to_string())?;
writeln!(out, " _server.run()").map_err(|e| e.to_string())?;
writeln!(out, " logger.info(\"{service_name} stopped\")\n")
.map_err(|e| e.to_string())?;
writeln!(out, "\nif __name__ == \"__main__\":").map_err(|e| e.to_string())?;
writeln!(out, " main()").map_err(|e| e.to_string())?;
Ok(out)
}
}
fn to_snake(s: &str) -> String {
let mut out = String::new();
let chars: Vec<char> = s.chars().collect();
for (i, &c) in chars.iter().enumerate() {
if c.is_uppercase() {
if i != 0 {
out.push('_');
}
out.push(c.to_ascii_lowercase());
} else {
out.push(c);
}
}
out
}
fn sqlalchemy_type_for(py_type: &str) -> &'static str {
match py_type {
"str" => "String",
"int" => "Integer",
"bool" => "Boolean",
"float" => "Integer", "datetime" => "DateTime",
"UUID" => "PG_UUID(as_uuid=True)",
_ => "String",
}
}
fn example_value_for_type(py_type: &str, optional: bool) -> &'static str {
if optional {
return "null";
}
match py_type {
"str" => "\"example\"",
"int" => "0",
"bool" => "true",
"float" => "0.0",
"datetime" => "\"2024-01-01T00:00:00Z\"",
"UUID" => "\"00000000-0000-0000-0000-000000000000\"",
_ => "\"example\"",
}
}
#[cfg(test)]
mod tests {
use super::*;
fn order_fields() -> Vec<Field> {
vec![
Field {
name: "id".to_string(),
field_type: "UUID".to_string(),
optional: false,
default: None,
},
Field {
name: "customer_id".to_string(),
field_type: "UUID".to_string(),
optional: false,
default: None,
},
Field {
name: "status".to_string(),
field_type: "str".to_string(),
optional: false,
default: Some("\"pending\"".to_string()),
},
Field {
name: "total".to_string(),
field_type: "float".to_string(),
optional: true,
default: None,
},
]
}
fn order_endpoints() -> Vec<Endpoint> {
vec![
Endpoint {
method: "GET".to_string(),
path: "/api/v1/orders".to_string(),
handler: "list_orders".to_string(),
status_code: 200,
},
Endpoint {
method: "POST".to_string(),
path: "/api/v1/orders".to_string(),
handler: "create_order".to_string(),
status_code: 201,
},
Endpoint {
method: "GET".to_string(),
path: "/api/v1/orders/{order_id}".to_string(),
handler: "get_order".to_string(),
status_code: 200,
},
Endpoint {
method: "DELETE".to_string(),
path: "/api/v1/orders/{order_id}".to_string(),
handler: "delete_order".to_string(),
status_code: 204,
},
]
}
#[test]
fn test_fastapi_service_contains_health_endpoint() {
let result =
PythonGenerator::generate_fastapi_service("OrderService", 8001, "Manages orders");
assert!(result.is_ok(), "generate_fastapi_service should succeed");
let code = result.unwrap();
assert!(
code.contains("@app.get(\"/health\""),
"must contain @app.get(\"/health\")"
);
assert!(
code.contains("\"status\": \"ok\""),
"health response must include status: ok"
);
assert!(
code.contains("\"service\": \"OrderService\""),
"health response must include service name"
);
}
#[test]
fn test_fastapi_service_uses_lifespan() {
let result =
PythonGenerator::generate_fastapi_service("OrderService", 8001, "Manages orders");
assert!(result.is_ok());
let code = result.unwrap();
assert!(
code.contains("@asynccontextmanager"),
"must use @asynccontextmanager"
);
assert!(
code.contains("async def lifespan(app: FastAPI):"),
"must define lifespan function"
);
assert!(
code.contains("lifespan=lifespan"),
"FastAPI() must receive lifespan="
);
assert!(code.contains("yield"), "lifespan must yield");
}
#[test]
fn test_fastapi_service_non_empty_with_service_name() {
let result =
PythonGenerator::generate_fastapi_service("PaymentService", 9001, "Handles payments");
assert!(result.is_ok());
let code = result.unwrap();
assert!(!code.is_empty(), "generated code must not be empty");
assert!(
code.contains("PaymentService"),
"generated code must reference service name"
);
assert!(
code.contains("9001"),
"generated code must reference the port"
);
}
#[test]
fn test_pydantic_model_required_and_optional_fields() {
let fields = order_fields();
let result = PythonGenerator::generate_pydantic_model("Order", &fields);
assert!(result.is_ok(), "generate_pydantic_model should succeed");
let code = result.unwrap();
assert!(code.contains("id: UUID"), "must emit required UUID field");
assert!(code.contains("customer_id: UUID"), "must emit customer_id");
assert!(
code.contains("Optional[float]"),
"optional float must be Optional[float]"
);
assert!(
code.contains("from_attributes"),
"must include from_attributes ORM mode"
);
assert!(
code.contains("@model_validator"),
"must include @model_validator stub"
);
}
#[test]
fn test_sqlalchemy_repository_has_crud_methods() {
let fields = order_fields();
let result = PythonGenerator::generate_sqlalchemy_repository("Order", &fields);
assert!(
result.is_ok(),
"generate_sqlalchemy_repository should succeed"
);
let code = result.unwrap();
assert!(
code.contains("async def create("),
"must have create method"
);
assert!(
code.contains("async def get_by_id("),
"must have get_by_id method"
);
assert!(
code.contains("async def list_all("),
"must have list_all method"
);
assert!(
code.contains("async def update("),
"must have update method"
);
assert!(
code.contains("async def delete("),
"must have delete method"
);
}
#[test]
fn test_sqlalchemy_repository_uses_execution_timeout() {
let fields = order_fields();
let result = PythonGenerator::generate_sqlalchemy_repository("Order", &fields);
assert!(result.is_ok());
let code = result.unwrap();
let timeout_count = code.matches("execution_options(timeout=30)").count();
assert!(
timeout_count >= 3,
"at least 3 queries must use execution_options(timeout=30), found {}",
timeout_count
);
}
#[test]
fn test_pytest_file_has_fixtures_and_health_test() {
let endpoints = order_endpoints();
let result = PythonGenerator::generate_pytest_tests("OrderService", &endpoints);
assert!(result.is_ok(), "generate_pytest_tests should succeed");
let code = result.unwrap();
assert!(
code.contains("@pytest_asyncio.fixture"),
"must have async_client fixture"
);
assert!(
code.contains("async def async_client("),
"fixture must define async_client"
);
assert!(
code.contains("test_health_check"),
"must include health check test"
);
assert!(
code.contains("@pytest.mark.asyncio"),
"tests must be marked asyncio"
);
}
#[test]
fn test_requirements_txt_contains_fastapi_and_pydantic() {
let result =
PythonGenerator::generate_requirements_txt("OrderService", &["postgres", "test"]);
assert!(result.is_ok(), "generate_requirements_txt should succeed");
let code = result.unwrap();
assert!(code.contains("fastapi=="), "must pin fastapi");
assert!(code.contains("pydantic=="), "must pin pydantic");
assert!(
code.contains("sqlalchemy=="),
"postgres feature must add sqlalchemy"
);
assert!(code.contains("pytest=="), "test feature must add pytest");
}
#[test]
fn test_requirements_txt_no_optional_features() {
let result = PythonGenerator::generate_requirements_txt("MinimalService", &[]);
assert!(result.is_ok());
let code = result.unwrap();
assert!(code.contains("fastapi=="), "must always contain fastapi");
assert!(code.contains("pydantic=="), "must always contain pydantic");
assert!(
!code.contains("sqlalchemy"),
"sqlalchemy must NOT appear without postgres feature"
);
assert!(
!code.contains("pytest"),
"pytest must NOT appear without test feature"
);
}
#[test]
fn test_main_has_uvicorn_and_signal_handling() {
let result = PythonGenerator::generate_main("OrderService", 8001);
assert!(result.is_ok(), "generate_main should succeed");
let code = result.unwrap();
assert!(code.contains("uvicorn"), "must use uvicorn");
assert!(code.contains("signal.SIGINT"), "must handle SIGINT");
assert!(code.contains("signal.SIGTERM"), "must handle SIGTERM");
assert!(code.contains("8001"), "must reference the configured port");
assert!(
code.contains("if __name__ == \"__main__\":"),
"must have __main__ guard"
);
}
#[test]
fn test_pytest_generates_per_endpoint_tests() {
let endpoints = order_endpoints();
let result = PythonGenerator::generate_pytest_tests("OrderService", &endpoints);
assert!(result.is_ok());
let code = result.unwrap();
assert!(
code.contains("test_list_orders_happy_path"),
"must have test for list_orders"
);
assert!(
code.contains("test_create_order_happy_path"),
"must have test for create_order"
);
assert!(
code.contains("test_get_order_happy_path"),
"must have test for get_order"
);
assert!(
code.contains("test_get_order_not_found"),
"must have not-found test for get_order"
);
assert!(
code.contains("test_delete_order_not_found"),
"must have not-found test for delete_order"
);
}
#[test]
fn test_to_snake_converts_pascal_case() {
assert_eq!(to_snake("OrderService"), "order_service");
assert_eq!(to_snake("PaymentGateway"), "payment_gateway");
assert_eq!(to_snake("order"), "order");
assert_eq!(to_snake("A"), "a");
}
}