Skip to main content

alef_codegen/generators/
functions.rs

1use crate::generators::binding_helpers::{
2    gen_async_body, gen_call_args, gen_call_args_with_let_bindings, gen_named_let_bindings, gen_serde_let_bindings,
3    gen_unimplemented_body, has_named_params,
4};
5use crate::generators::{AdapterBodies, AsyncPattern, RustBindingConfig};
6use crate::shared::{function_params, function_sig_defaults};
7use crate::type_mapper::TypeMapper;
8use ahash::{AHashMap, AHashSet};
9use alef_core::ir::{ApiSurface, FunctionDef, TypeRef};
10use std::fmt::Write;
11
12/// Generate a free function.
13pub fn gen_function(
14    func: &FunctionDef,
15    mapper: &dyn TypeMapper,
16    cfg: &RustBindingConfig,
17    adapter_bodies: &AdapterBodies,
18    opaque_types: &AHashSet<String>,
19) -> String {
20    let map_fn = |ty: &alef_core::ir::TypeRef| mapper.map_type(ty);
21    let params = function_params(&func.params, &map_fn);
22    let return_type = mapper.map_type(&func.return_type);
23    let ret = mapper.wrap_return(&return_type, func.error_type.is_some());
24
25    // Use let-binding pattern for non-opaque Named params so core fns can take &CoreType
26    let use_let_bindings = has_named_params(&func.params, opaque_types);
27    let call_args = if use_let_bindings {
28        gen_call_args_with_let_bindings(&func.params, opaque_types)
29    } else {
30        gen_call_args(&func.params, opaque_types)
31    };
32    let core_import = cfg.core_import;
33    let let_bindings = if use_let_bindings {
34        gen_named_let_bindings(&func.params, opaque_types, core_import)
35    } else {
36        String::new()
37    };
38
39    // Use the function's rust_path for correct module path resolution
40    let core_fn_path = {
41        let path = func.rust_path.replace('-', "_");
42        if path.starts_with(core_import) {
43            path
44        } else {
45            format!("{core_import}::{}", func.name)
46        }
47    };
48
49    let can_delegate = crate::shared::can_auto_delegate_function(func, opaque_types);
50
51    // Backend-specific error conversion string for serde bindings
52    let serde_err_conv = match cfg.async_pattern {
53        AsyncPattern::Pyo3FutureIntoPy => ".map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))",
54        AsyncPattern::NapiNativeAsync => ".map_err(|e| napi::Error::new(napi::Status::GenericFailure, e.to_string()))",
55        AsyncPattern::WasmNativeAsync => ".map_err(|e| JsValue::from_str(&e.to_string()))",
56        _ => ".map_err(|e| e.to_string())",
57    };
58
59    // Generate the body based on async pattern
60    let body = if !can_delegate {
61        // Check if an adapter provides the body
62        if let Some(adapter_body) = adapter_bodies.get(&func.name) {
63            adapter_body.clone()
64        } else if cfg.has_serde && use_let_bindings && func.error_type.is_some() {
65            // MARKER_SERDE_PATH
66            // Serde-based param conversion: serialize binding types to JSON, deserialize to core types.
67            // This handles Named params (e.g., ProcessConfig) that lack binding→core From impls.
68            // For async functions with Pyo3FutureIntoPy, serde bindings use indented format.
69            let is_async_pyo3 = func.is_async && cfg.async_pattern == AsyncPattern::Pyo3FutureIntoPy;
70            let (serde_indent, serde_err_async) = if is_async_pyo3 {
71                (
72                    "        ",
73                    ".map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))",
74                )
75            } else {
76                ("    ", serde_err_conv)
77            };
78            let serde_bindings =
79                gen_serde_let_bindings(&func.params, opaque_types, core_import, serde_err_async, serde_indent);
80            let core_call = format!("{core_fn_path}({call_args})");
81
82            // Determine return wrapping strategy for serde async (uses explicit types to avoid E0283)
83            let returns_ref = func.returns_ref;
84            let wrap_return = |expr: &str| -> String {
85                match &func.return_type {
86                    TypeRef::Vec(inner) => {
87                        // Vec<T>: check if elements need conversion
88                        match inner.as_ref() {
89                            TypeRef::Named(_) => {
90                                // Vec<Named>: convert each element using Into::into
91                                format!("{expr}.into_iter().map(Into::into).collect()")
92                            }
93                            _ => expr.to_string(),
94                        }
95                    }
96                    TypeRef::Named(name) if opaque_types.contains(name.as_str()) => {
97                        if returns_ref {
98                            format!("{name} {{ inner: Arc::new({expr}.clone()) }}")
99                        } else {
100                            format!("{name} {{ inner: Arc::new({expr}) }}")
101                        }
102                    }
103                    TypeRef::Named(_) => {
104                        // Use explicit type with ::from() to avoid E0283 type inference issues in async context
105                        if returns_ref {
106                            format!("{return_type}::from({expr}.clone())")
107                        } else {
108                            format!("{return_type}::from({expr})")
109                        }
110                    }
111                    // String/Bytes are identity across all backends (String->String,
112                    // Vec<u8>->Vec<u8>) — no .into() needed for owned values.
113                    TypeRef::String | TypeRef::Bytes => expr.to_string(),
114                    TypeRef::Path => format!("{expr}.to_string_lossy().to_string()"),
115                    TypeRef::Json => format!("{expr}.to_string()"),
116                    _ => expr.to_string(),
117                }
118            };
119
120            if is_async_pyo3 {
121                // Async serde path: wrap everything in future_into_py
122                let is_unit = matches!(func.return_type, TypeRef::Unit);
123                let wrapped = wrap_return("result");
124                let core_await = format!(
125                    "{core_call}.await\n            .map_err(|e| PyErr::new::<PyRuntimeError, _>(e.to_string()))?"
126                );
127                let inner_body = if is_unit {
128                    format!("{serde_bindings}{core_await};\n            Ok(())")
129                } else {
130                    // When wrapped contains type conversions like .into() or ::from(),
131                    // bind to a variable to help type inference for the generic future_into_py.
132                    // This avoids E0283 "type annotations needed".
133                    if wrapped.contains(".into()") || wrapped.contains("::from(") || wrapped.contains("Into::into") {
134                        // Add explicit type annotation to help type inference
135                        format!(
136                            "{serde_bindings}let result = {core_await};\n            let wrapped_result: {return_type} = {wrapped};\n            Ok(wrapped_result)"
137                        )
138                    } else {
139                        format!("{serde_bindings}let result = {core_await};\n            Ok({wrapped})")
140                    }
141                };
142                format!("pyo3_async_runtimes::tokio::future_into_py(py, async move {{\n{inner_body}\n        }})")
143            } else if func.is_async {
144                // Async serde path for other backends (NAPI, etc.): use gen_async_body
145                let is_unit = matches!(func.return_type, TypeRef::Unit);
146                let wrapped = wrap_return("result");
147                let async_body = gen_async_body(
148                    &core_call,
149                    cfg,
150                    func.error_type.is_some(),
151                    &wrapped,
152                    false,
153                    "",
154                    is_unit,
155                    Some(&return_type),
156                );
157                format!("{serde_bindings}{async_body}")
158            } else if matches!(func.return_type, TypeRef::Unit) {
159                // Unit return with error: avoid let_unit_value
160                let await_kw = if func.is_async { ".await" } else { "" };
161                let debug_marker = if func.is_async { "/*ASYNC_UNIT*/ " } else { "" };
162                format!("{serde_bindings}{debug_marker}{core_call}{await_kw}{serde_err_conv}?;\n    Ok(())")
163            } else {
164                let wrapped = wrap_return("val");
165                let await_kw = if func.is_async { ".await" } else { "" };
166                if wrapped == "val" {
167                    format!("{serde_bindings}{core_call}{await_kw}{serde_err_conv}")
168                } else if wrapped == "val.into()" {
169                    format!("{serde_bindings}{core_call}{await_kw}.map(Into::into){serde_err_conv}")
170                } else if let Some(type_path) = wrapped.strip_suffix("::from(val)") {
171                    format!("{serde_bindings}{core_call}{await_kw}.map({type_path}::from){serde_err_conv}")
172                } else {
173                    format!("{serde_bindings}{core_call}{await_kw}.map(|val| {wrapped}){serde_err_conv}")
174                }
175            }
176        } else if func.is_async && cfg.async_pattern == AsyncPattern::Pyo3FutureIntoPy {
177            // Async function that can't be auto-delegated — wrap unimplemented body in future_into_py
178            let suppress = if func.params.is_empty() {
179                String::new()
180            } else {
181                let names: Vec<&str> = func.params.iter().map(|p| p.name.as_str()).collect();
182                format!("let _ = ({});\n        ", names.join(", "))
183            };
184            format!(
185                "{suppress}Err(pyo3::exceptions::PyNotImplementedError::new_err(\"not implemented: {}\"))",
186                func.name
187            )
188        } else {
189            // Function can't be auto-delegated — return a default/error based on return type
190            gen_unimplemented_body(
191                &func.return_type,
192                &func.name,
193                func.error_type.is_some(),
194                cfg,
195                &func.params,
196                opaque_types,
197            )
198        }
199    } else if func.is_async {
200        // MARKER_DELEGATE_ASYNC
201        let core_call = format!("{core_fn_path}({call_args})");
202        // In async contexts (future_into_py, etc.), the compiler often can't infer the
203        // target type for .into(). Use explicit From::from() / collect::<Vec<T>>() instead.
204        let return_wrap = match &func.return_type {
205            TypeRef::Named(n) if opaque_types.contains(n.as_str()) => {
206                format!("{n} {{ inner: Arc::new(result) }}")
207            }
208            TypeRef::Named(_) => {
209                format!("{return_type}::from(result)")
210            }
211            TypeRef::Vec(inner) => match inner.as_ref() {
212                TypeRef::Named(n) if opaque_types.contains(n.as_str()) => {
213                    format!("result.into_iter().map(|v| {n} {{ inner: Arc::new(v) }}).collect::<Vec<_>>()")
214                }
215                TypeRef::Named(_) => {
216                    let inner_mapped = mapper.map_type(inner);
217                    format!("result.into_iter().map({inner_mapped}::from).collect::<Vec<_>>()")
218                }
219                _ => "result".to_string(),
220            },
221            TypeRef::Unit => "result".to_string(),
222            _ => super::binding_helpers::wrap_return(
223                "result",
224                &func.return_type,
225                "",
226                opaque_types,
227                false,
228                func.returns_ref,
229                false,
230            ),
231        };
232        let async_body = gen_async_body(
233            &core_call,
234            cfg,
235            func.error_type.is_some(),
236            &return_wrap,
237            false,
238            "",
239            matches!(func.return_type, TypeRef::Unit),
240            Some(&return_type),
241        );
242        format!("{let_bindings}{async_body}")
243    } else {
244        let core_call = format!("{core_fn_path}({call_args})");
245
246        // Determine return wrapping strategy
247        let returns_ref = func.returns_ref;
248        let wrap_return = |expr: &str| -> String {
249            match &func.return_type {
250                // Opaque type return: wrap in Arc
251                TypeRef::Named(name) if opaque_types.contains(name.as_str()) => {
252                    if returns_ref {
253                        format!("{name} {{ inner: Arc::new({expr}.clone()) }}")
254                    } else {
255                        format!("{name} {{ inner: Arc::new({expr}) }}")
256                    }
257                }
258                // Non-opaque Named: use .into() if From impl exists
259                TypeRef::Named(_name) => {
260                    if returns_ref {
261                        format!("{expr}.clone().into()")
262                    } else {
263                        format!("{expr}.into()")
264                    }
265                }
266                // String/Bytes: .into() handles &str→String, skip for owned
267                TypeRef::String | TypeRef::Bytes => {
268                    if returns_ref {
269                        format!("{expr}.into()")
270                    } else {
271                        expr.to_string()
272                    }
273                }
274                // Path: PathBuf→String needs to_string_lossy
275                TypeRef::Path => format!("{expr}.to_string_lossy().to_string()"),
276                // Json: serde_json::Value to string
277                TypeRef::Json => format!("{expr}.to_string()"),
278                // Optional with opaque inner
279                TypeRef::Optional(inner) => match inner.as_ref() {
280                    TypeRef::Named(name) if opaque_types.contains(name.as_str()) => {
281                        if returns_ref {
282                            format!("{expr}.map(|v| {name} {{ inner: Arc::new(v.clone()) }})")
283                        } else {
284                            format!("{expr}.map(|v| {name} {{ inner: Arc::new(v) }})")
285                        }
286                    }
287                    TypeRef::Named(_) => {
288                        if returns_ref {
289                            format!("{expr}.map(|v| v.clone().into())")
290                        } else {
291                            format!("{expr}.map(Into::into)")
292                        }
293                    }
294                    TypeRef::Path => {
295                        format!("{expr}.map(|v| v.to_string_lossy().to_string())")
296                    }
297                    TypeRef::String | TypeRef::Bytes => {
298                        if returns_ref {
299                            format!("{expr}.map(Into::into)")
300                        } else {
301                            expr.to_string()
302                        }
303                    }
304                    TypeRef::Vec(vi) => match vi.as_ref() {
305                        TypeRef::Named(name) if opaque_types.contains(name.as_str()) => {
306                            format!("{expr}.map(|v| v.into_iter().map(|x| {name} {{ inner: Arc::new(x) }}).collect())")
307                        }
308                        TypeRef::Named(_) => {
309                            format!("{expr}.map(|v| v.into_iter().map(Into::into).collect())")
310                        }
311                        _ => expr.to_string(),
312                    },
313                    _ => expr.to_string(),
314                },
315                // Vec<Named>: map each element through Into
316                TypeRef::Vec(inner) => match inner.as_ref() {
317                    TypeRef::Named(name) if opaque_types.contains(name.as_str()) => {
318                        if returns_ref {
319                            format!("{expr}.into_iter().map(|v| {name} {{ inner: Arc::new(v.clone()) }}).collect()")
320                        } else {
321                            format!("{expr}.into_iter().map(|v| {name} {{ inner: Arc::new(v) }}).collect()")
322                        }
323                    }
324                    TypeRef::Named(_) => {
325                        if returns_ref {
326                            format!("{expr}.into_iter().map(|v| v.clone().into()).collect()")
327                        } else {
328                            format!("{expr}.into_iter().map(Into::into).collect()")
329                        }
330                    }
331                    TypeRef::Path => {
332                        format!("{expr}.into_iter().map(|v| v.to_string_lossy().to_string()).collect()")
333                    }
334                    TypeRef::String | TypeRef::Bytes => {
335                        if returns_ref {
336                            format!("{expr}.into_iter().map(Into::into).collect()")
337                        } else {
338                            expr.to_string()
339                        }
340                    }
341                    _ => expr.to_string(),
342                },
343                _ => expr.to_string(),
344            }
345        };
346
347        if func.error_type.is_some() {
348            // Backend-specific error conversion
349            let err_conv = match cfg.async_pattern {
350                AsyncPattern::Pyo3FutureIntoPy => {
351                    ".map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))"
352                }
353                AsyncPattern::NapiNativeAsync => {
354                    ".map_err(|e| napi::Error::new(napi::Status::GenericFailure, e.to_string()))"
355                }
356                AsyncPattern::WasmNativeAsync => ".map_err(|e| JsValue::from_str(&e.to_string()))",
357                _ => ".map_err(|e| e.to_string())",
358            };
359            let wrapped = wrap_return("val");
360            if wrapped == "val" {
361                format!("{core_call}{err_conv}")
362            } else if wrapped == "val.into()" {
363                format!("{core_call}.map(Into::into){err_conv}")
364            } else if let Some(type_path) = wrapped.strip_suffix("::from(val)") {
365                format!("{core_call}.map({type_path}::from){err_conv}")
366            } else {
367                format!("{core_call}.map(|val| {wrapped}){err_conv}")
368            }
369        } else {
370            wrap_return(&core_call)
371        }
372    };
373
374    // Prepend let bindings for non-opaque Named params (sync delegate case).
375    // Only prepend when can_delegate is true — the !can_delegate serde path does its own bindings.
376    // However, always prepend Vec<String> ref bindings (names_refs) since serde path doesn't handle them.
377    let body = if !let_bindings.is_empty() && !func.is_async {
378        if can_delegate {
379            format!("{let_bindings}{body}")
380        } else {
381            // For the !can_delegate path, only prepend Vec<String>+is_ref bindings (names_refs)
382            // since serde bindings handle Named type conversions.
383            let vec_str_bindings: String = func.params.iter().filter(|p| {
384                p.is_ref && matches!(&p.ty, TypeRef::Vec(inner) if matches!(inner.as_ref(), TypeRef::String | TypeRef::Char))
385            }).map(|p| {
386                // Handle both Vec<String> and Option<Vec<String>> parameters.
387                // When p.optional=true, p.ty is the inner type (Vec<String>), so we need to unwrap first.
388                if p.optional {
389                    format!("let {}_refs: Vec<&str> = {}.as_ref().map(|v| v.iter().map(|s| s.as_str()).collect()).unwrap_or_default();\n    ", p.name, p.name)
390                } else {
391                    format!("let {}_refs: Vec<&str> = {}.iter().map(|s| s.as_str()).collect();\n    ", p.name, p.name)
392                }
393            }).collect();
394            if !vec_str_bindings.is_empty() {
395                format!("{vec_str_bindings}{body}")
396            } else {
397                body
398            }
399        }
400    } else {
401        body
402    };
403
404    // Wrap long signature if necessary
405    let async_kw = if func.is_async { "async " } else { "" };
406    let func_needs_py = func.is_async && cfg.async_pattern == AsyncPattern::Pyo3FutureIntoPy;
407
408    // For async PyO3 free functions, override return type and add lifetime generic.
409    let ret = if func_needs_py {
410        "PyResult<Bound<'py, PyAny>>".to_string()
411    } else {
412        ret
413    };
414    let func_lifetime = if func_needs_py { "<'py>" } else { "" };
415
416    let (func_sig, _params_formatted) = if params.len() > 100 {
417        // When formatting for long signatures, promote optional params like function_params() does
418        let mut seen_optional = false;
419        let wrapped_params = func
420            .params
421            .iter()
422            .map(|p| {
423                if p.optional {
424                    seen_optional = true;
425                }
426                let ty = if p.optional || seen_optional {
427                    format!("Option<{}>", mapper.map_type(&p.ty))
428                } else {
429                    mapper.map_type(&p.ty)
430                };
431                format!("{}: {}", p.name, ty)
432            })
433            .collect::<Vec<_>>()
434            .join(",\n    ");
435
436        // For async PyO3, we need special signature handling
437        if func_needs_py {
438            (
439                format!(
440                    "pub fn {}{func_lifetime}(py: Python<'py>,\n    {}\n) -> {ret}",
441                    func.name,
442                    wrapped_params,
443                    ret = ret
444                ),
445                "",
446            )
447        } else {
448            (
449                format!(
450                    "pub {async_kw}fn {}(\n    {}\n) -> {ret}",
451                    func.name,
452                    wrapped_params,
453                    ret = ret
454                ),
455                "",
456            )
457        }
458    } else if func_needs_py {
459        (
460            format!(
461                "pub fn {}{func_lifetime}(py: Python<'py>, {params}) -> {ret}",
462                func.name
463            ),
464            "",
465        )
466    } else {
467        (format!("pub {async_kw}fn {}({params}) -> {ret}", func.name), "")
468    };
469
470    let mut out = String::with_capacity(1024);
471    // Per-item clippy suppression: too_many_arguments when >7 params (including py)
472    let total_params = func.params.len() + if func_needs_py { 1 } else { 0 };
473    if total_params > 7 {
474        writeln!(out, "#[allow(clippy::too_many_arguments)]").ok();
475    }
476    // Per-item clippy suppression: missing_errors_doc for Result-returning functions
477    if func.error_type.is_some() {
478        writeln!(out, "#[allow(clippy::missing_errors_doc)]").ok();
479    }
480    let attr_inner = cfg
481        .function_attr
482        .trim_start_matches('#')
483        .trim_start_matches('[')
484        .trim_end_matches(']');
485    writeln!(out, "#[{attr_inner}]").ok();
486    if cfg.needs_signature {
487        let sig = function_sig_defaults(&func.params);
488        writeln!(out, "{}{}{}", cfg.signature_prefix, sig, cfg.signature_suffix).ok();
489    }
490    write!(out, "{} {{\n    {body}\n}}", func_sig,).ok();
491    out
492}
493
494/// Collect all unique trait import paths from types' methods.
495///
496/// Returns a deduplicated, sorted list of trait paths (e.g. `["liter_llm::LlmClient"]`)
497/// that need to be imported in generated binding code so that trait methods can be called.
498/// Both opaque and non-opaque types are scanned because non-opaque wrapper types also
499/// delegate trait method calls to their inner core type.
500pub fn collect_trait_imports(api: &ApiSurface) -> Vec<String> {
501    // Collect all trait paths, then deduplicate by last segment (trait name).
502    // When two paths resolve to the same trait name (e.g. `spikard_core::Dependency`
503    // and `spikard_core::di::Dependency`), only one import is needed. Keep the
504    // shorter (public re-export) path to avoid E0252 duplicate-import errors.
505    let mut traits: AHashSet<String> = AHashSet::new();
506    for typ in api.types.iter().filter(|typ| !typ.is_trait) {
507        for method in &typ.methods {
508            if let Some(ref trait_path) = method.trait_source {
509                traits.insert(trait_path.clone());
510            }
511        }
512    }
513
514    // Deduplicate by last path segment: keep the shortest path for each trait name.
515    let mut by_name: AHashMap<String, String> = AHashMap::new();
516    for path in traits {
517        let name = path.split("::").last().unwrap_or(&path).to_string();
518        let entry = by_name.entry(name).or_insert_with(|| path.clone());
519        // Prefer shorter paths (public re-exports are shorter than internal paths)
520        if path.len() < entry.len() {
521            *entry = path;
522        }
523    }
524
525    let mut sorted: Vec<String> = by_name.into_values().collect();
526    sorted.sort();
527    sorted
528}
529
530/// Check if any type has methods from trait impls whose trait_source could not be resolved.
531///
532/// When true, the binding crate should add a glob import of the core crate (e.g.
533/// `use kreuzberg::*`) to bring all publicly exported traits into scope.
534/// This handles traits defined in private submodules that are re-exported.
535pub fn has_unresolved_trait_methods(api: &ApiSurface) -> bool {
536    // Count method names that appear on multiple non-trait types but lack trait_source.
537    // Such methods likely come from trait impls whose trait path could not be resolved
538    // (e.g. traits defined in private modules but re-exported via `pub use`).
539    let mut method_counts: AHashMap<&str, (usize, usize)> = AHashMap::new(); // (total, with_source)
540    for typ in api.types.iter().filter(|typ| !typ.is_trait) {
541        if typ.is_trait {
542            continue;
543        }
544        for method in &typ.methods {
545            let entry = method_counts.entry(&method.name).or_insert((0, 0));
546            entry.0 += 1;
547            if method.trait_source.is_some() {
548                entry.1 += 1;
549            }
550        }
551    }
552    // A method appearing on 3+ types without trait_source on any is almost certainly a trait method
553    method_counts
554        .values()
555        .any(|&(total, with_source)| total >= 3 && with_source == 0)
556}
557
558/// Collect explicit type and enum names from the API surface for named imports.
559///
560/// Returns a sorted, deduplicated list of type and enum names that should be
561/// imported from the core crate. This replaces glob imports (`use core::*`)
562/// which can cause name conflicts with local binding definitions (e.g. a
563/// `convert` function or `Result` type alias from the core crate shadowing
564/// the binding's own `convert` wrapper or `std::result::Result`).
565///
566/// Only struct/enum names are included — functions and type aliases are
567/// intentionally excluded because they are the source of conflicts.
568pub fn collect_explicit_core_imports(api: &ApiSurface) -> Vec<String> {
569    let mut names = std::collections::BTreeSet::new();
570    for typ in api.types.iter().filter(|typ| !typ.is_trait) {
571        names.insert(typ.name.clone());
572    }
573    for e in &api.enums {
574        names.insert(e.name.clone());
575    }
576    names.into_iter().collect()
577}