Skip to main content

alef_codegen/generators/
functions.rs

1use crate::generators::binding_helpers::{
2    gen_async_body, gen_call_args, gen_call_args_with_let_bindings, gen_named_let_bindings, gen_serde_let_bindings,
3    gen_unimplemented_body, has_named_params,
4};
5use crate::generators::{AdapterBodies, AsyncPattern, RustBindingConfig};
6use crate::shared::{function_params, function_sig_defaults};
7use crate::type_mapper::TypeMapper;
8use ahash::AHashSet;
9use alef_core::ir::{ApiSurface, FunctionDef, TypeRef};
10use std::fmt::Write;
11
12/// Generate a free function.
13pub fn gen_function(
14    func: &FunctionDef,
15    mapper: &dyn TypeMapper,
16    cfg: &RustBindingConfig,
17    adapter_bodies: &AdapterBodies,
18    opaque_types: &AHashSet<String>,
19) -> String {
20    let map_fn = |ty: &alef_core::ir::TypeRef| mapper.map_type(ty);
21    let params = function_params(&func.params, &map_fn);
22    let return_type = mapper.map_type(&func.return_type);
23    let ret = mapper.wrap_return(&return_type, func.error_type.is_some());
24
25    // Use let-binding pattern for non-opaque Named params so core fns can take &CoreType
26    let use_let_bindings = has_named_params(&func.params, opaque_types);
27    let call_args = if use_let_bindings {
28        gen_call_args_with_let_bindings(&func.params, opaque_types)
29    } else {
30        gen_call_args(&func.params, opaque_types)
31    };
32    let let_bindings = if use_let_bindings {
33        gen_named_let_bindings(&func.params, opaque_types)
34    } else {
35        String::new()
36    };
37    let core_import = cfg.core_import;
38
39    // Use the function's rust_path for correct module path resolution
40    let core_fn_path = {
41        let path = func.rust_path.replace('-', "_");
42        if path.starts_with(core_import) {
43            path
44        } else {
45            format!("{core_import}::{}", func.name)
46        }
47    };
48
49    let can_delegate = crate::shared::can_auto_delegate_function(func, opaque_types);
50
51    // Backend-specific error conversion string for serde bindings
52    let serde_err_conv = match cfg.async_pattern {
53        AsyncPattern::Pyo3FutureIntoPy => ".map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))",
54        AsyncPattern::NapiNativeAsync => ".map_err(|e| napi::Error::new(napi::Status::GenericFailure, e.to_string()))",
55        AsyncPattern::WasmNativeAsync => ".map_err(|e| JsValue::from_str(&e.to_string()))",
56        _ => ".map_err(|e| e.to_string())",
57    };
58
59    // Generate the body based on async pattern
60    let body = if !can_delegate {
61        // Check if an adapter provides the body
62        if let Some(adapter_body) = adapter_bodies.get(&func.name) {
63            adapter_body.clone()
64        } else if cfg.has_serde && use_let_bindings && func.error_type.is_some() {
65            // Serde-based param conversion: serialize binding types to JSON, deserialize to core types.
66            // This handles Named params (e.g., ProcessConfig) that lack binding→core From impls.
67            let serde_bindings =
68                gen_serde_let_bindings(&func.params, opaque_types, core_import, serde_err_conv, "    ");
69            let core_call = format!("{core_fn_path}({call_args})");
70
71            // Determine return wrapping strategy (same as delegatable case)
72            let returns_ref = func.returns_ref;
73            let wrap_return = |expr: &str| -> String {
74                match &func.return_type {
75                    TypeRef::Named(name) if opaque_types.contains(name.as_str()) => {
76                        if returns_ref {
77                            format!("{name} {{ inner: Arc::new({expr}.clone()) }}")
78                        } else {
79                            format!("{name} {{ inner: Arc::new({expr}) }}")
80                        }
81                    }
82                    TypeRef::Named(_name) => {
83                        if returns_ref {
84                            format!("{expr}.clone().into()")
85                        } else {
86                            format!("{expr}.into()")
87                        }
88                    }
89                    TypeRef::String | TypeRef::Bytes => format!("{expr}.into()"),
90                    TypeRef::Path => format!("{expr}.to_string_lossy().to_string()"),
91                    TypeRef::Json => format!("{expr}.to_string()"),
92                    _ => expr.to_string(),
93                }
94            };
95
96            if matches!(func.return_type, TypeRef::Unit) {
97                // Unit return with error: avoid let_unit_value
98                format!("{serde_bindings}{core_call}{serde_err_conv}?;\n    Ok(())")
99            } else {
100                let wrapped = wrap_return("val");
101                if wrapped == "val" {
102                    format!("{serde_bindings}{core_call}{serde_err_conv}")
103                } else {
104                    format!("{serde_bindings}{core_call}.map(|val| {wrapped}){serde_err_conv}")
105                }
106            }
107        } else {
108            // Function can't be auto-delegated — return a default/error based on return type
109            gen_unimplemented_body(
110                &func.return_type,
111                &func.name,
112                func.error_type.is_some(),
113                cfg,
114                &func.params,
115            )
116        }
117    } else if func.is_async {
118        let core_call = format!("{core_fn_path}({call_args})");
119        // In async contexts (future_into_py, etc.), the compiler often can't infer the
120        // target type for .into(). Use explicit From::from() / collect::<Vec<T>>() instead.
121        let return_wrap = match &func.return_type {
122            TypeRef::Named(n) if opaque_types.contains(n.as_str()) => {
123                format!("{n} {{ inner: Arc::new(result) }}")
124            }
125            TypeRef::Named(_) => {
126                format!("{return_type}::from(result)")
127            }
128            TypeRef::Vec(inner) => match inner.as_ref() {
129                TypeRef::Named(n) if opaque_types.contains(n.as_str()) => {
130                    format!("result.into_iter().map(|v| {n} {{ inner: Arc::new(v) }}).collect::<Vec<_>>()")
131                }
132                TypeRef::Named(_) => {
133                    let inner_mapped = mapper.map_type(inner);
134                    format!("result.into_iter().map({inner_mapped}::from).collect::<Vec<_>>()")
135                }
136                _ => "result".to_string(),
137            },
138            TypeRef::Unit => "result".to_string(),
139            _ => super::binding_helpers::wrap_return(
140                "result",
141                &func.return_type,
142                "",
143                opaque_types,
144                false,
145                func.returns_ref,
146            ),
147        };
148        let async_body = gen_async_body(
149            &core_call,
150            cfg,
151            func.error_type.is_some(),
152            &return_wrap,
153            false,
154            "",
155            matches!(func.return_type, TypeRef::Unit),
156        );
157        format!("{let_bindings}{async_body}")
158    } else {
159        let core_call = format!("{core_fn_path}({call_args})");
160
161        // Determine return wrapping strategy
162        let returns_ref = func.returns_ref;
163        let wrap_return = |expr: &str| -> String {
164            match &func.return_type {
165                // Opaque type return: wrap in Arc
166                TypeRef::Named(name) if opaque_types.contains(name.as_str()) => {
167                    if returns_ref {
168                        format!("{name} {{ inner: Arc::new({expr}.clone()) }}")
169                    } else {
170                        format!("{name} {{ inner: Arc::new({expr}) }}")
171                    }
172                }
173                // Non-opaque Named: use .into() if From impl exists
174                TypeRef::Named(_name) => {
175                    if returns_ref {
176                        format!("{expr}.clone().into()")
177                    } else {
178                        format!("{expr}.into()")
179                    }
180                }
181                // String/Bytes: .into() handles &str→String etc.
182                TypeRef::String | TypeRef::Bytes => format!("{expr}.into()"),
183                // Path: PathBuf→String needs to_string_lossy
184                TypeRef::Path => format!("{expr}.to_string_lossy().to_string()"),
185                // Json: serde_json::Value to string
186                TypeRef::Json => format!("{expr}.to_string()"),
187                // Optional with opaque inner
188                TypeRef::Optional(inner) => match inner.as_ref() {
189                    TypeRef::Named(name) if opaque_types.contains(name.as_str()) => {
190                        if returns_ref {
191                            format!("{expr}.map(|v| {name} {{ inner: Arc::new(v.clone()) }})")
192                        } else {
193                            format!("{expr}.map(|v| {name} {{ inner: Arc::new(v) }})")
194                        }
195                    }
196                    TypeRef::Named(_) => {
197                        if returns_ref {
198                            format!("{expr}.map(|v| v.clone().into())")
199                        } else {
200                            format!("{expr}.map(Into::into)")
201                        }
202                    }
203                    TypeRef::String | TypeRef::Bytes | TypeRef::Path => {
204                        format!("{expr}.map(Into::into)")
205                    }
206                    _ => expr.to_string(),
207                },
208                // Vec<Named>: map each element through Into
209                TypeRef::Vec(inner) => match inner.as_ref() {
210                    TypeRef::Named(name) if opaque_types.contains(name.as_str()) => {
211                        if returns_ref {
212                            format!("{expr}.into_iter().map(|v| {name} {{ inner: Arc::new(v.clone()) }}).collect()")
213                        } else {
214                            format!("{expr}.into_iter().map(|v| {name} {{ inner: Arc::new(v) }}).collect()")
215                        }
216                    }
217                    TypeRef::Named(_) => {
218                        if returns_ref {
219                            format!("{expr}.into_iter().map(|v| v.clone().into()).collect()")
220                        } else {
221                            format!("{expr}.into_iter().map(Into::into).collect()")
222                        }
223                    }
224                    TypeRef::String | TypeRef::Bytes | TypeRef::Path => {
225                        format!("{expr}.into_iter().map(Into::into).collect()")
226                    }
227                    _ => expr.to_string(),
228                },
229                _ => expr.to_string(),
230            }
231        };
232
233        if func.error_type.is_some() {
234            // Backend-specific error conversion
235            let err_conv = match cfg.async_pattern {
236                AsyncPattern::Pyo3FutureIntoPy => {
237                    ".map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))"
238                }
239                AsyncPattern::NapiNativeAsync => {
240                    ".map_err(|e| napi::Error::new(napi::Status::GenericFailure, e.to_string()))"
241                }
242                AsyncPattern::WasmNativeAsync => ".map_err(|e| JsValue::from_str(&e.to_string()))",
243                _ => ".map_err(|e| e.to_string())",
244            };
245            let wrapped = wrap_return("val");
246            if wrapped == "val" {
247                format!("{core_call}{err_conv}")
248            } else {
249                format!("{core_call}.map(|val| {wrapped}){err_conv}")
250            }
251        } else {
252            wrap_return(&core_call)
253        }
254    };
255
256    // Prepend let bindings for non-opaque Named params (sync non-adapter case)
257    let body = if !let_bindings.is_empty() && can_delegate && !func.is_async {
258        format!("{let_bindings}{body}")
259    } else {
260        body
261    };
262
263    // Wrap long signature if necessary
264    let async_kw = if func.is_async { "async " } else { "" };
265    let func_needs_py = func.is_async && cfg.async_pattern == AsyncPattern::Pyo3FutureIntoPy;
266
267    // For async PyO3 free functions, override return type and add lifetime generic.
268    let ret = if func_needs_py {
269        "PyResult<Bound<'py, PyAny>>".to_string()
270    } else {
271        ret
272    };
273    let func_lifetime = if func_needs_py { "<'py>" } else { "" };
274
275    let (func_sig, _params_formatted) = if params.len() > 100 {
276        let wrapped_params = func
277            .params
278            .iter()
279            .map(|p| {
280                let ty = if p.optional {
281                    format!("Option<{}>", mapper.map_type(&p.ty))
282                } else {
283                    mapper.map_type(&p.ty)
284                };
285                format!("{}: {}", p.name, ty)
286            })
287            .collect::<Vec<_>>()
288            .join(",\n    ");
289
290        // For async PyO3, we need special signature handling
291        if func_needs_py {
292            (
293                format!(
294                    "pub fn {}{func_lifetime}(py: Python<'py>,\n    {}\n) -> {ret}",
295                    func.name,
296                    wrapped_params,
297                    ret = ret
298                ),
299                "",
300            )
301        } else {
302            (
303                format!(
304                    "pub {async_kw}fn {}(\n    {}\n) -> {ret}",
305                    func.name,
306                    wrapped_params,
307                    ret = ret
308                ),
309                "",
310            )
311        }
312    } else if func_needs_py {
313        (
314            format!(
315                "pub fn {}{func_lifetime}(py: Python<'py>, {params}) -> {ret}",
316                func.name
317            ),
318            "",
319        )
320    } else {
321        (format!("pub {async_kw}fn {}({params}) -> {ret}", func.name), "")
322    };
323
324    let mut out = String::with_capacity(1024);
325    // Per-item clippy suppression: too_many_arguments when >7 params (including py)
326    let total_params = func.params.len() + if func_needs_py { 1 } else { 0 };
327    if total_params > 7 {
328        writeln!(out, "#[allow(clippy::too_many_arguments)]").ok();
329    }
330    // Per-item clippy suppression: missing_errors_doc for Result-returning functions
331    if func.error_type.is_some() {
332        writeln!(out, "#[allow(clippy::missing_errors_doc)]").ok();
333    }
334    let attr_inner = cfg
335        .function_attr
336        .trim_start_matches('#')
337        .trim_start_matches('[')
338        .trim_end_matches(']');
339    writeln!(out, "#[{attr_inner}]").ok();
340    if cfg.needs_signature {
341        let sig = function_sig_defaults(&func.params);
342        writeln!(out, "{}{}{}", cfg.signature_prefix, sig, cfg.signature_suffix).ok();
343    }
344    write!(out, "{} {{\n    {body}\n}}", func_sig,).ok();
345    out
346}
347
348/// Collect all unique trait import paths from opaque types' methods.
349///
350/// Returns a deduplicated, sorted list of trait paths (e.g. `["liter_llm::LlmClient"]`)
351/// that need to be imported in generated binding code so that trait methods can be called.
352pub fn collect_trait_imports(api: &ApiSurface) -> Vec<String> {
353    let mut traits: AHashSet<String> = AHashSet::new();
354    for typ in &api.types {
355        if !typ.is_opaque {
356            continue;
357        }
358        for method in &typ.methods {
359            if let Some(ref trait_path) = method.trait_source {
360                traits.insert(trait_path.clone());
361            }
362        }
363    }
364    let mut sorted: Vec<String> = traits.into_iter().collect();
365    sorted.sort();
366    sorted
367}