Skip to main content

alef_codegen/generators/
functions.rs

1use crate::generators::binding_helpers::{
2    gen_async_body, gen_call_args, gen_call_args_with_let_bindings, gen_named_let_bindings, gen_serde_let_bindings,
3    gen_unimplemented_body, has_named_params,
4};
5use crate::generators::{AdapterBodies, AsyncPattern, RustBindingConfig};
6use crate::shared::{function_params, function_sig_defaults};
7use crate::type_mapper::TypeMapper;
8use ahash::AHashSet;
9use alef_core::ir::{ApiSurface, FunctionDef, TypeRef};
10use std::fmt::Write;
11
12/// Generate a free function.
13pub fn gen_function(
14    func: &FunctionDef,
15    mapper: &dyn TypeMapper,
16    cfg: &RustBindingConfig,
17    adapter_bodies: &AdapterBodies,
18    opaque_types: &AHashSet<String>,
19) -> String {
20    let map_fn = |ty: &alef_core::ir::TypeRef| mapper.map_type(ty);
21    let params = function_params(&func.params, &map_fn);
22    let return_type = mapper.map_type(&func.return_type);
23    let ret = mapper.wrap_return(&return_type, func.error_type.is_some());
24
25    // Use let-binding pattern for non-opaque Named params so core fns can take &CoreType
26    let use_let_bindings = has_named_params(&func.params, opaque_types);
27    let call_args = if use_let_bindings {
28        gen_call_args_with_let_bindings(&func.params, opaque_types)
29    } else {
30        gen_call_args(&func.params, opaque_types)
31    };
32    let let_bindings = if use_let_bindings {
33        gen_named_let_bindings(&func.params, opaque_types)
34    } else {
35        String::new()
36    };
37    let core_import = cfg.core_import;
38
39    // Use the function's rust_path for correct module path resolution
40    let core_fn_path = {
41        let path = func.rust_path.replace('-', "_");
42        if path.starts_with(core_import) {
43            path
44        } else {
45            format!("{core_import}::{}", func.name)
46        }
47    };
48
49    let can_delegate = crate::shared::can_auto_delegate_function(func, opaque_types);
50
51    // Backend-specific error conversion string for serde bindings
52    let serde_err_conv = match cfg.async_pattern {
53        AsyncPattern::Pyo3FutureIntoPy => ".map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))",
54        AsyncPattern::NapiNativeAsync => ".map_err(|e| napi::Error::new(napi::Status::GenericFailure, e.to_string()))",
55        AsyncPattern::WasmNativeAsync => ".map_err(|e| JsValue::from_str(&e.to_string()))",
56        _ => ".map_err(|e| e.to_string())",
57    };
58
59    // Generate the body based on async pattern
60    let body = if !can_delegate {
61        // Check if an adapter provides the body
62        if let Some(adapter_body) = adapter_bodies.get(&func.name) {
63            adapter_body.clone()
64        } else if cfg.has_serde && use_let_bindings && func.error_type.is_some() {
65            // Serde-based param conversion: serialize binding types to JSON, deserialize to core types.
66            // This handles Named params (e.g., ProcessConfig) that lack binding→core From impls.
67            // For async functions with Pyo3FutureIntoPy, serde bindings use indented format.
68            let is_async_pyo3 = func.is_async && cfg.async_pattern == AsyncPattern::Pyo3FutureIntoPy;
69            let (serde_indent, serde_err_async) = if is_async_pyo3 {
70                (
71                    "        ",
72                    ".map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))",
73                )
74            } else {
75                ("    ", serde_err_conv)
76            };
77            let serde_bindings =
78                gen_serde_let_bindings(&func.params, opaque_types, core_import, serde_err_async, serde_indent);
79            let core_call = format!("{core_fn_path}({call_args})");
80
81            // Determine return wrapping strategy (same as delegatable case)
82            let returns_ref = func.returns_ref;
83            let wrap_return = |expr: &str| -> String {
84                match &func.return_type {
85                    TypeRef::Named(name) if opaque_types.contains(name.as_str()) => {
86                        if returns_ref {
87                            format!("{name} {{ inner: Arc::new({expr}.clone()) }}")
88                        } else {
89                            format!("{name} {{ inner: Arc::new({expr}) }}")
90                        }
91                    }
92                    TypeRef::Named(_name) => {
93                        if returns_ref {
94                            format!("{expr}.clone().into()")
95                        } else {
96                            format!("{expr}.into()")
97                        }
98                    }
99                    TypeRef::String | TypeRef::Bytes => format!("{expr}.into()"),
100                    TypeRef::Path => format!("{expr}.to_string_lossy().to_string()"),
101                    TypeRef::Json => format!("{expr}.to_string()"),
102                    _ => expr.to_string(),
103                }
104            };
105
106            if is_async_pyo3 {
107                // Async serde path: wrap everything in future_into_py
108                let is_unit = matches!(func.return_type, TypeRef::Unit);
109                let wrapped = wrap_return("result");
110                let core_await = format!(
111                    "{core_call}.await\n            .map_err(|e| PyErr::new::<PyRuntimeError, _>(e.to_string()))?"
112                );
113                let inner_body = if is_unit {
114                    format!("{serde_bindings}{core_await};\n            Ok(())")
115                } else {
116                    format!("{serde_bindings}let result = {core_await};\n            Ok({wrapped})")
117                };
118                format!("pyo3_async_runtimes::tokio::future_into_py(py, async move {{\n{inner_body}\n        }})")
119            } else if matches!(func.return_type, TypeRef::Unit) {
120                // Unit return with error: avoid let_unit_value
121                format!("{serde_bindings}{core_call}{serde_err_conv}?;\n    Ok(())")
122            } else {
123                let wrapped = wrap_return("val");
124                if wrapped == "val" {
125                    format!("{serde_bindings}{core_call}{serde_err_conv}")
126                } else {
127                    format!("{serde_bindings}{core_call}.map(|val| {wrapped}){serde_err_conv}")
128                }
129            }
130        } else {
131            // Function can't be auto-delegated — return a default/error based on return type
132            gen_unimplemented_body(
133                &func.return_type,
134                &func.name,
135                func.error_type.is_some(),
136                cfg,
137                &func.params,
138            )
139        }
140    } else if func.is_async {
141        let core_call = format!("{core_fn_path}({call_args})");
142        // In async contexts (future_into_py, etc.), the compiler often can't infer the
143        // target type for .into(). Use explicit From::from() / collect::<Vec<T>>() instead.
144        let return_wrap = match &func.return_type {
145            TypeRef::Named(n) if opaque_types.contains(n.as_str()) => {
146                format!("{n} {{ inner: Arc::new(result) }}")
147            }
148            TypeRef::Named(_) => {
149                format!("{return_type}::from(result)")
150            }
151            TypeRef::Vec(inner) => match inner.as_ref() {
152                TypeRef::Named(n) if opaque_types.contains(n.as_str()) => {
153                    format!("result.into_iter().map(|v| {n} {{ inner: Arc::new(v) }}).collect::<Vec<_>>()")
154                }
155                TypeRef::Named(_) => {
156                    let inner_mapped = mapper.map_type(inner);
157                    format!("result.into_iter().map({inner_mapped}::from).collect::<Vec<_>>()")
158                }
159                _ => "result".to_string(),
160            },
161            TypeRef::Unit => "result".to_string(),
162            _ => super::binding_helpers::wrap_return(
163                "result",
164                &func.return_type,
165                "",
166                opaque_types,
167                false,
168                func.returns_ref,
169                false,
170            ),
171        };
172        let async_body = gen_async_body(
173            &core_call,
174            cfg,
175            func.error_type.is_some(),
176            &return_wrap,
177            false,
178            "",
179            matches!(func.return_type, TypeRef::Unit),
180        );
181        format!("{let_bindings}{async_body}")
182    } else {
183        let core_call = format!("{core_fn_path}({call_args})");
184
185        // Determine return wrapping strategy
186        let returns_ref = func.returns_ref;
187        let wrap_return = |expr: &str| -> String {
188            match &func.return_type {
189                // Opaque type return: wrap in Arc
190                TypeRef::Named(name) if opaque_types.contains(name.as_str()) => {
191                    if returns_ref {
192                        format!("{name} {{ inner: Arc::new({expr}.clone()) }}")
193                    } else {
194                        format!("{name} {{ inner: Arc::new({expr}) }}")
195                    }
196                }
197                // Non-opaque Named: use .into() if From impl exists
198                TypeRef::Named(_name) => {
199                    if returns_ref {
200                        format!("{expr}.clone().into()")
201                    } else {
202                        format!("{expr}.into()")
203                    }
204                }
205                // String/Bytes: .into() handles &str→String etc.
206                TypeRef::String | TypeRef::Bytes => format!("{expr}.into()"),
207                // Path: PathBuf→String needs to_string_lossy
208                TypeRef::Path => format!("{expr}.to_string_lossy().to_string()"),
209                // Json: serde_json::Value to string
210                TypeRef::Json => format!("{expr}.to_string()"),
211                // Optional with opaque inner
212                TypeRef::Optional(inner) => match inner.as_ref() {
213                    TypeRef::Named(name) if opaque_types.contains(name.as_str()) => {
214                        if returns_ref {
215                            format!("{expr}.map(|v| {name} {{ inner: Arc::new(v.clone()) }})")
216                        } else {
217                            format!("{expr}.map(|v| {name} {{ inner: Arc::new(v) }})")
218                        }
219                    }
220                    TypeRef::Named(_) => {
221                        if returns_ref {
222                            format!("{expr}.map(|v| v.clone().into())")
223                        } else {
224                            format!("{expr}.map(Into::into)")
225                        }
226                    }
227                    TypeRef::String | TypeRef::Bytes | TypeRef::Path => {
228                        format!("{expr}.map(Into::into)")
229                    }
230                    _ => expr.to_string(),
231                },
232                // Vec<Named>: map each element through Into
233                TypeRef::Vec(inner) => match inner.as_ref() {
234                    TypeRef::Named(name) if opaque_types.contains(name.as_str()) => {
235                        if returns_ref {
236                            format!("{expr}.into_iter().map(|v| {name} {{ inner: Arc::new(v.clone()) }}).collect()")
237                        } else {
238                            format!("{expr}.into_iter().map(|v| {name} {{ inner: Arc::new(v) }}).collect()")
239                        }
240                    }
241                    TypeRef::Named(_) => {
242                        if returns_ref {
243                            format!("{expr}.into_iter().map(|v| v.clone().into()).collect()")
244                        } else {
245                            format!("{expr}.into_iter().map(Into::into).collect()")
246                        }
247                    }
248                    TypeRef::String | TypeRef::Bytes | TypeRef::Path => {
249                        format!("{expr}.into_iter().map(Into::into).collect()")
250                    }
251                    _ => expr.to_string(),
252                },
253                _ => expr.to_string(),
254            }
255        };
256
257        if func.error_type.is_some() {
258            // Backend-specific error conversion
259            let err_conv = match cfg.async_pattern {
260                AsyncPattern::Pyo3FutureIntoPy => {
261                    ".map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))"
262                }
263                AsyncPattern::NapiNativeAsync => {
264                    ".map_err(|e| napi::Error::new(napi::Status::GenericFailure, e.to_string()))"
265                }
266                AsyncPattern::WasmNativeAsync => ".map_err(|e| JsValue::from_str(&e.to_string()))",
267                _ => ".map_err(|e| e.to_string())",
268            };
269            let wrapped = wrap_return("val");
270            if wrapped == "val" {
271                format!("{core_call}{err_conv}")
272            } else {
273                format!("{core_call}.map(|val| {wrapped}){err_conv}")
274            }
275        } else {
276            wrap_return(&core_call)
277        }
278    };
279
280    // Prepend let bindings for non-opaque Named params (sync non-adapter case)
281    let body = if !let_bindings.is_empty() && can_delegate && !func.is_async {
282        format!("{let_bindings}{body}")
283    } else {
284        body
285    };
286
287    // Wrap long signature if necessary
288    let async_kw = if func.is_async { "async " } else { "" };
289    let func_needs_py = func.is_async && cfg.async_pattern == AsyncPattern::Pyo3FutureIntoPy;
290
291    // For async PyO3 free functions, override return type and add lifetime generic.
292    let ret = if func_needs_py {
293        "PyResult<Bound<'py, PyAny>>".to_string()
294    } else {
295        ret
296    };
297    let func_lifetime = if func_needs_py { "<'py>" } else { "" };
298
299    let (func_sig, _params_formatted) = if params.len() > 100 {
300        let wrapped_params = func
301            .params
302            .iter()
303            .map(|p| {
304                let ty = if p.optional {
305                    format!("Option<{}>", mapper.map_type(&p.ty))
306                } else {
307                    mapper.map_type(&p.ty)
308                };
309                format!("{}: {}", p.name, ty)
310            })
311            .collect::<Vec<_>>()
312            .join(",\n    ");
313
314        // For async PyO3, we need special signature handling
315        if func_needs_py {
316            (
317                format!(
318                    "pub fn {}{func_lifetime}(py: Python<'py>,\n    {}\n) -> {ret}",
319                    func.name,
320                    wrapped_params,
321                    ret = ret
322                ),
323                "",
324            )
325        } else {
326            (
327                format!(
328                    "pub {async_kw}fn {}(\n    {}\n) -> {ret}",
329                    func.name,
330                    wrapped_params,
331                    ret = ret
332                ),
333                "",
334            )
335        }
336    } else if func_needs_py {
337        (
338            format!(
339                "pub fn {}{func_lifetime}(py: Python<'py>, {params}) -> {ret}",
340                func.name
341            ),
342            "",
343        )
344    } else {
345        (format!("pub {async_kw}fn {}({params}) -> {ret}", func.name), "")
346    };
347
348    let mut out = String::with_capacity(1024);
349    // Per-item clippy suppression: too_many_arguments when >7 params (including py)
350    let total_params = func.params.len() + if func_needs_py { 1 } else { 0 };
351    if total_params > 7 {
352        writeln!(out, "#[allow(clippy::too_many_arguments)]").ok();
353    }
354    // Per-item clippy suppression: missing_errors_doc for Result-returning functions
355    if func.error_type.is_some() {
356        writeln!(out, "#[allow(clippy::missing_errors_doc)]").ok();
357    }
358    let attr_inner = cfg
359        .function_attr
360        .trim_start_matches('#')
361        .trim_start_matches('[')
362        .trim_end_matches(']');
363    writeln!(out, "#[{attr_inner}]").ok();
364    if cfg.needs_signature {
365        let sig = function_sig_defaults(&func.params);
366        writeln!(out, "{}{}{}", cfg.signature_prefix, sig, cfg.signature_suffix).ok();
367    }
368    write!(out, "{} {{\n    {body}\n}}", func_sig,).ok();
369    out
370}
371
372/// Collect all unique trait import paths from opaque types' methods.
373///
374/// Returns a deduplicated, sorted list of trait paths (e.g. `["liter_llm::LlmClient"]`)
375/// that need to be imported in generated binding code so that trait methods can be called.
376pub fn collect_trait_imports(api: &ApiSurface) -> Vec<String> {
377    let mut traits: AHashSet<String> = AHashSet::new();
378    for typ in &api.types {
379        if !typ.is_opaque {
380            continue;
381        }
382        for method in &typ.methods {
383            if let Some(ref trait_path) = method.trait_source {
384                traits.insert(trait_path.clone());
385            }
386        }
387    }
388    let mut sorted: Vec<String> = traits.into_iter().collect();
389    sorted.sort();
390    sorted
391}