Skip to main content

alef_codegen/generators/
trait_bridge.rs

1//! Shared trait bridge code generation.
2//!
3//! Generates wrapper structs that allow foreign language objects (Python, JS, etc.)
4//! to implement Rust traits via FFI. Each backend implements [`TraitBridgeGenerator`]
5//! to provide language-specific dispatch logic; the shared functions in this module
6//! handle the structural boilerplate.
7
8use alef_core::config::TraitBridgeConfig;
9use alef_core::ir::{FunctionDef, MethodDef, ParamDef, PrimitiveType, TypeDef, TypeRef};
10use heck::ToSnakeCase;
11use std::collections::HashMap;
12use std::fmt::Write;
13
14/// Everything needed to generate a trait bridge for one trait.
15pub struct TraitBridgeSpec<'a> {
16    /// The trait definition from the IR.
17    pub trait_def: &'a TypeDef,
18    /// Bridge configuration from `alef.toml`.
19    pub bridge_config: &'a TraitBridgeConfig,
20    /// Core crate import path (e.g., `"kreuzberg"`).
21    pub core_import: &'a str,
22    /// Language-specific prefix for the wrapper type (e.g., `"Python"`, `"Js"`, `"Wasm"`).
23    pub wrapper_prefix: &'a str,
24    /// Map of type name → fully-qualified Rust path for qualifying `Named` types.
25    pub type_paths: HashMap<String, String>,
26    /// The crate's error type name (e.g., `"KreuzbergError"`). Defaults to `"Error"`.
27    pub error_type: String,
28    /// Error constructor pattern. `{msg}` is replaced with the message expression.
29    pub error_constructor: String,
30}
31
32impl<'a> TraitBridgeSpec<'a> {
33    /// Fully qualified error type path (e.g., `"kreuzberg::KreuzbergError"`).
34    pub fn error_path(&self) -> String {
35        format!("{}::{}", self.core_import, self.error_type)
36    }
37
38    /// Generate an error construction expression from a message expression.
39    pub fn make_error(&self, msg_expr: &str) -> String {
40        self.error_constructor.replace("{msg}", msg_expr)
41    }
42
43    /// Wrapper struct name: `{prefix}{TraitName}Bridge` (e.g., `PythonOcrBackendBridge`).
44    pub fn wrapper_name(&self) -> String {
45        format!("{}{}Bridge", self.wrapper_prefix, self.trait_def.name)
46    }
47
48    /// Snake-case version of the trait name (e.g., `"ocr_backend"`).
49    pub fn trait_snake(&self) -> String {
50        self.trait_def.name.to_snake_case()
51    }
52
53    /// Full Rust path to the trait (e.g., `kreuzberg::OcrBackend`).
54    pub fn trait_path(&self) -> String {
55        self.trait_def.rust_path.replace('-', "_")
56    }
57
58    /// Methods that are required (no default impl) — must be provided by the foreign object.
59    pub fn required_methods(&self) -> Vec<&'a MethodDef> {
60        self.trait_def.methods.iter().filter(|m| !m.has_default_impl).collect()
61    }
62
63    /// Methods that have a default impl — optional on the foreign object.
64    pub fn optional_methods(&self) -> Vec<&'a MethodDef> {
65        self.trait_def.methods.iter().filter(|m| m.has_default_impl).collect()
66    }
67}
68
69/// Backend-specific trait bridge generation.
70///
71/// Each binding backend (PyO3, NAPI-RS, wasm-bindgen, etc.) implements this trait
72/// to provide the language-specific parts of bridge codegen. The shared functions
73/// in this module call these methods to fill in the backend-dependent pieces.
74pub trait TraitBridgeGenerator {
75    /// The type of the wrapped foreign object (e.g., `"Py<PyAny>"`, `"ThreadsafeFunction"`).
76    fn foreign_object_type(&self) -> &str;
77
78    /// Additional `use` imports needed for the bridge code.
79    fn bridge_imports(&self) -> Vec<String>;
80
81    /// Generate the body of a synchronous method bridge.
82    ///
83    /// The returned string is inserted inside the trait impl method. It should
84    /// call through to the foreign object and convert the result.
85    fn gen_sync_method_body(&self, method: &MethodDef, spec: &TraitBridgeSpec) -> String;
86
87    /// Generate the body of an async method bridge.
88    ///
89    /// The returned string is the body of a `Box::pin(async move { ... })` block.
90    fn gen_async_method_body(&self, method: &MethodDef, spec: &TraitBridgeSpec) -> String;
91
92    /// Generate the constructor body that validates and wraps the foreign object.
93    ///
94    /// Should check that the foreign object provides all required methods and
95    /// return `Self { ... }` on success.
96    fn gen_constructor(&self, spec: &TraitBridgeSpec) -> String;
97
98    /// Generate the complete registration function including attributes, signature, and body.
99    ///
100    /// Each backend needs different function signatures (PyO3 takes `py: Python`,
101    /// NAPI takes `#[napi]` with JS params, FFI takes `extern "C"` with raw pointers),
102    /// so the generator owns the full function.
103    fn gen_registration_fn(&self, spec: &TraitBridgeSpec) -> String;
104}
105
106// ---------------------------------------------------------------------------
107// Shared generation functions
108// ---------------------------------------------------------------------------
109
110/// Generate the wrapper struct holding the foreign object and cached fields.
111///
112/// Produces a struct like:
113/// ```ignore
114/// pub struct PythonOcrBackendBridge {
115///     inner: Py<PyAny>,
116///     cached_name: String,
117/// }
118/// ```
119pub fn gen_bridge_wrapper_struct(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> String {
120    let wrapper = spec.wrapper_name();
121    let foreign_type = generator.foreign_object_type();
122    let mut out = String::with_capacity(512);
123
124    writeln!(
125        out,
126        "/// Wrapper that bridges a foreign {prefix} object to the `{trait_name}` trait.",
127        prefix = spec.wrapper_prefix,
128        trait_name = spec.trait_def.name,
129    )
130    .ok();
131    writeln!(out, "pub struct {wrapper} {{").ok();
132    writeln!(out, "    inner: {foreign_type},").ok();
133    writeln!(out, "    cached_name: String,").ok();
134    write!(out, "}}").ok();
135    out
136}
137
138/// Generate `impl SuperTrait for Wrapper` when the bridge config specifies a super-trait.
139///
140/// Forwards `name()`, `version()`, `initialize()`, and `shutdown()` to the
141/// foreign object, using `cached_name` for `name()`.
142///
143/// The super-trait path is derived from the config's `super_trait` field. If it
144/// contains `::`, it's used as-is; otherwise it's qualified as `{core_import}::{super_trait}`.
145pub fn gen_bridge_plugin_impl(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> Option<String> {
146    let super_trait_name = spec.bridge_config.super_trait.as_deref()?;
147
148    let wrapper = spec.wrapper_name();
149    let core_import = spec.core_import;
150
151    // Derive the fully-qualified super-trait path
152    let super_trait_path = if super_trait_name.contains("::") {
153        super_trait_name.to_string()
154    } else {
155        format!("{core_import}::{super_trait_name}")
156    };
157
158    // Build synthetic MethodDefs for the Plugin methods and delegate to the generator
159    // for the actual call bodies. The Plugin trait interface is well-known: name(),
160    // version(), initialize(), shutdown().
161    let mut out = String::with_capacity(1024);
162    writeln!(out, "impl {super_trait_path} for {wrapper} {{").ok();
163
164    // name() -> &str — uses cached field
165    writeln!(out, "    fn name(&self) -> &str {{").ok();
166    writeln!(out, "        &self.cached_name").ok();
167    writeln!(out, "    }}").ok();
168    writeln!(out).ok();
169
170    let error_path = spec.error_path();
171
172    // version() -> String — delegate to foreign object
173    writeln!(out, "    fn version(&self) -> String {{").ok();
174    let version_method = MethodDef {
175        name: "version".to_string(),
176        params: vec![],
177        return_type: alef_core::ir::TypeRef::String,
178        is_async: false,
179        is_static: false,
180        error_type: None,
181        doc: String::new(),
182        receiver: Some(alef_core::ir::ReceiverKind::Ref),
183        sanitized: false,
184        trait_source: None,
185        returns_ref: false,
186        returns_cow: false,
187        return_newtype_wrapper: None,
188        has_default_impl: false,
189    };
190    let version_body = generator.gen_sync_method_body(&version_method, spec);
191    for line in version_body.lines() {
192        writeln!(out, "        {}", line.trim_start()).ok();
193    }
194    writeln!(out, "    }}").ok();
195    writeln!(out).ok();
196
197    // initialize() -> Result<(), ErrorType>
198    writeln!(
199        out,
200        "    fn initialize(&self) -> std::result::Result<(), {error_path}> {{"
201    )
202    .ok();
203    let init_method = MethodDef {
204        name: "initialize".to_string(),
205        params: vec![],
206        return_type: alef_core::ir::TypeRef::Unit,
207        is_async: false,
208        is_static: false,
209        error_type: Some(error_path.clone()),
210        doc: String::new(),
211        receiver: Some(alef_core::ir::ReceiverKind::Ref),
212        sanitized: false,
213        trait_source: None,
214        returns_ref: false,
215        returns_cow: false,
216        return_newtype_wrapper: None,
217        has_default_impl: true,
218    };
219    let init_body = generator.gen_sync_method_body(&init_method, spec);
220    for line in init_body.lines() {
221        writeln!(out, "        {}", line.trim_start()).ok();
222    }
223    writeln!(out, "    }}").ok();
224    writeln!(out).ok();
225
226    // shutdown() -> Result<(), ErrorType>
227    writeln!(
228        out,
229        "    fn shutdown(&self) -> std::result::Result<(), {error_path}> {{"
230    )
231    .ok();
232    let shutdown_method = MethodDef {
233        name: "shutdown".to_string(),
234        params: vec![],
235        return_type: alef_core::ir::TypeRef::Unit,
236        is_async: false,
237        is_static: false,
238        error_type: Some(error_path.clone()),
239        doc: String::new(),
240        receiver: Some(alef_core::ir::ReceiverKind::Ref),
241        sanitized: false,
242        trait_source: None,
243        returns_ref: false,
244        returns_cow: false,
245        return_newtype_wrapper: None,
246        has_default_impl: true,
247    };
248    let shutdown_body = generator.gen_sync_method_body(&shutdown_method, spec);
249    for line in shutdown_body.lines() {
250        writeln!(out, "        {}", line.trim_start()).ok();
251    }
252    writeln!(out, "    }}").ok();
253    write!(out, "}}").ok();
254    Some(out)
255}
256
257/// Generate `impl Trait for Wrapper` dispatching each method through the generator.
258///
259/// Every method on the trait (including those with `has_default_impl`) gets a
260/// generated body that forwards to the foreign object.
261pub fn gen_bridge_trait_impl(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> String {
262    let wrapper = spec.wrapper_name();
263    let trait_path = spec.trait_path();
264    let mut out = String::with_capacity(2048);
265
266    // Add #[async_trait] when the trait has async methods (needed for async_trait macro compatibility)
267    let has_async_methods = spec
268        .trait_def
269        .methods
270        .iter()
271        .any(|m| m.is_async && m.trait_source.is_none());
272    if has_async_methods {
273        writeln!(out, "#[async_trait::async_trait]").ok();
274    }
275    writeln!(out, "impl {trait_path} for {wrapper} {{").ok();
276
277    // Filter out methods inherited from super-traits (they're handled by gen_bridge_plugin_impl)
278    let own_methods: Vec<_> = spec
279        .trait_def
280        .methods
281        .iter()
282        .filter(|m| m.trait_source.is_none())
283        .collect();
284
285    for (i, method) in own_methods.iter().enumerate() {
286        if i > 0 {
287            writeln!(out).ok();
288        }
289
290        // Build the method signature
291        let async_kw = if method.is_async { "async " } else { "" };
292        let receiver = match &method.receiver {
293            Some(alef_core::ir::ReceiverKind::Ref) => "&self",
294            Some(alef_core::ir::ReceiverKind::RefMut) => "&mut self",
295            Some(alef_core::ir::ReceiverKind::Owned) => "self",
296            None => "",
297        };
298
299        // Build params (excluding self), using format_param_type to respect is_ref/is_mut
300        let params: Vec<String> = method
301            .params
302            .iter()
303            .map(|p| format!("{}: {}", p.name, format_param_type(p, &spec.type_paths)))
304            .collect();
305
306        let all_params = if receiver.is_empty() {
307            params.join(", ")
308        } else if params.is_empty() {
309            receiver.to_string()
310        } else {
311            format!("{}, {}", receiver, params.join(", "))
312        };
313
314        // Return type — override the IR's error type with the configured crate error type
315        // so the impl matches the actual trait definition (the IR may extract a different
316        // error type like anyhow::Error from re-exports or type alias resolution).
317        let error_override = method.error_type.as_ref().map(|_| spec.error_path());
318        let ret = format_return_type(&method.return_type, error_override.as_deref(), &spec.type_paths);
319
320        writeln!(out, "    {async_kw}fn {}({all_params}) -> {ret} {{", method.name).ok();
321
322        // Generate body: async methods use Box::pin, sync methods call directly
323        let body = if method.is_async {
324            generator.gen_async_method_body(method, spec)
325        } else {
326            generator.gen_sync_method_body(method, spec)
327        };
328
329        for line in body.lines() {
330            writeln!(out, "        {line}").ok();
331        }
332        writeln!(out, "    }}").ok();
333    }
334
335    write!(out, "}}").ok();
336    out
337}
338
339/// Generate the `register_xxx()` function that wraps a foreign object and
340/// inserts it into the plugin registry.
341///
342/// Returns `None` when `bridge_config.register_fn` is absent (per-call bridge pattern).
343/// The generator owns the full function (attributes, signature, body) because each
344/// backend needs different signatures.
345pub fn gen_bridge_registration_fn(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> Option<String> {
346    spec.bridge_config.register_fn.as_deref()?;
347    Some(generator.gen_registration_fn(spec))
348}
349
350/// Result of trait bridge generation: imports (to be added via `builder.add_import`)
351/// and the code body (to be added via `builder.add_item`).
352pub struct BridgeOutput {
353    /// Import paths (e.g., `"std::sync::Arc"`) — callers should add via `builder.add_import()`.
354    pub imports: Vec<String>,
355    /// The generated code (struct, impls, registration fn).
356    pub code: String,
357}
358
359/// Generate the complete trait bridge code block: struct, impls, and
360/// optionally a registration function.
361///
362/// Returns [`BridgeOutput`] with imports separated from code so callers can
363/// route imports through `builder.add_import()` (which deduplicates).
364pub fn gen_bridge_all(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> BridgeOutput {
365    let imports = generator.bridge_imports();
366    let mut out = String::with_capacity(4096);
367
368    // Wrapper struct
369    out.push_str(&gen_bridge_wrapper_struct(spec, generator));
370    writeln!(out).ok();
371    writeln!(out).ok();
372
373    // Constructor (impl block with new())
374    out.push_str(&generator.gen_constructor(spec));
375    writeln!(out).ok();
376    writeln!(out).ok();
377
378    // Plugin super-trait impl (if applicable)
379    if let Some(plugin_impl) = gen_bridge_plugin_impl(spec, generator) {
380        out.push_str(&plugin_impl);
381        writeln!(out).ok();
382        writeln!(out).ok();
383    }
384
385    // Trait impl
386    out.push_str(&gen_bridge_trait_impl(spec, generator));
387
388    // Registration function — only when register_fn is configured
389    if let Some(reg_fn_code) = gen_bridge_registration_fn(spec, generator) {
390        writeln!(out).ok();
391        writeln!(out).ok();
392        out.push_str(&reg_fn_code);
393    }
394
395    BridgeOutput { imports, code: out }
396}
397
398// ---------------------------------------------------------------------------
399// Helpers
400// ---------------------------------------------------------------------------
401
402/// Format a `TypeRef` as a Rust type string for use in trait method signatures.
403///
404/// `type_paths` qualifies `Named` types with their full Rust path (e.g., `"Config"` →
405/// `"kreuzberg::Config"`). If a name isn't in `type_paths`, it's used as-is.
406pub fn format_type_ref(ty: &alef_core::ir::TypeRef, type_paths: &HashMap<String, String>) -> String {
407    use alef_core::ir::{PrimitiveType, TypeRef};
408    match ty {
409        TypeRef::Primitive(p) => match p {
410            PrimitiveType::Bool => "bool",
411            PrimitiveType::U8 => "u8",
412            PrimitiveType::U16 => "u16",
413            PrimitiveType::U32 => "u32",
414            PrimitiveType::U64 => "u64",
415            PrimitiveType::I8 => "i8",
416            PrimitiveType::I16 => "i16",
417            PrimitiveType::I32 => "i32",
418            PrimitiveType::I64 => "i64",
419            PrimitiveType::F32 => "f32",
420            PrimitiveType::F64 => "f64",
421            PrimitiveType::Usize => "usize",
422            PrimitiveType::Isize => "isize",
423        }
424        .to_string(),
425        TypeRef::String => "String".to_string(),
426        TypeRef::Char => "char".to_string(),
427        TypeRef::Bytes => "Vec<u8>".to_string(),
428        TypeRef::Optional(inner) => format!("Option<{}>", format_type_ref(inner, type_paths)),
429        TypeRef::Vec(inner) => format!("Vec<{}>", format_type_ref(inner, type_paths)),
430        TypeRef::Map(k, v) => format!(
431            "std::collections::HashMap<{}, {}>",
432            format_type_ref(k, type_paths),
433            format_type_ref(v, type_paths)
434        ),
435        TypeRef::Named(name) => type_paths.get(name.as_str()).cloned().unwrap_or_else(|| name.clone()),
436        TypeRef::Path => "std::path::PathBuf".to_string(),
437        TypeRef::Unit => "()".to_string(),
438        TypeRef::Json => "serde_json::Value".to_string(),
439        TypeRef::Duration => "std::time::Duration".to_string(),
440    }
441}
442
443/// Format a return type, wrapping in `Result` when an error type is present.
444pub fn format_return_type(
445    ty: &alef_core::ir::TypeRef,
446    error_type: Option<&str>,
447    type_paths: &HashMap<String, String>,
448) -> String {
449    let inner = format_type_ref(ty, type_paths);
450    match error_type {
451        Some(err) => format!("std::result::Result<{inner}, {err}>"),
452        None => inner,
453    }
454}
455
456/// Format a parameter type, respecting `is_ref`, `is_mut`, and `optional` from the IR.
457///
458/// Unlike [`format_type_ref`], this function produces reference types when the
459/// original Rust parameter was a `&T` or `&mut T`, and wraps in `Option<>` when
460/// `param.optional` is true:
461/// - `String + is_ref` → `&str`
462/// - `String + is_ref + optional` → `Option<&str>`
463/// - `Bytes + is_ref` → `&[u8]`
464/// - `Path + is_ref` → `&std::path::Path`
465/// - `Vec<T> + is_ref` → `&[T]`
466/// - `Named(n) + is_ref` → `&{qualified_name}`
467pub fn format_param_type(param: &ParamDef, type_paths: &HashMap<String, String>) -> String {
468    use alef_core::ir::TypeRef;
469    let base = if param.is_ref {
470        let mutability = if param.is_mut { "mut " } else { "" };
471        match &param.ty {
472            TypeRef::String => format!("&{mutability}str"),
473            TypeRef::Bytes => format!("&{mutability}[u8]"),
474            TypeRef::Path => format!("&{mutability}std::path::Path"),
475            TypeRef::Vec(inner) => format!("&{mutability}[{}]", format_type_ref(inner, type_paths)),
476            TypeRef::Named(name) => {
477                let qualified = type_paths.get(name.as_str()).cloned().unwrap_or_else(|| name.clone());
478                format!("&{mutability}{qualified}")
479            }
480            TypeRef::Optional(inner) => {
481                // Preserve the Option wrapper but apply the ref transformation to the inner type.
482                // e.g. Option<String> + is_ref → Option<&str>
483                //      Option<Vec<T>> + is_ref → Option<&[T]>
484                let inner_type_str = match inner.as_ref() {
485                    TypeRef::String => format!("&{mutability}str"),
486                    TypeRef::Bytes => format!("&{mutability}[u8]"),
487                    TypeRef::Path => format!("&{mutability}std::path::Path"),
488                    TypeRef::Vec(v) => format!("&{mutability}[{}]", format_type_ref(v, type_paths)),
489                    TypeRef::Named(name) => {
490                        let qualified = type_paths.get(name.as_str()).cloned().unwrap_or_else(|| name.clone());
491                        format!("&{mutability}{qualified}")
492                    }
493                    // Primitives and other Copy types: pass by value inside Option
494                    other => format_type_ref(other, type_paths),
495                };
496                // Already wrapped in Option — return directly to avoid double-wrapping below.
497                return format!("Option<{inner_type_str}>");
498            }
499            // All other types are Copy/small — pass by value even when is_ref is set
500            other => format_type_ref(other, type_paths),
501        }
502    } else {
503        format_type_ref(&param.ty, type_paths)
504    };
505
506    // Wrap in Option<> when the parameter is optional (e.g. `title: Option<&str>`).
507    // The TypeRef::Optional arm above returns early, so this only fires for the
508    // `optional: true` IR flag pattern where ty is the unwrapped inner type.
509    if param.optional {
510        format!("Option<{base}>")
511    } else {
512        base
513    }
514}
515
516// ---------------------------------------------------------------------------
517// Shared helpers — used by all backend trait_bridge modules.
518// ---------------------------------------------------------------------------
519
520/// Map a Rust primitive to its type string.
521pub fn prim(p: &PrimitiveType) -> &'static str {
522    use PrimitiveType::*;
523    match p {
524        Bool => "bool",
525        U8 => "u8",
526        U16 => "u16",
527        U32 => "u32",
528        U64 => "u64",
529        I8 => "i8",
530        I16 => "i16",
531        I32 => "i32",
532        I64 => "i64",
533        F32 => "f32",
534        F64 => "f64",
535        Usize => "usize",
536        Isize => "isize",
537    }
538}
539
540/// Map a `TypeRef` to its Rust source type string for use in trait bridge method
541/// signatures. `ci` is the core import path (e.g. `"kreuzberg"`), `tp` maps
542/// type names to fully-qualified paths.
543pub fn bridge_param_type(ty: &TypeRef, ci: &str, is_ref: bool, tp: &HashMap<String, String>) -> String {
544    match ty {
545        TypeRef::Bytes if is_ref => "&[u8]".into(),
546        TypeRef::Bytes => "Vec<u8>".into(),
547        TypeRef::String if is_ref => "&str".into(),
548        TypeRef::String => "String".into(),
549        TypeRef::Path if is_ref => "&std::path::Path".into(),
550        TypeRef::Path => "std::path::PathBuf".into(),
551        TypeRef::Named(n) => {
552            let qualified = tp.get(n).cloned().unwrap_or_else(|| format!("{ci}::{n}"));
553            if is_ref { format!("&{qualified}") } else { qualified }
554        }
555        TypeRef::Vec(inner) => format!("Vec<{}>", bridge_param_type(inner, ci, false, tp)),
556        TypeRef::Optional(inner) => format!("Option<{}>", bridge_param_type(inner, ci, false, tp)),
557        TypeRef::Primitive(p) => prim(p).into(),
558        TypeRef::Unit => "()".into(),
559        TypeRef::Char => "char".into(),
560        TypeRef::Map(k, v) => format!(
561            "std::collections::HashMap<{}, {}>",
562            bridge_param_type(k, ci, false, tp),
563            bridge_param_type(v, ci, false, tp)
564        ),
565        TypeRef::Json => "serde_json::Value".into(),
566        TypeRef::Duration => "std::time::Duration".into(),
567    }
568}
569
570/// Map a visitor method parameter type to the correct Rust type string, handling
571/// IR quirks:
572/// - `ty=String, optional=true, is_ref=true` → `Option<&str>` (IR collapses `Option<&str>`)
573/// - `ty=Vec<T>, is_ref=true` → `&[T]` (IR collapses `&[T]`)
574/// - Everything else delegates to [`bridge_param_type`].
575pub fn visitor_param_type(ty: &TypeRef, is_ref: bool, optional: bool, tp: &HashMap<String, String>) -> String {
576    if optional && matches!(ty, TypeRef::String) && is_ref {
577        return "Option<&str>".to_string();
578    }
579    if is_ref {
580        if let TypeRef::Vec(inner) = ty {
581            let inner_str = bridge_param_type(inner, "", false, tp);
582            return format!("&[{inner_str}]");
583        }
584    }
585    bridge_param_type(ty, "", is_ref, tp)
586}
587
588/// Find the first function parameter that matches a trait bridge configuration
589/// (by type alias or parameter name).
590pub fn find_bridge_param<'a>(
591    func: &FunctionDef,
592    bridges: &'a [TraitBridgeConfig],
593) -> Option<(usize, &'a TraitBridgeConfig)> {
594    for (idx, param) in func.params.iter().enumerate() {
595        let named = match &param.ty {
596            TypeRef::Named(n) => Some(n.as_str()),
597            TypeRef::Optional(inner) => {
598                if let TypeRef::Named(n) = inner.as_ref() {
599                    Some(n.as_str())
600                } else {
601                    None
602                }
603            }
604            _ => None,
605        };
606        for bridge in bridges {
607            if let Some(type_name) = named {
608                if bridge.type_alias.as_deref() == Some(type_name) {
609                    return Some((idx, bridge));
610                }
611            }
612            if bridge.param_name.as_deref() == Some(param.name.as_str()) {
613                return Some((idx, bridge));
614            }
615        }
616    }
617    None
618}
619
620/// Convert a snake_case string to camelCase.
621pub fn to_camel_case(s: &str) -> String {
622    let mut result = String::new();
623    let mut capitalize_next = false;
624    for ch in s.chars() {
625        if ch == '_' {
626            capitalize_next = true;
627        } else if capitalize_next {
628            result.push(ch.to_ascii_uppercase());
629            capitalize_next = false;
630        } else {
631            result.push(ch);
632        }
633    }
634    result
635}
636
637#[cfg(test)]
638mod tests {
639    use super::*;
640    use alef_core::config::TraitBridgeConfig;
641    use alef_core::ir::{MethodDef, ParamDef, PrimitiveType, ReceiverKind, TypeDef, TypeRef};
642
643    // ---------------------------------------------------------------------------
644    // Test helpers
645    // ---------------------------------------------------------------------------
646
647    fn make_trait_bridge_config(super_trait: Option<&str>, register_fn: Option<&str>) -> TraitBridgeConfig {
648        TraitBridgeConfig {
649            trait_name: "OcrBackend".to_string(),
650            super_trait: super_trait.map(str::to_string),
651            registry_getter: None,
652            register_fn: register_fn.map(str::to_string),
653            type_alias: None,
654            param_name: None,
655            register_extra_args: None,
656            exclude_languages: Vec::new(),
657        }
658    }
659
660    fn make_type_def(name: &str, rust_path: &str, methods: Vec<MethodDef>) -> TypeDef {
661        TypeDef {
662            name: name.to_string(),
663            rust_path: rust_path.to_string(),
664            original_rust_path: rust_path.to_string(),
665            fields: vec![],
666            methods,
667            is_opaque: true,
668            is_clone: false,
669            doc: String::new(),
670            cfg: None,
671            is_trait: true,
672            has_default: false,
673            has_stripped_cfg_fields: false,
674            is_return_type: false,
675            serde_rename_all: None,
676            has_serde: false,
677            super_traits: vec![],
678        }
679    }
680
681    fn make_method(
682        name: &str,
683        params: Vec<ParamDef>,
684        return_type: TypeRef,
685        is_async: bool,
686        has_default_impl: bool,
687        trait_source: Option<&str>,
688        error_type: Option<&str>,
689    ) -> MethodDef {
690        MethodDef {
691            name: name.to_string(),
692            params,
693            return_type,
694            is_async,
695            is_static: false,
696            error_type: error_type.map(str::to_string),
697            doc: String::new(),
698            receiver: Some(ReceiverKind::Ref),
699            sanitized: false,
700            trait_source: trait_source.map(str::to_string),
701            returns_ref: false,
702            returns_cow: false,
703            return_newtype_wrapper: None,
704            has_default_impl,
705        }
706    }
707
708    fn make_param(name: &str, ty: TypeRef, is_ref: bool) -> ParamDef {
709        ParamDef {
710            name: name.to_string(),
711            ty,
712            optional: false,
713            default: None,
714            sanitized: false,
715            typed_default: None,
716            is_ref,
717            is_mut: false,
718            newtype_wrapper: None,
719            original_type: None,
720        }
721    }
722
723    fn make_spec<'a>(
724        trait_def: &'a TypeDef,
725        bridge_config: &'a TraitBridgeConfig,
726        wrapper_prefix: &'a str,
727        type_paths: HashMap<String, String>,
728    ) -> TraitBridgeSpec<'a> {
729        TraitBridgeSpec {
730            trait_def,
731            bridge_config,
732            core_import: "mylib",
733            wrapper_prefix,
734            type_paths,
735            error_type: "MyError".to_string(),
736            error_constructor: "MyError::from({msg})".to_string(),
737        }
738    }
739
740    // ---------------------------------------------------------------------------
741    // Mock backend
742    // ---------------------------------------------------------------------------
743
744    struct MockBridgeGenerator;
745
746    impl TraitBridgeGenerator for MockBridgeGenerator {
747        fn foreign_object_type(&self) -> &str {
748            "Py<PyAny>"
749        }
750
751        fn bridge_imports(&self) -> Vec<String> {
752            vec!["pyo3::prelude::*".to_string(), "pyo3::types::PyString".to_string()]
753        }
754
755        fn gen_sync_method_body(&self, method: &MethodDef, _spec: &TraitBridgeSpec) -> String {
756            format!("// sync body for {}", method.name)
757        }
758
759        fn gen_async_method_body(&self, method: &MethodDef, _spec: &TraitBridgeSpec) -> String {
760            format!("// async body for {}", method.name)
761        }
762
763        fn gen_constructor(&self, spec: &TraitBridgeSpec) -> String {
764            format!(
765                "impl {} {{\n    pub fn new(obj: Py<PyAny>) -> Self {{ Self {{ inner: obj, cached_name: String::new() }} }}\n}}",
766                spec.wrapper_name()
767            )
768        }
769
770        fn gen_registration_fn(&self, spec: &TraitBridgeSpec) -> String {
771            let fn_name = spec.bridge_config.register_fn.as_deref().unwrap_or("register");
772            format!("pub fn {fn_name}(obj: Py<PyAny>) {{ /* register */ }}")
773        }
774    }
775
776    // ---------------------------------------------------------------------------
777    // TraitBridgeSpec helpers
778    // ---------------------------------------------------------------------------
779
780    #[test]
781    fn test_wrapper_name() {
782        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
783        let config = make_trait_bridge_config(None, None);
784        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
785        assert_eq!(spec.wrapper_name(), "PyOcrBackendBridge");
786    }
787
788    #[test]
789    fn test_trait_snake() {
790        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
791        let config = make_trait_bridge_config(None, None);
792        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
793        assert_eq!(spec.trait_snake(), "ocr_backend");
794    }
795
796    #[test]
797    fn test_trait_path_replaces_hyphens() {
798        let trait_def = make_type_def("OcrBackend", "my-lib::OcrBackend", vec![]);
799        let config = make_trait_bridge_config(None, None);
800        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
801        assert_eq!(spec.trait_path(), "my_lib::OcrBackend");
802    }
803
804    #[test]
805    fn test_required_methods_filters_no_default_impl() {
806        let methods = vec![
807            make_method("process", vec![], TypeRef::String, false, false, None, None),
808            make_method("initialize", vec![], TypeRef::Unit, false, true, None, None),
809            make_method("detect", vec![], TypeRef::String, false, false, None, None),
810        ];
811        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
812        let config = make_trait_bridge_config(None, None);
813        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
814        let required = spec.required_methods();
815        assert_eq!(required.len(), 2);
816        assert!(required.iter().any(|m| m.name == "process"));
817        assert!(required.iter().any(|m| m.name == "detect"));
818    }
819
820    #[test]
821    fn test_optional_methods_filters_has_default_impl() {
822        let methods = vec![
823            make_method("process", vec![], TypeRef::String, false, false, None, None),
824            make_method("initialize", vec![], TypeRef::Unit, false, true, None, None),
825            make_method("shutdown", vec![], TypeRef::Unit, false, true, None, None),
826        ];
827        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
828        let config = make_trait_bridge_config(None, None);
829        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
830        let optional = spec.optional_methods();
831        assert_eq!(optional.len(), 2);
832        assert!(optional.iter().any(|m| m.name == "initialize"));
833        assert!(optional.iter().any(|m| m.name == "shutdown"));
834    }
835
836    #[test]
837    fn test_error_path() {
838        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
839        let config = make_trait_bridge_config(None, None);
840        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
841        assert_eq!(spec.error_path(), "mylib::MyError");
842    }
843
844    // ---------------------------------------------------------------------------
845    // format_type_ref
846    // ---------------------------------------------------------------------------
847
848    #[test]
849    fn test_format_type_ref_primitives() {
850        let paths = HashMap::new();
851        let cases: Vec<(TypeRef, &str)> = vec![
852            (TypeRef::Primitive(PrimitiveType::Bool), "bool"),
853            (TypeRef::Primitive(PrimitiveType::U8), "u8"),
854            (TypeRef::Primitive(PrimitiveType::U16), "u16"),
855            (TypeRef::Primitive(PrimitiveType::U32), "u32"),
856            (TypeRef::Primitive(PrimitiveType::U64), "u64"),
857            (TypeRef::Primitive(PrimitiveType::I8), "i8"),
858            (TypeRef::Primitive(PrimitiveType::I16), "i16"),
859            (TypeRef::Primitive(PrimitiveType::I32), "i32"),
860            (TypeRef::Primitive(PrimitiveType::I64), "i64"),
861            (TypeRef::Primitive(PrimitiveType::F32), "f32"),
862            (TypeRef::Primitive(PrimitiveType::F64), "f64"),
863            (TypeRef::Primitive(PrimitiveType::Usize), "usize"),
864            (TypeRef::Primitive(PrimitiveType::Isize), "isize"),
865        ];
866        for (ty, expected) in cases {
867            assert_eq!(format_type_ref(&ty, &paths), expected, "mismatch for {expected}");
868        }
869    }
870
871    #[test]
872    fn test_format_type_ref_string() {
873        assert_eq!(format_type_ref(&TypeRef::String, &HashMap::new()), "String");
874    }
875
876    #[test]
877    fn test_format_type_ref_char() {
878        assert_eq!(format_type_ref(&TypeRef::Char, &HashMap::new()), "char");
879    }
880
881    #[test]
882    fn test_format_type_ref_bytes() {
883        assert_eq!(format_type_ref(&TypeRef::Bytes, &HashMap::new()), "Vec<u8>");
884    }
885
886    #[test]
887    fn test_format_type_ref_path() {
888        assert_eq!(format_type_ref(&TypeRef::Path, &HashMap::new()), "std::path::PathBuf");
889    }
890
891    #[test]
892    fn test_format_type_ref_unit() {
893        assert_eq!(format_type_ref(&TypeRef::Unit, &HashMap::new()), "()");
894    }
895
896    #[test]
897    fn test_format_type_ref_json() {
898        assert_eq!(format_type_ref(&TypeRef::Json, &HashMap::new()), "serde_json::Value");
899    }
900
901    #[test]
902    fn test_format_type_ref_duration() {
903        assert_eq!(
904            format_type_ref(&TypeRef::Duration, &HashMap::new()),
905            "std::time::Duration"
906        );
907    }
908
909    #[test]
910    fn test_format_type_ref_optional() {
911        let ty = TypeRef::Optional(Box::new(TypeRef::String));
912        assert_eq!(format_type_ref(&ty, &HashMap::new()), "Option<String>");
913    }
914
915    #[test]
916    fn test_format_type_ref_optional_nested() {
917        let ty = TypeRef::Optional(Box::new(TypeRef::Optional(Box::new(TypeRef::Primitive(
918            PrimitiveType::U32,
919        )))));
920        assert_eq!(format_type_ref(&ty, &HashMap::new()), "Option<Option<u32>>");
921    }
922
923    #[test]
924    fn test_format_type_ref_vec() {
925        let ty = TypeRef::Vec(Box::new(TypeRef::Primitive(PrimitiveType::U8)));
926        assert_eq!(format_type_ref(&ty, &HashMap::new()), "Vec<u8>");
927    }
928
929    #[test]
930    fn test_format_type_ref_vec_nested() {
931        let ty = TypeRef::Vec(Box::new(TypeRef::Vec(Box::new(TypeRef::String))));
932        assert_eq!(format_type_ref(&ty, &HashMap::new()), "Vec<Vec<String>>");
933    }
934
935    #[test]
936    fn test_format_type_ref_map() {
937        let ty = TypeRef::Map(
938            Box::new(TypeRef::String),
939            Box::new(TypeRef::Primitive(PrimitiveType::I64)),
940        );
941        assert_eq!(
942            format_type_ref(&ty, &HashMap::new()),
943            "std::collections::HashMap<String, i64>"
944        );
945    }
946
947    #[test]
948    fn test_format_type_ref_map_nested_value() {
949        let ty = TypeRef::Map(
950            Box::new(TypeRef::String),
951            Box::new(TypeRef::Vec(Box::new(TypeRef::String))),
952        );
953        assert_eq!(
954            format_type_ref(&ty, &HashMap::new()),
955            "std::collections::HashMap<String, Vec<String>>"
956        );
957    }
958
959    #[test]
960    fn test_format_type_ref_named_without_type_paths() {
961        let ty = TypeRef::Named("Config".to_string());
962        assert_eq!(format_type_ref(&ty, &HashMap::new()), "Config");
963    }
964
965    #[test]
966    fn test_format_type_ref_named_with_type_paths() {
967        let ty = TypeRef::Named("Config".to_string());
968        let mut paths = HashMap::new();
969        paths.insert("Config".to_string(), "mylib::Config".to_string());
970        assert_eq!(format_type_ref(&ty, &paths), "mylib::Config");
971    }
972
973    #[test]
974    fn test_format_type_ref_named_not_in_type_paths_falls_back_to_name() {
975        let ty = TypeRef::Named("Unknown".to_string());
976        let mut paths = HashMap::new();
977        paths.insert("Other".to_string(), "mylib::Other".to_string());
978        assert_eq!(format_type_ref(&ty, &paths), "Unknown");
979    }
980
981    // ---------------------------------------------------------------------------
982    // format_param_type
983    // ---------------------------------------------------------------------------
984
985    #[test]
986    fn test_format_param_type_string_ref() {
987        let param = make_param("input", TypeRef::String, true);
988        assert_eq!(format_param_type(&param, &HashMap::new()), "&str");
989    }
990
991    #[test]
992    fn test_format_param_type_string_owned() {
993        let param = make_param("input", TypeRef::String, false);
994        assert_eq!(format_param_type(&param, &HashMap::new()), "String");
995    }
996
997    #[test]
998    fn test_format_param_type_bytes_ref() {
999        let param = make_param("data", TypeRef::Bytes, true);
1000        assert_eq!(format_param_type(&param, &HashMap::new()), "&[u8]");
1001    }
1002
1003    #[test]
1004    fn test_format_param_type_bytes_owned() {
1005        let param = make_param("data", TypeRef::Bytes, false);
1006        assert_eq!(format_param_type(&param, &HashMap::new()), "Vec<u8>");
1007    }
1008
1009    #[test]
1010    fn test_format_param_type_path_ref() {
1011        let param = make_param("path", TypeRef::Path, true);
1012        assert_eq!(format_param_type(&param, &HashMap::new()), "&std::path::Path");
1013    }
1014
1015    #[test]
1016    fn test_format_param_type_path_owned() {
1017        let param = make_param("path", TypeRef::Path, false);
1018        assert_eq!(format_param_type(&param, &HashMap::new()), "std::path::PathBuf");
1019    }
1020
1021    #[test]
1022    fn test_format_param_type_vec_ref() {
1023        let param = make_param("items", TypeRef::Vec(Box::new(TypeRef::String)), true);
1024        assert_eq!(format_param_type(&param, &HashMap::new()), "&[String]");
1025    }
1026
1027    #[test]
1028    fn test_format_param_type_vec_owned() {
1029        let param = make_param("items", TypeRef::Vec(Box::new(TypeRef::String)), false);
1030        assert_eq!(format_param_type(&param, &HashMap::new()), "Vec<String>");
1031    }
1032
1033    #[test]
1034    fn test_format_param_type_named_ref_with_type_paths() {
1035        let mut paths = HashMap::new();
1036        paths.insert("Config".to_string(), "mylib::Config".to_string());
1037        let param = make_param("cfg", TypeRef::Named("Config".to_string()), true);
1038        assert_eq!(format_param_type(&param, &paths), "&mylib::Config");
1039    }
1040
1041    #[test]
1042    fn test_format_param_type_named_ref_without_type_paths() {
1043        let param = make_param("cfg", TypeRef::Named("Config".to_string()), true);
1044        assert_eq!(format_param_type(&param, &HashMap::new()), "&Config");
1045    }
1046
1047    #[test]
1048    fn test_format_param_type_primitive_ref_passes_by_value() {
1049        // Copy types like u32 are passed by value even when is_ref is set
1050        let param = make_param("count", TypeRef::Primitive(PrimitiveType::U32), true);
1051        assert_eq!(format_param_type(&param, &HashMap::new()), "u32");
1052    }
1053
1054    #[test]
1055    fn test_format_param_type_unit_ref_passes_by_value() {
1056        let param = make_param("nothing", TypeRef::Unit, true);
1057        assert_eq!(format_param_type(&param, &HashMap::new()), "()");
1058    }
1059
1060    // ---------------------------------------------------------------------------
1061    // format_return_type
1062    // ---------------------------------------------------------------------------
1063
1064    #[test]
1065    fn test_format_return_type_without_error() {
1066        let result = format_return_type(&TypeRef::String, None, &HashMap::new());
1067        assert_eq!(result, "String");
1068    }
1069
1070    #[test]
1071    fn test_format_return_type_with_error() {
1072        let result = format_return_type(&TypeRef::String, Some("MyError"), &HashMap::new());
1073        assert_eq!(result, "std::result::Result<String, MyError>");
1074    }
1075
1076    #[test]
1077    fn test_format_return_type_unit_with_error() {
1078        let result = format_return_type(&TypeRef::Unit, Some("Box<dyn std::error::Error>"), &HashMap::new());
1079        assert_eq!(result, "std::result::Result<(), Box<dyn std::error::Error>>");
1080    }
1081
1082    #[test]
1083    fn test_format_return_type_named_with_type_paths_and_error() {
1084        let mut paths = HashMap::new();
1085        paths.insert("Output".to_string(), "mylib::Output".to_string());
1086        let result = format_return_type(&TypeRef::Named("Output".to_string()), Some("mylib::MyError"), &paths);
1087        assert_eq!(result, "std::result::Result<mylib::Output, mylib::MyError>");
1088    }
1089
1090    // ---------------------------------------------------------------------------
1091    // gen_bridge_wrapper_struct
1092    // ---------------------------------------------------------------------------
1093
1094    #[test]
1095    fn test_gen_bridge_wrapper_struct_contains_struct_name() {
1096        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1097        let config = make_trait_bridge_config(None, None);
1098        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1099        let generator = MockBridgeGenerator;
1100        let result = gen_bridge_wrapper_struct(&spec, &generator);
1101        assert!(
1102            result.contains("pub struct PyOcrBackendBridge"),
1103            "missing struct declaration in:\n{result}"
1104        );
1105    }
1106
1107    #[test]
1108    fn test_gen_bridge_wrapper_struct_contains_inner_field() {
1109        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1110        let config = make_trait_bridge_config(None, None);
1111        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1112        let generator = MockBridgeGenerator;
1113        let result = gen_bridge_wrapper_struct(&spec, &generator);
1114        assert!(result.contains("inner: Py<PyAny>"), "missing inner field in:\n{result}");
1115    }
1116
1117    #[test]
1118    fn test_gen_bridge_wrapper_struct_contains_cached_name() {
1119        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1120        let config = make_trait_bridge_config(None, None);
1121        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1122        let generator = MockBridgeGenerator;
1123        let result = gen_bridge_wrapper_struct(&spec, &generator);
1124        assert!(
1125            result.contains("cached_name: String"),
1126            "missing cached_name field in:\n{result}"
1127        );
1128    }
1129
1130    // ---------------------------------------------------------------------------
1131    // gen_bridge_plugin_impl
1132    // ---------------------------------------------------------------------------
1133
1134    #[test]
1135    fn test_gen_bridge_plugin_impl_returns_none_when_no_super_trait() {
1136        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1137        let config = make_trait_bridge_config(None, None);
1138        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1139        let generator = MockBridgeGenerator;
1140        assert!(gen_bridge_plugin_impl(&spec, &generator).is_none());
1141    }
1142
1143    #[test]
1144    fn test_gen_bridge_plugin_impl_returns_some_when_super_trait_configured() {
1145        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1146        let config = make_trait_bridge_config(Some("Plugin"), None);
1147        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1148        let generator = MockBridgeGenerator;
1149        assert!(gen_bridge_plugin_impl(&spec, &generator).is_some());
1150    }
1151
1152    #[test]
1153    fn test_gen_bridge_plugin_impl_uses_qualified_super_trait_path() {
1154        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1155        let config = make_trait_bridge_config(Some("Plugin"), None);
1156        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1157        let generator = MockBridgeGenerator;
1158        let result = gen_bridge_plugin_impl(&spec, &generator).unwrap();
1159        assert!(
1160            result.contains("impl mylib::Plugin for PyOcrBackendBridge"),
1161            "missing qualified super-trait path in:\n{result}"
1162        );
1163    }
1164
1165    #[test]
1166    fn test_gen_bridge_plugin_impl_uses_already_qualified_super_trait_path() {
1167        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1168        let config = make_trait_bridge_config(Some("other_crate::Plugin"), None);
1169        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1170        let generator = MockBridgeGenerator;
1171        let result = gen_bridge_plugin_impl(&spec, &generator).unwrap();
1172        assert!(
1173            result.contains("impl other_crate::Plugin for PyOcrBackendBridge"),
1174            "wrong super-trait path in:\n{result}"
1175        );
1176    }
1177
1178    #[test]
1179    fn test_gen_bridge_plugin_impl_contains_name_fn() {
1180        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1181        let config = make_trait_bridge_config(Some("Plugin"), None);
1182        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1183        let generator = MockBridgeGenerator;
1184        let result = gen_bridge_plugin_impl(&spec, &generator).unwrap();
1185        assert!(
1186            result.contains("fn name(") && result.contains("cached_name"),
1187            "missing name() using cached_name in:\n{result}"
1188        );
1189    }
1190
1191    #[test]
1192    fn test_gen_bridge_plugin_impl_contains_version_fn() {
1193        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1194        let config = make_trait_bridge_config(Some("Plugin"), None);
1195        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1196        let generator = MockBridgeGenerator;
1197        let result = gen_bridge_plugin_impl(&spec, &generator).unwrap();
1198        assert!(result.contains("fn version("), "missing version() in:\n{result}");
1199    }
1200
1201    #[test]
1202    fn test_gen_bridge_plugin_impl_contains_initialize_fn() {
1203        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1204        let config = make_trait_bridge_config(Some("Plugin"), None);
1205        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1206        let generator = MockBridgeGenerator;
1207        let result = gen_bridge_plugin_impl(&spec, &generator).unwrap();
1208        assert!(result.contains("fn initialize("), "missing initialize() in:\n{result}");
1209    }
1210
1211    #[test]
1212    fn test_gen_bridge_plugin_impl_contains_shutdown_fn() {
1213        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1214        let config = make_trait_bridge_config(Some("Plugin"), None);
1215        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1216        let generator = MockBridgeGenerator;
1217        let result = gen_bridge_plugin_impl(&spec, &generator).unwrap();
1218        assert!(result.contains("fn shutdown("), "missing shutdown() in:\n{result}");
1219    }
1220
1221    // ---------------------------------------------------------------------------
1222    // gen_bridge_trait_impl
1223    // ---------------------------------------------------------------------------
1224
1225    #[test]
1226    fn test_gen_bridge_trait_impl_includes_impl_header() {
1227        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1228        let config = make_trait_bridge_config(None, None);
1229        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1230        let generator = MockBridgeGenerator;
1231        let result = gen_bridge_trait_impl(&spec, &generator);
1232        assert!(
1233            result.contains("impl mylib::OcrBackend for PyOcrBackendBridge"),
1234            "missing impl header in:\n{result}"
1235        );
1236    }
1237
1238    #[test]
1239    fn test_gen_bridge_trait_impl_includes_method_signatures() {
1240        let methods = vec![make_method(
1241            "process",
1242            vec![],
1243            TypeRef::String,
1244            false,
1245            false,
1246            None,
1247            None,
1248        )];
1249        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1250        let config = make_trait_bridge_config(None, None);
1251        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1252        let generator = MockBridgeGenerator;
1253        let result = gen_bridge_trait_impl(&spec, &generator);
1254        assert!(result.contains("fn process("), "missing method signature in:\n{result}");
1255    }
1256
1257    #[test]
1258    fn test_gen_bridge_trait_impl_includes_method_body_from_generator() {
1259        let methods = vec![make_method(
1260            "process",
1261            vec![],
1262            TypeRef::String,
1263            false,
1264            false,
1265            None,
1266            None,
1267        )];
1268        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1269        let config = make_trait_bridge_config(None, None);
1270        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1271        let generator = MockBridgeGenerator;
1272        let result = gen_bridge_trait_impl(&spec, &generator);
1273        assert!(
1274            result.contains("// sync body for process"),
1275            "missing sync method body in:\n{result}"
1276        );
1277    }
1278
1279    #[test]
1280    fn test_gen_bridge_trait_impl_async_method_uses_async_body() {
1281        let methods = vec![make_method(
1282            "process_async",
1283            vec![],
1284            TypeRef::String,
1285            true,
1286            false,
1287            None,
1288            None,
1289        )];
1290        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1291        let config = make_trait_bridge_config(None, None);
1292        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1293        let generator = MockBridgeGenerator;
1294        let result = gen_bridge_trait_impl(&spec, &generator);
1295        assert!(
1296            result.contains("// async body for process_async"),
1297            "missing async method body in:\n{result}"
1298        );
1299        assert!(
1300            result.contains("async fn process_async("),
1301            "missing async keyword in method signature in:\n{result}"
1302        );
1303    }
1304
1305    #[test]
1306    fn test_gen_bridge_trait_impl_filters_trait_source_methods() {
1307        // Methods with trait_source set come from super-traits and should be excluded
1308        let methods = vec![
1309            make_method("own_method", vec![], TypeRef::String, false, false, None, None),
1310            make_method(
1311                "inherited_method",
1312                vec![],
1313                TypeRef::String,
1314                false,
1315                false,
1316                Some("other_crate::OtherTrait"),
1317                None,
1318            ),
1319        ];
1320        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1321        let config = make_trait_bridge_config(None, None);
1322        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1323        let generator = MockBridgeGenerator;
1324        let result = gen_bridge_trait_impl(&spec, &generator);
1325        assert!(
1326            result.contains("fn own_method("),
1327            "own method should be present in:\n{result}"
1328        );
1329        assert!(
1330            !result.contains("fn inherited_method("),
1331            "inherited method should be filtered out in:\n{result}"
1332        );
1333    }
1334
1335    #[test]
1336    fn test_gen_bridge_trait_impl_method_with_params() {
1337        let params = vec![
1338            make_param("input", TypeRef::String, true),
1339            make_param("count", TypeRef::Primitive(PrimitiveType::U32), false),
1340        ];
1341        let methods = vec![make_method(
1342            "process",
1343            params,
1344            TypeRef::String,
1345            false,
1346            false,
1347            None,
1348            None,
1349        )];
1350        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1351        let config = make_trait_bridge_config(None, None);
1352        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1353        let generator = MockBridgeGenerator;
1354        let result = gen_bridge_trait_impl(&spec, &generator);
1355        assert!(result.contains("input: &str"), "missing &str param in:\n{result}");
1356        assert!(result.contains("count: u32"), "missing u32 param in:\n{result}");
1357    }
1358
1359    #[test]
1360    fn test_gen_bridge_trait_impl_return_type_with_error() {
1361        let methods = vec![make_method(
1362            "process",
1363            vec![],
1364            TypeRef::String,
1365            false,
1366            false,
1367            None,
1368            Some("MyError"),
1369        )];
1370        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1371        let config = make_trait_bridge_config(None, None);
1372        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1373        let generator = MockBridgeGenerator;
1374        let result = gen_bridge_trait_impl(&spec, &generator);
1375        assert!(
1376            result.contains("-> std::result::Result<String, mylib::MyError>"),
1377            "missing std::result::Result return type in:\n{result}"
1378        );
1379    }
1380
1381    // ---------------------------------------------------------------------------
1382    // gen_bridge_registration_fn
1383    // ---------------------------------------------------------------------------
1384
1385    #[test]
1386    fn test_gen_bridge_registration_fn_returns_none_without_register_fn() {
1387        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1388        let config = make_trait_bridge_config(None, None);
1389        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1390        let generator = MockBridgeGenerator;
1391        assert!(gen_bridge_registration_fn(&spec, &generator).is_none());
1392    }
1393
1394    #[test]
1395    fn test_gen_bridge_registration_fn_returns_some_with_register_fn() {
1396        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1397        let config = make_trait_bridge_config(None, Some("register_ocr_backend"));
1398        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1399        let generator = MockBridgeGenerator;
1400        let result = gen_bridge_registration_fn(&spec, &generator);
1401        assert!(result.is_some());
1402        let code = result.unwrap();
1403        assert!(
1404            code.contains("register_ocr_backend"),
1405            "missing register fn name in:\n{code}"
1406        );
1407    }
1408
1409    // ---------------------------------------------------------------------------
1410    // gen_bridge_all
1411    // ---------------------------------------------------------------------------
1412
1413    #[test]
1414    fn test_gen_bridge_all_includes_imports() {
1415        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1416        let config = make_trait_bridge_config(None, None);
1417        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1418        let generator = MockBridgeGenerator;
1419        let output = gen_bridge_all(&spec, &generator);
1420        assert!(output.imports.contains(&"pyo3::prelude::*".to_string()));
1421        assert!(output.imports.contains(&"pyo3::types::PyString".to_string()));
1422    }
1423
1424    #[test]
1425    fn test_gen_bridge_all_includes_wrapper_struct() {
1426        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1427        let config = make_trait_bridge_config(None, None);
1428        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1429        let generator = MockBridgeGenerator;
1430        let output = gen_bridge_all(&spec, &generator);
1431        assert!(
1432            output.code.contains("pub struct PyOcrBackendBridge"),
1433            "missing struct in:\n{}",
1434            output.code
1435        );
1436    }
1437
1438    #[test]
1439    fn test_gen_bridge_all_includes_constructor() {
1440        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1441        let config = make_trait_bridge_config(None, None);
1442        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1443        let generator = MockBridgeGenerator;
1444        let output = gen_bridge_all(&spec, &generator);
1445        assert!(
1446            output.code.contains("pub fn new("),
1447            "missing constructor in:\n{}",
1448            output.code
1449        );
1450    }
1451
1452    #[test]
1453    fn test_gen_bridge_all_includes_trait_impl() {
1454        let methods = vec![make_method(
1455            "process",
1456            vec![],
1457            TypeRef::String,
1458            false,
1459            false,
1460            None,
1461            None,
1462        )];
1463        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1464        let config = make_trait_bridge_config(None, None);
1465        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1466        let generator = MockBridgeGenerator;
1467        let output = gen_bridge_all(&spec, &generator);
1468        assert!(
1469            output.code.contains("impl mylib::OcrBackend for PyOcrBackendBridge"),
1470            "missing trait impl in:\n{}",
1471            output.code
1472        );
1473    }
1474
1475    #[test]
1476    fn test_gen_bridge_all_includes_plugin_impl_when_super_trait_set() {
1477        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1478        let config = make_trait_bridge_config(Some("Plugin"), None);
1479        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1480        let generator = MockBridgeGenerator;
1481        let output = gen_bridge_all(&spec, &generator);
1482        assert!(
1483            output.code.contains("impl mylib::Plugin for PyOcrBackendBridge"),
1484            "missing plugin impl in:\n{}",
1485            output.code
1486        );
1487    }
1488
1489    #[test]
1490    fn test_gen_bridge_all_no_plugin_impl_when_no_super_trait() {
1491        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1492        let config = make_trait_bridge_config(None, None);
1493        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1494        let generator = MockBridgeGenerator;
1495        let output = gen_bridge_all(&spec, &generator);
1496        assert!(
1497            !output.code.contains("fn name(") || !output.code.contains("cached_name"),
1498            "unexpected plugin impl present without super_trait"
1499        );
1500    }
1501
1502    #[test]
1503    fn test_gen_bridge_all_includes_registration_fn_when_configured() {
1504        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1505        let config = make_trait_bridge_config(None, Some("register_ocr_backend"));
1506        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1507        let generator = MockBridgeGenerator;
1508        let output = gen_bridge_all(&spec, &generator);
1509        assert!(
1510            output.code.contains("register_ocr_backend"),
1511            "missing registration fn in:\n{}",
1512            output.code
1513        );
1514    }
1515
1516    #[test]
1517    fn test_gen_bridge_all_no_registration_fn_when_absent() {
1518        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1519        let config = make_trait_bridge_config(None, None);
1520        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1521        let generator = MockBridgeGenerator;
1522        let output = gen_bridge_all(&spec, &generator);
1523        assert!(
1524            !output.code.contains("register_ocr_backend"),
1525            "unexpected registration fn present:\n{}",
1526            output.code
1527        );
1528    }
1529
1530    #[test]
1531    fn test_gen_bridge_all_ordering_struct_before_trait_impl() {
1532        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1533        let config = make_trait_bridge_config(None, None);
1534        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1535        let generator = MockBridgeGenerator;
1536        let output = gen_bridge_all(&spec, &generator);
1537        let struct_pos = output.code.find("pub struct PyOcrBackendBridge").unwrap();
1538        let impl_pos = output
1539            .code
1540            .find("impl mylib::OcrBackend for PyOcrBackendBridge")
1541            .unwrap();
1542        assert!(struct_pos < impl_pos, "struct should appear before trait impl");
1543    }
1544}