use crate::generators::binding_helpers::{
gen_async_body, gen_call_args, gen_call_args_with_let_bindings, gen_named_let_bindings, gen_serde_let_bindings,
gen_unimplemented_body, has_named_params,
};
use crate::generators::{AdapterBodies, AsyncPattern, RustBindingConfig};
use crate::shared::{function_params, function_sig_defaults};
use crate::type_mapper::TypeMapper;
use ahash::{AHashMap, AHashSet};
use alef_core::ir::{ApiSurface, FunctionDef, TypeRef};
use std::fmt::Write;
pub fn gen_function(
func: &FunctionDef,
mapper: &dyn TypeMapper,
cfg: &RustBindingConfig,
adapter_bodies: &AdapterBodies,
opaque_types: &AHashSet<String>,
) -> String {
let map_fn = |ty: &alef_core::ir::TypeRef| mapper.map_type(ty);
let params = function_params(&func.params, &map_fn);
let return_type = mapper.map_type(&func.return_type);
let ret = mapper.wrap_return(&return_type, func.error_type.is_some());
let use_let_bindings = has_named_params(&func.params, opaque_types);
let call_args = if use_let_bindings {
gen_call_args_with_let_bindings(&func.params, opaque_types)
} else {
gen_call_args(&func.params, opaque_types)
};
let let_bindings = if use_let_bindings {
gen_named_let_bindings(&func.params, opaque_types)
} else {
String::new()
};
let core_import = cfg.core_import;
let core_fn_path = {
let path = func.rust_path.replace('-', "_");
if path.starts_with(core_import) {
path
} else {
format!("{core_import}::{}", func.name)
}
};
let can_delegate = crate::shared::can_auto_delegate_function(func, opaque_types);
let serde_err_conv = match cfg.async_pattern {
AsyncPattern::Pyo3FutureIntoPy => ".map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))",
AsyncPattern::NapiNativeAsync => ".map_err(|e| napi::Error::new(napi::Status::GenericFailure, e.to_string()))",
AsyncPattern::WasmNativeAsync => ".map_err(|e| JsValue::from_str(&e.to_string()))",
_ => ".map_err(|e| e.to_string())",
};
let body = if !can_delegate {
if let Some(adapter_body) = adapter_bodies.get(&func.name) {
adapter_body.clone()
} else if cfg.has_serde && use_let_bindings && func.error_type.is_some() {
let is_async_pyo3 = func.is_async && cfg.async_pattern == AsyncPattern::Pyo3FutureIntoPy;
let (serde_indent, serde_err_async) = if is_async_pyo3 {
(
" ",
".map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))",
)
} else {
(" ", serde_err_conv)
};
let serde_bindings =
gen_serde_let_bindings(&func.params, opaque_types, core_import, serde_err_async, serde_indent);
let core_call = format!("{core_fn_path}({call_args})");
let returns_ref = func.returns_ref;
let wrap_return = |expr: &str| -> String {
match &func.return_type {
TypeRef::Named(name) if opaque_types.contains(name.as_str()) => {
if returns_ref {
format!("{name} {{ inner: Arc::new({expr}.clone()) }}")
} else {
format!("{name} {{ inner: Arc::new({expr}) }}")
}
}
TypeRef::Named(_name) => {
if returns_ref {
format!("{expr}.clone().into()")
} else {
format!("{expr}.into()")
}
}
TypeRef::String | TypeRef::Bytes => format!("{expr}.into()"),
TypeRef::Path => format!("{expr}.to_string_lossy().to_string()"),
TypeRef::Json => format!("{expr}.to_string()"),
_ => expr.to_string(),
}
};
if is_async_pyo3 {
let is_unit = matches!(func.return_type, TypeRef::Unit);
let wrapped = wrap_return("result");
let core_await = format!(
"{core_call}.await\n .map_err(|e| PyErr::new::<PyRuntimeError, _>(e.to_string()))?"
);
let inner_body = if is_unit {
format!("{serde_bindings}{core_await};\n Ok(())")
} else {
format!("{serde_bindings}let result = {core_await};\n Ok({wrapped})")
};
format!("pyo3_async_runtimes::tokio::future_into_py(py, async move {{\n{inner_body}\n }})")
} else if matches!(func.return_type, TypeRef::Unit) {
format!("{serde_bindings}{core_call}{serde_err_conv}?;\n Ok(())")
} else {
let wrapped = wrap_return("val");
if wrapped == "val" {
format!("{serde_bindings}{core_call}{serde_err_conv}")
} else {
format!("{serde_bindings}{core_call}.map(|val| {wrapped}){serde_err_conv}")
}
}
} else {
gen_unimplemented_body(
&func.return_type,
&func.name,
func.error_type.is_some(),
cfg,
&func.params,
)
}
} else if func.is_async {
let core_call = format!("{core_fn_path}({call_args})");
let return_wrap = match &func.return_type {
TypeRef::Named(n) if opaque_types.contains(n.as_str()) => {
format!("{n} {{ inner: Arc::new(result) }}")
}
TypeRef::Named(_) => {
format!("{return_type}::from(result)")
}
TypeRef::Vec(inner) => match inner.as_ref() {
TypeRef::Named(n) if opaque_types.contains(n.as_str()) => {
format!("result.into_iter().map(|v| {n} {{ inner: Arc::new(v) }}).collect::<Vec<_>>()")
}
TypeRef::Named(_) => {
let inner_mapped = mapper.map_type(inner);
format!("result.into_iter().map({inner_mapped}::from).collect::<Vec<_>>()")
}
_ => "result".to_string(),
},
TypeRef::Unit => "result".to_string(),
_ => super::binding_helpers::wrap_return(
"result",
&func.return_type,
"",
opaque_types,
false,
func.returns_ref,
false,
),
};
let async_body = gen_async_body(
&core_call,
cfg,
func.error_type.is_some(),
&return_wrap,
false,
"",
matches!(func.return_type, TypeRef::Unit),
);
format!("{let_bindings}{async_body}")
} else {
let core_call = format!("{core_fn_path}({call_args})");
let returns_ref = func.returns_ref;
let wrap_return = |expr: &str| -> String {
match &func.return_type {
TypeRef::Named(name) if opaque_types.contains(name.as_str()) => {
if returns_ref {
format!("{name} {{ inner: Arc::new({expr}.clone()) }}")
} else {
format!("{name} {{ inner: Arc::new({expr}) }}")
}
}
TypeRef::Named(_name) => {
if returns_ref {
format!("{expr}.clone().into()")
} else {
format!("{expr}.into()")
}
}
TypeRef::String | TypeRef::Bytes => format!("{expr}.into()"),
TypeRef::Path => format!("{expr}.to_string_lossy().to_string()"),
TypeRef::Json => format!("{expr}.to_string()"),
TypeRef::Optional(inner) => match inner.as_ref() {
TypeRef::Named(name) if opaque_types.contains(name.as_str()) => {
if returns_ref {
format!("{expr}.map(|v| {name} {{ inner: Arc::new(v.clone()) }})")
} else {
format!("{expr}.map(|v| {name} {{ inner: Arc::new(v) }})")
}
}
TypeRef::Named(_) => {
if returns_ref {
format!("{expr}.map(|v| v.clone().into())")
} else {
format!("{expr}.map(Into::into)")
}
}
TypeRef::String | TypeRef::Bytes | TypeRef::Path => {
format!("{expr}.map(Into::into)")
}
_ => expr.to_string(),
},
TypeRef::Vec(inner) => match inner.as_ref() {
TypeRef::Named(name) if opaque_types.contains(name.as_str()) => {
if returns_ref {
format!("{expr}.into_iter().map(|v| {name} {{ inner: Arc::new(v.clone()) }}).collect()")
} else {
format!("{expr}.into_iter().map(|v| {name} {{ inner: Arc::new(v) }}).collect()")
}
}
TypeRef::Named(_) => {
if returns_ref {
format!("{expr}.into_iter().map(|v| v.clone().into()).collect()")
} else {
format!("{expr}.into_iter().map(Into::into).collect()")
}
}
TypeRef::String | TypeRef::Bytes | TypeRef::Path => {
format!("{expr}.into_iter().map(Into::into).collect()")
}
_ => expr.to_string(),
},
_ => expr.to_string(),
}
};
if func.error_type.is_some() {
let err_conv = match cfg.async_pattern {
AsyncPattern::Pyo3FutureIntoPy => {
".map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))"
}
AsyncPattern::NapiNativeAsync => {
".map_err(|e| napi::Error::new(napi::Status::GenericFailure, e.to_string()))"
}
AsyncPattern::WasmNativeAsync => ".map_err(|e| JsValue::from_str(&e.to_string()))",
_ => ".map_err(|e| e.to_string())",
};
let wrapped = wrap_return("val");
if wrapped == "val" {
format!("{core_call}{err_conv}")
} else {
format!("{core_call}.map(|val| {wrapped}){err_conv}")
}
} else {
wrap_return(&core_call)
}
};
let body = if !let_bindings.is_empty() && can_delegate && !func.is_async {
format!("{let_bindings}{body}")
} else {
body
};
let async_kw = if func.is_async { "async " } else { "" };
let func_needs_py = func.is_async && cfg.async_pattern == AsyncPattern::Pyo3FutureIntoPy;
let ret = if func_needs_py {
"PyResult<Bound<'py, PyAny>>".to_string()
} else {
ret
};
let func_lifetime = if func_needs_py { "<'py>" } else { "" };
let (func_sig, _params_formatted) = if params.len() > 100 {
let wrapped_params = func
.params
.iter()
.map(|p| {
let ty = if p.optional {
format!("Option<{}>", mapper.map_type(&p.ty))
} else {
mapper.map_type(&p.ty)
};
format!("{}: {}", p.name, ty)
})
.collect::<Vec<_>>()
.join(",\n ");
if func_needs_py {
(
format!(
"pub fn {}{func_lifetime}(py: Python<'py>,\n {}\n) -> {ret}",
func.name,
wrapped_params,
ret = ret
),
"",
)
} else {
(
format!(
"pub {async_kw}fn {}(\n {}\n) -> {ret}",
func.name,
wrapped_params,
ret = ret
),
"",
)
}
} else if func_needs_py {
(
format!(
"pub fn {}{func_lifetime}(py: Python<'py>, {params}) -> {ret}",
func.name
),
"",
)
} else {
(format!("pub {async_kw}fn {}({params}) -> {ret}", func.name), "")
};
let mut out = String::with_capacity(1024);
let total_params = func.params.len() + if func_needs_py { 1 } else { 0 };
if total_params > 7 {
writeln!(out, "#[allow(clippy::too_many_arguments)]").ok();
}
if func.error_type.is_some() {
writeln!(out, "#[allow(clippy::missing_errors_doc)]").ok();
}
let attr_inner = cfg
.function_attr
.trim_start_matches('#')
.trim_start_matches('[')
.trim_end_matches(']');
writeln!(out, "#[{attr_inner}]").ok();
if cfg.needs_signature {
let sig = function_sig_defaults(&func.params);
writeln!(out, "{}{}{}", cfg.signature_prefix, sig, cfg.signature_suffix).ok();
}
write!(out, "{} {{\n {body}\n}}", func_sig,).ok();
out
}
pub fn collect_trait_imports(api: &ApiSurface) -> Vec<String> {
let mut traits: AHashSet<String> = AHashSet::new();
for typ in &api.types {
for method in &typ.methods {
if let Some(ref trait_path) = method.trait_source {
traits.insert(trait_path.clone());
}
}
}
let mut sorted: Vec<String> = traits.into_iter().collect();
sorted.sort();
sorted
}
pub fn has_unresolved_trait_methods(api: &ApiSurface) -> bool {
let mut method_counts: AHashMap<&str, (usize, usize)> = AHashMap::new(); for typ in &api.types {
if typ.is_trait {
continue;
}
for method in &typ.methods {
let entry = method_counts.entry(&method.name).or_insert((0, 0));
entry.0 += 1;
if method.trait_source.is_some() {
entry.1 += 1;
}
}
}
method_counts
.values()
.any(|&(total, with_source)| total >= 3 && with_source == 0)
}