use crate::backends::pyo3::gen_bindings::enums::sanitize_python_doc;
use crate::backends::pyo3::type_map::python_type;
use crate::codegen::shared::binding_fields;
use crate::core::config::workspace::ClientConstructorConfig;
use crate::core::config::{AdapterPattern, Language, ResolvedCrateConfig, TraitBridgeConfig};
use crate::core::hash::{self, CommentStyle};
use crate::core::ir::{ApiSurface, EnumDef, FunctionDef, MethodDef, TypeDef, TypeRef};
fn pyi_docstring(doc: &str, indent: &str) -> Option<String> {
let trimmed = doc.trim();
if trimmed.is_empty() {
return None;
}
let first_paragraph = trimmed.split("\n\n").next().unwrap_or(trimmed);
let joined: String = first_paragraph
.lines()
.map(|l| l.trim().trim_start_matches("///").trim())
.filter(|l| !l.is_empty())
.collect::<Vec<_>>()
.join(" ");
if joined.is_empty() {
return None;
}
let sanitized = sanitize_python_doc(&joined);
let escaped = sanitized.replace('\\', "\\\\").replace("\"\"\"", "\\\"\\\"\\\"");
Some(format!("{indent}\"\"\"{escaped}\"\"\""))
}
fn python_safe_name(name: &str) -> String {
crate::core::keywords::python_ident(name)
}
pub fn is_python_builtin_name(name: &str) -> bool {
const BUILTINS: &[&str] = &[
"id",
"type",
"input",
"hash",
"format",
"dir",
"help",
"list",
"map",
"filter",
"range",
"set",
"dict",
"str",
"int",
"float",
"bool",
"bytes",
"tuple",
"len",
"max",
"min",
"sum",
"abs",
"all",
"any",
"print",
"open",
"next",
"iter",
"vars",
"zip",
"object",
"property",
"super",
"staticmethod",
"classmethod",
"compile",
"exec",
"eval",
"license",
"credits",
"copyright",
];
BUILTINS.contains(&name)
}
fn constructor_rust_type_to_python(rust_type: &str) -> &str {
match rust_type {
"String" | "&str" | "&'static str" | "std::string::String" => "str",
"bytes::Bytes" | "Vec<u8>" | "&[u8]" => "bytes",
"bool" => "bool",
"i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32" | "u64" | "u128" | "usize" => "int",
"f32" | "f64" => "float",
"()" => "None",
_ => "Any",
}
}
fn constructor_param_type(ty: &TypeRef, api: &ApiSurface) -> String {
use crate::codegen::generators::enum_has_data_variants;
let enum_names: std::collections::HashSet<String> = api.enums.iter().map(|e| e.name.clone()).collect();
let data_enum_names: std::collections::HashSet<String> = api
.enums
.iter()
.filter(|e| enum_has_data_variants(e))
.map(|e| e.name.clone())
.collect();
match ty {
TypeRef::Named(name) if data_enum_names.contains(name) => name.clone(),
TypeRef::Named(name) if enum_names.contains(name) => format!("{} | str", name),
TypeRef::Optional(inner) => match inner.as_ref() {
TypeRef::Named(name) if data_enum_names.contains(name) => {
format!("{} | None", name)
}
TypeRef::Named(name) if enum_names.contains(name) => format!("{} | str | None", name),
_ => python_type(ty),
},
_ => python_type(ty),
}
}
pub fn gen_stubs(
api: &ApiSurface,
trait_bridges: &[TraitBridgeConfig],
config: &ResolvedCrateConfig,
exclude_functions: &ahash::AHashSet<String>,
) -> String {
let header = hash::header(CommentStyle::Hash);
let mut header_lines: Vec<String> = header.lines().map(str::to_string).collect();
header_lines.push("".to_string());
let bridge_param_names: std::collections::HashSet<&str> =
trait_bridges.iter().filter_map(|b| b.param_name.as_deref()).collect();
let options_field_bridges: std::collections::HashMap<&str, (&str, Option<&str>)> = trait_bridges
.iter()
.filter(|b| b.bind_via == crate::core::config::BridgeBinding::OptionsField)
.filter_map(|b| {
let options_type = b.options_type.as_deref()?;
let param_name = b.param_name.as_deref()?;
let type_alias = b.type_alias.as_deref();
Some((options_type, (param_name, type_alias)))
})
.collect();
let streaming_return_types: std::collections::HashMap<(Option<String>, String), String> = config
.adapters
.iter()
.filter(|a| matches!(a.pattern, AdapterPattern::Streaming))
.map(|a| {
let item = a.item_type.as_deref().unwrap_or("Any").to_string();
((a.owner_type.clone(), a.name.clone()), item)
})
.collect();
let capsule_names: std::collections::HashSet<&str> = config
.python
.as_ref()
.map(|p| p.capsule_types.keys().map(String::as_str).collect())
.unwrap_or_default();
let emit_docstrings = config
.python
.as_ref()
.and_then(|p| p.stubs.as_ref())
.is_some_and(|s| s.emit_docstrings);
let (opaque, non_opaque): (Vec<_>, Vec<_>) = api
.types
.iter()
.filter(|typ| !typ.is_trait && !typ.binding_excluded)
.partition(|typ| typ.is_opaque);
let mut body_lines: Vec<String> = Vec::new();
for typ in &non_opaque {
body_lines.push(gen_type_stub(
typ,
api,
config,
&capsule_names,
&options_field_bridges,
emit_docstrings,
&streaming_return_types,
));
body_lines.push("".to_string());
}
let opaque_non_capsule: Vec<_> = opaque
.iter()
.filter(|typ| !capsule_names.contains(typ.name.as_str()))
.collect();
if !opaque_non_capsule.is_empty() {
for typ in &opaque_non_capsule {
let ctor = config.client_constructors.get(&typ.name);
body_lines.push(gen_opaque_type_stub(typ, &capsule_names, &streaming_return_types, ctor));
}
body_lines.push("".to_string());
}
for enum_def in &api.enums {
body_lines.push(gen_enum_stub(enum_def, emit_docstrings));
body_lines.push("".to_string());
}
for func in api.functions.iter().filter(|f| !exclude_functions.contains(&f.name)) {
body_lines.push(gen_function_stub(
func,
&bridge_param_names,
&capsule_names,
&options_field_bridges,
&streaming_return_types,
));
}
{
use heck::ToSnakeCase as _;
for service in &api.services {
let service_snake = service.name.to_snake_case();
for ep in &service.entrypoints {
let func_name = format!("{service_snake}_{}", ep.method);
let return_annot = match &ep.return_type {
TypeRef::Unit => "None".to_string(),
_ => "Any".to_string(),
};
body_lines.push(format!(
"def {func_name}(registrations: list[Any]) -> {return_annot}: ..."
));
}
}
}
for bridge in trait_bridges {
if let Some(register_fn) = bridge.register_fn.as_deref() {
body_lines.push(format!("def {register_fn}(backend: object) -> None: ..."));
}
if let Some(unregister_fn) = bridge.unregister_fn.as_deref() {
body_lines.push(format!("def {unregister_fn}(name: str) -> None: ..."));
}
if let Some(clear_fn) = bridge.clear_fn.as_deref() {
body_lines.push(format!("def {clear_fn}() -> None: ..."));
}
}
let body_joined = body_lines.join("\n");
let used_typing: Vec<&str> = ["Any", "AsyncIterator", "Literal", "TypeAlias", "TypedDict"]
.iter()
.copied()
.filter(|name| contains_word(&body_joined, name))
.collect();
let mut lines = header_lines;
if !used_typing.is_empty() {
lines.push(format!("from typing import {}", used_typing.join(", ")));
lines.push("".to_string());
}
lines.extend(body_lines);
lines.join("\n")
}
fn contains_word(text: &str, word: &str) -> bool {
let bytes = text.as_bytes();
let needle = word.as_bytes();
let is_ident = |b: u8| b.is_ascii_alphanumeric() || b == b'_';
let mut start = 0;
while let Some(idx) = text[start..].find(word) {
let pos = start + idx;
let before_ok = pos == 0 || !is_ident(bytes[pos - 1]);
let after_pos = pos + needle.len();
let after_ok = after_pos == bytes.len() || !is_ident(bytes[after_pos]);
if before_ok && after_ok {
return true;
}
start = pos + 1;
}
false
}
fn substitute_capsule_type(type_str: &str, capsule_names: &std::collections::HashSet<&str>) -> String {
let mut result = type_str.to_string();
for name in capsule_names {
let list_pattern = format!("list[{name}]");
if result.contains(&list_pattern) {
result = result.replace(&list_pattern, "list[Any]");
continue;
}
let optional_pattern = format!("{name} | None");
if result.contains(&optional_pattern) {
result = result.replace(&optional_pattern, "Any");
continue;
}
let is_ident = |b: u8| b.is_ascii_alphanumeric() || b == b'_';
let needle = name.as_bytes();
let bytes = result.as_bytes();
let mut out = String::new();
let mut start = 0usize;
while let Some(idx) = result[start..].find(name) {
let pos = start + idx;
let before_ok = pos == 0 || !is_ident(bytes[pos - 1]);
let after_pos = pos + needle.len();
let after_ok = after_pos == bytes.len() || !is_ident(bytes[after_pos]);
if before_ok && after_ok {
out.push_str(&result[start..pos]);
out.push_str("Any");
start = after_pos;
} else {
out.push_str(&result[start..=pos]);
start = pos + 1;
}
}
out.push_str(&result[start..]);
result = out;
}
result
}
fn gen_opaque_type_stub(
typ: &TypeDef,
capsule_names: &std::collections::HashSet<&str>,
streaming_return_types: &std::collections::HashMap<(Option<String>, String), String>,
ctor: Option<&ClientConstructorConfig>,
) -> String {
let mut lines = vec![];
lines.push(format!("class {}:", typ.name));
if let Some(ctor) = ctor {
let mut params: Vec<String> = ctor
.params
.iter()
.map(|p| {
let py_type = constructor_rust_type_to_python(&p.ty);
format!("{}: {}", p.name, py_type)
})
.collect();
let single = format!(" def __init__(self, {}) -> None: ...", params.join(", "));
if single.len() <= 100 {
lines.push(single);
} else {
let mut wrapped = String::from(" def __init__(\n self,\n");
for param in &mut params {
wrapped.push_str(&crate::backends::pyo3::template_env::render(
"stub_wrapped_param_line.jinja",
minijinja::context! { param => param },
));
}
wrapped.push_str(" ) -> None: ...");
lines.push(wrapped);
}
}
for method in &typ.methods {
if !method.is_static {
lines.push(gen_method_stub(
method,
false,
capsule_names,
Some(&typ.name),
streaming_return_types,
));
}
}
for method in &typ.methods {
if method.is_static {
lines.push(gen_method_stub(
method,
true,
capsule_names,
Some(&typ.name),
streaming_return_types,
));
}
}
if typ.methods.is_empty() && ctor.is_none() {
return format!("class {}: ...", typ.name);
}
lines.join("\n")
}
fn gen_type_stub(
typ: &TypeDef,
api: &ApiSurface,
config: &ResolvedCrateConfig,
capsule_names: &std::collections::HashSet<&str>,
options_field_bridges: &std::collections::HashMap<&str, (&str, Option<&str>)>,
emit_docstrings: bool,
streaming_return_types: &std::collections::HashMap<(Option<String>, String), String>,
) -> String {
let mut lines = vec![];
lines.push(format!("class {}:", typ.name));
if emit_docstrings {
if let Some(docstring) = pyi_docstring(&typ.doc, " ") {
lines.push(docstring);
}
}
for field in binding_fields(&typ.fields) {
let type_str = python_type(&field.ty);
let is_optional_duration = typ.has_default && matches!(field.ty, TypeRef::Duration) && !field.optional;
let field_type = if (is_optional_duration || field.optional) && !type_str.contains("| None") {
format!("{} | None", type_str)
} else {
type_str
};
let stub_field_name = config
.resolve_field_name(Language::Python, &typ.name, &field.name)
.unwrap_or_else(|| field.name.clone());
lines.push(format!(" {stub_field_name}: {field_type}"));
if emit_docstrings {
if let Some(docstring) = pyi_docstring(&field.doc, " ") {
lines.push(docstring);
}
}
}
lines.push(gen_type_init_stub(typ, api, config, options_field_bridges));
for method in &typ.methods {
if !method.is_static {
lines.push(gen_method_stub(
method,
false,
capsule_names,
Some(&typ.name),
streaming_return_types,
));
}
}
for method in &typ.methods {
if method.is_static {
lines.push(gen_method_stub(
method,
true,
capsule_names,
Some(&typ.name),
streaming_return_types,
));
}
}
lines.join("\n")
}
fn gen_type_init_stub(
typ: &TypeDef,
api: &ApiSurface,
config: &ResolvedCrateConfig,
options_field_bridges: &std::collections::HashMap<&str, (&str, Option<&str>)>,
) -> String {
let (required, optional): (Vec<_>, Vec<_>) =
binding_fields(&typ.fields).filter(|f| f.cfg.is_none()).partition(|f| {
if typ.has_default {
return false;
}
let is_optional_duration = matches!(f.ty, TypeRef::Duration) && !f.optional;
!f.optional && !is_optional_duration
});
let mut params: Vec<String> = required
.iter()
.map(|f| {
let param_type = constructor_param_type(&f.ty, api);
let param_name = config
.resolve_field_name(Language::Python, &typ.name, &f.name)
.unwrap_or_else(|| f.name.clone());
format!("{param_name}: {param_type}")
})
.collect();
params.extend(optional.iter().map(|f| {
let type_str = constructor_param_type(&f.ty, api);
let param_type = if !type_str.ends_with("| None") {
format!("{} | None", type_str)
} else {
type_str
};
let param_name = config
.resolve_field_name(Language::Python, &typ.name, &f.name)
.unwrap_or_else(|| f.name.clone());
format!("{param_name}: {param_type} = None")
}));
if let Some((kwarg_name, type_alias)) = options_field_bridges.get(typ.name.as_str()) {
let visitor_type = type_alias.unwrap_or("object");
params.push(format!("{kwarg_name}: {visitor_type} | None = None"));
}
let has_builtin_param = params
.iter()
.any(|p| is_python_builtin_name(p.split(':').next().unwrap_or("").trim()));
let single_line = format!(" def __init__(self, {}) -> None: ...", params.join(", "));
if single_line.len() <= 100 && !has_builtin_param {
single_line
} else {
let mut wrapped = String::from(" def __init__(\n");
wrapped.push_str(" self,\n");
for param in ¶ms {
let name = param.split(':').next().unwrap_or("").trim();
if is_python_builtin_name(name) {
wrapped.push_str(&crate::backends::pyo3::template_env::render(
"stub_param_wrapped_noqa.jinja",
minijinja::context! { param => param, indent => " " },
));
} else {
wrapped.push_str(&crate::backends::pyo3::template_env::render(
"stub_param_wrapped.jinja",
minijinja::context! { param => param, indent => " " },
));
}
}
wrapped.push_str(" ) -> None: ...");
wrapped
}
}
fn gen_method_stub(
method: &MethodDef,
is_static: bool,
capsule_names: &std::collections::HashSet<&str>,
owner_type: Option<&str>,
streaming_return_types: &std::collections::HashMap<(Option<String>, String), String>,
) -> String {
let (required, optional): (Vec<_>, Vec<_>) = method.params.iter().partition(|p| !p.optional);
let mut params: Vec<String> = required
.iter()
.map(|p| {
let param_type = substitute_capsule_type(&python_type(&p.ty), capsule_names);
format!("{}: {}", p.name, param_type)
})
.collect();
params.extend(optional.iter().map(|p| {
let type_str = substitute_capsule_type(&python_type(&p.ty), capsule_names);
let param_type = if !type_str.ends_with("| None") {
format!("{} | None", type_str)
} else {
type_str
};
format!("{}: {} = None", p.name, param_type)
}));
let streaming_key = (owner_type.map(str::to_string), method.name.clone());
let return_type = if let Some(item_type) = streaming_return_types.get(&streaming_key) {
format!("AsyncIterator[{item_type}]")
} else {
substitute_capsule_type(&python_type(&method.return_type), capsule_names)
};
let indent = " ";
let safe_name = python_safe_name(&method.name);
let def_kw = if method.is_async { "async def" } else { "def" };
let has_builtin_param = params
.iter()
.any(|p| is_python_builtin_name(p.split(':').next().unwrap_or("").trim()));
let emit_params_wrapped = |prefix: &str, suffix: &str| -> String {
let mut wrapped = format!("{prefix}\n");
for param in ¶ms {
let name = param.split(':').next().unwrap_or("").trim();
if is_python_builtin_name(name) {
wrapped.push_str(&crate::backends::pyo3::template_env::render(
"stub_param_method_wrapped_noqa.jinja",
minijinja::context! { indent => indent, param => param },
));
} else {
wrapped.push_str(&crate::backends::pyo3::template_env::render(
"stub_param_method_wrapped.jinja",
minijinja::context! { indent => indent, param => param },
));
}
}
wrapped.push_str(suffix);
wrapped
};
if is_static {
if params.is_empty() {
format!(
"{}@staticmethod\n{}{} {}() -> {}: ...",
indent, indent, def_kw, safe_name, return_type
)
} else {
let prefix = format!("{}@staticmethod\n{}{} {}(", indent, indent, def_kw, safe_name);
let suffix = format!("{}) -> {}: ...", indent, return_type);
let def_line = format!(
"{}{} {}({}) -> {}: ...",
indent,
def_kw,
safe_name,
params.join(", "),
return_type
);
if def_line.len() <= 100 && !has_builtin_param {
format!(
"{}@staticmethod\n{}{} {}({}) -> {}: ...",
indent,
indent,
def_kw,
safe_name,
params.join(", "),
return_type
)
} else {
emit_params_wrapped(&prefix, &suffix)
}
}
} else if params.is_empty() {
format!("{}{} {}(self) -> {}: ...", indent, def_kw, safe_name, return_type)
} else {
let single_line = format!(
"{}{} {}(self, {}) -> {}: ...",
indent,
def_kw,
safe_name,
params.join(", "),
return_type
);
if single_line.len() <= 100 && !has_builtin_param {
single_line
} else {
let prefix = format!("{}{} {}(\n{} self,", indent, def_kw, safe_name, indent);
let suffix = format!("{}) -> {}: ...", indent, return_type);
emit_params_wrapped(&prefix, &suffix)
}
}
}
fn to_python_enum_variant(name: &str) -> String {
use heck::ToShoutySnakeCase;
crate::core::keywords::python_str_enum_ident(&name.to_shouty_snake_case())
}
fn gen_enum_stub(enum_def: &EnumDef, emit_docstrings: bool) -> String {
use crate::codegen::generators::enum_has_data_variants;
let mut lines = vec![];
if enum_has_data_variants(enum_def) {
gen_data_enum_typeddicts(&mut lines, enum_def);
} else {
lines.push(format!("class {}:", enum_def.name));
if emit_docstrings {
if let Some(docstring) = pyi_docstring(&enum_def.doc, " ") {
lines.push(docstring);
}
}
for variant in &enum_def.variants {
lines.push(format!(
" {}: {} = ...",
to_python_enum_variant(&variant.name),
enum_def.name
));
if emit_docstrings {
if let Some(docstring) = pyi_docstring(&variant.doc, " ") {
lines.push(docstring);
}
}
}
lines.push(" def __init__(self, value: int | str) -> None: ...".to_string());
}
lines.join("\n")
}
fn gen_data_enum_typeddicts(lines: &mut Vec<String>, enum_def: &EnumDef) {
let tag_field = enum_def.serde_tag.as_deref().unwrap_or("type");
let rename_all = enum_def.serde_rename_all.as_deref();
let mut variant_class_names = vec![];
for variant in &enum_def.variants {
let class_name = format!("{}{}Variant", enum_def.name, variant.name);
variant_class_names.push(class_name.clone());
let tag_value =
crate::codegen::naming::wire_variant_value(&variant.name, variant.serde_rename.as_deref(), rename_all);
lines.push(format!("class {}(TypedDict):", class_name));
lines.push(format!(" {}: Literal[\"{}\"]", tag_field, tag_value));
for field in &variant.fields {
let field_type = python_type(&field.ty);
let field_type = if field.optional && !field_type.contains("| None") {
format!("{} | None", field_type)
} else {
field_type
};
lines.push(format!(" {}: {}", python_safe_name(&field.name), field_type));
}
if variant.fields.is_empty() && lines.last().is_some_and(|l| l.ends_with("):")) {
}
lines.push("".to_string());
}
lines.push(format!("class {}:", enum_def.name));
lines.push(format!(" {}: str", tag_field));
lines.push(" def __str__(self) -> str: ... # noqa: PYI029".to_string());
lines.push(" def __repr__(self) -> str: ... # noqa: PYI029".to_string());
}
fn gen_function_stub(
func: &FunctionDef,
bridge_param_names: &std::collections::HashSet<&str>,
capsule_names: &std::collections::HashSet<&str>,
options_field_bridges: &std::collections::HashMap<&str, (&str, Option<&str>)>,
streaming_return_types: &std::collections::HashMap<(Option<String>, String), String>,
) -> String {
let (required, optional): (Vec<_>, Vec<_>) = func.params.iter().partition(|p| !p.optional);
let mut params: Vec<String> = required
.iter()
.map(|p| {
let param_type = if bridge_param_names.contains(p.name.as_str()) {
"object".to_string()
} else {
substitute_capsule_type(&python_type(&p.ty), capsule_names)
};
format!("{}: {}", p.name, param_type)
})
.collect();
params.extend(optional.iter().map(|p| {
let type_str = if bridge_param_names.contains(p.name.as_str()) {
"object".to_string()
} else {
substitute_capsule_type(&python_type(&p.ty), capsule_names)
};
let param_type = if !type_str.ends_with("| None") {
format!("{} | None", type_str)
} else {
type_str
};
format!("{}: {} = None", p.name, param_type)
}));
let bridge_kwarg = func.params.iter().find_map(|p| {
let type_name = match &p.ty {
TypeRef::Named(n) => Some(n.as_str()),
TypeRef::Optional(inner) => match inner.as_ref() {
TypeRef::Named(n) => Some(n.as_str()),
_ => None,
},
_ => None,
}?;
let (kwarg_name, type_alias) = options_field_bridges.get(type_name)?;
Some((*kwarg_name, *type_alias))
});
if let Some((kwarg_name, type_alias)) = bridge_kwarg {
let visitor_type = type_alias.unwrap_or("object");
params.push(format!("{kwarg_name}: {visitor_type} | None = None"));
}
let streaming_key = (None::<String>, func.name.clone());
let return_type = if let Some(item_type) = streaming_return_types.get(&streaming_key) {
format!("AsyncIterator[{item_type}]")
} else {
substitute_capsule_type(&python_type(&func.return_type), capsule_names)
};
let safe_name = python_safe_name(&func.name);
let def_kw = if func.is_async { "async def" } else { "def" };
let has_builtin_param = params
.iter()
.any(|p| is_python_builtin_name(p.split(':').next().unwrap_or("").trim()));
let single_line = format!(
"{} {}({}) -> {}: ...",
def_kw,
safe_name,
params.join(", "),
return_type
);
if single_line.len() <= 100 && !has_builtin_param {
single_line
} else {
let mut wrapped = format!("{} {}(\n", def_kw, safe_name);
for param in ¶ms {
let name = param.split(':').next().unwrap_or("").trim();
if is_python_builtin_name(name) {
wrapped.push_str(&crate::backends::pyo3::template_env::render(
"stub_param_wrapped_noqa.jinja",
minijinja::context! { param => param, indent => " " },
));
} else {
wrapped.push_str(&crate::backends::pyo3::template_env::render(
"stub_param_wrapped.jinja",
minijinja::context! { param => param, indent => " " },
));
}
}
wrapped.push_str(&crate::backends::pyo3::template_env::render(
"stub_method_signature_end.jinja",
minijinja::context! { return_type => &return_type },
));
wrapped
}
}