Skip to main content

alef_codegen/generators/
trait_bridge.rs

1//! Shared trait bridge code generation.
2//!
3//! Generates wrapper structs that allow foreign language objects (Python, JS, etc.)
4//! to implement Rust traits via FFI. Each backend implements [`TraitBridgeGenerator`]
5//! to provide language-specific dispatch logic; the shared functions in this module
6//! handle the structural boilerplate.
7
8use alef_core::config::TraitBridgeConfig;
9use alef_core::ir::{MethodDef, TypeDef};
10use heck::ToSnakeCase;
11use std::fmt::Write;
12
13/// Everything needed to generate a trait bridge for one trait.
14pub struct TraitBridgeSpec<'a> {
15    /// The trait definition from the IR.
16    pub trait_def: &'a TypeDef,
17    /// Bridge configuration from `alef.toml`.
18    pub bridge_config: &'a TraitBridgeConfig,
19    /// Core crate import path (e.g., `"kreuzberg"`).
20    pub core_import: &'a str,
21    /// Language-specific prefix for the wrapper type (e.g., `"Python"`, `"Js"`, `"Wasm"`).
22    pub wrapper_prefix: &'a str,
23}
24
25impl<'a> TraitBridgeSpec<'a> {
26    /// Wrapper struct name: `{prefix}{TraitName}Bridge` (e.g., `PythonOcrBackendBridge`).
27    pub fn wrapper_name(&self) -> String {
28        format!("{}{}Bridge", self.wrapper_prefix, self.trait_def.name)
29    }
30
31    /// Snake-case version of the trait name (e.g., `"ocr_backend"`).
32    pub fn trait_snake(&self) -> String {
33        self.trait_def.name.to_snake_case()
34    }
35
36    /// Full Rust path to the trait (e.g., `kreuzberg::OcrBackend`).
37    pub fn trait_path(&self) -> String {
38        self.trait_def.rust_path.replace('-', "_")
39    }
40
41    /// Methods that are required (no default impl) — must be provided by the foreign object.
42    pub fn required_methods(&self) -> Vec<&'a MethodDef> {
43        self.trait_def.methods.iter().filter(|m| !m.has_default_impl).collect()
44    }
45
46    /// Methods that have a default impl — optional on the foreign object.
47    pub fn optional_methods(&self) -> Vec<&'a MethodDef> {
48        self.trait_def.methods.iter().filter(|m| m.has_default_impl).collect()
49    }
50}
51
52/// Backend-specific trait bridge generation.
53///
54/// Each binding backend (PyO3, NAPI-RS, wasm-bindgen, etc.) implements this trait
55/// to provide the language-specific parts of bridge codegen. The shared functions
56/// in this module call these methods to fill in the backend-dependent pieces.
57pub trait TraitBridgeGenerator {
58    /// The type of the wrapped foreign object (e.g., `"Py<PyAny>"`, `"ThreadsafeFunction"`).
59    fn foreign_object_type(&self) -> &str;
60
61    /// Additional `use` imports needed for the bridge code.
62    fn bridge_imports(&self) -> Vec<String>;
63
64    /// Generate the body of a synchronous method bridge.
65    ///
66    /// The returned string is inserted inside the trait impl method. It should
67    /// call through to the foreign object and convert the result.
68    fn gen_sync_method_body(&self, method: &MethodDef, spec: &TraitBridgeSpec) -> String;
69
70    /// Generate the body of an async method bridge.
71    ///
72    /// The returned string is the body of a `Box::pin(async move { ... })` block.
73    fn gen_async_method_body(&self, method: &MethodDef, spec: &TraitBridgeSpec) -> String;
74
75    /// Generate the constructor body that validates and wraps the foreign object.
76    ///
77    /// Should check that the foreign object provides all required methods and
78    /// return `Self { ... }` on success.
79    fn gen_constructor(&self, spec: &TraitBridgeSpec) -> String;
80
81    /// Generate the body of the registration function.
82    ///
83    /// This is the code inside the `register_xxx()` function that wraps the
84    /// foreign object, builds the bridge, and inserts it into the registry.
85    fn gen_registration_fn(&self, spec: &TraitBridgeSpec) -> String;
86
87    /// The attribute placed on the registration function (e.g., `"#[pyfunction]"`, `"#[napi]"`).
88    fn registration_fn_attr(&self) -> &str;
89}
90
91// ---------------------------------------------------------------------------
92// Shared generation functions
93// ---------------------------------------------------------------------------
94
95/// Generate the wrapper struct holding the foreign object and cached fields.
96///
97/// Produces a struct like:
98/// ```ignore
99/// pub struct PythonOcrBackendBridge {
100///     inner: Py<PyAny>,
101///     cached_name: String,
102/// }
103/// ```
104pub fn gen_bridge_wrapper_struct(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> String {
105    let wrapper = spec.wrapper_name();
106    let foreign_type = generator.foreign_object_type();
107    let mut out = String::with_capacity(512);
108
109    writeln!(
110        out,
111        "/// Wrapper that bridges a foreign {prefix} object to the `{trait_name}` trait.",
112        prefix = spec.wrapper_prefix,
113        trait_name = spec.trait_def.name,
114    )
115    .ok();
116    writeln!(out, "pub struct {wrapper} {{").ok();
117    writeln!(out, "    inner: {foreign_type},").ok();
118    writeln!(out, "    cached_name: String,").ok();
119    write!(out, "}}").ok();
120    out
121}
122
123/// Generate `impl Plugin for Wrapper` when the trait has `Plugin` as a super-trait.
124///
125/// Forwards `name()`, `version()`, `initialize()`, and `shutdown()` to the
126/// foreign object, using `cached_name` for `name()`.
127pub fn gen_bridge_plugin_impl(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> Option<String> {
128    if !spec.trait_def.super_traits.iter().any(|s| s == "Plugin") {
129        return None;
130    }
131
132    let wrapper = spec.wrapper_name();
133    let core_import = spec.core_import;
134
135    // Build synthetic MethodDefs for the Plugin methods and delegate to the generator
136    // for the actual call bodies. Use a simplified approach: generate the impl block
137    // with hardcoded structure since Plugin's interface is well-known.
138    let mut out = String::with_capacity(1024);
139    writeln!(out, "impl {core_import}::Plugin for {wrapper} {{").ok();
140
141    // name() -> &str — uses cached field
142    writeln!(out, "    fn name(&self) -> &str {{").ok();
143    writeln!(out, "        &self.cached_name").ok();
144    writeln!(out, "    }}").ok();
145    writeln!(out).ok();
146
147    // version() -> &str — delegate to foreign object
148    writeln!(out, "    fn version(&self) -> &str {{").ok();
149    // Build a synthetic method for version
150    let version_method = MethodDef {
151        name: "version".to_string(),
152        params: vec![],
153        return_type: alef_core::ir::TypeRef::String,
154        is_async: false,
155        is_static: false,
156        error_type: None,
157        doc: String::new(),
158        receiver: Some(alef_core::ir::ReceiverKind::Ref),
159        sanitized: false,
160        trait_source: None,
161        returns_ref: true,
162        returns_cow: false,
163        return_newtype_wrapper: None,
164        has_default_impl: false,
165    };
166    let version_body = generator.gen_sync_method_body(&version_method, spec);
167    for line in version_body.lines() {
168        writeln!(out, "        {line}").ok();
169    }
170    writeln!(out, "    }}").ok();
171    writeln!(out).ok();
172
173    // initialize() -> Result<()> — delegate to foreign object
174    writeln!(
175        out,
176        "    fn initialize(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {{"
177    )
178    .ok();
179    let init_method = MethodDef {
180        name: "initialize".to_string(),
181        params: vec![],
182        return_type: alef_core::ir::TypeRef::Unit,
183        is_async: false,
184        is_static: false,
185        error_type: Some("Box<dyn std::error::Error + Send + Sync>".to_string()),
186        doc: String::new(),
187        receiver: Some(alef_core::ir::ReceiverKind::Ref),
188        sanitized: false,
189        trait_source: None,
190        returns_ref: false,
191        returns_cow: false,
192        return_newtype_wrapper: None,
193        has_default_impl: true,
194    };
195    let init_body = generator.gen_sync_method_body(&init_method, spec);
196    for line in init_body.lines() {
197        writeln!(out, "        {line}").ok();
198    }
199    writeln!(out, "    }}").ok();
200    writeln!(out).ok();
201
202    // shutdown() -> Result<()> — delegate to foreign object
203    writeln!(
204        out,
205        "    fn shutdown(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {{"
206    )
207    .ok();
208    let shutdown_method = MethodDef {
209        name: "shutdown".to_string(),
210        params: vec![],
211        return_type: alef_core::ir::TypeRef::Unit,
212        is_async: false,
213        is_static: false,
214        error_type: Some("Box<dyn std::error::Error + Send + Sync>".to_string()),
215        doc: String::new(),
216        receiver: Some(alef_core::ir::ReceiverKind::Ref),
217        sanitized: false,
218        trait_source: None,
219        returns_ref: false,
220        returns_cow: false,
221        return_newtype_wrapper: None,
222        has_default_impl: true,
223    };
224    let shutdown_body = generator.gen_sync_method_body(&shutdown_method, spec);
225    for line in shutdown_body.lines() {
226        writeln!(out, "        {line}").ok();
227    }
228    writeln!(out, "    }}").ok();
229    write!(out, "}}").ok();
230    Some(out)
231}
232
233/// Generate `impl Trait for Wrapper` dispatching each method through the generator.
234///
235/// Every method on the trait (including those with `has_default_impl`) gets a
236/// generated body that forwards to the foreign object.
237pub fn gen_bridge_trait_impl(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> String {
238    let wrapper = spec.wrapper_name();
239    let trait_path = spec.trait_path();
240    let mut out = String::with_capacity(2048);
241
242    writeln!(out, "impl {trait_path} for {wrapper} {{").ok();
243
244    for (i, method) in spec.trait_def.methods.iter().enumerate() {
245        if i > 0 {
246            writeln!(out).ok();
247        }
248
249        // Build the method signature
250        let async_kw = if method.is_async { "async " } else { "" };
251        let receiver = match &method.receiver {
252            Some(alef_core::ir::ReceiverKind::Ref) => "&self",
253            Some(alef_core::ir::ReceiverKind::RefMut) => "&mut self",
254            Some(alef_core::ir::ReceiverKind::Owned) => "self",
255            None => "",
256        };
257
258        // Build params (excluding self)
259        let params: Vec<String> = method
260            .params
261            .iter()
262            .map(|p| format!("{}: {}", p.name, format_type_ref(&p.ty)))
263            .collect();
264
265        let all_params = if receiver.is_empty() {
266            params.join(", ")
267        } else if params.is_empty() {
268            receiver.to_string()
269        } else {
270            format!("{}, {}", receiver, params.join(", "))
271        };
272
273        // Return type
274        let ret = format_return_type(&method.return_type, method.error_type.as_deref());
275
276        writeln!(out, "    {async_kw}fn {}({all_params}) -> {ret} {{", method.name).ok();
277
278        // Generate body: async methods use Box::pin, sync methods call directly
279        let body = if method.is_async {
280            generator.gen_async_method_body(method, spec)
281        } else {
282            generator.gen_sync_method_body(method, spec)
283        };
284
285        for line in body.lines() {
286            writeln!(out, "        {line}").ok();
287        }
288        writeln!(out, "    }}").ok();
289    }
290
291    write!(out, "}}").ok();
292    out
293}
294
295/// Generate the `register_xxx()` function that wraps a foreign object and
296/// inserts it into the plugin registry.
297///
298/// Returns `None` when `bridge_config.register_fn` is absent (per-call bridge pattern).
299pub fn gen_bridge_registration_fn(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> Option<String> {
300    let reg_fn = spec.bridge_config.register_fn.as_deref()?;
301    let attr = generator.registration_fn_attr();
302    let mut out = String::with_capacity(1024);
303
304    writeln!(out, "{attr}").ok();
305    writeln!(out, "pub fn {reg_fn}() -> Result<(), String> {{").ok();
306
307    let body = generator.gen_registration_fn(spec);
308    for line in body.lines() {
309        writeln!(out, "    {line}").ok();
310    }
311
312    writeln!(out, "}}").ok();
313    Some(out)
314}
315
316/// Generate the complete trait bridge code block: imports, struct, impls, and
317/// optionally a registration function.
318///
319/// The registration function is only emitted when `bridge_config.register_fn` is set.
320/// Bridges without a `register_fn` use the per-call visitor pattern instead.
321pub fn gen_bridge_all(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> String {
322    let mut out = String::with_capacity(4096);
323
324    // Imports
325    let imports = generator.bridge_imports();
326    for imp in &imports {
327        writeln!(out, "use {imp};").ok();
328    }
329    if !imports.is_empty() {
330        writeln!(out).ok();
331    }
332
333    // Wrapper struct
334    out.push_str(&gen_bridge_wrapper_struct(spec, generator));
335    writeln!(out).ok();
336    writeln!(out).ok();
337
338    // Plugin super-trait impl (if applicable)
339    if let Some(plugin_impl) = gen_bridge_plugin_impl(spec, generator) {
340        out.push_str(&plugin_impl);
341        writeln!(out).ok();
342        writeln!(out).ok();
343    }
344
345    // Trait impl
346    out.push_str(&gen_bridge_trait_impl(spec, generator));
347
348    // Registration function — only when register_fn is configured
349    if let Some(reg_fn_code) = gen_bridge_registration_fn(spec, generator) {
350        writeln!(out).ok();
351        writeln!(out).ok();
352        out.push_str(&reg_fn_code);
353    }
354
355    out
356}
357
358// ---------------------------------------------------------------------------
359// Helpers
360// ---------------------------------------------------------------------------
361
362/// Format a `TypeRef` as a Rust type string for use in trait method signatures.
363fn format_type_ref(ty: &alef_core::ir::TypeRef) -> String {
364    use alef_core::ir::{PrimitiveType, TypeRef};
365    match ty {
366        TypeRef::Primitive(p) => match p {
367            PrimitiveType::Bool => "bool",
368            PrimitiveType::U8 => "u8",
369            PrimitiveType::U16 => "u16",
370            PrimitiveType::U32 => "u32",
371            PrimitiveType::U64 => "u64",
372            PrimitiveType::I8 => "i8",
373            PrimitiveType::I16 => "i16",
374            PrimitiveType::I32 => "i32",
375            PrimitiveType::I64 => "i64",
376            PrimitiveType::F32 => "f32",
377            PrimitiveType::F64 => "f64",
378            PrimitiveType::Usize => "usize",
379            PrimitiveType::Isize => "isize",
380        }
381        .to_string(),
382        TypeRef::String => "String".to_string(),
383        TypeRef::Char => "char".to_string(),
384        TypeRef::Bytes => "Vec<u8>".to_string(),
385        TypeRef::Optional(inner) => format!("Option<{}>", format_type_ref(inner)),
386        TypeRef::Vec(inner) => format!("Vec<{}>", format_type_ref(inner)),
387        TypeRef::Map(k, v) => format!(
388            "std::collections::HashMap<{}, {}>",
389            format_type_ref(k),
390            format_type_ref(v)
391        ),
392        TypeRef::Named(name) => name.clone(),
393        TypeRef::Path => "std::path::PathBuf".to_string(),
394        TypeRef::Unit => "()".to_string(),
395        TypeRef::Json => "serde_json::Value".to_string(),
396        TypeRef::Duration => "std::time::Duration".to_string(),
397    }
398}
399
400/// Format a return type, wrapping in `Result` when an error type is present.
401fn format_return_type(ty: &alef_core::ir::TypeRef, error_type: Option<&str>) -> String {
402    let inner = format_type_ref(ty);
403    match error_type {
404        Some(err) => format!("Result<{inner}, {err}>"),
405        None => inner,
406    }
407}