use super::GraphQLGenerator;
use crate::codegen::common::case_conversion::to_snake_case;
use crate::codegen::common::escaping::{EscapeContext, escape_for_docstring};
use crate::codegen::graphql::sdl::{SdlBuilder, TargetLanguage, TypeMapper};
use crate::codegen::graphql::spec_parser::GraphQLSchema;
use anyhow::Result;
#[derive(Default, Debug, Clone, Copy)]
pub struct PythonGenerator;
impl GraphQLGenerator for PythonGenerator {
fn generate_types(&self, schema: &GraphQLSchema) -> Result<String> {
use crate::codegen::graphql::spec_parser::TypeKind;
let mut code = String::new();
code.push_str("#!/usr/bin/env python3\n");
code.push_str("# ruff: noqa: EXE001, I001\n");
code.push_str("# DO NOT EDIT - Auto-generated by Spikard CLI\n");
code.push_str("#\n");
code.push_str("# This file was automatically generated from your GraphQL schema.\n");
code.push_str("# Any manual changes will be overwritten on the next generation.\n");
code.push_str("\"\"\"GraphQL types generated from schema.\"\"\"\n\n");
let has_enums = schema.types.values().any(|t| t.kind == TypeKind::Enum);
let has_structs = schema.types.values().any(|t| {
matches!(t.kind, TypeKind::InputObject | TypeKind::Object)
&& t.name != "Query"
&& t.name != "Mutation"
&& t.name != "Subscription"
});
let has_unions = schema.types.values().any(|t| t.kind == TypeKind::Union);
code.push_str("from __future__ import annotations\n");
if has_enums {
code.push_str("from enum import Enum\n");
}
if has_structs {
code.push_str("from msgspec import Struct\n");
}
if has_unions {
code.push_str("from typing import TypeAlias\n");
}
code.push('\n');
let mapper = TypeMapper::new(TargetLanguage::Python, Some(schema));
for type_def in schema.types.values() {
if type_def.kind == TypeKind::Enum {
code.push_str(&format!("class {}(str, Enum):\n", type_def.name));
if let Some(desc) = &type_def.description {
code.push_str(&format!(
" \"\"\"{}\"\"\"\n",
escape_for_docstring(desc, EscapeContext::Python)
));
}
for value in &type_def.enum_values {
if let Some(desc) = &value.description {
code.push_str(&format!(" # {desc}\n"));
}
code.push_str(&format!(" {} = \"{}\"\n", value.name, value.name));
}
code.push_str("\n\n");
}
}
for type_def in schema.types.values() {
if type_def.kind == TypeKind::InputObject {
code.push_str(&format!(
"class {}(Struct, frozen=True, kw_only=True):\n",
type_def.name
));
if let Some(desc) = &type_def.description {
code.push_str(&format!(
" \"\"\"{}\"\"\"\n",
escape_for_docstring(desc, EscapeContext::Python)
));
} else {
code.push_str(&format!(" \"\"\"GraphQL input type {}.\"\"\"\n", type_def.name));
}
if type_def.input_fields.is_empty() {
code.push_str(" pass\n");
} else {
for field in &type_def.input_fields {
if let Some(desc) = &field.description {
code.push_str(&format!(" # {desc}\n"));
}
let py_type = mapper.map_type_with_list_nullability(
&field.type_name,
field.is_nullable,
field.is_list,
field.list_item_nullable,
);
code.push_str(&format!(" {}: {}\n", field.name, py_type));
}
}
code.push_str("\n\n");
} else if type_def.kind == TypeKind::Object
&& type_def.name != "Query"
&& type_def.name != "Mutation"
&& type_def.name != "Subscription"
{
code.push_str(&format!(
"class {}(Struct, frozen=True, kw_only=True):\n",
type_def.name
));
if let Some(desc) = &type_def.description {
code.push_str(&format!(
" \"\"\"{}\"\"\"\n",
escape_for_docstring(desc, EscapeContext::Python)
));
} else {
code.push_str(&format!(" \"\"\"GraphQL object type {}.\"\"\"\n", type_def.name));
}
if type_def.fields.is_empty() {
code.push_str(" pass\n");
} else {
for field in &type_def.fields {
if let Some(desc) = &field.description {
code.push_str(&format!(" # {desc}\n"));
}
let py_type = mapper.map_type_with_list_nullability(
&field.type_name,
field.is_nullable,
field.is_list,
field.list_item_nullable,
);
code.push_str(&format!(" {}: {}\n", field.name, py_type));
}
}
code.push_str("\n\n");
} else if type_def.kind == TypeKind::Union {
let members = type_def.possible_types.join(" | ");
code.push_str(&format!("{}: TypeAlias = \"{}\"\n", type_def.name, members));
code.push('\n');
}
}
Ok(code)
}
fn generate_resolvers(&self, schema: &GraphQLSchema) -> Result<String> {
let mut code = String::new();
code.push_str("#!/usr/bin/env python3\n");
code.push_str("# ruff: noqa: EXE001\n");
code.push_str("# DO NOT EDIT - Auto-generated by Spikard CLI\n");
code.push_str("#\n");
code.push_str("# This file was automatically generated from your GraphQL schema.\n");
code.push_str("# Any manual changes will be overwritten on the next generation.\n");
code.push_str("\"\"\"GraphQL resolver functions.\"\"\"\n\n");
code.push_str("from __future__ import annotations\n\n");
if !schema.subscriptions.is_empty() {
code.push_str("from collections.abc import AsyncIterator\n");
}
code.push_str("from typing import TYPE_CHECKING\n\n");
code.push_str("if TYPE_CHECKING:\n");
code.push_str(" from graphql import GraphQLResolveInfo\n");
let mut used_types: std::collections::HashSet<String> = std::collections::HashSet::new();
for query in &schema.queries {
if let Some(type_name) = extract_base_type_name(&query.type_name)
&& is_custom_type(&type_name, schema)
{
used_types.insert(type_name);
}
for arg in &query.arguments {
if let Some(type_name) = extract_base_type_name(&arg.type_name)
&& is_custom_type(&type_name, schema)
{
used_types.insert(type_name);
}
}
}
for mutation in &schema.mutations {
if let Some(type_name) = extract_base_type_name(&mutation.type_name)
&& is_custom_type(&type_name, schema)
{
used_types.insert(type_name);
}
for arg in &mutation.arguments {
if let Some(type_name) = extract_base_type_name(&arg.type_name)
&& is_custom_type(&type_name, schema)
{
used_types.insert(type_name);
}
}
}
for subscription in &schema.subscriptions {
if let Some(type_name) = extract_base_type_name(&subscription.type_name)
&& is_custom_type(&type_name, schema)
{
used_types.insert(type_name);
}
for arg in &subscription.arguments {
if let Some(type_name) = extract_base_type_name(&arg.type_name)
&& is_custom_type(&type_name, schema)
{
used_types.insert(type_name);
}
}
}
if !used_types.is_empty() {
let mut sorted_types: Vec<_> = used_types.iter().collect();
sorted_types.sort();
let types_list = sorted_types.iter().map(|s| s.as_str()).collect::<Vec<_>>().join(", ");
code.push_str(&format!("from .types import {types_list}\n"));
}
code.push('\n');
let _mapper = TypeMapper::new(TargetLanguage::Python, Some(schema));
let format_resolver =
|name: &str, field: &crate::codegen::graphql::spec_parser::GraphQLField, schema: &GraphQLSchema| {
let mut sig = format!(
"async def resolve_{}(parent: dict[str, object], info: GraphQLResolveInfo",
to_snake_case(name)
);
let needs_builtin_noqa = field.arguments.iter().any(|arg| is_python_builtin_name(&arg.name));
let mapper = TypeMapper::new(TargetLanguage::Python, Some(schema));
for arg in &field.arguments {
let arg_type = mapper.map_type_with_list_nullability(
&arg.type_name,
arg.is_nullable,
arg.is_list,
arg.list_item_nullable,
);
sig.push_str(&format!(", {}: {}", arg.name, arg_type));
}
let py_type = mapper.map_type_with_list_nullability(
&field.type_name,
field.is_nullable,
field.is_list,
field.list_item_nullable,
);
sig.push_str(&format!(") -> {py_type}:"));
if needs_builtin_noqa {
sig.push_str(" # noqa: A002");
}
sig
};
if !schema.queries.is_empty() {
code.push_str("# Query resolvers\n\n");
for field in &schema.queries {
code.push_str(&format_resolver(&field.name, field, schema));
code.push('\n');
code.push_str(" \"\"\"Resolve query field.\"\"\"\n");
code.push_str(" raise NotImplementedError\n\n");
}
code.push('\n');
}
if !schema.mutations.is_empty() {
code.push_str("# Mutation resolvers\n\n");
for field in &schema.mutations {
code.push_str(&format_resolver(&field.name, field, schema));
code.push('\n');
code.push_str(" \"\"\"Resolve mutation field.\"\"\"\n");
code.push_str(" raise NotImplementedError\n\n");
}
}
if !schema.subscriptions.is_empty() {
code.push_str("\n# Subscription resolvers\n\n");
for field in &schema.subscriptions {
let mapper = TypeMapper::new(TargetLanguage::Python, Some(schema));
let field_type = mapper.map_type_with_list_nullability(
&field.type_name,
field.is_nullable,
field.is_list,
field.list_item_nullable,
);
let needs_builtin_noqa = field.arguments.iter().any(|arg| is_python_builtin_name(&arg.name));
let mut source_signature = format!(
"async def subscribe_{}(parent: dict[str, object], info: GraphQLResolveInfo",
to_snake_case(&field.name)
);
for arg in &field.arguments {
let arg_type = mapper.map_type_with_list_nullability(
&arg.type_name,
arg.is_nullable,
arg.is_list,
arg.list_item_nullable,
);
source_signature.push_str(&format!(", {}: {}", arg.name, arg_type));
}
source_signature.push_str(&format!(") -> AsyncIterator[{field_type}]:"));
if needs_builtin_noqa {
source_signature.push_str(" # noqa: A002");
}
code.push_str(&source_signature);
code.push('\n');
code.push_str(" \"\"\"Stream subscription events.\"\"\"\n");
code.push_str(" raise NotImplementedError\n\n");
let mut resolver_signature = format!(
"async def resolve_{}(value: {}, info: GraphQLResolveInfo",
to_snake_case(&field.name),
field_type
);
for arg in &field.arguments {
let arg_type = mapper.map_type_with_list_nullability(
&arg.type_name,
arg.is_nullable,
arg.is_list,
arg.list_item_nullable,
);
resolver_signature.push_str(&format!(", {}: {}", arg.name, arg_type));
}
resolver_signature.push_str(&format!(") -> {field_type}:"));
if needs_builtin_noqa {
resolver_signature.push_str(" # noqa: A002");
}
code.push_str(&resolver_signature);
code.push('\n');
code.push_str(" \"\"\"Resolve a streamed subscription event.\"\"\"\n");
code.push_str(" raise NotImplementedError\n\n");
}
}
Ok(code)
}
fn generate_schema_definition(&self, schema: &GraphQLSchema) -> Result<String> {
let mut code = String::new();
code.push_str("#!/usr/bin/env python3\n");
code.push_str("# ruff: noqa: EXE001\n");
code.push_str("# DO NOT EDIT - Auto-generated by Spikard CLI\n");
code.push_str("#\n");
code.push_str("# This file was automatically generated from your GraphQL schema.\n");
code.push_str("# Any manual changes will be overwritten on the next generation.\n");
code.push_str("\"\"\"GraphQL Schema Definition.\"\"\"\n\n");
code.push_str("from __future__ import annotations\n\n");
let mut ariadne_imports = vec!["make_executable_schema", "QueryType"];
if !schema.mutations.is_empty() {
ariadne_imports.push("MutationType");
}
if !schema.subscriptions.is_empty() {
ariadne_imports.push("SubscriptionType");
}
code.push_str(&format!("from ariadne import {}\n\n", ariadne_imports.join(", ")));
let sdl = SdlBuilder::new(schema).build();
code.push_str("# GraphQL Schema Definition Language (SDL)\n");
code.push_str("#\n");
code.push_str("# Defines all types, queries, mutations, and subscriptions\n");
code.push_str("# in the GraphQL schema.\n");
code.push_str("type_defs = \"\"\"\n");
for line in sdl.lines() {
if line.is_empty() {
code.push('\n');
} else {
code.push_str(" ");
code.push_str(line);
code.push('\n');
}
}
code.push_str("\"\"\"\n\n");
code.push_str("# Query resolvers\n");
code.push_str("query = QueryType()\n\n");
code.push_str("def _register_query_resolvers() -> None:\n");
if schema.queries.is_empty() {
code.push_str(" return\n\n");
} else {
for field in &schema.queries {
let resolver_name = format!("resolve_{}", to_snake_case(&field.name));
code.push_str(&format!(" resolver = globals().get(\"{resolver_name}\")\n"));
code.push_str(" if resolver is not None:\n");
code.push_str(&format!(" query.set_field(\"{}\", resolver)\n", field.name));
}
code.push('\n');
}
code.push_str("_register_query_resolvers()\n\n");
if !schema.mutations.is_empty() {
code.push_str("# Mutation resolvers\n");
code.push_str("mutation = MutationType()\n\n");
code.push_str("def _register_mutation_resolvers() -> None:\n");
for field in &schema.mutations {
let resolver_name = format!("resolve_{}", to_snake_case(&field.name));
code.push_str(&format!(" resolver = globals().get(\"{resolver_name}\")\n"));
code.push_str(" if resolver is not None:\n");
code.push_str(&format!(" mutation.set_field(\"{}\", resolver)\n", field.name));
}
code.push('\n');
code.push_str("_register_mutation_resolvers()\n\n");
}
if !schema.subscriptions.is_empty() {
code.push_str("# Subscription resolvers\n");
code.push_str("subscription = SubscriptionType()\n\n");
code.push_str("def _register_subscription_resolvers() -> None:\n");
for field in &schema.subscriptions {
let source_name = format!("subscribe_{}", to_snake_case(&field.name));
let resolver_name = format!("resolve_{}", to_snake_case(&field.name));
code.push_str(&format!(" source = globals().get(\"{source_name}\")\n"));
code.push_str(" if source is not None:\n");
code.push_str(&format!(
" subscription.set_source(\"{}\", source)\n",
field.name
));
code.push_str(&format!(" resolver = globals().get(\"{resolver_name}\")\n"));
code.push_str(" if resolver is not None:\n");
code.push_str(&format!(
" subscription.set_field(\"{}\", resolver)\n",
field.name
));
}
code.push('\n');
code.push_str("_register_subscription_resolvers()\n\n");
}
code.push_str("# Executable GraphQL Schema\n");
code.push_str("#\n");
code.push_str("# Combines the type definitions with resolvers to create\n");
code.push_str("# a fully functional GraphQL schema ready for use with\n");
code.push_str("# Ariadne GraphQL or similar frameworks.\n");
let mut resolvers = vec!["query".to_string()];
if !schema.mutations.is_empty() {
resolvers.push("mutation".to_string());
}
if !schema.subscriptions.is_empty() {
resolvers.push("subscription".to_string());
}
code.push_str("schema = make_executable_schema(type_defs, ");
code.push_str(&resolvers.join(", "));
code.push_str(")\n\n");
code.push_str("# Exported for advanced use cases where the SDL\n");
code.push_str("# string might be needed directly.\n");
code.push_str("__all__ = [\"schema\", \"type_defs\"]\n");
Ok(code)
}
fn generate_complete(&self, schema: &GraphQLSchema) -> Result<String> {
let types = self.generate_types(schema)?;
let resolvers = self.generate_resolvers(schema)?;
let schema_def = self.generate_schema_definition(schema)?;
fn extract_header_imports_and_code(s: &str) -> (Vec<String>, Vec<String>, Vec<String>) {
let mut header_lines: Vec<String> = Vec::new(); let mut imports: Vec<String> = Vec::new();
let mut code: Vec<String> = Vec::new();
let mut in_header_docstring = false;
let mut in_type_checking_block = false;
let mut past_header = false;
let mut found_import_section = false;
for line in s.lines() {
let trimmed = line.trim();
if !past_header {
if trimmed.starts_with("\"\"\"") {
if trimmed.len() >= 6 && trimmed.ends_with("\"\"\"") && !trimmed.starts_with("\"\"\"\"") {
header_lines.push(line.to_string());
continue;
}
header_lines.push(line.to_string());
in_header_docstring = !in_header_docstring;
continue;
}
if in_header_docstring {
header_lines.push(line.to_string());
continue;
}
if trimmed.starts_with("#!/")
|| trimmed.starts_with("# ruff:")
|| trimmed.starts_with("# DO NOT EDIT")
|| trimmed == "#"
|| (trimmed.starts_with("# This file") && trimmed.contains("generated"))
|| (trimmed.starts_with("# Any manual"))
{
header_lines.push(line.to_string());
continue;
}
if trimmed.is_empty() {
continue;
}
past_header = true;
}
if in_type_checking_block {
if trimmed.is_empty() || line.starts_with(' ') || line.starts_with('\t') {
code.push(line.to_string());
if trimmed.is_empty() {
in_type_checking_block = false;
}
continue;
}
in_type_checking_block = false;
}
if trimmed.starts_with("from __future__") {
if !found_import_section {
imports.push(line.to_string());
found_import_section = true;
}
continue;
}
if trimmed == "if TYPE_CHECKING:" {
code.push(line.to_string());
in_type_checking_block = true;
found_import_section = true;
continue;
}
if trimmed.starts_with("import ") || trimmed.starts_with("from ") {
imports.push(line.to_string());
found_import_section = true;
} else {
if trimmed.is_empty() && !found_import_section && imports.is_empty() {
continue;
}
code.push(line.to_string());
}
}
(header_lines, imports, code)
}
let (types_header, types_imports, types_code) = extract_header_imports_and_code(&types);
let (_resolvers_header, resolvers_imports, resolvers_code) = extract_header_imports_and_code(&resolvers);
let (_schema_def_header, schema_def_imports, schema_def_code) = extract_header_imports_and_code(&schema_def);
let mut all_imports: Vec<String> = types_imports;
for imp in resolvers_imports.iter().chain(schema_def_imports.iter()) {
let trimmed = imp.trim();
if trimmed.starts_with("from .types") {
continue;
}
if !all_imports.contains(imp) {
all_imports.push(imp.clone());
}
}
let mut result = String::new();
for header_line in types_header {
result.push_str(&header_line);
result.push('\n');
}
for imp in &all_imports {
result.push_str(imp);
result.push('\n');
}
if !types_code.is_empty() {
result.push('\n');
for line in types_code {
result.push_str(&line);
result.push('\n');
}
}
if !resolvers_code.is_empty() {
result.push('\n');
for line in resolvers_code {
result.push_str(&line);
result.push('\n');
}
}
if !schema_def_code.is_empty() {
result.push('\n');
for line in schema_def_code {
result.push_str(&line);
result.push('\n');
}
}
result = result.trim_end().to_string();
result.push('\n');
Ok(result)
}
}
fn is_python_builtin_name(name: &str) -> bool {
matches!(
name,
"abs"
| "all"
| "any"
| "ascii"
| "bin"
| "bool"
| "bytes"
| "callable"
| "chr"
| "dict"
| "dir"
| "enumerate"
| "filter"
| "float"
| "format"
| "frozenset"
| "hash"
| "hex"
| "id"
| "input"
| "int"
| "iter"
| "len"
| "list"
| "map"
| "max"
| "min"
| "next"
| "object"
| "oct"
| "open"
| "ord"
| "pow"
| "print"
| "range"
| "repr"
| "reversed"
| "round"
| "set"
| "slice"
| "sorted"
| "str"
| "sum"
| "tuple"
| "type"
| "vars"
| "zip"
)
}
fn extract_base_type_name(type_name: &str) -> Option<String> {
let clean = type_name.trim_matches(|c| c == '!' || c == '[' || c == ']');
if clean.is_empty() {
None
} else {
Some(clean.to_string())
}
}
fn is_custom_type(type_name: &str, schema: &GraphQLSchema) -> bool {
let built_ins = [
"String", "Int", "Float", "Boolean", "ID", "DateTime", "Date", "Time", "JSON", "Upload",
];
if built_ins.contains(&type_name) {
return false;
}
schema.types.contains_key(type_name)
}