Skip to main content

alef_codegen/
shared.rs

1use ahash::AHashSet;
2use alef_core::ir::{DefaultValue, FieldDef, MethodDef, ParamDef, PrimitiveType, ReceiverKind, TypeRef};
3use std::collections::HashMap;
4
5/// Returns true if this parameter is required but must be promoted to optional
6/// because it follows an optional parameter in the list.
7/// PyO3 requires that required params come before all optional params.
8pub fn is_promoted_optional(params: &[ParamDef], idx: usize) -> bool {
9    if params[idx].optional {
10        return false; // naturally optional
11    }
12    // Check if any earlier param is optional
13    params[..idx].iter().any(|p| p.optional)
14}
15
16/// Check if a free function can be auto-delegated to the core crate.
17/// Opaque Named params are allowed (unwrapped via Arc). Non-opaque Named params are not
18/// (require From impls that may not exist for types with sanitized fields).
19pub fn can_auto_delegate_function(func: &alef_core::ir::FunctionDef, opaque_types: &AHashSet<String>) -> bool {
20    !func.sanitized
21        && func
22            .params
23            .iter()
24            .all(|p| !p.sanitized && is_delegatable_param(&p.ty, opaque_types) && !is_named_ref_param(p, opaque_types))
25        && is_delegatable_return(&func.return_type)
26}
27
28/// Check if all params and return type are delegatable.
29/// For opaque types, skip methods with RefMut receiver (cannot borrow Arc mutably).
30pub fn can_auto_delegate(method: &MethodDef, opaque_types: &AHashSet<String>) -> bool {
31    // Skip RefMut methods on opaque types (Arc doesn't allow mutable access)
32    if matches!(method.receiver, Some(ReceiverKind::RefMut)) && method.trait_source.is_none() {
33        return false;
34    }
35    !method.sanitized
36        && method
37            .params
38            .iter()
39            .all(|p| !p.sanitized && is_delegatable_param(&p.ty, opaque_types) && !is_named_ref_param(p, opaque_types))
40        && is_delegatable_return(&method.return_type)
41}
42
43/// A Named param with is_ref=true needs a let-binding (can't inline .into() + borrow).
44/// A Vec<String> param with is_ref=true needs conversion to Vec<&str>.
45fn is_named_ref_param(p: &alef_core::ir::ParamDef, opaque_types: &AHashSet<String>) -> bool {
46    if !p.is_ref {
47        return false;
48    }
49    match &p.ty {
50        TypeRef::Named(name) => !opaque_types.contains(name.as_str()),
51        TypeRef::Vec(inner) => matches!(inner.as_ref(), TypeRef::String | TypeRef::Char),
52        _ => false,
53    }
54}
55
56/// A param type is delegatable if it's simple, or a Named type (opaque → Arc unwrap, non-opaque → .into()).
57pub fn is_delegatable_param(ty: &TypeRef, _opaque_types: &AHashSet<String>) -> bool {
58    match ty {
59        TypeRef::Primitive(_)
60        | TypeRef::String
61        | TypeRef::Char
62        | TypeRef::Bytes
63        | TypeRef::Path
64        | TypeRef::Unit
65        | TypeRef::Duration => true,
66        TypeRef::Named(_) => true, // Opaque: &*param.inner; non-opaque: .into()
67        TypeRef::Optional(inner) | TypeRef::Vec(inner) => is_delegatable_param(inner, _opaque_types),
68        TypeRef::Map(k, v) => is_delegatable_param(k, _opaque_types) && is_delegatable_param(v, _opaque_types),
69        TypeRef::Json => false,
70    }
71}
72
73/// Return types are more permissive — Named types work via .into() (core→binding From exists).
74pub fn is_delegatable_return(ty: &TypeRef) -> bool {
75    match ty {
76        TypeRef::Primitive(_)
77        | TypeRef::String
78        | TypeRef::Char
79        | TypeRef::Bytes
80        | TypeRef::Path
81        | TypeRef::Unit
82        | TypeRef::Duration => true,
83        TypeRef::Named(_) => true, // core→binding From impl generated for all convertible types
84        TypeRef::Optional(inner) | TypeRef::Vec(inner) => is_delegatable_return(inner),
85        TypeRef::Map(k, v) => is_delegatable_return(k) && is_delegatable_return(v),
86        TypeRef::Json => false,
87    }
88}
89
90/// A type is delegatable if it can cross the binding boundary without From impls.
91/// Named types are NOT delegatable as function params (may lack From impls).
92/// For opaque methods, Named types are handled separately via Arc wrap/unwrap.
93pub fn is_delegatable_type(ty: &TypeRef) -> bool {
94    match ty {
95        TypeRef::Primitive(_)
96        | TypeRef::String
97        | TypeRef::Char
98        | TypeRef::Bytes
99        | TypeRef::Path
100        | TypeRef::Unit
101        | TypeRef::Duration => true,
102        TypeRef::Named(_) => false, // Requires From impl which may not exist
103        TypeRef::Optional(inner) | TypeRef::Vec(inner) => is_delegatable_type(inner),
104        TypeRef::Map(k, v) => is_delegatable_type(k) && is_delegatable_type(v),
105        TypeRef::Json => false,
106    }
107}
108
109/// Check if a type is delegatable in the opaque method context.
110/// Opaque methods can handle Named params via Arc unwrap and Named returns via Arc wrap.
111pub fn is_opaque_delegatable_type(ty: &TypeRef) -> bool {
112    match ty {
113        TypeRef::Primitive(_)
114        | TypeRef::String
115        | TypeRef::Char
116        | TypeRef::Bytes
117        | TypeRef::Path
118        | TypeRef::Unit
119        | TypeRef::Duration => true,
120        TypeRef::Named(_) => true, // Opaque: Arc unwrap/wrap. Non-opaque: .into()
121        TypeRef::Optional(inner) | TypeRef::Vec(inner) => is_opaque_delegatable_type(inner),
122        TypeRef::Map(k, v) => is_opaque_delegatable_type(k) && is_opaque_delegatable_type(v),
123        TypeRef::Json => false,
124    }
125}
126
127/// Check if a type is "simple" — can be passed without any conversion.
128pub fn is_simple_type(ty: &TypeRef) -> bool {
129    match ty {
130        TypeRef::Primitive(_)
131        | TypeRef::String
132        | TypeRef::Char
133        | TypeRef::Bytes
134        | TypeRef::Path
135        | TypeRef::Unit
136        | TypeRef::Duration => true,
137        TypeRef::Optional(inner) | TypeRef::Vec(inner) => is_simple_type(inner),
138        TypeRef::Map(k, v) => is_simple_type(k) && is_simple_type(v),
139        TypeRef::Named(_) | TypeRef::Json => false,
140    }
141}
142
143/// Partition methods into (instance, static).
144pub fn partition_methods(methods: &[MethodDef]) -> (Vec<&MethodDef>, Vec<&MethodDef>) {
145    let instance: Vec<_> = methods.iter().filter(|m| m.receiver.is_some()).collect();
146    let statics: Vec<_> = methods.iter().filter(|m| m.receiver.is_none()).collect();
147    (instance, statics)
148}
149
150/// Build a constructor parameter list string.
151/// Returns (param_list, signature_with_defaults, field_assignments).
152/// If param_list exceeds 100 chars, uses multiline format with trailing commas.
153pub fn constructor_parts(fields: &[FieldDef], type_mapper: &dyn Fn(&TypeRef) -> String) -> (String, String, String) {
154    constructor_parts_with_renames(fields, type_mapper, None)
155}
156
157/// Like `constructor_parts` but with optional field renames for keyword escaping.
158/// `field_renames` maps original field name → binding field name (e.g. "class" → "class_").
159/// Parameters keep the original name (valid in Rust), struct literal uses the renamed field.
160pub fn constructor_parts_with_renames(
161    fields: &[FieldDef],
162    type_mapper: &dyn Fn(&TypeRef) -> String,
163    field_renames: Option<&HashMap<String, String>>,
164) -> (String, String, String) {
165    // Sort fields: required first, then optional.
166    // Many FFI frameworks (PyO3, NAPI) require required params before optional ones.
167    // Skip cfg-gated fields — they depend on features that may not be enabled.
168    let mut sorted_fields: Vec<&FieldDef> = fields.iter().filter(|f| f.cfg.is_none()).collect();
169    sorted_fields.sort_by_key(|f| f.optional as u8);
170
171    let params: Vec<String> = sorted_fields
172        .iter()
173        .map(|f| {
174            let ty = if f.optional {
175                format!("Option<{}>", type_mapper(&f.ty))
176            } else {
177                type_mapper(&f.ty)
178            };
179            format!("{}: {}", f.name, ty)
180        })
181        .collect();
182
183    let defaults: Vec<String> = sorted_fields
184        .iter()
185        .map(|f| {
186            if f.optional {
187                format!("{}=None", f.name)
188            } else {
189                f.name.clone()
190            }
191        })
192        .collect();
193
194    // Assignments keep original field order (for struct literal), excluding cfg-gated.
195    // When a field is renamed, emit `renamed: original` instead of just `original`.
196    let assignments: Vec<String> = fields
197        .iter()
198        .filter(|f| f.cfg.is_none())
199        .map(|f| {
200            if let Some(renames) = field_renames {
201                if let Some(renamed) = renames.get(&f.name) {
202                    return format!("{}: {}", renamed, f.name);
203                }
204            }
205            f.name.clone()
206        })
207        .collect();
208
209    // Format param_list with line wrapping if needed
210    let single_line = params.join(", ");
211    let param_list = if single_line.len() > 100 {
212        format!("\n        {},\n    ", params.join(",\n        "))
213    } else {
214        single_line
215    };
216
217    (param_list, defaults.join(", "), assignments.join(", "))
218}
219
220/// Build a function parameter list.
221pub fn function_params(params: &[ParamDef], type_mapper: &dyn Fn(&TypeRef) -> String) -> String {
222    // After the first optional param, all subsequent params must also be optional
223    // to satisfy PyO3's signature constraint (required params can't follow optional ones).
224    let mut seen_optional = false;
225    params
226        .iter()
227        .map(|p| {
228            if p.optional {
229                seen_optional = true;
230            }
231            let ty = if p.optional || seen_optional {
232                format!("Option<{}>", type_mapper(&p.ty))
233            } else {
234                type_mapper(&p.ty)
235            };
236            format!("{}: {}", p.name, ty)
237        })
238        .collect::<Vec<_>>()
239        .join(", ")
240}
241
242/// Build a function signature defaults string (for pyo3 signature etc.).
243pub fn function_sig_defaults(params: &[ParamDef]) -> String {
244    // After the first optional param, all subsequent params must also carry a default
245    // to satisfy PyO3's signature constraint (required params can't follow optional ones).
246    // For optional params and Named/non-primitive promoted params: use `=None`.
247    // For promoted non-optional primitive params: use a type-appropriate zero/false default
248    // so PyO3 does not wrap the Rust type in Option<T> (which would cause a `?` unwrap error).
249    let mut seen_optional = false;
250    params
251        .iter()
252        .map(|p| {
253            if p.optional {
254                seen_optional = true;
255            }
256            if p.optional {
257                format!("{}=None", p.name)
258            } else if seen_optional {
259                // Promoted non-optional param: emit a type-appropriate default instead of None
260                // so PyO3 keeps the Rust parameter type as T (not Option<T>).
261                let default = match &p.ty {
262                    TypeRef::Primitive(PrimitiveType::Bool) => "false",
263                    TypeRef::Primitive(_) => "0",
264                    _ => "None",
265                };
266                format!("{}={}", p.name, default)
267            } else {
268                p.name.clone()
269            }
270        })
271        .collect::<Vec<_>>()
272        .join(", ")
273}
274
275/// Format a DefaultValue as Rust code for the target language.
276/// Used by backends generating config constructors with defaults.
277pub fn format_default_value(default: &DefaultValue) -> String {
278    match default {
279        DefaultValue::BoolLiteral(b) => format!("{}", b),
280        DefaultValue::StringLiteral(s) => format!("\"{}\".to_string()", s.escape_default()),
281        DefaultValue::IntLiteral(i) => format!("{}", i),
282        DefaultValue::FloatLiteral(f) => {
283            let s = format!("{}", f);
284            // Ensure the literal is a valid Rust float (must contain '.' or 'e'/'E')
285            if s.contains('.') || s.contains('e') || s.contains('E') {
286                s
287            } else {
288                format!("{s}.0")
289            }
290        }
291        DefaultValue::EnumVariant(v) => v.clone(),
292        DefaultValue::Empty => "Default::default()".to_string(),
293        DefaultValue::None => "None".to_string(),
294    }
295}
296
297/// Generate constructor parameter and assignment lists for types with has_default.
298/// All fields become Option<T> with None defaults for optional fields,
299/// or unwrap_or_else with actual defaults for required fields.
300///
301/// Returns (param_list, signature_defaults, assignments).
302/// This is used by PyO3 and similar backends that need signature annotations.
303/// Like `config_constructor_parts` but with extra options.
304/// When `option_duration_on_defaults` is true, non-optional Duration fields are stored
305/// as `Option<u64>` in the binding struct, so the constructor assignment is a passthrough
306/// (the From conversion will handle the None → core default mapping).
307pub fn config_constructor_parts_with_options(
308    fields: &[FieldDef],
309    type_mapper: &dyn Fn(&TypeRef) -> String,
310    option_duration_on_defaults: bool,
311) -> (String, String, String) {
312    config_constructor_parts_inner(fields, type_mapper, option_duration_on_defaults, None)
313}
314
315/// Like `config_constructor_parts_with_options` but with field renames for keyword escaping.
316pub fn config_constructor_parts_with_renames(
317    fields: &[FieldDef],
318    type_mapper: &dyn Fn(&TypeRef) -> String,
319    option_duration_on_defaults: bool,
320    field_renames: Option<&HashMap<String, String>>,
321) -> (String, String, String) {
322    config_constructor_parts_inner(fields, type_mapper, option_duration_on_defaults, field_renames)
323}
324
325pub fn config_constructor_parts(
326    fields: &[FieldDef],
327    type_mapper: &dyn Fn(&TypeRef) -> String,
328) -> (String, String, String) {
329    config_constructor_parts_inner(fields, type_mapper, false, None)
330}
331
332fn config_constructor_parts_inner(
333    fields: &[FieldDef],
334    type_mapper: &dyn Fn(&TypeRef) -> String,
335    option_duration_on_defaults: bool,
336    field_renames: Option<&HashMap<String, String>>,
337) -> (String, String, String) {
338    let mut sorted_fields: Vec<&FieldDef> = fields.iter().filter(|f| f.cfg.is_none()).collect();
339    sorted_fields.sort_by_key(|f| f.optional as u8);
340
341    let params: Vec<String> = sorted_fields
342        .iter()
343        .map(|f| {
344            let ty = type_mapper(&f.ty);
345            // All fields become Option<T>, but avoid Option<Option<T>> for already-optional fields.
346            // When f.ty is TypeRef::Optional(X), type_mapper already returns "Option<X>".
347            // Wrapping it again would yield Option<Option<X>>, making `None` ambiguous in PyO3
348            // signatures (E0283: type annotations needed).
349            if matches!(f.ty, TypeRef::Optional(_)) {
350                format!("{}: {}", f.name, ty)
351            } else {
352                format!("{}: Option<{}>", f.name, ty)
353            }
354        })
355        .collect();
356
357    // All fields have None default in signature
358    let defaults = sorted_fields
359        .iter()
360        .map(|f| format!("{}=None", f.name))
361        .collect::<Vec<_>>()
362        .join(", ");
363
364    // Assignments use unwrap_or_else with the typed default.
365    // `binding_name` is the struct field name (possibly renamed for keyword escaping),
366    // `f.name` is the original name used as the constructor parameter.
367    let assignments: Vec<String> = fields
368        .iter()
369        .filter(|f| f.cfg.is_none())
370        .map(|f| {
371            let binding_name = field_renames
372                .and_then(|r| r.get(&f.name))
373                .map_or_else(|| f.name.as_str(), |s| s.as_str());
374            // Duration fields on has_default types are stored as Option<u64> when
375            // option_duration_on_defaults is set — treat them as passthrough.
376            if option_duration_on_defaults && matches!(f.ty, TypeRef::Duration) {
377                return format!("{}: {}", binding_name, f.name);
378            }
379            if f.optional || matches!(&f.ty, TypeRef::Optional(_)) {
380                // Optional fields: passthrough (both param and field are Option<T>)
381                format!("{}: {}", binding_name, f.name)
382            } else if let Some(ref typed_default) = f.typed_default {
383                // For EnumVariant and Empty defaults, use unwrap_or_default()
384                // because we can't generate qualified Rust paths here.
385                match typed_default {
386                    DefaultValue::EnumVariant(_) | DefaultValue::Empty => {
387                        format!("{}: {}.unwrap_or_default()", binding_name, f.name)
388                    }
389                    _ => {
390                        let default_val = format_default_value(typed_default);
391                        // Use unwrap_or() for Copy literals (bool, int, float) to avoid
392                        // clippy::unnecessary_lazy_evaluations; use unwrap_or_else for heap types.
393                        match typed_default {
394                            DefaultValue::BoolLiteral(_)
395                            | DefaultValue::IntLiteral(_)
396                            | DefaultValue::FloatLiteral(_) => {
397                                format!("{}: {}.unwrap_or({})", binding_name, f.name, default_val)
398                            }
399                            _ => {
400                                format!("{}: {}.unwrap_or_else(|| {})", binding_name, f.name, default_val)
401                            }
402                        }
403                    }
404                }
405            } else {
406                // All binding types should impl Default (enums default to first variant,
407                // structs default via From<CoreType::default()>). unwrap_or_default() works.
408                format!("{}: {}.unwrap_or_default()", binding_name, f.name)
409            }
410        })
411        .collect();
412
413    let single_line = params.join(", ");
414    let param_list = if single_line.len() > 100 {
415        format!("\n        {},\n    ", params.join(",\n        "))
416    } else {
417        single_line
418    };
419
420    (param_list, defaults, assignments.join(", "))
421}