use super::{pyi_docstring, python_safe_name};
use crate::backends::pyo3::type_map::python_type;
use crate::core::ir::EnumDef;
fn to_python_enum_variant(name: &str) -> String {
use heck::ToShoutySnakeCase;
crate::core::keywords::python_str_enum_ident(&name.to_shouty_snake_case())
}
pub(super) fn gen_enum_stub(enum_def: &EnumDef, emit_docstrings: bool) -> String {
use crate::codegen::generators::enum_has_data_variants;
let mut lines = vec![];
if enum_has_data_variants(enum_def) {
gen_data_enum_typeddicts(&mut lines, enum_def);
} else {
lines.push(format!("class {}:", enum_def.name));
if emit_docstrings {
if let Some(docstring) = pyi_docstring(&enum_def.doc, " ") {
lines.push(docstring);
}
}
for variant in &enum_def.variants {
lines.push(format!(
" {}: {} = ...",
to_python_enum_variant(&variant.name),
enum_def.name
));
if emit_docstrings {
if let Some(docstring) = pyi_docstring(&variant.doc, " ") {
lines.push(docstring);
}
}
}
lines.push(" def __init__(self, value: int | str) -> None: ...".to_string());
}
lines.join("\n")
}
fn gen_data_enum_typeddicts(lines: &mut Vec<String>, enum_def: &EnumDef) {
let tag_field = enum_def.serde_tag.as_deref().unwrap_or("type");
let rename_all = enum_def.serde_rename_all.as_deref();
let mut variant_class_names = vec![];
for variant in &enum_def.variants {
let class_name = format!("{}{}Variant", enum_def.name, variant.name);
variant_class_names.push(class_name.clone());
let tag_value =
crate::codegen::naming::wire_variant_value(&variant.name, variant.serde_rename.as_deref(), rename_all);
lines.push(format!("class {}(TypedDict):", class_name));
lines.push(format!(" {}: Literal[\"{}\"]", tag_field, tag_value));
for field in &variant.fields {
let field_type = python_type(&field.ty);
let field_type = if field.optional && !field_type.contains("| None") {
format!("{} | None", field_type)
} else {
field_type
};
lines.push(format!(" {}: {}", python_safe_name(&field.name), field_type));
}
if variant.fields.is_empty() && lines.last().is_some_and(|l| l.ends_with("):")) {
}
lines.push("".to_string());
}
lines.push(format!("class {}:", enum_def.name));
lines.push(format!(" {}: str", tag_field));
lines.push(" def __str__(self) -> str: ... # noqa: PYI029".to_string());
lines.push(" def __repr__(self) -> str: ... # noqa: PYI029".to_string());
}