use crate::codegen::doc_emission::doc_first_paragraph_joined;
use crate::codegen::generators;
use crate::codegen::shared::binding_fields;
use crate::core::config::{DtoConfig, PythonDtoStyle};
use crate::core::hash::{self, CommentStyle};
use crate::core::ir::ApiSurface;
use ahash::AHashSet;
use super::enums::{class_name_to_docstring, sanitize_python_doc};
fn render_relative_import(module_name: &str, imports: &[String]) -> String {
crate::backends::pyo3::template_env::render(
"import_from_module.jinja",
minijinja::context! {
module_name => module_name,
imports => imports.join(", "),
},
)
}
fn render_absolute_import(module_name: &str, imports: &[String]) -> String {
crate::backends::pyo3::template_env::render(
"import_from_absolute_module.jinja",
minijinja::context! {
module_name => module_name,
imports => imports.join(", "),
},
)
}
fn is_long_import(import_statement: &str) -> bool {
import_statement.trim_end_matches('\n').len() > 88
}
fn exception_property_stub(method_name: &str) -> Option<(&'static str, &'static str, &'static str)> {
match method_name {
"status_code" => Some((
"status_code",
"int",
"HTTP status code for this error (0 means no associated status).",
)),
"is_transient" => Some((
"is_transient",
"bool",
"Returns True if the error is transient and a retry may succeed.",
)),
"error_type" => Some((
"error_type",
"str",
"Machine-readable error category string for matching and logging.",
)),
_ => None,
}
}
pub(super) fn gen_exceptions_py(api: &ApiSurface) -> String {
let mut out = String::with_capacity(1024);
let mut seen_classes: AHashSet<String> = AHashSet::new();
out.push_str(&hash::header(CommentStyle::Hash));
out.push_str("\"\"\"Exception hierarchy.\"\"\"\n\n\n");
for error in &api.errors {
if !seen_classes.insert(error.name.clone()) {
continue; }
let doc = if !error.doc.is_empty() {
let first_line = sanitize_python_doc(&doc_first_paragraph_joined(&error.doc));
if first_line.ends_with('.') {
first_line
} else {
format!("{}.", first_line)
}
} else {
class_name_to_docstring(&error.name)
};
out.push_str(&crate::backends::pyo3::template_env::render(
"exception_base_class.jinja",
minijinja::context! { name => &error.name, doc => doc },
));
for method in &error.methods {
if let Some((name, return_type, doc)) = exception_property_stub(&method.name) {
out.push_str(&crate::backends::pyo3::template_env::render(
"exception_property_stub.jinja",
minijinja::context! {
name => name,
return_type => return_type,
doc => doc,
},
));
}
}
for variant in &error.variants {
let variant_name = crate::codegen::error_gen::python_exception_name(&variant.name, &error.name);
if !seen_classes.insert(variant_name.clone()) {
continue; }
let doc = if !variant.doc.is_empty() {
let first_line = sanitize_python_doc(&doc_first_paragraph_joined(&variant.doc));
if first_line.ends_with('.') {
first_line
} else {
format!("{}.", first_line)
}
} else {
class_name_to_docstring(&variant_name)
};
out.push_str(&crate::backends::pyo3::template_env::render(
"exception_variant_class.jinja",
minijinja::context! { name => &variant_name, base => &error.name, doc => doc },
));
}
}
out
}
#[allow(clippy::too_many_arguments)]
pub(super) fn gen_init_py(
api: &ApiSurface,
module_name: &str,
version: &str,
dto: &DtoConfig,
trait_bridges: &[crate::core::config::TraitBridgeConfig],
extra_init_imports: &std::collections::BTreeMap<String, Vec<String>>,
capsule_types: &std::collections::HashMap<String, crate::core::config::CapsuleTypeConfig>,
adapters: &[crate::core::config::AdapterConfig],
opaque_types: &std::collections::HashMap<String, String>,
exclude_functions: &AHashSet<String>,
) -> String {
use crate::core::ir::TypeRef;
let mut out = String::with_capacity(1024);
out.push_str(&hash::header(CommentStyle::Hash));
out.push_str(&crate::backends::pyo3::template_env::render(
"init_header.jinja",
minijinja::context! { module_name => module_name, version => version },
));
out.push('\n');
let enum_names: AHashSet<&str> = api.enums.iter().map(|e| e.name.as_str()).collect();
let data_enum_names: AHashSet<&str> = api
.enums
.iter()
.filter(|e| generators::enum_has_data_variants(e))
.map(|e| e.name.as_str())
.collect();
let output_style = dto.python_output_style();
let mut needed_enums: Vec<String> = Vec::new();
let mut needed_data_enums: Vec<String> = Vec::new();
let mut config_types: Vec<String> = Vec::new();
let mut native_return_types: Vec<String> = Vec::new();
for typ in api.types.iter().filter(|typ| !typ.is_trait && !typ.binding_excluded) {
if typ.name.ends_with("Builder") {
continue;
}
if typ.has_default && !typ.name.ends_with("Update") && !typ.fields.is_empty() {
let is_native_return = typ.is_return_type && output_style != PythonDtoStyle::TypedDict;
if is_native_return {
native_return_types.push(typ.name.clone());
} else {
config_types.push(typ.name.clone());
}
for field in binding_fields(&typ.fields) {
let inner_name = match &field.ty {
TypeRef::Named(n) => Some(n.as_str()),
TypeRef::Optional(inner) => {
if let TypeRef::Named(n) = inner.as_ref() {
Some(n.as_str())
} else {
None
}
}
_ => None,
};
if let Some(name) = inner_name {
if data_enum_names.contains(&name) {
if !needed_data_enums.iter().any(|n| n == name) {
needed_data_enums.push(name.to_string());
}
} else if enum_names.contains(&name) && !needed_enums.contains(&name.to_string()) {
needed_enums.push(name.to_string());
}
}
}
}
}
let mut imports_from_native: Vec<String> = Vec::new();
let options_type_set: AHashSet<&str> = config_types.iter().map(|s| s.as_str()).collect();
let error_type_set: AHashSet<&str> = api.errors.iter().map(|e| e.name.as_str()).collect();
for typ in api.types.iter().filter(|t| !t.is_trait && !t.binding_excluded) {
if typ.name.ends_with("Update") || typ.name.ends_with("Builder") {
continue;
}
if error_type_set.contains(typ.name.as_str()) {
continue;
}
if options_type_set.contains(typ.name.as_str()) {
continue;
}
if native_return_types.iter().any(|n| n == &typ.name) {
continue;
}
if !needed_data_enums.iter().any(|n| n == &typ.name) {
imports_from_native.push(typ.name.clone());
}
}
for enum_def in &api.enums {
if needed_data_enums.iter().any(|n| n == &enum_def.name) {
continue;
}
if !imports_from_native.iter().any(|n| n == &enum_def.name) {
imports_from_native.push(enum_def.name.clone());
}
}
let mut imports_from_api = Vec::new();
let mut imports_from_options = Vec::new();
let mut imports_from_exceptions = Vec::new();
{
let mut names: Vec<_> = api
.functions
.iter()
.filter(|f| !exclude_functions.contains(&f.name))
.map(|f| f.name.clone())
.collect();
names.extend(crate::backends::pyo3::trait_bridge::collect_bridge_register_fns(
trait_bridges,
));
names.extend(crate::backends::pyo3::trait_bridge::collect_bridge_unregister_fns(
trait_bridges,
));
names.extend(crate::backends::pyo3::trait_bridge::collect_bridge_clear_fns(
trait_bridges,
));
names.extend(adapters.iter().map(|a| a.name.clone()));
names.sort();
names.dedup();
imports_from_api.extend(names);
}
needed_data_enums.sort();
imports_from_native.extend(needed_data_enums.iter().cloned());
native_return_types.sort();
imports_from_native.extend(native_return_types.iter().cloned());
imports_from_native.retain(|n| !capsule_types.contains_key(n));
let python_capsule_type_names: ahash::AHashSet<&str> = capsule_types.keys().map(|k| k.as_str()).collect();
imports_from_native.retain(|n| {
if opaque_types.contains_key(n) {
!python_capsule_type_names.contains(n.as_str())
} else {
true
}
});
imports_from_native.sort_by_key(|a| a.to_lowercase());
imports_from_native.dedup();
let mut opt_imports: Vec<String> = config_types.to_vec();
opt_imports.sort();
imports_from_options.extend(opt_imports);
let mut exc_names = Vec::new();
for error in &api.errors {
exc_names.push(error.name.clone());
for variant in &error.variants {
let variant_name = crate::codegen::error_gen::python_exception_name(&variant.name, &error.name);
exc_names.push(variant_name);
}
}
exc_names.sort();
exc_names.dedup();
imports_from_exceptions.extend(exc_names.clone());
if !imports_from_native.is_empty() {
let import_statement = render_relative_import(module_name, &imports_from_native);
if is_long_import(&import_statement) {
out.push_str(&crate::backends::pyo3::template_env::render(
"import_from_relative_module_header.jinja",
minijinja::context! { module_name => module_name },
));
for name in &imports_from_native {
out.push_str(&crate::backends::pyo3::template_env::render(
"trait_bridge/indented_import_item.jinja",
minijinja::context! { name => name },
));
}
out.push_str(")\n");
} else {
out.push_str(&import_statement);
}
}
if !imports_from_api.is_empty() {
let import_statement = render_relative_import("api", &imports_from_api);
if is_long_import(&import_statement) {
out.push_str("from .api import (\n");
for name in &imports_from_api {
out.push_str(&crate::backends::pyo3::template_env::render(
"trait_bridge/indented_import_item.jinja",
minijinja::context! { name => name },
));
}
out.push_str(")\n");
} else {
out.push_str(&import_statement);
}
}
if !imports_from_exceptions.is_empty() {
let import_statement = render_relative_import("exceptions", &imports_from_exceptions);
if is_long_import(&import_statement) {
out.push_str("from .exceptions import (\n");
for name in &imports_from_exceptions {
out.push_str(&crate::backends::pyo3::template_env::render(
"trait_bridge/indented_import_item.jinja",
minijinja::context! { name => name },
));
}
out.push_str(")\n");
} else {
out.push_str(&import_statement);
}
}
if !imports_from_options.is_empty() {
let import_statement = render_relative_import("options", &imports_from_options);
if is_long_import(&import_statement) {
out.push_str("from .options import (\n");
for name in &imports_from_options {
out.push_str(&crate::backends::pyo3::template_env::render(
"trait_bridge/indented_import_item.jinja",
minijinja::context! { name => name },
));
}
out.push_str(")\n");
} else {
out.push_str(&import_statement);
}
}
let mut service_owners: Vec<String> = api.services.iter().map(|s| s.name.clone()).collect();
service_owners.sort();
service_owners.dedup();
if !service_owners.is_empty() {
let import_statement = render_relative_import("service", &service_owners);
if is_long_import(&import_statement) {
out.push_str("from .service import (\n");
for name in &service_owners {
out.push_str(&crate::backends::pyo3::template_env::render(
"trait_bridge/indented_import_item.jinja",
minijinja::context! { name => name },
));
}
out.push_str(")\n");
} else {
out.push_str(&import_statement);
}
}
let mut extra_all_items: Vec<String> = Vec::new();
for (module, symbols) in extra_init_imports {
if symbols.is_empty() {
continue;
}
let import_statement = render_absolute_import(module, symbols);
if is_long_import(&import_statement) {
out.push_str(&crate::backends::pyo3::template_env::render(
"import_from_module_header.jinja",
minijinja::context! { module_name => module },
));
for name in symbols {
out.push_str(&crate::backends::pyo3::template_env::render(
"trait_bridge/indented_import_item.jinja",
minijinja::context! { name => name },
));
}
out.push_str(")\n");
} else {
out.push_str(&import_statement);
}
extra_all_items.extend(symbols.iter().cloned());
}
let mut all_items = Vec::new();
for f in &api.functions {
all_items.push(f.name.clone());
}
all_items.extend(crate::backends::pyo3::trait_bridge::collect_bridge_register_fns(
trait_bridges,
));
all_items.extend(crate::backends::pyo3::trait_bridge::collect_bridge_unregister_fns(
trait_bridges,
));
all_items.extend(crate::backends::pyo3::trait_bridge::collect_bridge_clear_fns(
trait_bridges,
));
all_items.extend(adapters.iter().map(|a| a.name.clone()));
all_items.extend(needed_enums);
all_items.extend(imports_from_native.iter().cloned());
all_items.extend(config_types);
all_items.extend(exc_names);
all_items.extend(service_owners);
all_items.extend(extra_all_items);
all_items.sort();
all_items.dedup();
all_items.retain(|n| {
if opaque_types.contains_key(n) {
!python_capsule_type_names.contains(n.as_str())
} else {
true
}
});
out.push_str("\n__all__ = [\n");
for name in &all_items {
out.push_str(&crate::backends::pyo3::template_env::render(
"init_all_entry.jinja",
minijinja::context! { name => name },
));
}
out.push_str("]\n\n");
out.push_str(&crate::backends::pyo3::template_env::render(
"version_declaration.jinja",
minijinja::context! { version => version },
));
out
}
#[cfg(test)]
mod tests {
use super::{gen_exceptions_py, gen_init_py};
use crate::core::config::DtoConfig;
use crate::core::ir::ApiSurface;
fn empty_api() -> ApiSurface {
ApiSurface {
crate_name: "test-lib".to_string(),
version: "0.1.0".to_string(),
types: vec![],
functions: vec![],
enums: vec![],
errors: vec![],
excluded_type_paths: ::std::collections::HashMap::new(),
excluded_trait_names: ::std::collections::HashSet::new(),
services: vec![],
handler_contracts: vec![],
unsupported_public_items: Vec::new(),
}
}
#[test]
fn gen_exceptions_py_property_stubs_use_raise_not_implemented_error() {
use crate::core::ir::{ErrorDef, ErrorVariant, MethodDef, PrimitiveType, TypeRef};
let make_method = |name: &str| MethodDef {
name: name.to_string(),
params: vec![],
return_type: TypeRef::Primitive(PrimitiveType::Bool),
is_async: false,
is_static: false,
error_type: None,
doc: String::new(),
receiver: None,
sanitized: false,
trait_source: None,
returns_ref: false,
returns_cow: false,
return_newtype_wrapper: None,
has_default_impl: false,
binding_excluded: false,
binding_exclusion_reason: None,
version: Default::default(),
};
let error = ErrorDef {
name: "LibError".to_string(),
rust_path: "lib::LibError".to_string(),
original_rust_path: String::new(),
variants: vec![ErrorVariant {
name: "Io".to_string(),
message_template: None,
fields: vec![],
has_source: false,
has_from: false,
is_unit: true,
is_tuple: false,
doc: "I/O error.".to_string(),
}],
doc: "Library errors.".to_string(),
methods: vec![
make_method("status_code"),
make_method("is_transient"),
make_method("error_type"),
],
binding_excluded: false,
binding_exclusion_reason: None,
version: Default::default(),
};
let mut api = empty_api();
api.errors.push(error);
let result = gen_exceptions_py(&api);
assert!(
result.contains("raise NotImplementedError"),
"@property body must use raise NotImplementedError, got:\n{result}",
);
let occurrences = result.matches("raise NotImplementedError").count();
assert_eq!(
occurrences, 3,
"expected 3 `raise NotImplementedError` for status_code/is_transient/error_type, got {occurrences} in:\n{result}",
);
assert!(
!result.contains("\n ...\n"),
"property body must not use `...` (ruff PIE790 strips it, causing mypy empty-body), got:\n{result}",
);
}
#[test]
fn gen_exceptions_py_empty_api_produces_header_only() {
let api = empty_api();
let result = gen_exceptions_py(&api);
assert!(result.contains("Exception hierarchy"));
assert!(!result.contains("class "));
}
#[test]
fn gen_init_py_empty_api_has_version() {
let api = empty_api();
let dto = DtoConfig::default();
let extra = std::collections::BTreeMap::new();
let caps = std::collections::HashMap::new();
let adapters = vec![];
let opaque = std::collections::HashMap::new();
let result = gen_init_py(
&api,
"_mod",
"1.2.3",
&dto,
&[],
&extra,
&caps,
&adapters,
&opaque,
&ahash::AHashSet::new(),
);
assert!(result.contains("__version__ = \"1.2.3\""));
assert!(result.contains("__all__"));
}
#[test]
fn gen_init_py_dedups_overlapping_exception_names() {
use crate::core::ir::{ErrorDef, ErrorVariant};
let make_variant = |name: &str| ErrorVariant {
name: name.to_string(),
message_template: None,
fields: vec![],
has_source: false,
has_from: false,
is_unit: true,
is_tuple: false,
doc: String::new(),
};
let make_error = |name: &str, variants: Vec<&str>| ErrorDef {
name: name.to_string(),
rust_path: format!("lib::{name}"),
original_rust_path: String::new(),
variants: variants.into_iter().map(make_variant).collect(),
doc: String::new(),
methods: vec![],
binding_excluded: false,
binding_exclusion_reason: None,
version: Default::default(),
};
let mut api = empty_api();
api.errors.push(make_error(
"ParseError",
vec!["Validation", "ComplexityLimitExceeded", "DepthLimitExceeded"],
));
api.errors.push(make_error(
"QueryError",
vec!["Validation", "ComplexityLimitExceeded", "DepthLimitExceeded"],
));
let dto = DtoConfig::default();
let extra = std::collections::BTreeMap::new();
let caps = std::collections::HashMap::new();
let adapters = vec![];
let opaque = std::collections::HashMap::new();
let result = gen_init_py(
&api,
"_mod",
"1.2.3",
&dto,
&[],
&extra,
&caps,
&adapters,
&opaque,
&ahash::AHashSet::new(),
);
for symbol in [
"ValidationError",
"ComplexityLimitExceededError",
"DepthLimitExceededError",
] {
let import_occurrences = result.matches(&format!(" {symbol},\n")).count();
assert_eq!(
import_occurrences, 1,
"{symbol} must be imported once, got {import_occurrences} in:\n{result}",
);
let all_occurrences = result.matches(&format!("\"{symbol}\"")).count();
assert_eq!(
all_occurrences, 1,
"{symbol} must appear once in __all__, got {all_occurrences} in:\n{result}",
);
}
}
#[test]
fn gen_init_py_extra_imports_are_emitted_and_in_all() {
let api = empty_api();
let dto = DtoConfig::default();
let mut extra = std::collections::BTreeMap::new();
extra.insert(
"._supported_languages".to_string(),
vec!["SupportedLanguage".to_string()],
);
let caps = std::collections::HashMap::new();
let adapters = vec![];
let opaque = std::collections::HashMap::new();
let result = gen_init_py(
&api,
"_mod",
"1.2.3",
&dto,
&[],
&extra,
&caps,
&adapters,
&opaque,
&ahash::AHashSet::new(),
);
assert!(
result.contains("from ._supported_languages import SupportedLanguage"),
"missing import line in:\n{result}",
);
assert!(
result.contains("\"SupportedLanguage\""),
"SupportedLanguage missing from __all__ in:\n{result}",
);
}
}