use crate::codegen::generators::trait_bridge::{TraitBridgeGenerator, TraitBridgeSpec, host_function_path};
use crate::core::ir::{MethodDef, TypeRef};
use std::collections::HashMap;
fn exported_pyfunction_symbol(fn_name: &str) -> String {
fn_name.to_string()
}
pub struct Pyo3BridgeGenerator {
pub core_import: String,
pub type_paths: HashMap<String, String>,
pub error_type: String,
}
impl TraitBridgeGenerator for Pyo3BridgeGenerator {
fn foreign_object_type(&self) -> &str {
"Py<PyAny>"
}
fn bridge_imports(&self) -> Vec<String> {
vec!["pyo3::prelude::*".to_string(), "std::sync::Arc".to_string()]
}
fn gen_sync_method_body(&self, method: &MethodDef, spec: &TraitBridgeSpec) -> String {
let name = &method.name;
let has_error = method.error_type.is_some();
let py_args = self.sync_py_args(method);
let call = if py_args.is_empty() {
format!("self.inner.bind(py).call_method0(\"{name}\")")
} else {
format!("self.inner.bind(py).call_method1(\"{name}\", ({py_args}))")
};
let error_expr = spec.make_error(&format!(
"format!(\"Plugin '{{}}' method '{name}' failed: {{}}\", self.cached_name, e)"
));
if matches!(method.return_type, TypeRef::Unit) {
crate::backends::pyo3::template_env::render(
"trait_bridge/sync_method_unit_return.jinja",
minijinja::context! {
method_name => name,
call => call,
has_error => has_error,
error_expr => error_expr,
},
)
} else {
let ext = self.extract_ty(&method.return_type);
let is_named = matches!(method.return_type, TypeRef::Named(_));
crate::backends::pyo3::template_env::render(
"trait_bridge/sync_method_non_unit_return.jinja",
minijinja::context! {
method_name => name,
call => call,
is_named => is_named,
extract_ty => ext,
has_error => has_error,
error_expr => error_expr,
},
)
}
}
fn gen_async_method_body(&self, method: &MethodDef, spec: &TraitBridgeSpec) -> String {
let name = &method.name;
let params: Vec<minijinja::Value> = method
.params
.iter()
.map(|p| {
minijinja::context! {
name => &p.name,
ty => match &p.ty {
TypeRef::Bytes => "Bytes",
TypeRef::Path => "Path",
TypeRef::Named(_) => {
match &p.ty {
TypeRef::Named(n) => n.as_str(),
_ => "",
}
},
_ => "",
}.to_string(),
ty_is_named => matches!(&p.ty, TypeRef::Named(_)),
is_ref => p.is_ref,
}
})
.collect();
let param_cloning = crate::backends::pyo3::template_env::render(
"trait_bridge/async_param_cloning.jinja",
minijinja::context! {
params => params,
},
);
let py_args = self.async_py_args(method);
let call = if py_args.is_empty() {
format!("obj.call_method0(\"{name}\")")
} else {
format!("obj.call_method1(\"{name}\", ({py_args}))")
};
let error_expr = spec.make_error(&format!(
"format!(\"Plugin '{{}}' method '{name}' failed: {{}}\", cached_name, e)"
));
let json_error_expr =
spec.make_error("format!(\"Plugin '{}': JSON serialization failed: {}\", cached_name, e)");
let deserialize_error_expr =
spec.make_error("format!(\"Plugin '{}': deserialization failed: {}\", cached_name, e)");
let spawn_error_expr = spec.make_error("format!(\"spawn_blocking failed: {}\", e)");
if self.is_named(&method.return_type) {
let return_type =
crate::codegen::generators::trait_bridge::format_type_ref(&method.return_type, &spec.type_paths);
crate::backends::pyo3::template_env::render(
"trait_bridge/async_method_named_return.jinja",
minijinja::context! {
method_name => name,
call => call,
param_cloning => param_cloning,
return_type => return_type,
error_expr => error_expr,
json_error_expr => json_error_expr,
deserialize_error_expr => deserialize_error_expr,
spawn_error_expr => spawn_error_expr,
},
)
} else if matches!(method.return_type, TypeRef::Unit) {
crate::backends::pyo3::template_env::render(
"trait_bridge/async_method_unit_return.jinja",
minijinja::context! {
method_name => name,
call => call,
param_cloning => param_cloning,
error_expr => error_expr,
spawn_error_expr => spawn_error_expr,
},
)
} else {
let ext = self.extract_ty(&method.return_type);
crate::backends::pyo3::template_env::render(
"trait_bridge/async_method_non_unit_return.jinja",
minijinja::context! {
method_name => name,
call => call,
extract_ty => ext,
param_cloning => param_cloning,
error_expr => error_expr,
spawn_error_expr => spawn_error_expr,
},
)
}
}
fn gen_constructor(&self, spec: &TraitBridgeSpec) -> String {
let wrapper = spec.wrapper_name();
let required_methods = spec.required_methods();
crate::backends::pyo3::template_env::render(
"trait_bridge/constructor.jinja",
minijinja::context! {
wrapper => wrapper,
required_methods => required_methods,
},
)
}
fn gen_unregistration_fn(&self, spec: &TraitBridgeSpec) -> String {
let Some(unregister_fn) = spec.bridge_config.unregister_fn.as_deref() else {
return String::new();
};
let host_path = host_function_path(spec, unregister_fn);
let host_symbol = exported_pyfunction_symbol(unregister_fn);
crate::backends::pyo3::template_env::render(
"trait_bridge/unregistration_fn.jinja",
minijinja::context! {
unregister_fn => unregister_fn,
host_symbol => host_symbol,
host_path => host_path,
},
)
}
fn gen_clear_fn(&self, spec: &TraitBridgeSpec) -> String {
let Some(clear_fn) = spec.bridge_config.clear_fn.as_deref() else {
return String::new();
};
let host_path = host_function_path(spec, clear_fn);
let host_symbol = exported_pyfunction_symbol(clear_fn);
crate::backends::pyo3::template_env::render(
"trait_bridge/clear_fn.jinja",
minijinja::context! {
clear_fn => clear_fn,
host_symbol => host_symbol,
host_path => host_path,
},
)
}
fn gen_registration_fn(&self, spec: &TraitBridgeSpec) -> String {
let Some(register_fn) = spec.bridge_config.register_fn.as_deref() else {
return String::new();
};
let Some(registry_getter) = spec.bridge_config.registry_getter.as_deref() else {
return String::new();
};
let wrapper = spec.wrapper_name();
let trait_path = spec.trait_path();
let req_methods: Vec<&MethodDef> = spec.required_methods();
let required_methods_str = req_methods
.iter()
.map(|m| format!("\"{}\"", m.name))
.collect::<Vec<_>>()
.join(", ");
let register_extra_args = spec
.bridge_config
.register_extra_args
.as_deref()
.map(|a| format!(", {a}"))
.unwrap_or_default();
crate::backends::pyo3::template_env::render(
"trait_bridge/registration_fn.jinja",
minijinja::context! {
register_fn => register_fn,
wrapper => wrapper,
trait_path => trait_path,
registry_getter => registry_getter,
register_extra_args => register_extra_args,
has_required_methods => !req_methods.is_empty(),
required_methods_str => required_methods_str,
},
)
}
}
impl Pyo3BridgeGenerator {
fn extract_ty(&self, ty: &TypeRef) -> String {
match ty {
TypeRef::Primitive(p) => self.prim(p).to_string(),
TypeRef::String | TypeRef::Path | TypeRef::Char => "String".into(),
TypeRef::Bytes => "Vec<u8>".into(),
TypeRef::Vec(inner) => format!("Vec<{}>", self.extract_ty(inner)),
TypeRef::Optional(inner) => format!("Option<{}>", self.extract_ty(inner)),
TypeRef::Named(name) => {
self.type_paths
.get(name.as_str())
.map(|p| p.replace('-', "_"))
.unwrap_or_else(|| format!("{}::{}", self.core_import, name))
}
TypeRef::Unit => "()".into(),
TypeRef::Map(k, v) => format!(
"std::collections::HashMap<{}, {}>",
self.extract_ty(k),
self.extract_ty(v)
),
TypeRef::Json => "String".into(),
TypeRef::Duration => "u64".into(),
}
}
fn prim(&self, p: &crate::core::ir::PrimitiveType) -> &'static str {
use crate::core::ir::PrimitiveType::*;
match p {
Bool => "bool",
U8 => "u8",
U16 => "u16",
U32 => "u32",
U64 => "u64",
I8 => "i8",
I16 => "i16",
I32 => "i32",
I64 => "i64",
F32 => "f32",
F64 => "f64",
Usize => "usize",
Isize => "isize",
}
}
fn sync_py_args(&self, method: &MethodDef) -> String {
let args: Vec<String> = method
.params
.iter()
.map(|p| match (&p.ty, p.is_ref) {
(TypeRef::Bytes, true) => format!("pyo3::types::PyBytes::new(py, {})", p.name),
(TypeRef::Path, true) => format!("{}.to_str().unwrap_or_default()", p.name),
(TypeRef::Named(_), true) => {
format!("serde_json::to_string({}).unwrap_or_default()", p.name)
}
_ => p.name.clone(),
})
.collect();
if args.len() == 1 {
format!("{},", args[0])
} else {
args.join(", ")
}
}
fn async_py_args(&self, method: &MethodDef) -> String {
let args: Vec<String> = method
.params
.iter()
.map(|p| match (&p.ty, p.is_ref) {
(TypeRef::Bytes, true) => format!("pyo3::types::PyBytes::new(py, &{})", p.name),
(TypeRef::Path, true) => format!("{}_str.as_str()", p.name),
(TypeRef::Named(_), true) => format!("{}_json.as_str()", p.name),
_ => p.name.clone(),
})
.collect();
if args.len() == 1 {
format!("{},", args[0])
} else {
args.join(", ")
}
}
fn is_named(&self, ty: &TypeRef) -> bool {
matches!(ty, TypeRef::Named(_))
}
}