Skip to main content

alef_codegen/generators/
functions.rs

1use crate::generators::binding_helpers::{
2    gen_async_body, gen_call_args, gen_call_args_with_let_bindings, gen_named_let_bindings, gen_serde_let_bindings,
3    gen_unimplemented_body, has_named_params,
4};
5use crate::generators::{AdapterBodies, AsyncPattern, RustBindingConfig};
6use crate::shared::{function_params, function_sig_defaults};
7use crate::type_mapper::TypeMapper;
8use ahash::{AHashMap, AHashSet};
9use alef_core::ir::{ApiSurface, FunctionDef, TypeRef};
10use std::fmt::Write;
11
12/// Generate a free function.
13pub fn gen_function(
14    func: &FunctionDef,
15    mapper: &dyn TypeMapper,
16    cfg: &RustBindingConfig,
17    adapter_bodies: &AdapterBodies,
18    opaque_types: &AHashSet<String>,
19) -> String {
20    let map_fn = |ty: &alef_core::ir::TypeRef| mapper.map_type(ty);
21    let params = function_params(&func.params, &map_fn);
22    let return_type = mapper.map_type(&func.return_type);
23    let ret = mapper.wrap_return(&return_type, func.error_type.is_some());
24
25    // Use let-binding pattern for non-opaque Named params so core fns can take &CoreType
26    let use_let_bindings = has_named_params(&func.params, opaque_types);
27    let call_args = if use_let_bindings {
28        gen_call_args_with_let_bindings(&func.params, opaque_types)
29    } else {
30        gen_call_args(&func.params, opaque_types)
31    };
32    let core_import = cfg.core_import;
33    let let_bindings = if use_let_bindings {
34        gen_named_let_bindings(&func.params, opaque_types, core_import)
35    } else {
36        String::new()
37    };
38
39    // Use the function's rust_path for correct module path resolution
40    let core_fn_path = {
41        let path = func.rust_path.replace('-', "_");
42        if path.starts_with(core_import) {
43            path
44        } else {
45            format!("{core_import}::{}", func.name)
46        }
47    };
48
49    let can_delegate = crate::shared::can_auto_delegate_function(func, opaque_types);
50
51    // Backend-specific error conversion string for serde bindings
52    let serde_err_conv = match cfg.async_pattern {
53        AsyncPattern::Pyo3FutureIntoPy => ".map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))",
54        AsyncPattern::NapiNativeAsync => ".map_err(|e| napi::Error::new(napi::Status::GenericFailure, e.to_string()))",
55        AsyncPattern::WasmNativeAsync => ".map_err(|e| JsValue::from_str(&e.to_string()))",
56        _ => ".map_err(|e| e.to_string())",
57    };
58
59    // Generate the body based on async pattern
60    let body = if !can_delegate {
61        // Check if an adapter provides the body
62        if let Some(adapter_body) = adapter_bodies.get(&func.name) {
63            adapter_body.clone()
64        } else if cfg.has_serde && use_let_bindings && func.error_type.is_some() {
65            // MARKER_SERDE_PATH
66            // Serde-based param conversion: serialize binding types to JSON, deserialize to core types.
67            // This handles Named params (e.g., ProcessConfig) that lack binding→core From impls.
68            // For async functions with Pyo3FutureIntoPy, serde bindings use indented format.
69            let is_async_pyo3 = func.is_async && cfg.async_pattern == AsyncPattern::Pyo3FutureIntoPy;
70            let (serde_indent, serde_err_async) = if is_async_pyo3 {
71                (
72                    "        ",
73                    ".map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))",
74                )
75            } else {
76                ("    ", serde_err_conv)
77            };
78            let serde_bindings =
79                gen_serde_let_bindings(&func.params, opaque_types, core_import, serde_err_async, serde_indent);
80            let core_call = format!("{core_fn_path}({call_args})");
81
82            // Determine return wrapping strategy for serde async (uses explicit types to avoid E0283)
83            let returns_ref = func.returns_ref;
84            let wrap_return = |expr: &str| -> String {
85                match &func.return_type {
86                    TypeRef::Named(name) if opaque_types.contains(name.as_str()) => {
87                        if returns_ref {
88                            format!("{name} {{ inner: Arc::new({expr}.clone()) }}")
89                        } else {
90                            format!("{name} {{ inner: Arc::new({expr}) }}")
91                        }
92                    }
93                    TypeRef::Named(_) => {
94                        // Use explicit type with ::from() to avoid E0283 type inference issues in async context
95                        if returns_ref {
96                            format!("{return_type}::from({expr}.clone())")
97                        } else {
98                            format!("{return_type}::from({expr})")
99                        }
100                    }
101                    TypeRef::String | TypeRef::Bytes => format!("{expr}.into()"),
102                    TypeRef::Path => format!("{expr}.to_string_lossy().to_string()"),
103                    TypeRef::Json => format!("{expr}.to_string()"),
104                    _ => expr.to_string(),
105                }
106            };
107
108            if is_async_pyo3 {
109                // Async serde path: wrap everything in future_into_py
110                let is_unit = matches!(func.return_type, TypeRef::Unit);
111                let wrapped = wrap_return("result");
112                let core_await = format!(
113                    "{core_call}.await\n            .map_err(|e| PyErr::new::<PyRuntimeError, _>(e.to_string()))?"
114                );
115                let inner_body = if is_unit {
116                    format!("{serde_bindings}{core_await};\n            Ok(())")
117                } else {
118                    // When wrapped contains type conversions like .into() or ::from(),
119                    // bind to a variable to help type inference for the generic future_into_py.
120                    // This avoids E0283 "type annotations needed".
121                    if wrapped.contains(".into()") || wrapped.contains("::from(") {
122                        // Add explicit type annotation to help type inference
123                        format!(
124                            "{serde_bindings}let result = {core_await};\n            let wrapped_result: {return_type} = {wrapped};\n            Ok(wrapped_result)"
125                        )
126                    } else {
127                        format!("{serde_bindings}let result = {core_await};\n            Ok({wrapped})")
128                    }
129                };
130                format!("pyo3_async_runtimes::tokio::future_into_py(py, async move {{\n{inner_body}\n        }})")
131            } else if func.is_async {
132                // Async serde path for other backends (NAPI, etc.): use gen_async_body
133                let is_unit = matches!(func.return_type, TypeRef::Unit);
134                let wrapped = wrap_return("result");
135                let async_body = gen_async_body(
136                    &core_call,
137                    cfg,
138                    func.error_type.is_some(),
139                    &wrapped,
140                    false,
141                    "",
142                    is_unit,
143                    Some(&return_type),
144                );
145                format!("{serde_bindings}{async_body}")
146            } else if matches!(func.return_type, TypeRef::Unit) {
147                // Unit return with error: avoid let_unit_value
148                let await_kw = if func.is_async { ".await" } else { "" };
149                let debug_marker = if func.is_async { "/*ASYNC_UNIT*/ " } else { "" };
150                format!("{serde_bindings}{debug_marker}{core_call}{await_kw}{serde_err_conv}?;\n    Ok(())")
151            } else {
152                let wrapped = wrap_return("val");
153                let await_kw = if func.is_async { ".await" } else { "" };
154                if wrapped == "val" {
155                    format!("{serde_bindings}{core_call}{await_kw}{serde_err_conv}")
156                } else {
157                    format!("{serde_bindings}{core_call}{await_kw}.map(|val| {wrapped}){serde_err_conv}")
158                }
159            }
160        } else if func.is_async && cfg.async_pattern == AsyncPattern::Pyo3FutureIntoPy {
161            // Async function that can't be auto-delegated — wrap unimplemented body in future_into_py
162            let suppress = if func.params.is_empty() {
163                String::new()
164            } else {
165                let names: Vec<&str> = func.params.iter().map(|p| p.name.as_str()).collect();
166                format!("let _ = ({});\n        ", names.join(", "))
167            };
168            format!(
169                "{suppress}Err(pyo3::exceptions::PyNotImplementedError::new_err(\"not implemented: {}\"))",
170                func.name
171            )
172        } else {
173            // Function can't be auto-delegated — return a default/error based on return type
174            gen_unimplemented_body(
175                &func.return_type,
176                &func.name,
177                func.error_type.is_some(),
178                cfg,
179                &func.params,
180            )
181        }
182    } else if func.is_async {
183        // MARKER_DELEGATE_ASYNC
184        let core_call = format!("{core_fn_path}({call_args})");
185        // In async contexts (future_into_py, etc.), the compiler often can't infer the
186        // target type for .into(). Use explicit From::from() / collect::<Vec<T>>() instead.
187        let return_wrap = match &func.return_type {
188            TypeRef::Named(n) if opaque_types.contains(n.as_str()) => {
189                format!("{n} {{ inner: Arc::new(result) }}")
190            }
191            TypeRef::Named(_) => {
192                format!("{return_type}::from(result)")
193            }
194            TypeRef::Vec(inner) => match inner.as_ref() {
195                TypeRef::Named(n) if opaque_types.contains(n.as_str()) => {
196                    format!("result.into_iter().map(|v| {n} {{ inner: Arc::new(v) }}).collect::<Vec<_>>()")
197                }
198                TypeRef::Named(_) => {
199                    let inner_mapped = mapper.map_type(inner);
200                    format!("result.into_iter().map({inner_mapped}::from).collect::<Vec<_>>()")
201                }
202                _ => "result".to_string(),
203            },
204            TypeRef::Unit => "result".to_string(),
205            _ => super::binding_helpers::wrap_return(
206                "result",
207                &func.return_type,
208                "",
209                opaque_types,
210                false,
211                func.returns_ref,
212                false,
213            ),
214        };
215        let async_body = gen_async_body(
216            &core_call,
217            cfg,
218            func.error_type.is_some(),
219            &return_wrap,
220            false,
221            "",
222            matches!(func.return_type, TypeRef::Unit),
223            Some(&return_type),
224        );
225        format!("{let_bindings}{async_body}")
226    } else {
227        let core_call = format!("{core_fn_path}({call_args})");
228
229        // Determine return wrapping strategy
230        let returns_ref = func.returns_ref;
231        let wrap_return = |expr: &str| -> String {
232            match &func.return_type {
233                // Opaque type return: wrap in Arc
234                TypeRef::Named(name) if opaque_types.contains(name.as_str()) => {
235                    if returns_ref {
236                        format!("{name} {{ inner: Arc::new({expr}.clone()) }}")
237                    } else {
238                        format!("{name} {{ inner: Arc::new({expr}) }}")
239                    }
240                }
241                // Non-opaque Named: use .into() if From impl exists
242                TypeRef::Named(_name) => {
243                    if returns_ref {
244                        format!("{expr}.clone().into()")
245                    } else {
246                        format!("{expr}.into()")
247                    }
248                }
249                // String/Bytes: .into() handles &str→String etc.
250                TypeRef::String | TypeRef::Bytes => format!("{expr}.into()"),
251                // Path: PathBuf→String needs to_string_lossy
252                TypeRef::Path => format!("{expr}.to_string_lossy().to_string()"),
253                // Json: serde_json::Value to string
254                TypeRef::Json => format!("{expr}.to_string()"),
255                // Optional with opaque inner
256                TypeRef::Optional(inner) => match inner.as_ref() {
257                    TypeRef::Named(name) if opaque_types.contains(name.as_str()) => {
258                        if returns_ref {
259                            format!("{expr}.map(|v| {name} {{ inner: Arc::new(v.clone()) }})")
260                        } else {
261                            format!("{expr}.map(|v| {name} {{ inner: Arc::new(v) }})")
262                        }
263                    }
264                    TypeRef::Named(_) => {
265                        if returns_ref {
266                            format!("{expr}.map(|v| v.clone().into())")
267                        } else {
268                            format!("{expr}.map(Into::into)")
269                        }
270                    }
271                    TypeRef::String | TypeRef::Bytes | TypeRef::Path => {
272                        format!("{expr}.map(Into::into)")
273                    }
274                    TypeRef::Vec(vi) => match vi.as_ref() {
275                        TypeRef::Named(name) if opaque_types.contains(name.as_str()) => {
276                            format!("{expr}.map(|v| v.into_iter().map(|x| {name} {{ inner: Arc::new(x) }}).collect())")
277                        }
278                        TypeRef::Named(_) => {
279                            format!("{expr}.map(|v| v.into_iter().map(Into::into).collect())")
280                        }
281                        _ => expr.to_string(),
282                    },
283                    _ => expr.to_string(),
284                },
285                // Vec<Named>: map each element through Into
286                TypeRef::Vec(inner) => match inner.as_ref() {
287                    TypeRef::Named(name) if opaque_types.contains(name.as_str()) => {
288                        if returns_ref {
289                            format!("{expr}.into_iter().map(|v| {name} {{ inner: Arc::new(v.clone()) }}).collect()")
290                        } else {
291                            format!("{expr}.into_iter().map(|v| {name} {{ inner: Arc::new(v) }}).collect()")
292                        }
293                    }
294                    TypeRef::Named(_) => {
295                        if returns_ref {
296                            format!("{expr}.into_iter().map(|v| v.clone().into()).collect()")
297                        } else {
298                            format!("{expr}.into_iter().map(Into::into).collect()")
299                        }
300                    }
301                    TypeRef::String | TypeRef::Bytes | TypeRef::Path => {
302                        format!("{expr}.into_iter().map(Into::into).collect()")
303                    }
304                    _ => expr.to_string(),
305                },
306                _ => expr.to_string(),
307            }
308        };
309
310        if func.error_type.is_some() {
311            // Backend-specific error conversion
312            let err_conv = match cfg.async_pattern {
313                AsyncPattern::Pyo3FutureIntoPy => {
314                    ".map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))"
315                }
316                AsyncPattern::NapiNativeAsync => {
317                    ".map_err(|e| napi::Error::new(napi::Status::GenericFailure, e.to_string()))"
318                }
319                AsyncPattern::WasmNativeAsync => ".map_err(|e| JsValue::from_str(&e.to_string()))",
320                _ => ".map_err(|e| e.to_string())",
321            };
322            let wrapped = wrap_return("val");
323            if wrapped == "val" {
324                format!("{core_call}{err_conv}")
325            } else {
326                format!("{core_call}.map(|val| {wrapped}){err_conv}")
327            }
328        } else {
329            wrap_return(&core_call)
330        }
331    };
332
333    // Prepend let bindings for non-opaque Named params (sync non-adapter case)
334    let body = if !let_bindings.is_empty() && can_delegate && !func.is_async {
335        format!("{let_bindings}{body}")
336    } else {
337        body
338    };
339
340    // Wrap long signature if necessary
341    let async_kw = if func.is_async { "async " } else { "" };
342    let func_needs_py = func.is_async && cfg.async_pattern == AsyncPattern::Pyo3FutureIntoPy;
343
344    // For async PyO3 free functions, override return type and add lifetime generic.
345    let ret = if func_needs_py {
346        "PyResult<Bound<'py, PyAny>>".to_string()
347    } else {
348        ret
349    };
350    let func_lifetime = if func_needs_py { "<'py>" } else { "" };
351
352    let (func_sig, _params_formatted) = if params.len() > 100 {
353        // When formatting for long signatures, promote optional params like function_params() does
354        let mut seen_optional = false;
355        let wrapped_params = func
356            .params
357            .iter()
358            .map(|p| {
359                if p.optional {
360                    seen_optional = true;
361                }
362                let ty = if p.optional || seen_optional {
363                    format!("Option<{}>", mapper.map_type(&p.ty))
364                } else {
365                    mapper.map_type(&p.ty)
366                };
367                format!("{}: {}", p.name, ty)
368            })
369            .collect::<Vec<_>>()
370            .join(",\n    ");
371
372        // For async PyO3, we need special signature handling
373        if func_needs_py {
374            (
375                format!(
376                    "pub fn {}{func_lifetime}(py: Python<'py>,\n    {}\n) -> {ret}",
377                    func.name,
378                    wrapped_params,
379                    ret = ret
380                ),
381                "",
382            )
383        } else {
384            (
385                format!(
386                    "pub {async_kw}fn {}(\n    {}\n) -> {ret}",
387                    func.name,
388                    wrapped_params,
389                    ret = ret
390                ),
391                "",
392            )
393        }
394    } else if func_needs_py {
395        (
396            format!(
397                "pub fn {}{func_lifetime}(py: Python<'py>, {params}) -> {ret}",
398                func.name
399            ),
400            "",
401        )
402    } else {
403        (format!("pub {async_kw}fn {}({params}) -> {ret}", func.name), "")
404    };
405
406    let mut out = String::with_capacity(1024);
407    // Per-item clippy suppression: too_many_arguments when >7 params (including py)
408    let total_params = func.params.len() + if func_needs_py { 1 } else { 0 };
409    if total_params > 7 {
410        writeln!(out, "#[allow(clippy::too_many_arguments)]").ok();
411    }
412    // Per-item clippy suppression: missing_errors_doc for Result-returning functions
413    if func.error_type.is_some() {
414        writeln!(out, "#[allow(clippy::missing_errors_doc)]").ok();
415    }
416    let attr_inner = cfg
417        .function_attr
418        .trim_start_matches('#')
419        .trim_start_matches('[')
420        .trim_end_matches(']');
421    writeln!(out, "#[{attr_inner}]").ok();
422    if cfg.needs_signature {
423        let sig = function_sig_defaults(&func.params);
424        writeln!(out, "{}{}{}", cfg.signature_prefix, sig, cfg.signature_suffix).ok();
425    }
426    write!(out, "{} {{\n    {body}\n}}", func_sig,).ok();
427    out
428}
429
430/// Collect all unique trait import paths from types' methods.
431///
432/// Returns a deduplicated, sorted list of trait paths (e.g. `["liter_llm::LlmClient"]`)
433/// that need to be imported in generated binding code so that trait methods can be called.
434/// Both opaque and non-opaque types are scanned because non-opaque wrapper types also
435/// delegate trait method calls to their inner core type.
436pub fn collect_trait_imports(api: &ApiSurface) -> Vec<String> {
437    let mut traits: AHashSet<String> = AHashSet::new();
438    for typ in api.types.iter().filter(|typ| !typ.is_trait) {
439        for method in &typ.methods {
440            if let Some(ref trait_path) = method.trait_source {
441                traits.insert(trait_path.clone());
442            }
443        }
444    }
445    let mut sorted: Vec<String> = traits.into_iter().collect();
446    sorted.sort();
447    sorted
448}
449
450/// Check if any type has methods from trait impls whose trait_source could not be resolved.
451///
452/// When true, the binding crate should add a glob import of the core crate (e.g.
453/// `use kreuzberg::*`) to bring all publicly exported traits into scope.
454/// This handles traits defined in private submodules that are re-exported.
455pub fn has_unresolved_trait_methods(api: &ApiSurface) -> bool {
456    // Count method names that appear on multiple non-trait types but lack trait_source.
457    // Such methods likely come from trait impls whose trait path could not be resolved
458    // (e.g. traits defined in private modules but re-exported via `pub use`).
459    let mut method_counts: AHashMap<&str, (usize, usize)> = AHashMap::new(); // (total, with_source)
460    for typ in api.types.iter().filter(|typ| !typ.is_trait) {
461        if typ.is_trait {
462            continue;
463        }
464        for method in &typ.methods {
465            let entry = method_counts.entry(&method.name).or_insert((0, 0));
466            entry.0 += 1;
467            if method.trait_source.is_some() {
468                entry.1 += 1;
469            }
470        }
471    }
472    // A method appearing on 3+ types without trait_source on any is almost certainly a trait method
473    method_counts
474        .values()
475        .any(|&(total, with_source)| total >= 3 && with_source == 0)
476}
477
478/// Collect explicit type and enum names from the API surface for named imports.
479///
480/// Returns a sorted, deduplicated list of type and enum names that should be
481/// imported from the core crate. This replaces glob imports (`use core::*`)
482/// which can cause name conflicts with local binding definitions (e.g. a
483/// `convert` function or `Result` type alias from the core crate shadowing
484/// the binding's own `convert` wrapper or `std::result::Result`).
485///
486/// Only struct/enum names are included — functions and type aliases are
487/// intentionally excluded because they are the source of conflicts.
488pub fn collect_explicit_core_imports(api: &ApiSurface) -> Vec<String> {
489    let mut names = std::collections::BTreeSet::new();
490    for typ in api.types.iter().filter(|typ| !typ.is_trait) {
491        names.insert(typ.name.clone());
492    }
493    for e in &api.enums {
494        names.insert(e.name.clone());
495    }
496    names.into_iter().collect()
497}