Skip to main content

alef_codegen/generators/
trait_bridge.rs

1//! Shared trait bridge code generation.
2//!
3//! Generates wrapper structs that allow foreign language objects (Python, JS, etc.)
4//! to implement Rust traits via FFI. Each backend implements [`TraitBridgeGenerator`]
5//! to provide language-specific dispatch logic; the shared functions in this module
6//! handle the structural boilerplate.
7
8use alef_core::config::TraitBridgeConfig;
9use alef_core::ir::{FunctionDef, MethodDef, ParamDef, PrimitiveType, TypeDef, TypeRef};
10use heck::ToSnakeCase;
11use std::collections::HashMap;
12use std::fmt::Write;
13
14/// Everything needed to generate a trait bridge for one trait.
15pub struct TraitBridgeSpec<'a> {
16    /// The trait definition from the IR.
17    pub trait_def: &'a TypeDef,
18    /// Bridge configuration from `alef.toml`.
19    pub bridge_config: &'a TraitBridgeConfig,
20    /// Core crate import path (e.g., `"kreuzberg"`).
21    pub core_import: &'a str,
22    /// Language-specific prefix for the wrapper type (e.g., `"Python"`, `"Js"`, `"Wasm"`).
23    pub wrapper_prefix: &'a str,
24    /// Map of type name → fully-qualified Rust path for qualifying `Named` types.
25    pub type_paths: HashMap<String, String>,
26    /// The crate's error type name (e.g., `"KreuzbergError"`). Defaults to `"Error"`.
27    pub error_type: String,
28    /// Error constructor pattern. `{msg}` is replaced with the message expression.
29    pub error_constructor: String,
30}
31
32impl<'a> TraitBridgeSpec<'a> {
33    /// Fully qualified error type path (e.g., `"kreuzberg::KreuzbergError"`).
34    ///
35    /// If `error_type` already looks fully-qualified (contains `::`) or is a generic
36    /// type expression (contains `<`), it is returned as-is without prefixing
37    /// `core_import`. This lets backends specify rich error types like
38    /// `"Box<dyn std::error::Error + Send + Sync>"` directly.
39    pub fn error_path(&self) -> String {
40        if self.error_type.contains("::") || self.error_type.contains('<') {
41            self.error_type.clone()
42        } else {
43            format!("{}::{}", self.core_import, self.error_type)
44        }
45    }
46
47    /// Generate an error construction expression from a message expression.
48    pub fn make_error(&self, msg_expr: &str) -> String {
49        self.error_constructor.replace("{msg}", msg_expr)
50    }
51
52    /// Wrapper struct name: `{prefix}{TraitName}Bridge` (e.g., `PythonOcrBackendBridge`).
53    pub fn wrapper_name(&self) -> String {
54        format!("{}{}Bridge", self.wrapper_prefix, self.trait_def.name)
55    }
56
57    /// Snake-case version of the trait name (e.g., `"ocr_backend"`).
58    pub fn trait_snake(&self) -> String {
59        self.trait_def.name.to_snake_case()
60    }
61
62    /// Full Rust path to the trait (e.g., `kreuzberg::OcrBackend`).
63    pub fn trait_path(&self) -> String {
64        self.trait_def.rust_path.replace('-', "_")
65    }
66
67    /// Methods that are required (no default impl) — must be provided by the foreign object.
68    pub fn required_methods(&self) -> Vec<&'a MethodDef> {
69        self.trait_def.methods.iter().filter(|m| !m.has_default_impl).collect()
70    }
71
72    /// Methods that have a default impl — optional on the foreign object.
73    pub fn optional_methods(&self) -> Vec<&'a MethodDef> {
74        self.trait_def.methods.iter().filter(|m| m.has_default_impl).collect()
75    }
76}
77
78/// Backend-specific trait bridge generation.
79///
80/// Each binding backend (PyO3, NAPI-RS, wasm-bindgen, etc.) implements this trait
81/// to provide the language-specific parts of bridge codegen. The shared functions
82/// in this module call these methods to fill in the backend-dependent pieces.
83pub trait TraitBridgeGenerator {
84    /// The type of the wrapped foreign object (e.g., `"Py<PyAny>"`, `"ThreadsafeFunction"`).
85    fn foreign_object_type(&self) -> &str;
86
87    /// Additional `use` imports needed for the bridge code.
88    fn bridge_imports(&self) -> Vec<String>;
89
90    /// Generate the body of a synchronous method bridge.
91    ///
92    /// The returned string is inserted inside the trait impl method. It should
93    /// call through to the foreign object and convert the result.
94    fn gen_sync_method_body(&self, method: &MethodDef, spec: &TraitBridgeSpec) -> String;
95
96    /// Generate the body of an async method bridge.
97    ///
98    /// The returned string is the body of a `Box::pin(async move { ... })` block.
99    fn gen_async_method_body(&self, method: &MethodDef, spec: &TraitBridgeSpec) -> String;
100
101    /// Generate the constructor body that validates and wraps the foreign object.
102    ///
103    /// Should check that the foreign object provides all required methods and
104    /// return `Self { ... }` on success.
105    fn gen_constructor(&self, spec: &TraitBridgeSpec) -> String;
106
107    /// Generate the complete registration function including attributes, signature, and body.
108    ///
109    /// Each backend needs different function signatures (PyO3 takes `py: Python`,
110    /// NAPI takes `#[napi]` with JS params, FFI takes `extern "C"` with raw pointers),
111    /// so the generator owns the full function.
112    fn gen_registration_fn(&self, spec: &TraitBridgeSpec) -> String;
113}
114
115// ---------------------------------------------------------------------------
116// Shared generation functions
117// ---------------------------------------------------------------------------
118
119/// Generate the wrapper struct holding the foreign object and cached fields.
120///
121/// Produces a struct like:
122/// ```ignore
123/// pub struct PythonOcrBackendBridge {
124///     inner: Py<PyAny>,
125///     cached_name: String,
126/// }
127/// ```
128pub fn gen_bridge_wrapper_struct(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> String {
129    let wrapper = spec.wrapper_name();
130    let foreign_type = generator.foreign_object_type();
131    let mut out = String::with_capacity(512);
132
133    writeln!(
134        out,
135        "/// Wrapper that bridges a foreign {prefix} object to the `{trait_name}` trait.",
136        prefix = spec.wrapper_prefix,
137        trait_name = spec.trait_def.name,
138    )
139    .ok();
140    writeln!(out, "pub struct {wrapper} {{").ok();
141    writeln!(out, "    inner: {foreign_type},").ok();
142    writeln!(out, "    cached_name: String,").ok();
143    write!(out, "}}").ok();
144    out
145}
146
147/// Generate `impl SuperTrait for Wrapper` when the bridge config specifies a super-trait.
148///
149/// Forwards `name()`, `version()`, `initialize()`, and `shutdown()` to the
150/// foreign object, using `cached_name` for `name()`.
151///
152/// The super-trait path is derived from the config's `super_trait` field. If it
153/// contains `::`, it's used as-is; otherwise it's qualified as `{core_import}::{super_trait}`.
154pub fn gen_bridge_plugin_impl(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> Option<String> {
155    let super_trait_name = spec.bridge_config.super_trait.as_deref()?;
156
157    let wrapper = spec.wrapper_name();
158    let core_import = spec.core_import;
159
160    // Derive the fully-qualified super-trait path
161    let super_trait_path = if super_trait_name.contains("::") {
162        super_trait_name.to_string()
163    } else {
164        format!("{core_import}::{super_trait_name}")
165    };
166
167    // Build synthetic MethodDefs for the Plugin methods and delegate to the generator
168    // for the actual call bodies. The Plugin trait interface is well-known: name(),
169    // version(), initialize(), shutdown().
170    let mut out = String::with_capacity(1024);
171    writeln!(out, "impl {super_trait_path} for {wrapper} {{").ok();
172
173    // name() -> &str — uses cached field
174    writeln!(out, "    fn name(&self) -> &str {{").ok();
175    writeln!(out, "        &self.cached_name").ok();
176    writeln!(out, "    }}").ok();
177    writeln!(out).ok();
178
179    let error_path = spec.error_path();
180
181    // version() -> String — delegate to foreign object
182    writeln!(out, "    fn version(&self) -> String {{").ok();
183    let version_method = MethodDef {
184        name: "version".to_string(),
185        params: vec![],
186        return_type: alef_core::ir::TypeRef::String,
187        is_async: false,
188        is_static: false,
189        error_type: None,
190        doc: String::new(),
191        receiver: Some(alef_core::ir::ReceiverKind::Ref),
192        sanitized: false,
193        trait_source: None,
194        returns_ref: false,
195        returns_cow: false,
196        return_newtype_wrapper: None,
197        has_default_impl: false,
198    };
199    let version_body = generator.gen_sync_method_body(&version_method, spec);
200    for line in version_body.lines() {
201        writeln!(out, "        {}", line.trim_start()).ok();
202    }
203    writeln!(out, "    }}").ok();
204    writeln!(out).ok();
205
206    // initialize() -> Result<(), ErrorType>
207    writeln!(
208        out,
209        "    fn initialize(&self) -> std::result::Result<(), {error_path}> {{"
210    )
211    .ok();
212    let init_method = MethodDef {
213        name: "initialize".to_string(),
214        params: vec![],
215        return_type: alef_core::ir::TypeRef::Unit,
216        is_async: false,
217        is_static: false,
218        error_type: Some(error_path.clone()),
219        doc: String::new(),
220        receiver: Some(alef_core::ir::ReceiverKind::Ref),
221        sanitized: false,
222        trait_source: None,
223        returns_ref: false,
224        returns_cow: false,
225        return_newtype_wrapper: None,
226        has_default_impl: true,
227    };
228    let init_body = generator.gen_sync_method_body(&init_method, spec);
229    for line in init_body.lines() {
230        writeln!(out, "        {}", line.trim_start()).ok();
231    }
232    writeln!(out, "    }}").ok();
233    writeln!(out).ok();
234
235    // shutdown() -> Result<(), ErrorType>
236    writeln!(
237        out,
238        "    fn shutdown(&self) -> std::result::Result<(), {error_path}> {{"
239    )
240    .ok();
241    let shutdown_method = MethodDef {
242        name: "shutdown".to_string(),
243        params: vec![],
244        return_type: alef_core::ir::TypeRef::Unit,
245        is_async: false,
246        is_static: false,
247        error_type: Some(error_path.clone()),
248        doc: String::new(),
249        receiver: Some(alef_core::ir::ReceiverKind::Ref),
250        sanitized: false,
251        trait_source: None,
252        returns_ref: false,
253        returns_cow: false,
254        return_newtype_wrapper: None,
255        has_default_impl: true,
256    };
257    let shutdown_body = generator.gen_sync_method_body(&shutdown_method, spec);
258    for line in shutdown_body.lines() {
259        writeln!(out, "        {}", line.trim_start()).ok();
260    }
261    writeln!(out, "    }}").ok();
262    write!(out, "}}").ok();
263    Some(out)
264}
265
266/// Generate `impl Trait for Wrapper` dispatching each method through the generator.
267///
268/// Every method on the trait (including those with `has_default_impl`) gets a
269/// generated body that forwards to the foreign object.
270pub fn gen_bridge_trait_impl(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> String {
271    let wrapper = spec.wrapper_name();
272    let trait_path = spec.trait_path();
273    let mut out = String::with_capacity(2048);
274
275    // Add #[async_trait] when the trait has async methods (needed for async_trait macro compatibility)
276    let has_async_methods = spec
277        .trait_def
278        .methods
279        .iter()
280        .any(|m| m.is_async && m.trait_source.is_none());
281    if has_async_methods {
282        writeln!(out, "#[async_trait::async_trait]").ok();
283    }
284    writeln!(out, "impl {trait_path} for {wrapper} {{").ok();
285
286    // Filter out methods inherited from super-traits (they're handled by gen_bridge_plugin_impl)
287    let own_methods: Vec<_> = spec
288        .trait_def
289        .methods
290        .iter()
291        .filter(|m| m.trait_source.is_none())
292        .collect();
293
294    for (i, method) in own_methods.iter().enumerate() {
295        if i > 0 {
296            writeln!(out).ok();
297        }
298
299        // Build the method signature
300        let async_kw = if method.is_async { "async " } else { "" };
301        let receiver = match &method.receiver {
302            Some(alef_core::ir::ReceiverKind::Ref) => "&self",
303            Some(alef_core::ir::ReceiverKind::RefMut) => "&mut self",
304            Some(alef_core::ir::ReceiverKind::Owned) => "self",
305            None => "",
306        };
307
308        // Build params (excluding self), using format_param_type to respect is_ref/is_mut
309        let params: Vec<String> = method
310            .params
311            .iter()
312            .map(|p| format!("{}: {}", p.name, format_param_type(p, &spec.type_paths)))
313            .collect();
314
315        let all_params = if receiver.is_empty() {
316            params.join(", ")
317        } else if params.is_empty() {
318            receiver.to_string()
319        } else {
320            format!("{}, {}", receiver, params.join(", "))
321        };
322
323        // Return type — override the IR's error type with the configured crate error type
324        // so the impl matches the actual trait definition (the IR may extract a different
325        // error type like anyhow::Error from re-exports or type alias resolution).
326        let error_override = method.error_type.as_ref().map(|_| spec.error_path());
327        let ret = format_return_type(&method.return_type, error_override.as_deref(), &spec.type_paths);
328
329        writeln!(out, "    {async_kw}fn {}({all_params}) -> {ret} {{", method.name).ok();
330
331        // Generate body: async methods use Box::pin, sync methods call directly
332        let body = if method.is_async {
333            generator.gen_async_method_body(method, spec)
334        } else {
335            generator.gen_sync_method_body(method, spec)
336        };
337
338        for line in body.lines() {
339            writeln!(out, "        {line}").ok();
340        }
341        writeln!(out, "    }}").ok();
342    }
343
344    write!(out, "}}").ok();
345    out
346}
347
348/// Generate the `register_xxx()` function that wraps a foreign object and
349/// inserts it into the plugin registry.
350///
351/// Returns `None` when `bridge_config.register_fn` is absent (per-call bridge pattern).
352/// The generator owns the full function (attributes, signature, body) because each
353/// backend needs different signatures.
354pub fn gen_bridge_registration_fn(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> Option<String> {
355    spec.bridge_config.register_fn.as_deref()?;
356    Some(generator.gen_registration_fn(spec))
357}
358
359/// Result of trait bridge generation: imports (to be added via `builder.add_import`)
360/// and the code body (to be added via `builder.add_item`).
361pub struct BridgeOutput {
362    /// Import paths (e.g., `"std::sync::Arc"`) — callers should add via `builder.add_import()`.
363    pub imports: Vec<String>,
364    /// The generated code (struct, impls, registration fn).
365    pub code: String,
366}
367
368/// Generate the complete trait bridge code block: struct, impls, and
369/// optionally a registration function.
370///
371/// Returns [`BridgeOutput`] with imports separated from code so callers can
372/// route imports through `builder.add_import()` (which deduplicates).
373pub fn gen_bridge_all(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> BridgeOutput {
374    let imports = generator.bridge_imports();
375    let mut out = String::with_capacity(4096);
376
377    // Wrapper struct
378    out.push_str(&gen_bridge_wrapper_struct(spec, generator));
379    writeln!(out).ok();
380    writeln!(out).ok();
381
382    // Constructor (impl block with new())
383    out.push_str(&generator.gen_constructor(spec));
384    writeln!(out).ok();
385    writeln!(out).ok();
386
387    // Plugin super-trait impl (if applicable)
388    if let Some(plugin_impl) = gen_bridge_plugin_impl(spec, generator) {
389        out.push_str(&plugin_impl);
390        writeln!(out).ok();
391        writeln!(out).ok();
392    }
393
394    // Trait impl
395    out.push_str(&gen_bridge_trait_impl(spec, generator));
396
397    // Registration function — only when register_fn is configured
398    if let Some(reg_fn_code) = gen_bridge_registration_fn(spec, generator) {
399        writeln!(out).ok();
400        writeln!(out).ok();
401        out.push_str(&reg_fn_code);
402    }
403
404    BridgeOutput { imports, code: out }
405}
406
407// ---------------------------------------------------------------------------
408// Helpers
409// ---------------------------------------------------------------------------
410
411/// Format a `TypeRef` as a Rust type string for use in trait method signatures.
412///
413/// `type_paths` qualifies `Named` types with their full Rust path (e.g., `"Config"` →
414/// `"kreuzberg::Config"`). If a name isn't in `type_paths`, it's used as-is.
415pub fn format_type_ref(ty: &alef_core::ir::TypeRef, type_paths: &HashMap<String, String>) -> String {
416    use alef_core::ir::{PrimitiveType, TypeRef};
417    match ty {
418        TypeRef::Primitive(p) => match p {
419            PrimitiveType::Bool => "bool",
420            PrimitiveType::U8 => "u8",
421            PrimitiveType::U16 => "u16",
422            PrimitiveType::U32 => "u32",
423            PrimitiveType::U64 => "u64",
424            PrimitiveType::I8 => "i8",
425            PrimitiveType::I16 => "i16",
426            PrimitiveType::I32 => "i32",
427            PrimitiveType::I64 => "i64",
428            PrimitiveType::F32 => "f32",
429            PrimitiveType::F64 => "f64",
430            PrimitiveType::Usize => "usize",
431            PrimitiveType::Isize => "isize",
432        }
433        .to_string(),
434        TypeRef::String => "String".to_string(),
435        TypeRef::Char => "char".to_string(),
436        TypeRef::Bytes => "Vec<u8>".to_string(),
437        TypeRef::Optional(inner) => format!("Option<{}>", format_type_ref(inner, type_paths)),
438        TypeRef::Vec(inner) => format!("Vec<{}>", format_type_ref(inner, type_paths)),
439        TypeRef::Map(k, v) => format!(
440            "std::collections::HashMap<{}, {}>",
441            format_type_ref(k, type_paths),
442            format_type_ref(v, type_paths)
443        ),
444        TypeRef::Named(name) => type_paths.get(name.as_str()).cloned().unwrap_or_else(|| name.clone()),
445        TypeRef::Path => "std::path::PathBuf".to_string(),
446        TypeRef::Unit => "()".to_string(),
447        TypeRef::Json => "serde_json::Value".to_string(),
448        TypeRef::Duration => "std::time::Duration".to_string(),
449    }
450}
451
452/// Format a return type, wrapping in `Result` when an error type is present.
453pub fn format_return_type(
454    ty: &alef_core::ir::TypeRef,
455    error_type: Option<&str>,
456    type_paths: &HashMap<String, String>,
457) -> String {
458    let inner = format_type_ref(ty, type_paths);
459    match error_type {
460        Some(err) => format!("std::result::Result<{inner}, {err}>"),
461        None => inner,
462    }
463}
464
465/// Format a parameter type, respecting `is_ref`, `is_mut`, and `optional` from the IR.
466///
467/// Unlike [`format_type_ref`], this function produces reference types when the
468/// original Rust parameter was a `&T` or `&mut T`, and wraps in `Option<>` when
469/// `param.optional` is true:
470/// - `String + is_ref` → `&str`
471/// - `String + is_ref + optional` → `Option<&str>`
472/// - `Bytes + is_ref` → `&[u8]`
473/// - `Path + is_ref` → `&std::path::Path`
474/// - `Vec<T> + is_ref` → `&[T]`
475/// - `Named(n) + is_ref` → `&{qualified_name}`
476pub fn format_param_type(param: &ParamDef, type_paths: &HashMap<String, String>) -> String {
477    use alef_core::ir::TypeRef;
478    let base = if param.is_ref {
479        let mutability = if param.is_mut { "mut " } else { "" };
480        match &param.ty {
481            TypeRef::String => format!("&{mutability}str"),
482            TypeRef::Bytes => format!("&{mutability}[u8]"),
483            TypeRef::Path => format!("&{mutability}std::path::Path"),
484            TypeRef::Vec(inner) => format!("&{mutability}[{}]", format_type_ref(inner, type_paths)),
485            TypeRef::Named(name) => {
486                let qualified = type_paths.get(name.as_str()).cloned().unwrap_or_else(|| name.clone());
487                format!("&{mutability}{qualified}")
488            }
489            TypeRef::Optional(inner) => {
490                // Preserve the Option wrapper but apply the ref transformation to the inner type.
491                // e.g. Option<String> + is_ref → Option<&str>
492                //      Option<Vec<T>> + is_ref → Option<&[T]>
493                let inner_type_str = match inner.as_ref() {
494                    TypeRef::String => format!("&{mutability}str"),
495                    TypeRef::Bytes => format!("&{mutability}[u8]"),
496                    TypeRef::Path => format!("&{mutability}std::path::Path"),
497                    TypeRef::Vec(v) => format!("&{mutability}[{}]", format_type_ref(v, type_paths)),
498                    TypeRef::Named(name) => {
499                        let qualified = type_paths.get(name.as_str()).cloned().unwrap_or_else(|| name.clone());
500                        format!("&{mutability}{qualified}")
501                    }
502                    // Primitives and other Copy types: pass by value inside Option
503                    other => format_type_ref(other, type_paths),
504                };
505                // Already wrapped in Option — return directly to avoid double-wrapping below.
506                return format!("Option<{inner_type_str}>");
507            }
508            // All other types are Copy/small — pass by value even when is_ref is set
509            other => format_type_ref(other, type_paths),
510        }
511    } else {
512        format_type_ref(&param.ty, type_paths)
513    };
514
515    // Wrap in Option<> when the parameter is optional (e.g. `title: Option<&str>`).
516    // The TypeRef::Optional arm above returns early, so this only fires for the
517    // `optional: true` IR flag pattern where ty is the unwrapped inner type.
518    if param.optional {
519        format!("Option<{base}>")
520    } else {
521        base
522    }
523}
524
525// ---------------------------------------------------------------------------
526// Shared helpers — used by all backend trait_bridge modules.
527// ---------------------------------------------------------------------------
528
529/// Map a Rust primitive to its type string.
530pub fn prim(p: &PrimitiveType) -> &'static str {
531    use PrimitiveType::*;
532    match p {
533        Bool => "bool",
534        U8 => "u8",
535        U16 => "u16",
536        U32 => "u32",
537        U64 => "u64",
538        I8 => "i8",
539        I16 => "i16",
540        I32 => "i32",
541        I64 => "i64",
542        F32 => "f32",
543        F64 => "f64",
544        Usize => "usize",
545        Isize => "isize",
546    }
547}
548
549/// Map a `TypeRef` to its Rust source type string for use in trait bridge method
550/// signatures. `ci` is the core import path (e.g. `"kreuzberg"`), `tp` maps
551/// type names to fully-qualified paths.
552pub fn bridge_param_type(ty: &TypeRef, ci: &str, is_ref: bool, tp: &HashMap<String, String>) -> String {
553    match ty {
554        TypeRef::Bytes if is_ref => "&[u8]".into(),
555        TypeRef::Bytes => "Vec<u8>".into(),
556        TypeRef::String if is_ref => "&str".into(),
557        TypeRef::String => "String".into(),
558        TypeRef::Path if is_ref => "&std::path::Path".into(),
559        TypeRef::Path => "std::path::PathBuf".into(),
560        TypeRef::Named(n) => {
561            let qualified = tp.get(n).cloned().unwrap_or_else(|| format!("{ci}::{n}"));
562            if is_ref { format!("&{qualified}") } else { qualified }
563        }
564        TypeRef::Vec(inner) => format!("Vec<{}>", bridge_param_type(inner, ci, false, tp)),
565        TypeRef::Optional(inner) => format!("Option<{}>", bridge_param_type(inner, ci, false, tp)),
566        TypeRef::Primitive(p) => prim(p).into(),
567        TypeRef::Unit => "()".into(),
568        TypeRef::Char => "char".into(),
569        TypeRef::Map(k, v) => format!(
570            "std::collections::HashMap<{}, {}>",
571            bridge_param_type(k, ci, false, tp),
572            bridge_param_type(v, ci, false, tp)
573        ),
574        TypeRef::Json => "serde_json::Value".into(),
575        TypeRef::Duration => "std::time::Duration".into(),
576    }
577}
578
579/// Map a visitor method parameter type to the correct Rust type string, handling
580/// IR quirks:
581/// - `ty=String, optional=true, is_ref=true` → `Option<&str>` (IR collapses `Option<&str>`)
582/// - `ty=Vec<T>, is_ref=true` → `&[T]` (IR collapses `&[T]`)
583/// - Everything else delegates to [`bridge_param_type`].
584pub fn visitor_param_type(ty: &TypeRef, is_ref: bool, optional: bool, tp: &HashMap<String, String>) -> String {
585    if optional && matches!(ty, TypeRef::String) && is_ref {
586        return "Option<&str>".to_string();
587    }
588    if is_ref {
589        if let TypeRef::Vec(inner) = ty {
590            let inner_str = bridge_param_type(inner, "", false, tp);
591            return format!("&[{inner_str}]");
592        }
593    }
594    bridge_param_type(ty, "", is_ref, tp)
595}
596
597/// Find the first function parameter that matches a trait bridge configuration
598/// (by type alias or parameter name).
599pub fn find_bridge_param<'a>(
600    func: &FunctionDef,
601    bridges: &'a [TraitBridgeConfig],
602) -> Option<(usize, &'a TraitBridgeConfig)> {
603    for (idx, param) in func.params.iter().enumerate() {
604        let named = match &param.ty {
605            TypeRef::Named(n) => Some(n.as_str()),
606            TypeRef::Optional(inner) => {
607                if let TypeRef::Named(n) = inner.as_ref() {
608                    Some(n.as_str())
609                } else {
610                    None
611                }
612            }
613            _ => None,
614        };
615        for bridge in bridges {
616            if let Some(type_name) = named {
617                if bridge.type_alias.as_deref() == Some(type_name) {
618                    return Some((idx, bridge));
619                }
620            }
621            if bridge.param_name.as_deref() == Some(param.name.as_str()) {
622                return Some((idx, bridge));
623            }
624        }
625    }
626    None
627}
628
629/// Convert a snake_case string to camelCase.
630pub fn to_camel_case(s: &str) -> String {
631    let mut result = String::new();
632    let mut capitalize_next = false;
633    for ch in s.chars() {
634        if ch == '_' {
635            capitalize_next = true;
636        } else if capitalize_next {
637            result.push(ch.to_ascii_uppercase());
638            capitalize_next = false;
639        } else {
640            result.push(ch);
641        }
642    }
643    result
644}
645
646#[cfg(test)]
647mod tests {
648    use super::*;
649    use alef_core::config::TraitBridgeConfig;
650    use alef_core::ir::{MethodDef, ParamDef, PrimitiveType, ReceiverKind, TypeDef, TypeRef};
651
652    // ---------------------------------------------------------------------------
653    // Test helpers
654    // ---------------------------------------------------------------------------
655
656    fn make_trait_bridge_config(super_trait: Option<&str>, register_fn: Option<&str>) -> TraitBridgeConfig {
657        TraitBridgeConfig {
658            trait_name: "OcrBackend".to_string(),
659            super_trait: super_trait.map(str::to_string),
660            registry_getter: None,
661            register_fn: register_fn.map(str::to_string),
662            type_alias: None,
663            param_name: None,
664            register_extra_args: None,
665            exclude_languages: Vec::new(),
666        }
667    }
668
669    fn make_type_def(name: &str, rust_path: &str, methods: Vec<MethodDef>) -> TypeDef {
670        TypeDef {
671            name: name.to_string(),
672            rust_path: rust_path.to_string(),
673            original_rust_path: rust_path.to_string(),
674            fields: vec![],
675            methods,
676            is_opaque: true,
677            is_clone: false,
678            is_copy: false,
679            doc: String::new(),
680            cfg: None,
681            is_trait: true,
682            has_default: false,
683            has_stripped_cfg_fields: false,
684            is_return_type: false,
685            serde_rename_all: None,
686            has_serde: false,
687            super_traits: vec![],
688        }
689    }
690
691    fn make_method(
692        name: &str,
693        params: Vec<ParamDef>,
694        return_type: TypeRef,
695        is_async: bool,
696        has_default_impl: bool,
697        trait_source: Option<&str>,
698        error_type: Option<&str>,
699    ) -> MethodDef {
700        MethodDef {
701            name: name.to_string(),
702            params,
703            return_type,
704            is_async,
705            is_static: false,
706            error_type: error_type.map(str::to_string),
707            doc: String::new(),
708            receiver: Some(ReceiverKind::Ref),
709            sanitized: false,
710            trait_source: trait_source.map(str::to_string),
711            returns_ref: false,
712            returns_cow: false,
713            return_newtype_wrapper: None,
714            has_default_impl,
715        }
716    }
717
718    fn make_param(name: &str, ty: TypeRef, is_ref: bool) -> ParamDef {
719        ParamDef {
720            name: name.to_string(),
721            ty,
722            optional: false,
723            default: None,
724            sanitized: false,
725            typed_default: None,
726            is_ref,
727            is_mut: false,
728            newtype_wrapper: None,
729            original_type: None,
730        }
731    }
732
733    fn make_spec<'a>(
734        trait_def: &'a TypeDef,
735        bridge_config: &'a TraitBridgeConfig,
736        wrapper_prefix: &'a str,
737        type_paths: HashMap<String, String>,
738    ) -> TraitBridgeSpec<'a> {
739        TraitBridgeSpec {
740            trait_def,
741            bridge_config,
742            core_import: "mylib",
743            wrapper_prefix,
744            type_paths,
745            error_type: "MyError".to_string(),
746            error_constructor: "MyError::from({msg})".to_string(),
747        }
748    }
749
750    // ---------------------------------------------------------------------------
751    // Mock backend
752    // ---------------------------------------------------------------------------
753
754    struct MockBridgeGenerator;
755
756    impl TraitBridgeGenerator for MockBridgeGenerator {
757        fn foreign_object_type(&self) -> &str {
758            "Py<PyAny>"
759        }
760
761        fn bridge_imports(&self) -> Vec<String> {
762            vec!["pyo3::prelude::*".to_string(), "pyo3::types::PyString".to_string()]
763        }
764
765        fn gen_sync_method_body(&self, method: &MethodDef, _spec: &TraitBridgeSpec) -> String {
766            format!("// sync body for {}", method.name)
767        }
768
769        fn gen_async_method_body(&self, method: &MethodDef, _spec: &TraitBridgeSpec) -> String {
770            format!("// async body for {}", method.name)
771        }
772
773        fn gen_constructor(&self, spec: &TraitBridgeSpec) -> String {
774            format!(
775                "impl {} {{\n    pub fn new(obj: Py<PyAny>) -> Self {{ Self {{ inner: obj, cached_name: String::new() }} }}\n}}",
776                spec.wrapper_name()
777            )
778        }
779
780        fn gen_registration_fn(&self, spec: &TraitBridgeSpec) -> String {
781            let fn_name = spec.bridge_config.register_fn.as_deref().unwrap_or("register");
782            format!("pub fn {fn_name}(obj: Py<PyAny>) {{ /* register */ }}")
783        }
784    }
785
786    // ---------------------------------------------------------------------------
787    // TraitBridgeSpec helpers
788    // ---------------------------------------------------------------------------
789
790    #[test]
791    fn test_wrapper_name() {
792        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
793        let config = make_trait_bridge_config(None, None);
794        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
795        assert_eq!(spec.wrapper_name(), "PyOcrBackendBridge");
796    }
797
798    #[test]
799    fn test_trait_snake() {
800        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
801        let config = make_trait_bridge_config(None, None);
802        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
803        assert_eq!(spec.trait_snake(), "ocr_backend");
804    }
805
806    #[test]
807    fn test_trait_path_replaces_hyphens() {
808        let trait_def = make_type_def("OcrBackend", "my-lib::OcrBackend", vec![]);
809        let config = make_trait_bridge_config(None, None);
810        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
811        assert_eq!(spec.trait_path(), "my_lib::OcrBackend");
812    }
813
814    #[test]
815    fn test_required_methods_filters_no_default_impl() {
816        let methods = vec![
817            make_method("process", vec![], TypeRef::String, false, false, None, None),
818            make_method("initialize", vec![], TypeRef::Unit, false, true, None, None),
819            make_method("detect", vec![], TypeRef::String, false, false, None, None),
820        ];
821        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
822        let config = make_trait_bridge_config(None, None);
823        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
824        let required = spec.required_methods();
825        assert_eq!(required.len(), 2);
826        assert!(required.iter().any(|m| m.name == "process"));
827        assert!(required.iter().any(|m| m.name == "detect"));
828    }
829
830    #[test]
831    fn test_optional_methods_filters_has_default_impl() {
832        let methods = vec![
833            make_method("process", vec![], TypeRef::String, false, false, None, None),
834            make_method("initialize", vec![], TypeRef::Unit, false, true, None, None),
835            make_method("shutdown", vec![], TypeRef::Unit, false, true, None, None),
836        ];
837        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
838        let config = make_trait_bridge_config(None, None);
839        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
840        let optional = spec.optional_methods();
841        assert_eq!(optional.len(), 2);
842        assert!(optional.iter().any(|m| m.name == "initialize"));
843        assert!(optional.iter().any(|m| m.name == "shutdown"));
844    }
845
846    #[test]
847    fn test_error_path() {
848        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
849        let config = make_trait_bridge_config(None, None);
850        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
851        assert_eq!(spec.error_path(), "mylib::MyError");
852    }
853
854    // ---------------------------------------------------------------------------
855    // format_type_ref
856    // ---------------------------------------------------------------------------
857
858    #[test]
859    fn test_format_type_ref_primitives() {
860        let paths = HashMap::new();
861        let cases: Vec<(TypeRef, &str)> = vec![
862            (TypeRef::Primitive(PrimitiveType::Bool), "bool"),
863            (TypeRef::Primitive(PrimitiveType::U8), "u8"),
864            (TypeRef::Primitive(PrimitiveType::U16), "u16"),
865            (TypeRef::Primitive(PrimitiveType::U32), "u32"),
866            (TypeRef::Primitive(PrimitiveType::U64), "u64"),
867            (TypeRef::Primitive(PrimitiveType::I8), "i8"),
868            (TypeRef::Primitive(PrimitiveType::I16), "i16"),
869            (TypeRef::Primitive(PrimitiveType::I32), "i32"),
870            (TypeRef::Primitive(PrimitiveType::I64), "i64"),
871            (TypeRef::Primitive(PrimitiveType::F32), "f32"),
872            (TypeRef::Primitive(PrimitiveType::F64), "f64"),
873            (TypeRef::Primitive(PrimitiveType::Usize), "usize"),
874            (TypeRef::Primitive(PrimitiveType::Isize), "isize"),
875        ];
876        for (ty, expected) in cases {
877            assert_eq!(format_type_ref(&ty, &paths), expected, "mismatch for {expected}");
878        }
879    }
880
881    #[test]
882    fn test_format_type_ref_string() {
883        assert_eq!(format_type_ref(&TypeRef::String, &HashMap::new()), "String");
884    }
885
886    #[test]
887    fn test_format_type_ref_char() {
888        assert_eq!(format_type_ref(&TypeRef::Char, &HashMap::new()), "char");
889    }
890
891    #[test]
892    fn test_format_type_ref_bytes() {
893        assert_eq!(format_type_ref(&TypeRef::Bytes, &HashMap::new()), "Vec<u8>");
894    }
895
896    #[test]
897    fn test_format_type_ref_path() {
898        assert_eq!(format_type_ref(&TypeRef::Path, &HashMap::new()), "std::path::PathBuf");
899    }
900
901    #[test]
902    fn test_format_type_ref_unit() {
903        assert_eq!(format_type_ref(&TypeRef::Unit, &HashMap::new()), "()");
904    }
905
906    #[test]
907    fn test_format_type_ref_json() {
908        assert_eq!(format_type_ref(&TypeRef::Json, &HashMap::new()), "serde_json::Value");
909    }
910
911    #[test]
912    fn test_format_type_ref_duration() {
913        assert_eq!(
914            format_type_ref(&TypeRef::Duration, &HashMap::new()),
915            "std::time::Duration"
916        );
917    }
918
919    #[test]
920    fn test_format_type_ref_optional() {
921        let ty = TypeRef::Optional(Box::new(TypeRef::String));
922        assert_eq!(format_type_ref(&ty, &HashMap::new()), "Option<String>");
923    }
924
925    #[test]
926    fn test_format_type_ref_optional_nested() {
927        let ty = TypeRef::Optional(Box::new(TypeRef::Optional(Box::new(TypeRef::Primitive(
928            PrimitiveType::U32,
929        )))));
930        assert_eq!(format_type_ref(&ty, &HashMap::new()), "Option<Option<u32>>");
931    }
932
933    #[test]
934    fn test_format_type_ref_vec() {
935        let ty = TypeRef::Vec(Box::new(TypeRef::Primitive(PrimitiveType::U8)));
936        assert_eq!(format_type_ref(&ty, &HashMap::new()), "Vec<u8>");
937    }
938
939    #[test]
940    fn test_format_type_ref_vec_nested() {
941        let ty = TypeRef::Vec(Box::new(TypeRef::Vec(Box::new(TypeRef::String))));
942        assert_eq!(format_type_ref(&ty, &HashMap::new()), "Vec<Vec<String>>");
943    }
944
945    #[test]
946    fn test_format_type_ref_map() {
947        let ty = TypeRef::Map(
948            Box::new(TypeRef::String),
949            Box::new(TypeRef::Primitive(PrimitiveType::I64)),
950        );
951        assert_eq!(
952            format_type_ref(&ty, &HashMap::new()),
953            "std::collections::HashMap<String, i64>"
954        );
955    }
956
957    #[test]
958    fn test_format_type_ref_map_nested_value() {
959        let ty = TypeRef::Map(
960            Box::new(TypeRef::String),
961            Box::new(TypeRef::Vec(Box::new(TypeRef::String))),
962        );
963        assert_eq!(
964            format_type_ref(&ty, &HashMap::new()),
965            "std::collections::HashMap<String, Vec<String>>"
966        );
967    }
968
969    #[test]
970    fn test_format_type_ref_named_without_type_paths() {
971        let ty = TypeRef::Named("Config".to_string());
972        assert_eq!(format_type_ref(&ty, &HashMap::new()), "Config");
973    }
974
975    #[test]
976    fn test_format_type_ref_named_with_type_paths() {
977        let ty = TypeRef::Named("Config".to_string());
978        let mut paths = HashMap::new();
979        paths.insert("Config".to_string(), "mylib::Config".to_string());
980        assert_eq!(format_type_ref(&ty, &paths), "mylib::Config");
981    }
982
983    #[test]
984    fn test_format_type_ref_named_not_in_type_paths_falls_back_to_name() {
985        let ty = TypeRef::Named("Unknown".to_string());
986        let mut paths = HashMap::new();
987        paths.insert("Other".to_string(), "mylib::Other".to_string());
988        assert_eq!(format_type_ref(&ty, &paths), "Unknown");
989    }
990
991    // ---------------------------------------------------------------------------
992    // format_param_type
993    // ---------------------------------------------------------------------------
994
995    #[test]
996    fn test_format_param_type_string_ref() {
997        let param = make_param("input", TypeRef::String, true);
998        assert_eq!(format_param_type(&param, &HashMap::new()), "&str");
999    }
1000
1001    #[test]
1002    fn test_format_param_type_string_owned() {
1003        let param = make_param("input", TypeRef::String, false);
1004        assert_eq!(format_param_type(&param, &HashMap::new()), "String");
1005    }
1006
1007    #[test]
1008    fn test_format_param_type_bytes_ref() {
1009        let param = make_param("data", TypeRef::Bytes, true);
1010        assert_eq!(format_param_type(&param, &HashMap::new()), "&[u8]");
1011    }
1012
1013    #[test]
1014    fn test_format_param_type_bytes_owned() {
1015        let param = make_param("data", TypeRef::Bytes, false);
1016        assert_eq!(format_param_type(&param, &HashMap::new()), "Vec<u8>");
1017    }
1018
1019    #[test]
1020    fn test_format_param_type_path_ref() {
1021        let param = make_param("path", TypeRef::Path, true);
1022        assert_eq!(format_param_type(&param, &HashMap::new()), "&std::path::Path");
1023    }
1024
1025    #[test]
1026    fn test_format_param_type_path_owned() {
1027        let param = make_param("path", TypeRef::Path, false);
1028        assert_eq!(format_param_type(&param, &HashMap::new()), "std::path::PathBuf");
1029    }
1030
1031    #[test]
1032    fn test_format_param_type_vec_ref() {
1033        let param = make_param("items", TypeRef::Vec(Box::new(TypeRef::String)), true);
1034        assert_eq!(format_param_type(&param, &HashMap::new()), "&[String]");
1035    }
1036
1037    #[test]
1038    fn test_format_param_type_vec_owned() {
1039        let param = make_param("items", TypeRef::Vec(Box::new(TypeRef::String)), false);
1040        assert_eq!(format_param_type(&param, &HashMap::new()), "Vec<String>");
1041    }
1042
1043    #[test]
1044    fn test_format_param_type_named_ref_with_type_paths() {
1045        let mut paths = HashMap::new();
1046        paths.insert("Config".to_string(), "mylib::Config".to_string());
1047        let param = make_param("cfg", TypeRef::Named("Config".to_string()), true);
1048        assert_eq!(format_param_type(&param, &paths), "&mylib::Config");
1049    }
1050
1051    #[test]
1052    fn test_format_param_type_named_ref_without_type_paths() {
1053        let param = make_param("cfg", TypeRef::Named("Config".to_string()), true);
1054        assert_eq!(format_param_type(&param, &HashMap::new()), "&Config");
1055    }
1056
1057    #[test]
1058    fn test_format_param_type_primitive_ref_passes_by_value() {
1059        // Copy types like u32 are passed by value even when is_ref is set
1060        let param = make_param("count", TypeRef::Primitive(PrimitiveType::U32), true);
1061        assert_eq!(format_param_type(&param, &HashMap::new()), "u32");
1062    }
1063
1064    #[test]
1065    fn test_format_param_type_unit_ref_passes_by_value() {
1066        let param = make_param("nothing", TypeRef::Unit, true);
1067        assert_eq!(format_param_type(&param, &HashMap::new()), "()");
1068    }
1069
1070    // ---------------------------------------------------------------------------
1071    // format_return_type
1072    // ---------------------------------------------------------------------------
1073
1074    #[test]
1075    fn test_format_return_type_without_error() {
1076        let result = format_return_type(&TypeRef::String, None, &HashMap::new());
1077        assert_eq!(result, "String");
1078    }
1079
1080    #[test]
1081    fn test_format_return_type_with_error() {
1082        let result = format_return_type(&TypeRef::String, Some("MyError"), &HashMap::new());
1083        assert_eq!(result, "std::result::Result<String, MyError>");
1084    }
1085
1086    #[test]
1087    fn test_format_return_type_unit_with_error() {
1088        let result = format_return_type(&TypeRef::Unit, Some("Box<dyn std::error::Error>"), &HashMap::new());
1089        assert_eq!(result, "std::result::Result<(), Box<dyn std::error::Error>>");
1090    }
1091
1092    #[test]
1093    fn test_format_return_type_named_with_type_paths_and_error() {
1094        let mut paths = HashMap::new();
1095        paths.insert("Output".to_string(), "mylib::Output".to_string());
1096        let result = format_return_type(&TypeRef::Named("Output".to_string()), Some("mylib::MyError"), &paths);
1097        assert_eq!(result, "std::result::Result<mylib::Output, mylib::MyError>");
1098    }
1099
1100    // ---------------------------------------------------------------------------
1101    // gen_bridge_wrapper_struct
1102    // ---------------------------------------------------------------------------
1103
1104    #[test]
1105    fn test_gen_bridge_wrapper_struct_contains_struct_name() {
1106        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1107        let config = make_trait_bridge_config(None, None);
1108        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1109        let generator = MockBridgeGenerator;
1110        let result = gen_bridge_wrapper_struct(&spec, &generator);
1111        assert!(
1112            result.contains("pub struct PyOcrBackendBridge"),
1113            "missing struct declaration in:\n{result}"
1114        );
1115    }
1116
1117    #[test]
1118    fn test_gen_bridge_wrapper_struct_contains_inner_field() {
1119        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1120        let config = make_trait_bridge_config(None, None);
1121        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1122        let generator = MockBridgeGenerator;
1123        let result = gen_bridge_wrapper_struct(&spec, &generator);
1124        assert!(result.contains("inner: Py<PyAny>"), "missing inner field in:\n{result}");
1125    }
1126
1127    #[test]
1128    fn test_gen_bridge_wrapper_struct_contains_cached_name() {
1129        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1130        let config = make_trait_bridge_config(None, None);
1131        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1132        let generator = MockBridgeGenerator;
1133        let result = gen_bridge_wrapper_struct(&spec, &generator);
1134        assert!(
1135            result.contains("cached_name: String"),
1136            "missing cached_name field in:\n{result}"
1137        );
1138    }
1139
1140    // ---------------------------------------------------------------------------
1141    // gen_bridge_plugin_impl
1142    // ---------------------------------------------------------------------------
1143
1144    #[test]
1145    fn test_gen_bridge_plugin_impl_returns_none_when_no_super_trait() {
1146        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1147        let config = make_trait_bridge_config(None, None);
1148        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1149        let generator = MockBridgeGenerator;
1150        assert!(gen_bridge_plugin_impl(&spec, &generator).is_none());
1151    }
1152
1153    #[test]
1154    fn test_gen_bridge_plugin_impl_returns_some_when_super_trait_configured() {
1155        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1156        let config = make_trait_bridge_config(Some("Plugin"), None);
1157        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1158        let generator = MockBridgeGenerator;
1159        assert!(gen_bridge_plugin_impl(&spec, &generator).is_some());
1160    }
1161
1162    #[test]
1163    fn test_gen_bridge_plugin_impl_uses_qualified_super_trait_path() {
1164        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1165        let config = make_trait_bridge_config(Some("Plugin"), None);
1166        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1167        let generator = MockBridgeGenerator;
1168        let result = gen_bridge_plugin_impl(&spec, &generator).unwrap();
1169        assert!(
1170            result.contains("impl mylib::Plugin for PyOcrBackendBridge"),
1171            "missing qualified super-trait path in:\n{result}"
1172        );
1173    }
1174
1175    #[test]
1176    fn test_gen_bridge_plugin_impl_uses_already_qualified_super_trait_path() {
1177        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1178        let config = make_trait_bridge_config(Some("other_crate::Plugin"), None);
1179        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1180        let generator = MockBridgeGenerator;
1181        let result = gen_bridge_plugin_impl(&spec, &generator).unwrap();
1182        assert!(
1183            result.contains("impl other_crate::Plugin for PyOcrBackendBridge"),
1184            "wrong super-trait path in:\n{result}"
1185        );
1186    }
1187
1188    #[test]
1189    fn test_gen_bridge_plugin_impl_contains_name_fn() {
1190        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1191        let config = make_trait_bridge_config(Some("Plugin"), None);
1192        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1193        let generator = MockBridgeGenerator;
1194        let result = gen_bridge_plugin_impl(&spec, &generator).unwrap();
1195        assert!(
1196            result.contains("fn name(") && result.contains("cached_name"),
1197            "missing name() using cached_name in:\n{result}"
1198        );
1199    }
1200
1201    #[test]
1202    fn test_gen_bridge_plugin_impl_contains_version_fn() {
1203        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1204        let config = make_trait_bridge_config(Some("Plugin"), None);
1205        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1206        let generator = MockBridgeGenerator;
1207        let result = gen_bridge_plugin_impl(&spec, &generator).unwrap();
1208        assert!(result.contains("fn version("), "missing version() in:\n{result}");
1209    }
1210
1211    #[test]
1212    fn test_gen_bridge_plugin_impl_contains_initialize_fn() {
1213        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1214        let config = make_trait_bridge_config(Some("Plugin"), None);
1215        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1216        let generator = MockBridgeGenerator;
1217        let result = gen_bridge_plugin_impl(&spec, &generator).unwrap();
1218        assert!(result.contains("fn initialize("), "missing initialize() in:\n{result}");
1219    }
1220
1221    #[test]
1222    fn test_gen_bridge_plugin_impl_contains_shutdown_fn() {
1223        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1224        let config = make_trait_bridge_config(Some("Plugin"), None);
1225        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1226        let generator = MockBridgeGenerator;
1227        let result = gen_bridge_plugin_impl(&spec, &generator).unwrap();
1228        assert!(result.contains("fn shutdown("), "missing shutdown() in:\n{result}");
1229    }
1230
1231    // ---------------------------------------------------------------------------
1232    // gen_bridge_trait_impl
1233    // ---------------------------------------------------------------------------
1234
1235    #[test]
1236    fn test_gen_bridge_trait_impl_includes_impl_header() {
1237        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1238        let config = make_trait_bridge_config(None, None);
1239        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1240        let generator = MockBridgeGenerator;
1241        let result = gen_bridge_trait_impl(&spec, &generator);
1242        assert!(
1243            result.contains("impl mylib::OcrBackend for PyOcrBackendBridge"),
1244            "missing impl header in:\n{result}"
1245        );
1246    }
1247
1248    #[test]
1249    fn test_gen_bridge_trait_impl_includes_method_signatures() {
1250        let methods = vec![make_method(
1251            "process",
1252            vec![],
1253            TypeRef::String,
1254            false,
1255            false,
1256            None,
1257            None,
1258        )];
1259        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1260        let config = make_trait_bridge_config(None, None);
1261        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1262        let generator = MockBridgeGenerator;
1263        let result = gen_bridge_trait_impl(&spec, &generator);
1264        assert!(result.contains("fn process("), "missing method signature in:\n{result}");
1265    }
1266
1267    #[test]
1268    fn test_gen_bridge_trait_impl_includes_method_body_from_generator() {
1269        let methods = vec![make_method(
1270            "process",
1271            vec![],
1272            TypeRef::String,
1273            false,
1274            false,
1275            None,
1276            None,
1277        )];
1278        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1279        let config = make_trait_bridge_config(None, None);
1280        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1281        let generator = MockBridgeGenerator;
1282        let result = gen_bridge_trait_impl(&spec, &generator);
1283        assert!(
1284            result.contains("// sync body for process"),
1285            "missing sync method body in:\n{result}"
1286        );
1287    }
1288
1289    #[test]
1290    fn test_gen_bridge_trait_impl_async_method_uses_async_body() {
1291        let methods = vec![make_method(
1292            "process_async",
1293            vec![],
1294            TypeRef::String,
1295            true,
1296            false,
1297            None,
1298            None,
1299        )];
1300        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1301        let config = make_trait_bridge_config(None, None);
1302        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1303        let generator = MockBridgeGenerator;
1304        let result = gen_bridge_trait_impl(&spec, &generator);
1305        assert!(
1306            result.contains("// async body for process_async"),
1307            "missing async method body in:\n{result}"
1308        );
1309        assert!(
1310            result.contains("async fn process_async("),
1311            "missing async keyword in method signature in:\n{result}"
1312        );
1313    }
1314
1315    #[test]
1316    fn test_gen_bridge_trait_impl_filters_trait_source_methods() {
1317        // Methods with trait_source set come from super-traits and should be excluded
1318        let methods = vec![
1319            make_method("own_method", vec![], TypeRef::String, false, false, None, None),
1320            make_method(
1321                "inherited_method",
1322                vec![],
1323                TypeRef::String,
1324                false,
1325                false,
1326                Some("other_crate::OtherTrait"),
1327                None,
1328            ),
1329        ];
1330        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1331        let config = make_trait_bridge_config(None, None);
1332        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1333        let generator = MockBridgeGenerator;
1334        let result = gen_bridge_trait_impl(&spec, &generator);
1335        assert!(
1336            result.contains("fn own_method("),
1337            "own method should be present in:\n{result}"
1338        );
1339        assert!(
1340            !result.contains("fn inherited_method("),
1341            "inherited method should be filtered out in:\n{result}"
1342        );
1343    }
1344
1345    #[test]
1346    fn test_gen_bridge_trait_impl_method_with_params() {
1347        let params = vec![
1348            make_param("input", TypeRef::String, true),
1349            make_param("count", TypeRef::Primitive(PrimitiveType::U32), false),
1350        ];
1351        let methods = vec![make_method(
1352            "process",
1353            params,
1354            TypeRef::String,
1355            false,
1356            false,
1357            None,
1358            None,
1359        )];
1360        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1361        let config = make_trait_bridge_config(None, None);
1362        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1363        let generator = MockBridgeGenerator;
1364        let result = gen_bridge_trait_impl(&spec, &generator);
1365        assert!(result.contains("input: &str"), "missing &str param in:\n{result}");
1366        assert!(result.contains("count: u32"), "missing u32 param in:\n{result}");
1367    }
1368
1369    #[test]
1370    fn test_gen_bridge_trait_impl_return_type_with_error() {
1371        let methods = vec![make_method(
1372            "process",
1373            vec![],
1374            TypeRef::String,
1375            false,
1376            false,
1377            None,
1378            Some("MyError"),
1379        )];
1380        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1381        let config = make_trait_bridge_config(None, None);
1382        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1383        let generator = MockBridgeGenerator;
1384        let result = gen_bridge_trait_impl(&spec, &generator);
1385        assert!(
1386            result.contains("-> std::result::Result<String, mylib::MyError>"),
1387            "missing std::result::Result return type in:\n{result}"
1388        );
1389    }
1390
1391    // ---------------------------------------------------------------------------
1392    // gen_bridge_registration_fn
1393    // ---------------------------------------------------------------------------
1394
1395    #[test]
1396    fn test_gen_bridge_registration_fn_returns_none_without_register_fn() {
1397        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1398        let config = make_trait_bridge_config(None, None);
1399        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1400        let generator = MockBridgeGenerator;
1401        assert!(gen_bridge_registration_fn(&spec, &generator).is_none());
1402    }
1403
1404    #[test]
1405    fn test_gen_bridge_registration_fn_returns_some_with_register_fn() {
1406        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1407        let config = make_trait_bridge_config(None, Some("register_ocr_backend"));
1408        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1409        let generator = MockBridgeGenerator;
1410        let result = gen_bridge_registration_fn(&spec, &generator);
1411        assert!(result.is_some());
1412        let code = result.unwrap();
1413        assert!(
1414            code.contains("register_ocr_backend"),
1415            "missing register fn name in:\n{code}"
1416        );
1417    }
1418
1419    // ---------------------------------------------------------------------------
1420    // gen_bridge_all
1421    // ---------------------------------------------------------------------------
1422
1423    #[test]
1424    fn test_gen_bridge_all_includes_imports() {
1425        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1426        let config = make_trait_bridge_config(None, None);
1427        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1428        let generator = MockBridgeGenerator;
1429        let output = gen_bridge_all(&spec, &generator);
1430        assert!(output.imports.contains(&"pyo3::prelude::*".to_string()));
1431        assert!(output.imports.contains(&"pyo3::types::PyString".to_string()));
1432    }
1433
1434    #[test]
1435    fn test_gen_bridge_all_includes_wrapper_struct() {
1436        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1437        let config = make_trait_bridge_config(None, None);
1438        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1439        let generator = MockBridgeGenerator;
1440        let output = gen_bridge_all(&spec, &generator);
1441        assert!(
1442            output.code.contains("pub struct PyOcrBackendBridge"),
1443            "missing struct in:\n{}",
1444            output.code
1445        );
1446    }
1447
1448    #[test]
1449    fn test_gen_bridge_all_includes_constructor() {
1450        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1451        let config = make_trait_bridge_config(None, None);
1452        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1453        let generator = MockBridgeGenerator;
1454        let output = gen_bridge_all(&spec, &generator);
1455        assert!(
1456            output.code.contains("pub fn new("),
1457            "missing constructor in:\n{}",
1458            output.code
1459        );
1460    }
1461
1462    #[test]
1463    fn test_gen_bridge_all_includes_trait_impl() {
1464        let methods = vec![make_method(
1465            "process",
1466            vec![],
1467            TypeRef::String,
1468            false,
1469            false,
1470            None,
1471            None,
1472        )];
1473        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1474        let config = make_trait_bridge_config(None, None);
1475        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1476        let generator = MockBridgeGenerator;
1477        let output = gen_bridge_all(&spec, &generator);
1478        assert!(
1479            output.code.contains("impl mylib::OcrBackend for PyOcrBackendBridge"),
1480            "missing trait impl in:\n{}",
1481            output.code
1482        );
1483    }
1484
1485    #[test]
1486    fn test_gen_bridge_all_includes_plugin_impl_when_super_trait_set() {
1487        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1488        let config = make_trait_bridge_config(Some("Plugin"), None);
1489        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1490        let generator = MockBridgeGenerator;
1491        let output = gen_bridge_all(&spec, &generator);
1492        assert!(
1493            output.code.contains("impl mylib::Plugin for PyOcrBackendBridge"),
1494            "missing plugin impl in:\n{}",
1495            output.code
1496        );
1497    }
1498
1499    #[test]
1500    fn test_gen_bridge_all_no_plugin_impl_when_no_super_trait() {
1501        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1502        let config = make_trait_bridge_config(None, None);
1503        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1504        let generator = MockBridgeGenerator;
1505        let output = gen_bridge_all(&spec, &generator);
1506        assert!(
1507            !output.code.contains("fn name(") || !output.code.contains("cached_name"),
1508            "unexpected plugin impl present without super_trait"
1509        );
1510    }
1511
1512    #[test]
1513    fn test_gen_bridge_all_includes_registration_fn_when_configured() {
1514        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1515        let config = make_trait_bridge_config(None, Some("register_ocr_backend"));
1516        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1517        let generator = MockBridgeGenerator;
1518        let output = gen_bridge_all(&spec, &generator);
1519        assert!(
1520            output.code.contains("register_ocr_backend"),
1521            "missing registration fn in:\n{}",
1522            output.code
1523        );
1524    }
1525
1526    #[test]
1527    fn test_gen_bridge_all_no_registration_fn_when_absent() {
1528        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1529        let config = make_trait_bridge_config(None, None);
1530        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1531        let generator = MockBridgeGenerator;
1532        let output = gen_bridge_all(&spec, &generator);
1533        assert!(
1534            !output.code.contains("register_ocr_backend"),
1535            "unexpected registration fn present:\n{}",
1536            output.code
1537        );
1538    }
1539
1540    #[test]
1541    fn test_gen_bridge_all_ordering_struct_before_trait_impl() {
1542        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1543        let config = make_trait_bridge_config(None, None);
1544        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1545        let generator = MockBridgeGenerator;
1546        let output = gen_bridge_all(&spec, &generator);
1547        let struct_pos = output.code.find("pub struct PyOcrBackendBridge").unwrap();
1548        let impl_pos = output
1549            .code
1550            .find("impl mylib::OcrBackend for PyOcrBackendBridge")
1551            .unwrap();
1552        assert!(struct_pos < impl_pos, "struct should appear before trait impl");
1553    }
1554}