Skip to main content

alef_codegen/generators/
trait_bridge.rs

1//! Shared trait bridge code generation.
2//!
3//! Generates wrapper structs that allow foreign language objects (Python, JS, etc.)
4//! to implement Rust traits via FFI. Each backend implements [`TraitBridgeGenerator`]
5//! to provide language-specific dispatch logic; the shared functions in this module
6//! handle the structural boilerplate.
7
8use alef_core::config::TraitBridgeConfig;
9use alef_core::ir::{MethodDef, ParamDef, TypeDef};
10use heck::ToSnakeCase;
11use std::collections::HashMap;
12use std::fmt::Write;
13
14/// Everything needed to generate a trait bridge for one trait.
15pub struct TraitBridgeSpec<'a> {
16    /// The trait definition from the IR.
17    pub trait_def: &'a TypeDef,
18    /// Bridge configuration from `alef.toml`.
19    pub bridge_config: &'a TraitBridgeConfig,
20    /// Core crate import path (e.g., `"kreuzberg"`).
21    pub core_import: &'a str,
22    /// Language-specific prefix for the wrapper type (e.g., `"Python"`, `"Js"`, `"Wasm"`).
23    pub wrapper_prefix: &'a str,
24    /// Map of type name → fully-qualified Rust path for qualifying `Named` types.
25    pub type_paths: HashMap<String, String>,
26}
27
28impl<'a> TraitBridgeSpec<'a> {
29    /// Wrapper struct name: `{prefix}{TraitName}Bridge` (e.g., `PythonOcrBackendBridge`).
30    pub fn wrapper_name(&self) -> String {
31        format!("{}{}Bridge", self.wrapper_prefix, self.trait_def.name)
32    }
33
34    /// Snake-case version of the trait name (e.g., `"ocr_backend"`).
35    pub fn trait_snake(&self) -> String {
36        self.trait_def.name.to_snake_case()
37    }
38
39    /// Full Rust path to the trait (e.g., `kreuzberg::OcrBackend`).
40    pub fn trait_path(&self) -> String {
41        self.trait_def.rust_path.replace('-', "_")
42    }
43
44    /// Methods that are required (no default impl) — must be provided by the foreign object.
45    pub fn required_methods(&self) -> Vec<&'a MethodDef> {
46        self.trait_def.methods.iter().filter(|m| !m.has_default_impl).collect()
47    }
48
49    /// Methods that have a default impl — optional on the foreign object.
50    pub fn optional_methods(&self) -> Vec<&'a MethodDef> {
51        self.trait_def.methods.iter().filter(|m| m.has_default_impl).collect()
52    }
53}
54
55/// Backend-specific trait bridge generation.
56///
57/// Each binding backend (PyO3, NAPI-RS, wasm-bindgen, etc.) implements this trait
58/// to provide the language-specific parts of bridge codegen. The shared functions
59/// in this module call these methods to fill in the backend-dependent pieces.
60pub trait TraitBridgeGenerator {
61    /// The type of the wrapped foreign object (e.g., `"Py<PyAny>"`, `"ThreadsafeFunction"`).
62    fn foreign_object_type(&self) -> &str;
63
64    /// Additional `use` imports needed for the bridge code.
65    fn bridge_imports(&self) -> Vec<String>;
66
67    /// Generate the body of a synchronous method bridge.
68    ///
69    /// The returned string is inserted inside the trait impl method. It should
70    /// call through to the foreign object and convert the result.
71    fn gen_sync_method_body(&self, method: &MethodDef, spec: &TraitBridgeSpec) -> String;
72
73    /// Generate the body of an async method bridge.
74    ///
75    /// The returned string is the body of a `Box::pin(async move { ... })` block.
76    fn gen_async_method_body(&self, method: &MethodDef, spec: &TraitBridgeSpec) -> String;
77
78    /// Generate the constructor body that validates and wraps the foreign object.
79    ///
80    /// Should check that the foreign object provides all required methods and
81    /// return `Self { ... }` on success.
82    fn gen_constructor(&self, spec: &TraitBridgeSpec) -> String;
83
84    /// Generate the complete registration function including attributes, signature, and body.
85    ///
86    /// Each backend needs different function signatures (PyO3 takes `py: Python`,
87    /// NAPI takes `#[napi]` with JS params, FFI takes `extern "C"` with raw pointers),
88    /// so the generator owns the full function.
89    fn gen_registration_fn(&self, spec: &TraitBridgeSpec) -> String;
90}
91
92// ---------------------------------------------------------------------------
93// Shared generation functions
94// ---------------------------------------------------------------------------
95
96/// Generate the wrapper struct holding the foreign object and cached fields.
97///
98/// Produces a struct like:
99/// ```ignore
100/// pub struct PythonOcrBackendBridge {
101///     inner: Py<PyAny>,
102///     cached_name: String,
103/// }
104/// ```
105pub fn gen_bridge_wrapper_struct(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> String {
106    let wrapper = spec.wrapper_name();
107    let foreign_type = generator.foreign_object_type();
108    let mut out = String::with_capacity(512);
109
110    writeln!(
111        out,
112        "/// Wrapper that bridges a foreign {prefix} object to the `{trait_name}` trait.",
113        prefix = spec.wrapper_prefix,
114        trait_name = spec.trait_def.name,
115    )
116    .ok();
117    writeln!(out, "pub struct {wrapper} {{").ok();
118    writeln!(out, "    inner: {foreign_type},").ok();
119    writeln!(out, "    cached_name: String,").ok();
120    write!(out, "}}").ok();
121    out
122}
123
124/// Generate `impl SuperTrait for Wrapper` when the bridge config specifies a super-trait.
125///
126/// Forwards `name()`, `version()`, `initialize()`, and `shutdown()` to the
127/// foreign object, using `cached_name` for `name()`.
128///
129/// The super-trait path is derived from the config's `super_trait` field. If it
130/// contains `::`, it's used as-is; otherwise it's qualified as `{core_import}::{super_trait}`.
131pub fn gen_bridge_plugin_impl(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> Option<String> {
132    let super_trait_name = spec.bridge_config.super_trait.as_deref()?;
133
134    let wrapper = spec.wrapper_name();
135    let core_import = spec.core_import;
136
137    // Derive the fully-qualified super-trait path
138    let super_trait_path = if super_trait_name.contains("::") {
139        super_trait_name.to_string()
140    } else {
141        format!("{core_import}::{super_trait_name}")
142    };
143
144    // Build synthetic MethodDefs for the Plugin methods and delegate to the generator
145    // for the actual call bodies. The Plugin trait interface is well-known: name(),
146    // version(), initialize(), shutdown().
147    let mut out = String::with_capacity(1024);
148    writeln!(out, "impl {super_trait_path} for {wrapper} {{").ok();
149
150    // name() -> &str — uses cached field
151    writeln!(out, "    fn name(&self) -> &str {{").ok();
152    writeln!(out, "        &self.cached_name").ok();
153    writeln!(out, "    }}").ok();
154    writeln!(out).ok();
155
156    // version() -> &str — delegate to foreign object
157    writeln!(out, "    fn version(&self) -> &str {{").ok();
158    // Build a synthetic method for version
159    let version_method = MethodDef {
160        name: "version".to_string(),
161        params: vec![],
162        return_type: alef_core::ir::TypeRef::String,
163        is_async: false,
164        is_static: false,
165        error_type: None,
166        doc: String::new(),
167        receiver: Some(alef_core::ir::ReceiverKind::Ref),
168        sanitized: false,
169        trait_source: None,
170        returns_ref: true,
171        returns_cow: false,
172        return_newtype_wrapper: None,
173        has_default_impl: false,
174    };
175    let version_body = generator.gen_sync_method_body(&version_method, spec);
176    for line in version_body.lines() {
177        writeln!(out, "        {}", line.trim_start()).ok();
178    }
179    writeln!(out, "    }}").ok();
180    writeln!(out).ok();
181
182    // initialize() -> Result<()> — delegate to foreign object
183    writeln!(
184        out,
185        "    fn initialize(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {{"
186    )
187    .ok();
188    let init_method = MethodDef {
189        name: "initialize".to_string(),
190        params: vec![],
191        return_type: alef_core::ir::TypeRef::Unit,
192        is_async: false,
193        is_static: false,
194        error_type: Some("Box<dyn std::error::Error + Send + Sync>".to_string()),
195        doc: String::new(),
196        receiver: Some(alef_core::ir::ReceiverKind::Ref),
197        sanitized: false,
198        trait_source: None,
199        returns_ref: false,
200        returns_cow: false,
201        return_newtype_wrapper: None,
202        has_default_impl: true,
203    };
204    let init_body = generator.gen_sync_method_body(&init_method, spec);
205    for line in init_body.lines() {
206        writeln!(out, "        {}", line.trim_start()).ok();
207    }
208    writeln!(out, "    }}").ok();
209    writeln!(out).ok();
210
211    // shutdown() -> Result<()> — delegate to foreign object
212    writeln!(
213        out,
214        "    fn shutdown(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {{"
215    )
216    .ok();
217    let shutdown_method = MethodDef {
218        name: "shutdown".to_string(),
219        params: vec![],
220        return_type: alef_core::ir::TypeRef::Unit,
221        is_async: false,
222        is_static: false,
223        error_type: Some("Box<dyn std::error::Error + Send + Sync>".to_string()),
224        doc: String::new(),
225        receiver: Some(alef_core::ir::ReceiverKind::Ref),
226        sanitized: false,
227        trait_source: None,
228        returns_ref: false,
229        returns_cow: false,
230        return_newtype_wrapper: None,
231        has_default_impl: true,
232    };
233    let shutdown_body = generator.gen_sync_method_body(&shutdown_method, spec);
234    for line in shutdown_body.lines() {
235        writeln!(out, "        {}", line.trim_start()).ok();
236    }
237    writeln!(out, "    }}").ok();
238    write!(out, "}}").ok();
239    Some(out)
240}
241
242/// Generate `impl Trait for Wrapper` dispatching each method through the generator.
243///
244/// Every method on the trait (including those with `has_default_impl`) gets a
245/// generated body that forwards to the foreign object.
246pub fn gen_bridge_trait_impl(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> String {
247    let wrapper = spec.wrapper_name();
248    let trait_path = spec.trait_path();
249    let mut out = String::with_capacity(2048);
250
251    writeln!(out, "impl {trait_path} for {wrapper} {{").ok();
252
253    // Filter out methods inherited from super-traits (they're handled by gen_bridge_plugin_impl)
254    let own_methods: Vec<_> = spec
255        .trait_def
256        .methods
257        .iter()
258        .filter(|m| m.trait_source.is_none())
259        .collect();
260
261    for (i, method) in own_methods.iter().enumerate() {
262        if i > 0 {
263            writeln!(out).ok();
264        }
265
266        // Build the method signature
267        let async_kw = if method.is_async { "async " } else { "" };
268        let receiver = match &method.receiver {
269            Some(alef_core::ir::ReceiverKind::Ref) => "&self",
270            Some(alef_core::ir::ReceiverKind::RefMut) => "&mut self",
271            Some(alef_core::ir::ReceiverKind::Owned) => "self",
272            None => "",
273        };
274
275        // Build params (excluding self), using format_param_type to respect is_ref/is_mut
276        let params: Vec<String> = method
277            .params
278            .iter()
279            .map(|p| format!("{}: {}", p.name, format_param_type(p, &spec.type_paths)))
280            .collect();
281
282        let all_params = if receiver.is_empty() {
283            params.join(", ")
284        } else if params.is_empty() {
285            receiver.to_string()
286        } else {
287            format!("{}, {}", receiver, params.join(", "))
288        };
289
290        // Return type
291        let ret = format_return_type(&method.return_type, method.error_type.as_deref(), &spec.type_paths);
292
293        writeln!(out, "    {async_kw}fn {}({all_params}) -> {ret} {{", method.name).ok();
294
295        // Generate body: async methods use Box::pin, sync methods call directly
296        let body = if method.is_async {
297            generator.gen_async_method_body(method, spec)
298        } else {
299            generator.gen_sync_method_body(method, spec)
300        };
301
302        for line in body.lines() {
303            writeln!(out, "        {line}").ok();
304        }
305        writeln!(out, "    }}").ok();
306    }
307
308    write!(out, "}}").ok();
309    out
310}
311
312/// Generate the `register_xxx()` function that wraps a foreign object and
313/// inserts it into the plugin registry.
314///
315/// Returns `None` when `bridge_config.register_fn` is absent (per-call bridge pattern).
316/// The generator owns the full function (attributes, signature, body) because each
317/// backend needs different signatures.
318pub fn gen_bridge_registration_fn(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> Option<String> {
319    spec.bridge_config.register_fn.as_deref()?;
320    Some(generator.gen_registration_fn(spec))
321}
322
323/// Generate the complete trait bridge code block: imports, struct, impls, and
324/// optionally a registration function.
325///
326/// The registration function is only emitted when `bridge_config.register_fn` is set.
327/// Bridges without a `register_fn` use the per-call visitor pattern instead.
328pub fn gen_bridge_all(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> String {
329    let mut out = String::with_capacity(4096);
330
331    // Imports
332    let imports = generator.bridge_imports();
333    for imp in &imports {
334        writeln!(out, "use {imp};").ok();
335    }
336    if !imports.is_empty() {
337        writeln!(out).ok();
338    }
339
340    // Wrapper struct
341    out.push_str(&gen_bridge_wrapper_struct(spec, generator));
342    writeln!(out).ok();
343    writeln!(out).ok();
344
345    // Constructor (impl block with new())
346    out.push_str(&generator.gen_constructor(spec));
347    writeln!(out).ok();
348    writeln!(out).ok();
349
350    // Plugin super-trait impl (if applicable)
351    if let Some(plugin_impl) = gen_bridge_plugin_impl(spec, generator) {
352        out.push_str(&plugin_impl);
353        writeln!(out).ok();
354        writeln!(out).ok();
355    }
356
357    // Trait impl
358    out.push_str(&gen_bridge_trait_impl(spec, generator));
359
360    // Registration function — only when register_fn is configured
361    if let Some(reg_fn_code) = gen_bridge_registration_fn(spec, generator) {
362        writeln!(out).ok();
363        writeln!(out).ok();
364        out.push_str(&reg_fn_code);
365    }
366
367    out
368}
369
370// ---------------------------------------------------------------------------
371// Helpers
372// ---------------------------------------------------------------------------
373
374/// Format a `TypeRef` as a Rust type string for use in trait method signatures.
375///
376/// `type_paths` qualifies `Named` types with their full Rust path (e.g., `"Config"` →
377/// `"kreuzberg::Config"`). If a name isn't in `type_paths`, it's used as-is.
378pub fn format_type_ref(ty: &alef_core::ir::TypeRef, type_paths: &HashMap<String, String>) -> String {
379    use alef_core::ir::{PrimitiveType, TypeRef};
380    match ty {
381        TypeRef::Primitive(p) => match p {
382            PrimitiveType::Bool => "bool",
383            PrimitiveType::U8 => "u8",
384            PrimitiveType::U16 => "u16",
385            PrimitiveType::U32 => "u32",
386            PrimitiveType::U64 => "u64",
387            PrimitiveType::I8 => "i8",
388            PrimitiveType::I16 => "i16",
389            PrimitiveType::I32 => "i32",
390            PrimitiveType::I64 => "i64",
391            PrimitiveType::F32 => "f32",
392            PrimitiveType::F64 => "f64",
393            PrimitiveType::Usize => "usize",
394            PrimitiveType::Isize => "isize",
395        }
396        .to_string(),
397        TypeRef::String => "String".to_string(),
398        TypeRef::Char => "char".to_string(),
399        TypeRef::Bytes => "Vec<u8>".to_string(),
400        TypeRef::Optional(inner) => format!("Option<{}>", format_type_ref(inner, type_paths)),
401        TypeRef::Vec(inner) => format!("Vec<{}>", format_type_ref(inner, type_paths)),
402        TypeRef::Map(k, v) => format!(
403            "std::collections::HashMap<{}, {}>",
404            format_type_ref(k, type_paths),
405            format_type_ref(v, type_paths)
406        ),
407        TypeRef::Named(name) => type_paths.get(name.as_str()).cloned().unwrap_or_else(|| name.clone()),
408        TypeRef::Path => "std::path::PathBuf".to_string(),
409        TypeRef::Unit => "()".to_string(),
410        TypeRef::Json => "serde_json::Value".to_string(),
411        TypeRef::Duration => "std::time::Duration".to_string(),
412    }
413}
414
415/// Format a return type, wrapping in `Result` when an error type is present.
416pub fn format_return_type(
417    ty: &alef_core::ir::TypeRef,
418    error_type: Option<&str>,
419    type_paths: &HashMap<String, String>,
420) -> String {
421    let inner = format_type_ref(ty, type_paths);
422    match error_type {
423        Some(err) => format!("Result<{inner}, {err}>"),
424        None => inner,
425    }
426}
427
428/// Format a parameter type, respecting `is_ref` and `is_mut` from the IR.
429///
430/// Unlike [`format_type_ref`], this function produces reference types when the
431/// original Rust parameter was a `&T` or `&mut T`:
432/// - `String + is_ref` → `&str`
433/// - `Bytes + is_ref` → `&[u8]`
434/// - `Path + is_ref` → `&std::path::Path`
435/// - `Vec<T> + is_ref` → `&[T]`
436/// - `Named(n) + is_ref` → `&{qualified_name}`
437pub fn format_param_type(param: &ParamDef, type_paths: &HashMap<String, String>) -> String {
438    use alef_core::ir::TypeRef;
439    if param.is_ref {
440        match &param.ty {
441            TypeRef::String => "&str".to_string(),
442            TypeRef::Bytes => "&[u8]".to_string(),
443            TypeRef::Path => "&std::path::Path".to_string(),
444            TypeRef::Vec(inner) => format!("&[{}]", format_type_ref(inner, type_paths)),
445            TypeRef::Named(name) => {
446                let qualified = type_paths.get(name.as_str()).cloned().unwrap_or_else(|| name.clone());
447                format!("&{qualified}")
448            }
449            // All other types are Copy/small — pass by value even when is_ref is set
450            other => format_type_ref(other, type_paths),
451        }
452    } else {
453        format_type_ref(&param.ty, type_paths)
454    }
455}