mod classes;
mod enums;
mod functions;
mod protocol;
use crate::backends::pyo3::gen_bindings::enums::sanitize_python_doc;
use crate::backends::pyo3::type_map::python_type;
use crate::core::config::{AdapterPattern, ResolvedCrateConfig, TraitBridgeConfig};
use crate::core::hash::{self, CommentStyle};
use crate::core::ir::{ApiSurface, TypeRef};
type OptionsFieldBridges<'a> = std::collections::HashMap<&'a str, (&'a str, Option<&'a str>, Option<&'a str>)>;
use classes::{gen_opaque_type_stub, gen_type_stub};
use enums::gen_enum_stub;
use functions::gen_function_stub;
use protocol::gen_visitor_protocol_stub;
pub(super) fn pyi_docstring(doc: &str, indent: &str) -> Option<String> {
let trimmed = doc.trim();
if trimmed.is_empty() {
return None;
}
let first_paragraph = trimmed.split("\n\n").next().unwrap_or(trimmed);
let joined: String = first_paragraph
.lines()
.map(|l| l.trim().trim_start_matches("///").trim())
.filter(|l| !l.is_empty())
.collect::<Vec<_>>()
.join(" ");
if joined.is_empty() {
return None;
}
let sanitized = sanitize_python_doc(&joined);
let escaped = sanitized.replace('\\', "\\\\").replace("\"\"\"", "\\\"\\\"\\\"");
Some(format!("{indent}\"\"\"{escaped}\"\"\""))
}
pub(super) fn python_safe_name(name: &str) -> String {
crate::core::keywords::python_ident(name)
}
pub fn is_python_builtin_name(name: &str) -> bool {
const BUILTINS: &[&str] = &[
"id",
"type",
"input",
"hash",
"format",
"dir",
"help",
"list",
"map",
"filter",
"range",
"set",
"dict",
"str",
"int",
"float",
"bool",
"bytes",
"tuple",
"len",
"max",
"min",
"sum",
"abs",
"all",
"any",
"print",
"open",
"next",
"iter",
"vars",
"zip",
"object",
"property",
"super",
"staticmethod",
"classmethod",
"compile",
"exec",
"eval",
"license",
"credits",
"copyright",
];
BUILTINS.contains(&name)
}
pub(super) fn constructor_rust_type_to_python(rust_type: &str) -> &str {
match rust_type {
"String" | "&str" | "&'static str" | "std::string::String" => "str",
"bytes::Bytes" | "Vec<u8>" | "&[u8]" => "bytes",
"bool" => "bool",
"i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32" | "u64" | "u128" | "usize" => "int",
"f32" | "f64" => "float",
"()" => "None",
_ => "Any",
}
}
pub(super) fn constructor_param_type(ty: &TypeRef, api: &ApiSurface) -> String {
use crate::codegen::generators::enum_has_data_variants;
let enum_names: std::collections::HashSet<String> = api.enums.iter().map(|e| e.name.clone()).collect();
let data_enum_names: std::collections::HashSet<String> = api
.enums
.iter()
.filter(|e| enum_has_data_variants(e))
.map(|e| e.name.clone())
.collect();
match ty {
TypeRef::Named(name) if data_enum_names.contains(name) => name.clone(),
TypeRef::Named(name) if enum_names.contains(name) => format!("{} | str", name),
TypeRef::Optional(inner) => match inner.as_ref() {
TypeRef::Named(name) if data_enum_names.contains(name) => {
format!("{} | None", name)
}
TypeRef::Named(name) if enum_names.contains(name) => format!("{} | str | None", name),
_ => python_type(ty),
},
_ => python_type(ty),
}
}
pub fn gen_stubs(
api: &ApiSurface,
trait_bridges: &[TraitBridgeConfig],
config: &ResolvedCrateConfig,
exclude_functions: &ahash::AHashSet<String>,
) -> String {
let header = hash::header(CommentStyle::Hash);
let mut header_lines: Vec<String> = header.lines().map(str::to_string).collect();
header_lines.push("".to_string());
let bridge_param_names: std::collections::HashSet<&str> =
trait_bridges.iter().filter_map(|b| b.param_name.as_deref()).collect();
let options_field_bridges: OptionsFieldBridges<'_> = trait_bridges
.iter()
.filter(|b| b.bind_via == crate::core::config::BridgeBinding::OptionsField)
.filter_map(|b| {
let options_type = b.options_type.as_deref()?;
let param_name = b.param_name.as_deref()?;
let type_alias = b.type_alias.as_deref();
let trait_name = if api.types.iter().any(|t| t.name == b.trait_name) {
Some(b.trait_name.as_str())
} else {
None
};
Some((options_type, (param_name, type_alias, trait_name)))
})
.collect();
let streaming_return_types: std::collections::HashMap<(Option<String>, String), String> = config
.adapters
.iter()
.filter(|a| matches!(a.pattern, AdapterPattern::Streaming))
.map(|a| {
let item = a.item_type.as_deref().unwrap_or("Any").to_string();
((a.owner_type.clone(), a.name.clone()), item)
})
.collect();
let capsule_names: std::collections::HashSet<&str> = config
.python
.as_ref()
.map(|p| p.capsule_types.keys().map(String::as_str).collect())
.unwrap_or_default();
let emit_docstrings = config
.python
.as_ref()
.and_then(|p| p.stubs.as_ref())
.is_some_and(|s| s.emit_docstrings);
let (opaque, non_opaque): (Vec<_>, Vec<_>) = api
.types
.iter()
.filter(|typ| !typ.is_trait && !typ.binding_excluded)
.partition(|typ| typ.is_opaque);
let mut body_lines: Vec<String> = Vec::new();
for bridge in trait_bridges {
if bridge.bind_via != crate::core::config::BridgeBinding::OptionsField {
continue;
}
if let Some(stub) = gen_visitor_protocol_stub(bridge, api, &capsule_names, emit_docstrings) {
body_lines.push(stub);
body_lines.push("".to_string());
}
}
for typ in &non_opaque {
body_lines.push(gen_type_stub(
typ,
api,
config,
&capsule_names,
&options_field_bridges,
emit_docstrings,
&streaming_return_types,
));
body_lines.push("".to_string());
}
let opaque_non_capsule: Vec<_> = opaque
.iter()
.filter(|typ| !capsule_names.contains(typ.name.as_str()))
.collect();
if !opaque_non_capsule.is_empty() {
for typ in &opaque_non_capsule {
let ctor = config.client_constructors.get(&typ.name);
body_lines.push(gen_opaque_type_stub(typ, &capsule_names, &streaming_return_types, ctor));
}
body_lines.push("".to_string());
}
for enum_def in &api.enums {
body_lines.push(gen_enum_stub(enum_def, emit_docstrings));
body_lines.push("".to_string());
}
for func in api.functions.iter().filter(|f| !exclude_functions.contains(&f.name)) {
body_lines.push(gen_function_stub(
func,
&bridge_param_names,
&capsule_names,
&options_field_bridges,
&streaming_return_types,
));
}
{
use heck::ToSnakeCase as _;
for service in &api.services {
let service_snake = service.name.to_snake_case();
for ep in &service.entrypoints {
let func_name = format!("{service_snake}_{}", ep.method);
let return_annot = match &ep.return_type {
TypeRef::Unit => "None".to_string(),
_ => "Any".to_string(),
};
body_lines.push(format!(
"def {func_name}(registrations: list[Any]) -> {return_annot}: ..."
));
}
}
}
for bridge in trait_bridges {
if let Some(register_fn) = bridge.register_fn.as_deref() {
body_lines.push(format!("def {register_fn}(backend: object) -> None: ..."));
}
if let Some(unregister_fn) = bridge.unregister_fn.as_deref() {
body_lines.push(format!("def {unregister_fn}(name: str) -> None: ..."));
}
if let Some(clear_fn) = bridge.clear_fn.as_deref() {
body_lines.push(format!("def {clear_fn}() -> None: ..."));
}
}
let body_joined = body_lines.join("\n");
let used_typing: Vec<&str> = ["Any", "AsyncIterator", "Literal", "Protocol", "TypeAlias", "TypedDict"]
.iter()
.copied()
.filter(|name| contains_word(&body_joined, name))
.collect();
let mut lines = header_lines;
if !used_typing.is_empty() {
lines.push(format!("from typing import {}", used_typing.join(", ")));
lines.push("".to_string());
}
lines.extend(body_lines);
lines.join("\n")
}
fn contains_word(text: &str, word: &str) -> bool {
let bytes = text.as_bytes();
let needle = word.as_bytes();
let is_ident = |b: u8| b.is_ascii_alphanumeric() || b == b'_';
let mut start = 0;
while let Some(idx) = text[start..].find(word) {
let pos = start + idx;
let before_ok = pos == 0 || !is_ident(bytes[pos - 1]);
let after_pos = pos + needle.len();
let after_ok = after_pos == bytes.len() || !is_ident(bytes[after_pos]);
if before_ok && after_ok {
return true;
}
start = pos + 1;
}
false
}
pub(super) fn substitute_capsule_type(type_str: &str, capsule_names: &std::collections::HashSet<&str>) -> String {
let mut result = type_str.to_string();
for name in capsule_names {
let list_pattern = format!("list[{name}]");
if result.contains(&list_pattern) {
result = result.replace(&list_pattern, "list[Any]");
continue;
}
let optional_pattern = format!("{name} | None");
if result.contains(&optional_pattern) {
result = result.replace(&optional_pattern, "Any");
continue;
}
let is_ident = |b: u8| b.is_ascii_alphanumeric() || b == b'_';
let needle = name.as_bytes();
let bytes = result.as_bytes();
let mut out = String::new();
let mut start = 0usize;
while let Some(idx) = result[start..].find(name) {
let pos = start + idx;
let before_ok = pos == 0 || !is_ident(bytes[pos - 1]);
let after_pos = pos + needle.len();
let after_ok = after_pos == bytes.len() || !is_ident(bytes[after_pos]);
if before_ok && after_ok {
out.push_str(&result[start..pos]);
out.push_str("Any");
start = after_pos;
} else {
out.push_str(&result[start..=pos]);
start = pos + 1;
}
}
out.push_str(&result[start..]);
result = out;
}
result
}