1use crate::generators::binding_helpers::{
2 gen_async_body, gen_call_args, gen_call_args_cfg, gen_call_args_with_let_bindings, gen_named_let_bindings,
3 gen_named_let_bindings_by_ref, gen_serde_let_bindings, gen_unimplemented_body, has_named_params,
4};
5use crate::generators::{AdapterBodies, AsyncPattern, RustBindingConfig};
6use crate::shared::{function_params, function_sig_defaults};
7use crate::type_mapper::TypeMapper;
8use ahash::{AHashMap, AHashSet};
9use alef_core::ir::{ApiSurface, FunctionDef, TypeRef};
10
11fn expr_is_already_arc(expr: &str) -> bool {
15 let trimmed = expr.trim();
16 trimmed == "self.inner"
17 || trimmed == "self.inner.clone()"
18 || trimmed.starts_with("self.inner.as_ref()")
19 || trimmed.starts_with("self.inner.clone()")
20}
21
22fn arc_wrap_expr(val: &str, name: &str, mutex_types: &AHashSet<String>) -> String {
25 if mutex_types.contains(name) {
26 format!("Arc::new(std::sync::Mutex::new({val}))")
27 } else {
28 format!("Arc::new({val})")
29 }
30}
31
32pub fn gen_function(
34 func: &FunctionDef,
35 mapper: &dyn TypeMapper,
36 cfg: &RustBindingConfig,
37 adapter_bodies: &AdapterBodies,
38 opaque_types: &AHashSet<String>,
39) -> String {
40 gen_function_with_mutex(func, mapper, cfg, adapter_bodies, opaque_types, &AHashSet::new())
41}
42
43pub fn gen_function_with_mutex(
47 func: &FunctionDef,
48 mapper: &dyn TypeMapper,
49 cfg: &RustBindingConfig,
50 adapter_bodies: &AdapterBodies,
51 opaque_types: &AHashSet<String>,
52 mutex_types: &AHashSet<String>,
53) -> String {
54 let map_fn = |ty: &alef_core::ir::TypeRef| mapper.map_type(ty);
55 let params = if cfg.named_non_opaque_params_by_ref {
62 let mut seen_optional = false;
63 func.params
64 .iter()
65 .enumerate()
66 .map(|(idx, p)| {
67 if p.optional {
68 seen_optional = true;
69 }
70 let promoted = seen_optional && !p.optional && crate::shared::is_promoted_optional(&func.params, idx);
71 let ty = match &p.ty {
72 TypeRef::Named(n) if !opaque_types.contains(n.as_str()) => {
73 if p.optional || seen_optional || promoted {
74 format!("Nullable<&{}>", map_fn(&p.ty))
75 } else {
76 format!("&{}", map_fn(&p.ty))
77 }
78 }
79 _ => {
80 if p.optional || seen_optional {
81 format!("Option<{}>", map_fn(&p.ty))
82 } else {
83 map_fn(&p.ty)
84 }
85 }
86 };
87 format!("{}: {}", p.name, ty)
88 })
89 .collect::<Vec<_>>()
90 .join(", ")
91 } else {
92 function_params(&func.params, &map_fn)
93 };
94 let return_type = mapper.map_type(&func.return_type);
95 let ret = mapper.wrap_return(&return_type, func.error_type.is_some());
96
97 let effective_params: std::borrow::Cow<[alef_core::ir::ParamDef]> = if cfg.named_non_opaque_params_by_ref {
101 let modified: Vec<alef_core::ir::ParamDef> = func
102 .params
103 .iter()
104 .map(|p| {
105 if matches!(&p.ty, TypeRef::Named(n) if !opaque_types.contains(n.as_str())) {
106 alef_core::ir::ParamDef {
107 is_ref: true,
108 ..p.clone()
109 }
110 } else {
111 p.clone()
112 }
113 })
114 .collect();
115 std::borrow::Cow::Owned(modified)
116 } else {
117 std::borrow::Cow::Borrowed(&func.params)
118 };
119 let use_let_bindings = has_named_params(&effective_params, opaque_types);
120 let call_args = if use_let_bindings {
121 gen_call_args_with_let_bindings(&effective_params, opaque_types)
122 } else if cfg.cast_uints_to_i32 || cfg.cast_large_ints_to_f64 {
123 gen_call_args_cfg(
124 &effective_params,
125 opaque_types,
126 cfg.cast_uints_to_i32,
127 cfg.cast_large_ints_to_f64,
128 )
129 } else {
130 gen_call_args(&effective_params, opaque_types)
131 };
132 let core_import = cfg.core_import;
133 let let_bindings = if use_let_bindings {
134 if cfg.named_non_opaque_params_by_ref {
135 gen_named_let_bindings_by_ref(&func.params, opaque_types, core_import)
137 } else {
138 gen_named_let_bindings(&func.params, opaque_types, core_import)
139 }
140 } else {
141 String::new()
142 };
143
144 let core_fn_path = {
146 let path = func.rust_path.replace('-', "_");
147 if path.starts_with(core_import) {
148 path
149 } else {
150 format!("{core_import}::{}", func.name)
151 }
152 };
153
154 let can_delegate = crate::shared::can_auto_delegate_function(func, opaque_types)
155 || can_delegate_with_named_let_bindings(func, opaque_types);
156
157 let serde_err_conv = match cfg.async_pattern {
159 AsyncPattern::Pyo3FutureIntoPy => ".map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))",
160 AsyncPattern::NapiNativeAsync => ".map_err(|e| napi::Error::new(napi::Status::GenericFailure, e.to_string()))",
161 AsyncPattern::WasmNativeAsync => ".map_err(|e| JsValue::from_str(&e.to_string()))",
162 AsyncPattern::TokioBlockOn => {
163 ".map_err(|e| extendr_api::Error::Other(e.to_string().replace(\":\", \"_\").replace(\"/\", \"_\").replace(\"-\", \"_\").chars().take(255).collect::<String>()))"
164 }
165 _ => ".map_err(|e| e.to_string())",
166 };
167
168 let body = if !can_delegate {
170 if let Some(adapter_body) = adapter_bodies.get(&func.name) {
172 adapter_body.clone()
173 } else if cfg.has_serde && use_let_bindings && func.error_type.is_some() {
174 let is_async_pyo3 = func.is_async && cfg.async_pattern == AsyncPattern::Pyo3FutureIntoPy;
179 let (serde_indent, serde_err_async) = if is_async_pyo3 {
180 (
181 " ",
182 ".map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))",
183 )
184 } else {
185 (" ", serde_err_conv)
186 };
187 let serde_bindings =
188 gen_serde_let_bindings(&func.params, opaque_types, core_import, serde_err_async, serde_indent);
189 let core_call = format!("{core_fn_path}({call_args})");
190
191 let returns_ref = func.returns_ref;
193 let wrap_return = |expr: &str| -> String {
194 match &func.return_type {
195 TypeRef::Vec(inner) => {
196 match inner.as_ref() {
198 TypeRef::Named(_) => {
199 format!("{expr}.into_iter().map(Into::into).collect()")
201 }
202 _ => expr.to_string(),
203 }
204 }
205 TypeRef::Named(name) if opaque_types.contains(name.as_str()) => {
206 let mapped_name = mapper.named(name);
207 if returns_ref {
208 format!("{mapped_name} {{ inner: Arc::new({expr}.clone()) }}")
209 } else {
210 format!("{mapped_name} {{ inner: Arc::new({expr}) }}")
211 }
212 }
213 TypeRef::Named(_) => {
214 if returns_ref {
216 format!("{return_type}::from({expr}.clone())")
217 } else {
218 format!("{return_type}::from({expr})")
219 }
220 }
221 TypeRef::String | TypeRef::Bytes => expr.to_string(),
224 TypeRef::Path => format!("{expr}.to_string_lossy().to_string()"),
225 TypeRef::Json => format!("{expr}.to_string()"),
226 _ => expr.to_string(),
227 }
228 };
229
230 if is_async_pyo3 {
231 let is_unit = matches!(func.return_type, TypeRef::Unit);
233 let wrapped = wrap_return("result");
234 let core_await = format!(
235 "{core_call}.await\n .map_err(|e| PyErr::new::<PyRuntimeError, _>(e.to_string()))?"
236 );
237 let inner_body = if is_unit {
238 format!("{serde_bindings}{core_await};\n Ok(())")
239 } else {
240 if wrapped.contains(".into()") || wrapped.contains("::from(") || wrapped.contains("Into::into") {
244 format!(
246 "{serde_bindings}let result = {core_await};\n let wrapped_result: {return_type} = {wrapped};\n Ok(wrapped_result)"
247 )
248 } else {
249 format!("{serde_bindings}let result = {core_await};\n Ok({wrapped})")
250 }
251 };
252 format!("pyo3_async_runtimes::tokio::future_into_py(py, async move {{\n{inner_body}\n }})")
253 } else if func.is_async {
254 let is_unit = matches!(func.return_type, TypeRef::Unit);
256 let wrapped = wrap_return("result");
257 let async_body = gen_async_body(
258 &core_call,
259 cfg,
260 func.error_type.is_some(),
261 &wrapped,
262 false,
263 "",
264 is_unit,
265 Some(&return_type),
266 );
267 format!("{serde_bindings}{async_body}")
268 } else if matches!(func.return_type, TypeRef::Unit) {
269 let await_kw = if func.is_async { ".await" } else { "" };
271 let debug_marker = if func.is_async { "/*ASYNC_UNIT*/ " } else { "" };
272 format!("{serde_bindings}{debug_marker}{core_call}{await_kw}{serde_err_conv}?;\n Ok(())")
273 } else {
274 let wrapped = wrap_return("val");
275 let await_kw = if func.is_async { ".await" } else { "" };
276 if wrapped == "val" {
277 format!("{serde_bindings}{core_call}{await_kw}{serde_err_conv}")
278 } else if wrapped == "val.into()" {
279 format!("{serde_bindings}{core_call}{await_kw}.map(Into::into){serde_err_conv}")
280 } else if let Some(type_path) = wrapped.strip_suffix("::from(val)") {
281 format!("{serde_bindings}{core_call}{await_kw}.map({type_path}::from){serde_err_conv}")
282 } else {
283 format!("{serde_bindings}{core_call}{await_kw}.map(|val| {wrapped}){serde_err_conv}")
284 }
285 }
286 } else if func.is_async && cfg.async_pattern == AsyncPattern::Pyo3FutureIntoPy {
287 let suppress = if func.params.is_empty() {
289 String::new()
290 } else {
291 let names: Vec<&str> = func.params.iter().map(|p| p.name.as_str()).collect();
292 format!("let _ = ({});\n ", names.join(", "))
293 };
294 format!(
295 "{suppress}Err(pyo3::exceptions::PyNotImplementedError::new_err(\"not implemented: {}\"))",
296 func.name
297 )
298 } else {
299 gen_unimplemented_body(
301 &func.return_type,
302 &func.name,
303 func.error_type.is_some(),
304 cfg,
305 &func.params,
306 opaque_types,
307 )
308 }
309 } else if func.is_async {
310 let core_call = format!("{core_fn_path}({call_args})");
312 let return_wrap = match &func.return_type {
315 TypeRef::Named(n) if opaque_types.contains(n.as_str()) => {
316 let mapped_n = mapper.named(n);
317 let wrap = arc_wrap_expr("result", n, mutex_types);
318 format!("{mapped_n} {{ inner: {wrap} }}")
319 }
320 TypeRef::Named(_) => {
321 format!("{return_type}::from(result)")
322 }
323 TypeRef::Vec(inner) => match inner.as_ref() {
324 TypeRef::Named(n) if opaque_types.contains(n.as_str()) => {
325 let mapped_n = mapper.named(n);
326 let wrap = arc_wrap_expr("v", n, mutex_types);
327 format!("result.into_iter().map(|v| {mapped_n} {{ inner: {wrap} }}).collect::<Vec<_>>()")
328 }
329 TypeRef::Named(_) => {
330 let inner_mapped = mapper.map_type(inner);
331 format!("result.into_iter().map({inner_mapped}::from).collect::<Vec<_>>()")
332 }
333 _ => "result".to_string(),
334 },
335 TypeRef::Unit => "result".to_string(),
336 _ => super::binding_helpers::wrap_return(
337 "result",
338 &func.return_type,
339 "",
340 opaque_types,
341 false,
342 func.returns_ref,
343 false,
344 ),
345 };
346 let async_body = gen_async_body(
347 &core_call,
348 cfg,
349 func.error_type.is_some(),
350 &return_wrap,
351 false,
352 "",
353 matches!(func.return_type, TypeRef::Unit),
354 Some(&return_type),
355 );
356 format!("{let_bindings}{async_body}")
357 } else {
358 let core_call = format!("{core_fn_path}({call_args})");
359
360 let returns_ref = func.returns_ref;
362 let wrap_return = |expr: &str| -> String {
363 match &func.return_type {
364 TypeRef::Named(name) if opaque_types.contains(name.as_str()) => {
367 let mapped_name = mapper.named(name);
368 if expr_is_already_arc(expr) {
369 format!("{mapped_name} {{ inner: {expr} }}")
370 } else if returns_ref {
371 let wrap = arc_wrap_expr(&format!("{expr}.clone()"), name, mutex_types);
372 format!("{mapped_name} {{ inner: {wrap} }}")
373 } else {
374 let wrap = arc_wrap_expr(expr, name, mutex_types);
375 format!("{mapped_name} {{ inner: {wrap} }}")
376 }
377 }
378 TypeRef::Named(_name) => {
380 if returns_ref {
381 format!("{expr}.clone().into()")
382 } else {
383 format!("{expr}.into()")
384 }
385 }
386 TypeRef::String | TypeRef::Bytes => {
388 if returns_ref {
389 format!("{expr}.into()")
390 } else {
391 expr.to_string()
392 }
393 }
394 TypeRef::Path => format!("{expr}.to_string_lossy().to_string()"),
396 TypeRef::Json => format!("{expr}.to_string()"),
398 TypeRef::Optional(inner) => match inner.as_ref() {
400 TypeRef::Named(name) if opaque_types.contains(name.as_str()) => {
401 let mapped_name = mapper.named(name);
402 if returns_ref {
403 let wrap = arc_wrap_expr("v.clone()", name, mutex_types);
404 format!("{expr}.map(|v| {mapped_name} {{ inner: {wrap} }})")
405 } else {
406 let wrap = arc_wrap_expr("v", name, mutex_types);
407 format!("{expr}.map(|v| {mapped_name} {{ inner: {wrap} }})")
408 }
409 }
410 TypeRef::Named(_) => {
411 if returns_ref {
412 format!("{expr}.map(|v| v.clone().into())")
413 } else {
414 format!("{expr}.map(Into::into)")
415 }
416 }
417 TypeRef::Path => {
418 format!("{expr}.map(|v| v.to_string_lossy().to_string())")
419 }
420 TypeRef::String | TypeRef::Bytes => {
421 if returns_ref {
422 format!("{expr}.map(Into::into)")
423 } else {
424 expr.to_string()
425 }
426 }
427 TypeRef::Vec(vi) => match vi.as_ref() {
428 TypeRef::Named(name) if opaque_types.contains(name.as_str()) => {
429 let mapped_name = mapper.named(name);
430 let wrap = arc_wrap_expr("x", name, mutex_types);
431 format!(
432 "{expr}.map(|v| v.into_iter().map(|x| {mapped_name} {{ inner: {wrap} }}).collect())"
433 )
434 }
435 TypeRef::Named(_) => {
436 format!("{expr}.map(|v| v.into_iter().map(Into::into).collect())")
437 }
438 _ => expr.to_string(),
439 },
440 _ => expr.to_string(),
441 },
442 TypeRef::Vec(inner) => match inner.as_ref() {
444 TypeRef::Named(name) if opaque_types.contains(name.as_str()) => {
445 let mapped_name = mapper.named(name);
446 if returns_ref {
447 let wrap = arc_wrap_expr("v.clone()", name, mutex_types);
448 format!("{expr}.into_iter().map(|v| {mapped_name} {{ inner: {wrap} }}).collect()")
449 } else {
450 let wrap = arc_wrap_expr("v", name, mutex_types);
451 format!("{expr}.into_iter().map(|v| {mapped_name} {{ inner: {wrap} }}).collect()")
452 }
453 }
454 TypeRef::Named(_) => {
455 if returns_ref {
456 format!("{expr}.into_iter().map(|v| v.clone().into()).collect()")
457 } else {
458 format!("{expr}.into_iter().map(Into::into).collect()")
459 }
460 }
461 TypeRef::Path => {
462 format!("{expr}.into_iter().map(|v| v.to_string_lossy().to_string()).collect()")
463 }
464 TypeRef::String => {
465 if returns_ref {
466 format!("{expr}.iter().map(|s| s.to_string()).collect()")
469 } else {
470 expr.to_string()
471 }
472 }
473 TypeRef::Bytes => {
474 if returns_ref {
475 format!("{expr}.iter().map(|b| b.to_vec()).collect()")
476 } else {
477 expr.to_string()
478 }
479 }
480 _ => expr.to_string(),
481 },
482 _ => expr.to_string(),
483 }
484 };
485
486 if func.error_type.is_some() {
487 let err_conv = match cfg.async_pattern {
489 AsyncPattern::Pyo3FutureIntoPy => {
490 ".map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))"
491 }
492 AsyncPattern::NapiNativeAsync => {
493 ".map_err(|e| napi::Error::new(napi::Status::GenericFailure, e.to_string()))"
494 }
495 AsyncPattern::WasmNativeAsync => ".map_err(|e| JsValue::from_str(&e.to_string()))",
496 AsyncPattern::TokioBlockOn => {
497 ".map_err(|e| extendr_api::Error::Other(e.to_string().replace(\":\", \"_\").replace(\"/\", \"_\").replace(\"-\", \"_\").chars().take(255).collect::<String>()))"
498 }
499 _ => ".map_err(|e| e.to_string())",
500 };
501 let wrapped = wrap_return("val");
502 if wrapped == "val" {
503 format!("{core_call}{err_conv}")
504 } else if wrapped == "val.into()" {
505 format!("{core_call}.map(Into::into){err_conv}")
506 } else if let Some(type_path) = wrapped.strip_suffix("::from(val)") {
507 format!("{core_call}.map({type_path}::from){err_conv}")
508 } else {
509 format!("{core_call}.map(|val| {wrapped}){err_conv}")
510 }
511 } else {
512 wrap_return(&core_call)
513 }
514 };
515
516 let body = if !let_bindings.is_empty() && !func.is_async {
520 if can_delegate {
521 format!("{let_bindings}{body}")
522 } else {
523 let vec_str_bindings: String = func.params.iter().filter(|p| {
526 p.is_ref && matches!(&p.ty, TypeRef::Vec(inner) if matches!(inner.as_ref(), TypeRef::String | TypeRef::Char))
527 }).map(|p| {
528 if p.optional {
531 format!("let {}_refs: Vec<&str> = {}.as_ref().map(|v| v.iter().map(|s| s.as_str()).collect()).unwrap_or_default();\n ", p.name, p.name)
532 } else {
533 format!("let {}_refs: Vec<&str> = {}.iter().map(|s| s.as_str()).collect();\n ", p.name, p.name)
534 }
535 }).collect();
536 if !vec_str_bindings.is_empty() {
537 format!("{vec_str_bindings}{body}")
538 } else {
539 body
540 }
541 }
542 } else {
543 body
544 };
545
546 let async_kw = if func.is_async && cfg.async_pattern != AsyncPattern::TokioBlockOn {
550 "async "
551 } else {
552 ""
553 };
554 let func_needs_py = func.is_async && cfg.async_pattern == AsyncPattern::Pyo3FutureIntoPy;
555
556 let ret = if func_needs_py {
558 "PyResult<Bound<'py, PyAny>>".to_string()
559 } else {
560 ret
561 };
562 let func_lifetime = if func_needs_py { "<'py>" } else { "" };
563
564 let (func_sig, _params_formatted) = if params.len() > 100 {
565 let mut seen_optional = false;
567 let wrapped_params = func
568 .params
569 .iter()
570 .map(|p| {
571 if p.optional {
572 seen_optional = true;
573 }
574 let ty = if p.optional || seen_optional {
575 format!("Option<{}>", mapper.map_type(&p.ty))
576 } else {
577 mapper.map_type(&p.ty)
578 };
579 format!("{}: {}", p.name, ty)
580 })
581 .collect::<Vec<_>>()
582 .join(",\n ");
583
584 if func_needs_py {
586 (
587 format!(
588 "pub fn {}{func_lifetime}(py: Python<'py>,\n {}\n) -> {ret}",
589 func.name,
590 wrapped_params,
591 ret = ret
592 ),
593 "",
594 )
595 } else {
596 (
597 format!(
598 "pub {async_kw}fn {}(\n {}\n) -> {ret}",
599 func.name,
600 wrapped_params,
601 ret = ret
602 ),
603 "",
604 )
605 }
606 } else if func_needs_py {
607 (
608 format!(
609 "pub fn {}{func_lifetime}(py: Python<'py>, {params}) -> {ret}",
610 func.name
611 ),
612 "",
613 )
614 } else {
615 (format!("pub {async_kw}fn {}({params}) -> {ret}", func.name), "")
616 };
617
618 let total_params = func.params.len() + if func_needs_py { 1 } else { 0 };
619 let sig_defaults = if cfg.needs_signature {
620 function_sig_defaults(&func.params)
621 } else {
622 String::new()
623 };
624 let attr_inner = cfg
625 .function_attr
626 .trim_start_matches('#')
627 .trim_start_matches('[')
628 .trim_end_matches(']');
629
630 crate::template_env::render(
631 "generators/functions/function_definition.jinja",
632 minijinja::context! {
633 has_too_many_arguments => total_params > 7,
634 has_missing_errors_doc => func.error_type.is_some(),
635 attr_inner => attr_inner,
636 needs_signature => cfg.needs_signature,
637 signature_prefix => cfg.signature_prefix,
638 sig_defaults => sig_defaults,
639 signature_suffix => cfg.signature_suffix,
640 func_sig => func_sig,
641 body => body,
642 },
643 )
644}
645
646fn can_delegate_with_named_let_bindings(func: &FunctionDef, opaque_types: &AHashSet<String>) -> bool {
647 !func.sanitized
648 && func
649 .params
650 .iter()
651 .all(|p| !p.sanitized && crate::shared::is_delegatable_param(&p.ty, opaque_types))
652 && crate::shared::is_delegatable_return(&func.return_type)
653}
654
655pub fn collect_trait_imports(api: &ApiSurface) -> Vec<String> {
662 let mut traits: AHashSet<String> = AHashSet::new();
667 for typ in api.types.iter().filter(|typ| !typ.is_trait) {
668 for method in &typ.methods {
669 if let Some(ref trait_path) = method.trait_source {
670 traits.insert(trait_path.clone());
671 }
672 }
673 }
674
675 let mut by_name: AHashMap<String, String> = AHashMap::new();
677 for path in traits {
678 let name = path.split("::").last().unwrap_or(&path).to_string();
679 let entry = by_name.entry(name).or_insert_with(|| path.clone());
680 if path.len() < entry.len() {
682 *entry = path;
683 }
684 }
685
686 let mut sorted: Vec<String> = by_name.into_values().collect();
687 sorted.sort();
688 sorted
689}
690
691pub fn has_unresolved_trait_methods(api: &ApiSurface) -> bool {
697 let mut method_counts: AHashMap<&str, (usize, usize)> = AHashMap::new(); for typ in api.types.iter().filter(|typ| !typ.is_trait) {
702 if typ.is_trait {
703 continue;
704 }
705 for method in &typ.methods {
706 let entry = method_counts.entry(&method.name).or_insert((0, 0));
707 entry.0 += 1;
708 if method.trait_source.is_some() {
709 entry.1 += 1;
710 }
711 }
712 }
713 method_counts
715 .values()
716 .any(|&(total, with_source)| total >= 3 && with_source == 0)
717}
718
719pub fn collect_explicit_core_imports(api: &ApiSurface) -> Vec<String> {
730 let mut names = std::collections::BTreeSet::new();
731 for typ in api.types.iter().filter(|typ| !typ.is_trait) {
732 names.insert(typ.name.clone());
733 }
734 for e in &api.enums {
735 names.insert(e.name.clone());
736 }
737 names.into_iter().collect()
738}