Skip to main content

alef_codegen/generators/
trait_bridge.rs

1//! Shared trait bridge code generation.
2//!
3//! Generates wrapper structs that allow foreign language objects (Python, JS, etc.)
4//! to implement Rust traits via FFI. Each backend implements [`TraitBridgeGenerator`]
5//! to provide language-specific dispatch logic; the shared functions in this module
6//! handle the structural boilerplate.
7
8use alef_core::config::TraitBridgeConfig;
9use alef_core::ir::{FunctionDef, MethodDef, ParamDef, PrimitiveType, TypeDef, TypeRef};
10use heck::ToSnakeCase;
11use std::collections::HashMap;
12use std::fmt::Write;
13
14/// Everything needed to generate a trait bridge for one trait.
15pub struct TraitBridgeSpec<'a> {
16    /// The trait definition from the IR.
17    pub trait_def: &'a TypeDef,
18    /// Bridge configuration from `alef.toml`.
19    pub bridge_config: &'a TraitBridgeConfig,
20    /// Core crate import path (e.g., `"kreuzberg"`).
21    pub core_import: &'a str,
22    /// Language-specific prefix for the wrapper type (e.g., `"Python"`, `"Js"`, `"Wasm"`).
23    pub wrapper_prefix: &'a str,
24    /// Map of type name → fully-qualified Rust path for qualifying `Named` types.
25    pub type_paths: HashMap<String, String>,
26    /// The crate's error type name (e.g., `"KreuzbergError"`). Defaults to `"Error"`.
27    pub error_type: String,
28    /// Error constructor pattern. `{msg}` is replaced with the message expression.
29    pub error_constructor: String,
30}
31
32impl<'a> TraitBridgeSpec<'a> {
33    /// Fully qualified error type path (e.g., `"kreuzberg::KreuzbergError"`).
34    ///
35    /// If `error_type` already looks fully-qualified (contains `::`) or is a generic
36    /// type expression (contains `<`), it is returned as-is without prefixing
37    /// `core_import`. This lets backends specify rich error types like
38    /// `"Box<dyn std::error::Error + Send + Sync>"` directly.
39    pub fn error_path(&self) -> String {
40        if self.error_type.contains("::") || self.error_type.contains('<') {
41            self.error_type.clone()
42        } else {
43            format!("{}::{}", self.core_import, self.error_type)
44        }
45    }
46
47    /// Generate an error construction expression from a message expression.
48    pub fn make_error(&self, msg_expr: &str) -> String {
49        self.error_constructor.replace("{msg}", msg_expr)
50    }
51
52    /// Wrapper struct name: `{prefix}{TraitName}Bridge` (e.g., `PythonOcrBackendBridge`).
53    pub fn wrapper_name(&self) -> String {
54        format!("{}{}Bridge", self.wrapper_prefix, self.trait_def.name)
55    }
56
57    /// Snake-case version of the trait name (e.g., `"ocr_backend"`).
58    pub fn trait_snake(&self) -> String {
59        self.trait_def.name.to_snake_case()
60    }
61
62    /// Full Rust path to the trait (e.g., `kreuzberg::OcrBackend`).
63    pub fn trait_path(&self) -> String {
64        self.trait_def.rust_path.replace('-', "_")
65    }
66
67    /// Methods that are required (no default impl) — must be provided by the foreign object.
68    pub fn required_methods(&self) -> Vec<&'a MethodDef> {
69        self.trait_def.methods.iter().filter(|m| !m.has_default_impl).collect()
70    }
71
72    /// Methods that have a default impl — optional on the foreign object.
73    pub fn optional_methods(&self) -> Vec<&'a MethodDef> {
74        self.trait_def.methods.iter().filter(|m| m.has_default_impl).collect()
75    }
76}
77
78/// Backend-specific trait bridge generation.
79///
80/// Each binding backend (PyO3, NAPI-RS, wasm-bindgen, etc.) implements this trait
81/// to provide the language-specific parts of bridge codegen. The shared functions
82/// in this module call these methods to fill in the backend-dependent pieces.
83pub trait TraitBridgeGenerator {
84    /// The type of the wrapped foreign object (e.g., `"Py<PyAny>"`, `"ThreadsafeFunction"`).
85    fn foreign_object_type(&self) -> &str;
86
87    /// Additional `use` imports needed for the bridge code.
88    fn bridge_imports(&self) -> Vec<String>;
89
90    /// Generate the body of a synchronous method bridge.
91    ///
92    /// The returned string is inserted inside the trait impl method. It should
93    /// call through to the foreign object and convert the result.
94    fn gen_sync_method_body(&self, method: &MethodDef, spec: &TraitBridgeSpec) -> String;
95
96    /// Generate the body of an async method bridge.
97    ///
98    /// The returned string is the body of a `Box::pin(async move { ... })` block.
99    fn gen_async_method_body(&self, method: &MethodDef, spec: &TraitBridgeSpec) -> String;
100
101    /// Generate the constructor body that validates and wraps the foreign object.
102    ///
103    /// Should check that the foreign object provides all required methods and
104    /// return `Self { ... }` on success.
105    fn gen_constructor(&self, spec: &TraitBridgeSpec) -> String;
106
107    /// Generate the complete registration function including attributes, signature, and body.
108    ///
109    /// Each backend needs different function signatures (PyO3 takes `py: Python`,
110    /// NAPI takes `#[napi]` with JS params, FFI takes `extern "C"` with raw pointers),
111    /// so the generator owns the full function.
112    fn gen_registration_fn(&self, spec: &TraitBridgeSpec) -> String;
113
114    /// Whether the `#[async_trait]` macro should require `Send` on its futures.
115    ///
116    /// Returns `true` (default) for most targets. WASM is single-threaded so its
117    /// trait bounds don't include `Send`; implementors should return `false` there.
118    fn async_trait_is_send(&self) -> bool {
119        true
120    }
121}
122
123// ---------------------------------------------------------------------------
124// Shared generation functions
125// ---------------------------------------------------------------------------
126
127/// Generate the wrapper struct holding the foreign object and cached fields.
128///
129/// Produces a struct like:
130/// ```ignore
131/// pub struct PythonOcrBackendBridge {
132///     inner: Py<PyAny>,
133///     cached_name: String,
134/// }
135/// ```
136pub fn gen_bridge_wrapper_struct(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> String {
137    let wrapper = spec.wrapper_name();
138    let foreign_type = generator.foreign_object_type();
139    let mut out = String::with_capacity(512);
140
141    writeln!(
142        out,
143        "/// Wrapper that bridges a foreign {prefix} object to the `{trait_name}` trait.",
144        prefix = spec.wrapper_prefix,
145        trait_name = spec.trait_def.name,
146    )
147    .ok();
148    writeln!(out, "pub struct {wrapper} {{").ok();
149    writeln!(out, "    inner: {foreign_type},").ok();
150    writeln!(out, "    cached_name: String,").ok();
151    write!(out, "}}").ok();
152    out
153}
154
155/// Generate `impl SuperTrait for Wrapper` when the bridge config specifies a super-trait.
156///
157/// Forwards `name()`, `version()`, `initialize()`, and `shutdown()` to the
158/// foreign object, using `cached_name` for `name()`.
159///
160/// The super-trait path is derived from the config's `super_trait` field. If it
161/// contains `::`, it's used as-is; otherwise it's qualified as `{core_import}::{super_trait}`.
162pub fn gen_bridge_plugin_impl(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> Option<String> {
163    let super_trait_name = spec.bridge_config.super_trait.as_deref()?;
164
165    let wrapper = spec.wrapper_name();
166    let core_import = spec.core_import;
167
168    // Derive the fully-qualified super-trait path
169    let super_trait_path = if super_trait_name.contains("::") {
170        super_trait_name.to_string()
171    } else {
172        format!("{core_import}::{super_trait_name}")
173    };
174
175    // Build synthetic MethodDefs for the Plugin methods and delegate to the generator
176    // for the actual call bodies. The Plugin trait interface is well-known: name(),
177    // version(), initialize(), shutdown().
178    let mut out = String::with_capacity(1024);
179    writeln!(out, "impl {super_trait_path} for {wrapper} {{").ok();
180
181    // name() -> &str — uses cached field
182    writeln!(out, "    fn name(&self) -> &str {{").ok();
183    writeln!(out, "        &self.cached_name").ok();
184    writeln!(out, "    }}").ok();
185    writeln!(out).ok();
186
187    let error_path = spec.error_path();
188
189    // version() -> String — delegate to foreign object
190    writeln!(out, "    fn version(&self) -> String {{").ok();
191    let version_method = MethodDef {
192        name: "version".to_string(),
193        params: vec![],
194        return_type: alef_core::ir::TypeRef::String,
195        is_async: false,
196        is_static: false,
197        error_type: None,
198        doc: String::new(),
199        receiver: Some(alef_core::ir::ReceiverKind::Ref),
200        sanitized: false,
201        trait_source: None,
202        returns_ref: false,
203        returns_cow: false,
204        return_newtype_wrapper: None,
205        has_default_impl: false,
206    };
207    let version_body = generator.gen_sync_method_body(&version_method, spec);
208    for line in version_body.lines() {
209        writeln!(out, "        {}", line.trim_start()).ok();
210    }
211    writeln!(out, "    }}").ok();
212    writeln!(out).ok();
213
214    // initialize() -> Result<(), ErrorType>
215    writeln!(
216        out,
217        "    fn initialize(&self) -> std::result::Result<(), {error_path}> {{"
218    )
219    .ok();
220    let init_method = MethodDef {
221        name: "initialize".to_string(),
222        params: vec![],
223        return_type: alef_core::ir::TypeRef::Unit,
224        is_async: false,
225        is_static: false,
226        error_type: Some(error_path.clone()),
227        doc: String::new(),
228        receiver: Some(alef_core::ir::ReceiverKind::Ref),
229        sanitized: false,
230        trait_source: None,
231        returns_ref: false,
232        returns_cow: false,
233        return_newtype_wrapper: None,
234        has_default_impl: true,
235    };
236    let init_body = generator.gen_sync_method_body(&init_method, spec);
237    for line in init_body.lines() {
238        writeln!(out, "        {}", line.trim_start()).ok();
239    }
240    writeln!(out, "    }}").ok();
241    writeln!(out).ok();
242
243    // shutdown() -> Result<(), ErrorType>
244    writeln!(
245        out,
246        "    fn shutdown(&self) -> std::result::Result<(), {error_path}> {{"
247    )
248    .ok();
249    let shutdown_method = MethodDef {
250        name: "shutdown".to_string(),
251        params: vec![],
252        return_type: alef_core::ir::TypeRef::Unit,
253        is_async: false,
254        is_static: false,
255        error_type: Some(error_path.clone()),
256        doc: String::new(),
257        receiver: Some(alef_core::ir::ReceiverKind::Ref),
258        sanitized: false,
259        trait_source: None,
260        returns_ref: false,
261        returns_cow: false,
262        return_newtype_wrapper: None,
263        has_default_impl: true,
264    };
265    let shutdown_body = generator.gen_sync_method_body(&shutdown_method, spec);
266    for line in shutdown_body.lines() {
267        writeln!(out, "        {}", line.trim_start()).ok();
268    }
269    writeln!(out, "    }}").ok();
270    write!(out, "}}").ok();
271    Some(out)
272}
273
274/// Generate `impl Trait for Wrapper` dispatching each method through the generator.
275///
276/// Every method on the trait (including those with `has_default_impl`) gets a
277/// generated body that forwards to the foreign object.
278pub fn gen_bridge_trait_impl(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> String {
279    let wrapper = spec.wrapper_name();
280    let trait_path = spec.trait_path();
281    let mut out = String::with_capacity(2048);
282
283    // Add #[async_trait] when the trait has async methods (needed for async_trait macro compatibility).
284    // On non-Send targets (e.g. WASM), use `(?Send)` to drop the `Send` bound on futures.
285    let has_async_methods = spec
286        .trait_def
287        .methods
288        .iter()
289        .any(|m| m.is_async && m.trait_source.is_none());
290    if has_async_methods {
291        if generator.async_trait_is_send() {
292            writeln!(out, "#[async_trait::async_trait]").ok();
293        } else {
294            writeln!(out, "#[async_trait::async_trait(?Send)]").ok();
295        }
296    }
297    writeln!(out, "impl {trait_path} for {wrapper} {{").ok();
298
299    // Filter out methods inherited from super-traits (they're handled by gen_bridge_plugin_impl)
300    let own_methods: Vec<_> = spec
301        .trait_def
302        .methods
303        .iter()
304        .filter(|m| m.trait_source.is_none())
305        .collect();
306
307    for (i, method) in own_methods.iter().enumerate() {
308        if i > 0 {
309            writeln!(out).ok();
310        }
311
312        // Build the method signature
313        let async_kw = if method.is_async { "async " } else { "" };
314        let receiver = match &method.receiver {
315            Some(alef_core::ir::ReceiverKind::Ref) => "&self",
316            Some(alef_core::ir::ReceiverKind::RefMut) => "&mut self",
317            Some(alef_core::ir::ReceiverKind::Owned) => "self",
318            None => "",
319        };
320
321        // Build params (excluding self), using format_param_type to respect is_ref/is_mut
322        let params: Vec<String> = method
323            .params
324            .iter()
325            .map(|p| format!("{}: {}", p.name, format_param_type(p, &spec.type_paths)))
326            .collect();
327
328        let all_params = if receiver.is_empty() {
329            params.join(", ")
330        } else if params.is_empty() {
331            receiver.to_string()
332        } else {
333            format!("{}, {}", receiver, params.join(", "))
334        };
335
336        // Return type — override the IR's error type with the configured crate error type
337        // so the impl matches the actual trait definition (the IR may extract a different
338        // error type like anyhow::Error from re-exports or type alias resolution).
339        let error_override = method.error_type.as_ref().map(|_| spec.error_path());
340        let ret = format_return_type(&method.return_type, error_override.as_deref(), &spec.type_paths);
341
342        writeln!(out, "    {async_kw}fn {}({all_params}) -> {ret} {{", method.name).ok();
343
344        // Generate body: async methods use Box::pin, sync methods call directly
345        let body = if method.is_async {
346            generator.gen_async_method_body(method, spec)
347        } else {
348            generator.gen_sync_method_body(method, spec)
349        };
350
351        for line in body.lines() {
352            writeln!(out, "        {line}").ok();
353        }
354        writeln!(out, "    }}").ok();
355    }
356
357    write!(out, "}}").ok();
358    out
359}
360
361/// Generate the `register_xxx()` function that wraps a foreign object and
362/// inserts it into the plugin registry.
363///
364/// Returns `None` when `bridge_config.register_fn` is absent (per-call bridge pattern).
365/// The generator owns the full function (attributes, signature, body) because each
366/// backend needs different signatures.
367pub fn gen_bridge_registration_fn(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> Option<String> {
368    spec.bridge_config.register_fn.as_deref()?;
369    Some(generator.gen_registration_fn(spec))
370}
371
372/// Result of trait bridge generation: imports (to be added via `builder.add_import`)
373/// and the code body (to be added via `builder.add_item`).
374pub struct BridgeOutput {
375    /// Import paths (e.g., `"std::sync::Arc"`) — callers should add via `builder.add_import()`.
376    pub imports: Vec<String>,
377    /// The generated code (struct, impls, registration fn).
378    pub code: String,
379}
380
381/// Generate the complete trait bridge code block: struct, impls, and
382/// optionally a registration function.
383///
384/// Returns [`BridgeOutput`] with imports separated from code so callers can
385/// route imports through `builder.add_import()` (which deduplicates).
386pub fn gen_bridge_all(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> BridgeOutput {
387    let imports = generator.bridge_imports();
388    let mut out = String::with_capacity(4096);
389
390    // Wrapper struct
391    out.push_str(&gen_bridge_wrapper_struct(spec, generator));
392    writeln!(out).ok();
393    writeln!(out).ok();
394
395    // Constructor (impl block with new())
396    out.push_str(&generator.gen_constructor(spec));
397    writeln!(out).ok();
398    writeln!(out).ok();
399
400    // Plugin super-trait impl (if applicable)
401    if let Some(plugin_impl) = gen_bridge_plugin_impl(spec, generator) {
402        out.push_str(&plugin_impl);
403        writeln!(out).ok();
404        writeln!(out).ok();
405    }
406
407    // Trait impl
408    out.push_str(&gen_bridge_trait_impl(spec, generator));
409
410    // Registration function — only when register_fn is configured
411    if let Some(reg_fn_code) = gen_bridge_registration_fn(spec, generator) {
412        writeln!(out).ok();
413        writeln!(out).ok();
414        out.push_str(&reg_fn_code);
415    }
416
417    BridgeOutput { imports, code: out }
418}
419
420// ---------------------------------------------------------------------------
421// Helpers
422// ---------------------------------------------------------------------------
423
424/// Format a `TypeRef` as a Rust type string for use in trait method signatures.
425///
426/// `type_paths` qualifies `Named` types with their full Rust path (e.g., `"Config"` →
427/// `"kreuzberg::Config"`). If a name isn't in `type_paths`, it's used as-is.
428pub fn format_type_ref(ty: &alef_core::ir::TypeRef, type_paths: &HashMap<String, String>) -> String {
429    use alef_core::ir::{PrimitiveType, TypeRef};
430    match ty {
431        TypeRef::Primitive(p) => match p {
432            PrimitiveType::Bool => "bool",
433            PrimitiveType::U8 => "u8",
434            PrimitiveType::U16 => "u16",
435            PrimitiveType::U32 => "u32",
436            PrimitiveType::U64 => "u64",
437            PrimitiveType::I8 => "i8",
438            PrimitiveType::I16 => "i16",
439            PrimitiveType::I32 => "i32",
440            PrimitiveType::I64 => "i64",
441            PrimitiveType::F32 => "f32",
442            PrimitiveType::F64 => "f64",
443            PrimitiveType::Usize => "usize",
444            PrimitiveType::Isize => "isize",
445        }
446        .to_string(),
447        TypeRef::String => "String".to_string(),
448        TypeRef::Char => "char".to_string(),
449        TypeRef::Bytes => "Vec<u8>".to_string(),
450        TypeRef::Optional(inner) => format!("Option<{}>", format_type_ref(inner, type_paths)),
451        TypeRef::Vec(inner) => format!("Vec<{}>", format_type_ref(inner, type_paths)),
452        TypeRef::Map(k, v) => format!(
453            "std::collections::HashMap<{}, {}>",
454            format_type_ref(k, type_paths),
455            format_type_ref(v, type_paths)
456        ),
457        TypeRef::Named(name) => type_paths.get(name.as_str()).cloned().unwrap_or_else(|| name.clone()),
458        TypeRef::Path => "std::path::PathBuf".to_string(),
459        TypeRef::Unit => "()".to_string(),
460        TypeRef::Json => "serde_json::Value".to_string(),
461        TypeRef::Duration => "std::time::Duration".to_string(),
462    }
463}
464
465/// Format a return type, wrapping in `Result` when an error type is present.
466pub fn format_return_type(
467    ty: &alef_core::ir::TypeRef,
468    error_type: Option<&str>,
469    type_paths: &HashMap<String, String>,
470) -> String {
471    let inner = format_type_ref(ty, type_paths);
472    match error_type {
473        Some(err) => format!("std::result::Result<{inner}, {err}>"),
474        None => inner,
475    }
476}
477
478/// Format a parameter type, respecting `is_ref`, `is_mut`, and `optional` from the IR.
479///
480/// Unlike [`format_type_ref`], this function produces reference types when the
481/// original Rust parameter was a `&T` or `&mut T`, and wraps in `Option<>` when
482/// `param.optional` is true:
483/// - `String + is_ref` → `&str`
484/// - `String + is_ref + optional` → `Option<&str>`
485/// - `Bytes + is_ref` → `&[u8]`
486/// - `Path + is_ref` → `&std::path::Path`
487/// - `Vec<T> + is_ref` → `&[T]`
488/// - `Named(n) + is_ref` → `&{qualified_name}`
489pub fn format_param_type(param: &ParamDef, type_paths: &HashMap<String, String>) -> String {
490    use alef_core::ir::TypeRef;
491    let base = if param.is_ref {
492        let mutability = if param.is_mut { "mut " } else { "" };
493        match &param.ty {
494            TypeRef::String => format!("&{mutability}str"),
495            TypeRef::Bytes => format!("&{mutability}[u8]"),
496            TypeRef::Path => format!("&{mutability}std::path::Path"),
497            TypeRef::Vec(inner) => format!("&{mutability}[{}]", format_type_ref(inner, type_paths)),
498            TypeRef::Named(name) => {
499                let qualified = type_paths.get(name.as_str()).cloned().unwrap_or_else(|| name.clone());
500                format!("&{mutability}{qualified}")
501            }
502            TypeRef::Optional(inner) => {
503                // Preserve the Option wrapper but apply the ref transformation to the inner type.
504                // e.g. Option<String> + is_ref → Option<&str>
505                //      Option<Vec<T>> + is_ref → Option<&[T]>
506                let inner_type_str = match inner.as_ref() {
507                    TypeRef::String => format!("&{mutability}str"),
508                    TypeRef::Bytes => format!("&{mutability}[u8]"),
509                    TypeRef::Path => format!("&{mutability}std::path::Path"),
510                    TypeRef::Vec(v) => format!("&{mutability}[{}]", format_type_ref(v, type_paths)),
511                    TypeRef::Named(name) => {
512                        let qualified = type_paths.get(name.as_str()).cloned().unwrap_or_else(|| name.clone());
513                        format!("&{mutability}{qualified}")
514                    }
515                    // Primitives and other Copy types: pass by value inside Option
516                    other => format_type_ref(other, type_paths),
517                };
518                // Already wrapped in Option — return directly to avoid double-wrapping below.
519                return format!("Option<{inner_type_str}>");
520            }
521            // All other types are Copy/small — pass by value even when is_ref is set
522            other => format_type_ref(other, type_paths),
523        }
524    } else {
525        format_type_ref(&param.ty, type_paths)
526    };
527
528    // Wrap in Option<> when the parameter is optional (e.g. `title: Option<&str>`).
529    // The TypeRef::Optional arm above returns early, so this only fires for the
530    // `optional: true` IR flag pattern where ty is the unwrapped inner type.
531    if param.optional {
532        format!("Option<{base}>")
533    } else {
534        base
535    }
536}
537
538// ---------------------------------------------------------------------------
539// Shared helpers — used by all backend trait_bridge modules.
540// ---------------------------------------------------------------------------
541
542/// Map a Rust primitive to its type string.
543pub fn prim(p: &PrimitiveType) -> &'static str {
544    use PrimitiveType::*;
545    match p {
546        Bool => "bool",
547        U8 => "u8",
548        U16 => "u16",
549        U32 => "u32",
550        U64 => "u64",
551        I8 => "i8",
552        I16 => "i16",
553        I32 => "i32",
554        I64 => "i64",
555        F32 => "f32",
556        F64 => "f64",
557        Usize => "usize",
558        Isize => "isize",
559    }
560}
561
562/// Map a `TypeRef` to its Rust source type string for use in trait bridge method
563/// signatures. `ci` is the core import path (e.g. `"kreuzberg"`), `tp` maps
564/// type names to fully-qualified paths.
565pub fn bridge_param_type(ty: &TypeRef, ci: &str, is_ref: bool, tp: &HashMap<String, String>) -> String {
566    match ty {
567        TypeRef::Bytes if is_ref => "&[u8]".into(),
568        TypeRef::Bytes => "Vec<u8>".into(),
569        TypeRef::String if is_ref => "&str".into(),
570        TypeRef::String => "String".into(),
571        TypeRef::Path if is_ref => "&std::path::Path".into(),
572        TypeRef::Path => "std::path::PathBuf".into(),
573        TypeRef::Named(n) => {
574            let qualified = tp.get(n).cloned().unwrap_or_else(|| format!("{ci}::{n}"));
575            if is_ref { format!("&{qualified}") } else { qualified }
576        }
577        TypeRef::Vec(inner) => format!("Vec<{}>", bridge_param_type(inner, ci, false, tp)),
578        TypeRef::Optional(inner) => format!("Option<{}>", bridge_param_type(inner, ci, false, tp)),
579        TypeRef::Primitive(p) => prim(p).into(),
580        TypeRef::Unit => "()".into(),
581        TypeRef::Char => "char".into(),
582        TypeRef::Map(k, v) => format!(
583            "std::collections::HashMap<{}, {}>",
584            bridge_param_type(k, ci, false, tp),
585            bridge_param_type(v, ci, false, tp)
586        ),
587        TypeRef::Json => "serde_json::Value".into(),
588        TypeRef::Duration => "std::time::Duration".into(),
589    }
590}
591
592/// Map a visitor method parameter type to the correct Rust type string, handling
593/// IR quirks:
594/// - `ty=String, optional=true, is_ref=true` → `Option<&str>` (IR collapses `Option<&str>`)
595/// - `ty=Vec<T>, is_ref=true` → `&[T]` (IR collapses `&[T]`)
596/// - Everything else delegates to [`bridge_param_type`].
597pub fn visitor_param_type(ty: &TypeRef, is_ref: bool, optional: bool, tp: &HashMap<String, String>) -> String {
598    if optional && matches!(ty, TypeRef::String) && is_ref {
599        return "Option<&str>".to_string();
600    }
601    if is_ref {
602        if let TypeRef::Vec(inner) = ty {
603            let inner_str = bridge_param_type(inner, "", false, tp);
604            return format!("&[{inner_str}]");
605        }
606    }
607    bridge_param_type(ty, "", is_ref, tp)
608}
609
610/// Find the first function parameter that matches a trait bridge configuration
611/// (by type alias or parameter name).
612pub fn find_bridge_param<'a>(
613    func: &FunctionDef,
614    bridges: &'a [TraitBridgeConfig],
615) -> Option<(usize, &'a TraitBridgeConfig)> {
616    for (idx, param) in func.params.iter().enumerate() {
617        let named = match &param.ty {
618            TypeRef::Named(n) => Some(n.as_str()),
619            TypeRef::Optional(inner) => {
620                if let TypeRef::Named(n) = inner.as_ref() {
621                    Some(n.as_str())
622                } else {
623                    None
624                }
625            }
626            _ => None,
627        };
628        for bridge in bridges {
629            if let Some(type_name) = named {
630                if bridge.type_alias.as_deref() == Some(type_name) {
631                    return Some((idx, bridge));
632                }
633            }
634            if bridge.param_name.as_deref() == Some(param.name.as_str()) {
635                return Some((idx, bridge));
636            }
637        }
638    }
639    None
640}
641
642/// Convert a snake_case string to camelCase.
643pub fn to_camel_case(s: &str) -> String {
644    let mut result = String::new();
645    let mut capitalize_next = false;
646    for ch in s.chars() {
647        if ch == '_' {
648            capitalize_next = true;
649        } else if capitalize_next {
650            result.push(ch.to_ascii_uppercase());
651            capitalize_next = false;
652        } else {
653            result.push(ch);
654        }
655    }
656    result
657}
658
659#[cfg(test)]
660mod tests {
661    use super::*;
662    use alef_core::config::TraitBridgeConfig;
663    use alef_core::ir::{MethodDef, ParamDef, PrimitiveType, ReceiverKind, TypeDef, TypeRef};
664
665    // ---------------------------------------------------------------------------
666    // Test helpers
667    // ---------------------------------------------------------------------------
668
669    fn make_trait_bridge_config(super_trait: Option<&str>, register_fn: Option<&str>) -> TraitBridgeConfig {
670        TraitBridgeConfig {
671            trait_name: "OcrBackend".to_string(),
672            super_trait: super_trait.map(str::to_string),
673            registry_getter: None,
674            register_fn: register_fn.map(str::to_string),
675            type_alias: None,
676            param_name: None,
677            register_extra_args: None,
678            exclude_languages: Vec::new(),
679        }
680    }
681
682    fn make_type_def(name: &str, rust_path: &str, methods: Vec<MethodDef>) -> TypeDef {
683        TypeDef {
684            name: name.to_string(),
685            rust_path: rust_path.to_string(),
686            original_rust_path: rust_path.to_string(),
687            fields: vec![],
688            methods,
689            is_opaque: true,
690            is_clone: false,
691            is_copy: false,
692            doc: String::new(),
693            cfg: None,
694            is_trait: true,
695            has_default: false,
696            has_stripped_cfg_fields: false,
697            is_return_type: false,
698            serde_rename_all: None,
699            has_serde: false,
700            super_traits: vec![],
701        }
702    }
703
704    fn make_method(
705        name: &str,
706        params: Vec<ParamDef>,
707        return_type: TypeRef,
708        is_async: bool,
709        has_default_impl: bool,
710        trait_source: Option<&str>,
711        error_type: Option<&str>,
712    ) -> MethodDef {
713        MethodDef {
714            name: name.to_string(),
715            params,
716            return_type,
717            is_async,
718            is_static: false,
719            error_type: error_type.map(str::to_string),
720            doc: String::new(),
721            receiver: Some(ReceiverKind::Ref),
722            sanitized: false,
723            trait_source: trait_source.map(str::to_string),
724            returns_ref: false,
725            returns_cow: false,
726            return_newtype_wrapper: None,
727            has_default_impl,
728        }
729    }
730
731    fn make_param(name: &str, ty: TypeRef, is_ref: bool) -> ParamDef {
732        ParamDef {
733            name: name.to_string(),
734            ty,
735            optional: false,
736            default: None,
737            sanitized: false,
738            typed_default: None,
739            is_ref,
740            is_mut: false,
741            newtype_wrapper: None,
742            original_type: None,
743        }
744    }
745
746    fn make_spec<'a>(
747        trait_def: &'a TypeDef,
748        bridge_config: &'a TraitBridgeConfig,
749        wrapper_prefix: &'a str,
750        type_paths: HashMap<String, String>,
751    ) -> TraitBridgeSpec<'a> {
752        TraitBridgeSpec {
753            trait_def,
754            bridge_config,
755            core_import: "mylib",
756            wrapper_prefix,
757            type_paths,
758            error_type: "MyError".to_string(),
759            error_constructor: "MyError::from({msg})".to_string(),
760        }
761    }
762
763    // ---------------------------------------------------------------------------
764    // Mock backend
765    // ---------------------------------------------------------------------------
766
767    struct MockBridgeGenerator;
768
769    impl TraitBridgeGenerator for MockBridgeGenerator {
770        fn foreign_object_type(&self) -> &str {
771            "Py<PyAny>"
772        }
773
774        fn bridge_imports(&self) -> Vec<String> {
775            vec!["pyo3::prelude::*".to_string(), "pyo3::types::PyString".to_string()]
776        }
777
778        fn gen_sync_method_body(&self, method: &MethodDef, _spec: &TraitBridgeSpec) -> String {
779            format!("// sync body for {}", method.name)
780        }
781
782        fn gen_async_method_body(&self, method: &MethodDef, _spec: &TraitBridgeSpec) -> String {
783            format!("// async body for {}", method.name)
784        }
785
786        fn gen_constructor(&self, spec: &TraitBridgeSpec) -> String {
787            format!(
788                "impl {} {{\n    pub fn new(obj: Py<PyAny>) -> Self {{ Self {{ inner: obj, cached_name: String::new() }} }}\n}}",
789                spec.wrapper_name()
790            )
791        }
792
793        fn gen_registration_fn(&self, spec: &TraitBridgeSpec) -> String {
794            let fn_name = spec.bridge_config.register_fn.as_deref().unwrap_or("register");
795            format!("pub fn {fn_name}(obj: Py<PyAny>) {{ /* register */ }}")
796        }
797    }
798
799    // ---------------------------------------------------------------------------
800    // TraitBridgeSpec helpers
801    // ---------------------------------------------------------------------------
802
803    #[test]
804    fn test_wrapper_name() {
805        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
806        let config = make_trait_bridge_config(None, None);
807        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
808        assert_eq!(spec.wrapper_name(), "PyOcrBackendBridge");
809    }
810
811    #[test]
812    fn test_trait_snake() {
813        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
814        let config = make_trait_bridge_config(None, None);
815        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
816        assert_eq!(spec.trait_snake(), "ocr_backend");
817    }
818
819    #[test]
820    fn test_trait_path_replaces_hyphens() {
821        let trait_def = make_type_def("OcrBackend", "my-lib::OcrBackend", vec![]);
822        let config = make_trait_bridge_config(None, None);
823        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
824        assert_eq!(spec.trait_path(), "my_lib::OcrBackend");
825    }
826
827    #[test]
828    fn test_required_methods_filters_no_default_impl() {
829        let methods = vec![
830            make_method("process", vec![], TypeRef::String, false, false, None, None),
831            make_method("initialize", vec![], TypeRef::Unit, false, true, None, None),
832            make_method("detect", vec![], TypeRef::String, false, false, None, None),
833        ];
834        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
835        let config = make_trait_bridge_config(None, None);
836        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
837        let required = spec.required_methods();
838        assert_eq!(required.len(), 2);
839        assert!(required.iter().any(|m| m.name == "process"));
840        assert!(required.iter().any(|m| m.name == "detect"));
841    }
842
843    #[test]
844    fn test_optional_methods_filters_has_default_impl() {
845        let methods = vec![
846            make_method("process", vec![], TypeRef::String, false, false, None, None),
847            make_method("initialize", vec![], TypeRef::Unit, false, true, None, None),
848            make_method("shutdown", vec![], TypeRef::Unit, false, true, None, None),
849        ];
850        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
851        let config = make_trait_bridge_config(None, None);
852        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
853        let optional = spec.optional_methods();
854        assert_eq!(optional.len(), 2);
855        assert!(optional.iter().any(|m| m.name == "initialize"));
856        assert!(optional.iter().any(|m| m.name == "shutdown"));
857    }
858
859    #[test]
860    fn test_error_path() {
861        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
862        let config = make_trait_bridge_config(None, None);
863        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
864        assert_eq!(spec.error_path(), "mylib::MyError");
865    }
866
867    // ---------------------------------------------------------------------------
868    // format_type_ref
869    // ---------------------------------------------------------------------------
870
871    #[test]
872    fn test_format_type_ref_primitives() {
873        let paths = HashMap::new();
874        let cases: Vec<(TypeRef, &str)> = vec![
875            (TypeRef::Primitive(PrimitiveType::Bool), "bool"),
876            (TypeRef::Primitive(PrimitiveType::U8), "u8"),
877            (TypeRef::Primitive(PrimitiveType::U16), "u16"),
878            (TypeRef::Primitive(PrimitiveType::U32), "u32"),
879            (TypeRef::Primitive(PrimitiveType::U64), "u64"),
880            (TypeRef::Primitive(PrimitiveType::I8), "i8"),
881            (TypeRef::Primitive(PrimitiveType::I16), "i16"),
882            (TypeRef::Primitive(PrimitiveType::I32), "i32"),
883            (TypeRef::Primitive(PrimitiveType::I64), "i64"),
884            (TypeRef::Primitive(PrimitiveType::F32), "f32"),
885            (TypeRef::Primitive(PrimitiveType::F64), "f64"),
886            (TypeRef::Primitive(PrimitiveType::Usize), "usize"),
887            (TypeRef::Primitive(PrimitiveType::Isize), "isize"),
888        ];
889        for (ty, expected) in cases {
890            assert_eq!(format_type_ref(&ty, &paths), expected, "mismatch for {expected}");
891        }
892    }
893
894    #[test]
895    fn test_format_type_ref_string() {
896        assert_eq!(format_type_ref(&TypeRef::String, &HashMap::new()), "String");
897    }
898
899    #[test]
900    fn test_format_type_ref_char() {
901        assert_eq!(format_type_ref(&TypeRef::Char, &HashMap::new()), "char");
902    }
903
904    #[test]
905    fn test_format_type_ref_bytes() {
906        assert_eq!(format_type_ref(&TypeRef::Bytes, &HashMap::new()), "Vec<u8>");
907    }
908
909    #[test]
910    fn test_format_type_ref_path() {
911        assert_eq!(format_type_ref(&TypeRef::Path, &HashMap::new()), "std::path::PathBuf");
912    }
913
914    #[test]
915    fn test_format_type_ref_unit() {
916        assert_eq!(format_type_ref(&TypeRef::Unit, &HashMap::new()), "()");
917    }
918
919    #[test]
920    fn test_format_type_ref_json() {
921        assert_eq!(format_type_ref(&TypeRef::Json, &HashMap::new()), "serde_json::Value");
922    }
923
924    #[test]
925    fn test_format_type_ref_duration() {
926        assert_eq!(
927            format_type_ref(&TypeRef::Duration, &HashMap::new()),
928            "std::time::Duration"
929        );
930    }
931
932    #[test]
933    fn test_format_type_ref_optional() {
934        let ty = TypeRef::Optional(Box::new(TypeRef::String));
935        assert_eq!(format_type_ref(&ty, &HashMap::new()), "Option<String>");
936    }
937
938    #[test]
939    fn test_format_type_ref_optional_nested() {
940        let ty = TypeRef::Optional(Box::new(TypeRef::Optional(Box::new(TypeRef::Primitive(
941            PrimitiveType::U32,
942        )))));
943        assert_eq!(format_type_ref(&ty, &HashMap::new()), "Option<Option<u32>>");
944    }
945
946    #[test]
947    fn test_format_type_ref_vec() {
948        let ty = TypeRef::Vec(Box::new(TypeRef::Primitive(PrimitiveType::U8)));
949        assert_eq!(format_type_ref(&ty, &HashMap::new()), "Vec<u8>");
950    }
951
952    #[test]
953    fn test_format_type_ref_vec_nested() {
954        let ty = TypeRef::Vec(Box::new(TypeRef::Vec(Box::new(TypeRef::String))));
955        assert_eq!(format_type_ref(&ty, &HashMap::new()), "Vec<Vec<String>>");
956    }
957
958    #[test]
959    fn test_format_type_ref_map() {
960        let ty = TypeRef::Map(
961            Box::new(TypeRef::String),
962            Box::new(TypeRef::Primitive(PrimitiveType::I64)),
963        );
964        assert_eq!(
965            format_type_ref(&ty, &HashMap::new()),
966            "std::collections::HashMap<String, i64>"
967        );
968    }
969
970    #[test]
971    fn test_format_type_ref_map_nested_value() {
972        let ty = TypeRef::Map(
973            Box::new(TypeRef::String),
974            Box::new(TypeRef::Vec(Box::new(TypeRef::String))),
975        );
976        assert_eq!(
977            format_type_ref(&ty, &HashMap::new()),
978            "std::collections::HashMap<String, Vec<String>>"
979        );
980    }
981
982    #[test]
983    fn test_format_type_ref_named_without_type_paths() {
984        let ty = TypeRef::Named("Config".to_string());
985        assert_eq!(format_type_ref(&ty, &HashMap::new()), "Config");
986    }
987
988    #[test]
989    fn test_format_type_ref_named_with_type_paths() {
990        let ty = TypeRef::Named("Config".to_string());
991        let mut paths = HashMap::new();
992        paths.insert("Config".to_string(), "mylib::Config".to_string());
993        assert_eq!(format_type_ref(&ty, &paths), "mylib::Config");
994    }
995
996    #[test]
997    fn test_format_type_ref_named_not_in_type_paths_falls_back_to_name() {
998        let ty = TypeRef::Named("Unknown".to_string());
999        let mut paths = HashMap::new();
1000        paths.insert("Other".to_string(), "mylib::Other".to_string());
1001        assert_eq!(format_type_ref(&ty, &paths), "Unknown");
1002    }
1003
1004    // ---------------------------------------------------------------------------
1005    // format_param_type
1006    // ---------------------------------------------------------------------------
1007
1008    #[test]
1009    fn test_format_param_type_string_ref() {
1010        let param = make_param("input", TypeRef::String, true);
1011        assert_eq!(format_param_type(&param, &HashMap::new()), "&str");
1012    }
1013
1014    #[test]
1015    fn test_format_param_type_string_owned() {
1016        let param = make_param("input", TypeRef::String, false);
1017        assert_eq!(format_param_type(&param, &HashMap::new()), "String");
1018    }
1019
1020    #[test]
1021    fn test_format_param_type_bytes_ref() {
1022        let param = make_param("data", TypeRef::Bytes, true);
1023        assert_eq!(format_param_type(&param, &HashMap::new()), "&[u8]");
1024    }
1025
1026    #[test]
1027    fn test_format_param_type_bytes_owned() {
1028        let param = make_param("data", TypeRef::Bytes, false);
1029        assert_eq!(format_param_type(&param, &HashMap::new()), "Vec<u8>");
1030    }
1031
1032    #[test]
1033    fn test_format_param_type_path_ref() {
1034        let param = make_param("path", TypeRef::Path, true);
1035        assert_eq!(format_param_type(&param, &HashMap::new()), "&std::path::Path");
1036    }
1037
1038    #[test]
1039    fn test_format_param_type_path_owned() {
1040        let param = make_param("path", TypeRef::Path, false);
1041        assert_eq!(format_param_type(&param, &HashMap::new()), "std::path::PathBuf");
1042    }
1043
1044    #[test]
1045    fn test_format_param_type_vec_ref() {
1046        let param = make_param("items", TypeRef::Vec(Box::new(TypeRef::String)), true);
1047        assert_eq!(format_param_type(&param, &HashMap::new()), "&[String]");
1048    }
1049
1050    #[test]
1051    fn test_format_param_type_vec_owned() {
1052        let param = make_param("items", TypeRef::Vec(Box::new(TypeRef::String)), false);
1053        assert_eq!(format_param_type(&param, &HashMap::new()), "Vec<String>");
1054    }
1055
1056    #[test]
1057    fn test_format_param_type_named_ref_with_type_paths() {
1058        let mut paths = HashMap::new();
1059        paths.insert("Config".to_string(), "mylib::Config".to_string());
1060        let param = make_param("cfg", TypeRef::Named("Config".to_string()), true);
1061        assert_eq!(format_param_type(&param, &paths), "&mylib::Config");
1062    }
1063
1064    #[test]
1065    fn test_format_param_type_named_ref_without_type_paths() {
1066        let param = make_param("cfg", TypeRef::Named("Config".to_string()), true);
1067        assert_eq!(format_param_type(&param, &HashMap::new()), "&Config");
1068    }
1069
1070    #[test]
1071    fn test_format_param_type_primitive_ref_passes_by_value() {
1072        // Copy types like u32 are passed by value even when is_ref is set
1073        let param = make_param("count", TypeRef::Primitive(PrimitiveType::U32), true);
1074        assert_eq!(format_param_type(&param, &HashMap::new()), "u32");
1075    }
1076
1077    #[test]
1078    fn test_format_param_type_unit_ref_passes_by_value() {
1079        let param = make_param("nothing", TypeRef::Unit, true);
1080        assert_eq!(format_param_type(&param, &HashMap::new()), "()");
1081    }
1082
1083    // ---------------------------------------------------------------------------
1084    // format_return_type
1085    // ---------------------------------------------------------------------------
1086
1087    #[test]
1088    fn test_format_return_type_without_error() {
1089        let result = format_return_type(&TypeRef::String, None, &HashMap::new());
1090        assert_eq!(result, "String");
1091    }
1092
1093    #[test]
1094    fn test_format_return_type_with_error() {
1095        let result = format_return_type(&TypeRef::String, Some("MyError"), &HashMap::new());
1096        assert_eq!(result, "std::result::Result<String, MyError>");
1097    }
1098
1099    #[test]
1100    fn test_format_return_type_unit_with_error() {
1101        let result = format_return_type(&TypeRef::Unit, Some("Box<dyn std::error::Error>"), &HashMap::new());
1102        assert_eq!(result, "std::result::Result<(), Box<dyn std::error::Error>>");
1103    }
1104
1105    #[test]
1106    fn test_format_return_type_named_with_type_paths_and_error() {
1107        let mut paths = HashMap::new();
1108        paths.insert("Output".to_string(), "mylib::Output".to_string());
1109        let result = format_return_type(&TypeRef::Named("Output".to_string()), Some("mylib::MyError"), &paths);
1110        assert_eq!(result, "std::result::Result<mylib::Output, mylib::MyError>");
1111    }
1112
1113    // ---------------------------------------------------------------------------
1114    // gen_bridge_wrapper_struct
1115    // ---------------------------------------------------------------------------
1116
1117    #[test]
1118    fn test_gen_bridge_wrapper_struct_contains_struct_name() {
1119        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1120        let config = make_trait_bridge_config(None, None);
1121        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1122        let generator = MockBridgeGenerator;
1123        let result = gen_bridge_wrapper_struct(&spec, &generator);
1124        assert!(
1125            result.contains("pub struct PyOcrBackendBridge"),
1126            "missing struct declaration in:\n{result}"
1127        );
1128    }
1129
1130    #[test]
1131    fn test_gen_bridge_wrapper_struct_contains_inner_field() {
1132        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1133        let config = make_trait_bridge_config(None, None);
1134        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1135        let generator = MockBridgeGenerator;
1136        let result = gen_bridge_wrapper_struct(&spec, &generator);
1137        assert!(result.contains("inner: Py<PyAny>"), "missing inner field in:\n{result}");
1138    }
1139
1140    #[test]
1141    fn test_gen_bridge_wrapper_struct_contains_cached_name() {
1142        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1143        let config = make_trait_bridge_config(None, None);
1144        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1145        let generator = MockBridgeGenerator;
1146        let result = gen_bridge_wrapper_struct(&spec, &generator);
1147        assert!(
1148            result.contains("cached_name: String"),
1149            "missing cached_name field in:\n{result}"
1150        );
1151    }
1152
1153    // ---------------------------------------------------------------------------
1154    // gen_bridge_plugin_impl
1155    // ---------------------------------------------------------------------------
1156
1157    #[test]
1158    fn test_gen_bridge_plugin_impl_returns_none_when_no_super_trait() {
1159        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1160        let config = make_trait_bridge_config(None, None);
1161        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1162        let generator = MockBridgeGenerator;
1163        assert!(gen_bridge_plugin_impl(&spec, &generator).is_none());
1164    }
1165
1166    #[test]
1167    fn test_gen_bridge_plugin_impl_returns_some_when_super_trait_configured() {
1168        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1169        let config = make_trait_bridge_config(Some("Plugin"), None);
1170        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1171        let generator = MockBridgeGenerator;
1172        assert!(gen_bridge_plugin_impl(&spec, &generator).is_some());
1173    }
1174
1175    #[test]
1176    fn test_gen_bridge_plugin_impl_uses_qualified_super_trait_path() {
1177        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1178        let config = make_trait_bridge_config(Some("Plugin"), None);
1179        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1180        let generator = MockBridgeGenerator;
1181        let result = gen_bridge_plugin_impl(&spec, &generator).unwrap();
1182        assert!(
1183            result.contains("impl mylib::Plugin for PyOcrBackendBridge"),
1184            "missing qualified super-trait path in:\n{result}"
1185        );
1186    }
1187
1188    #[test]
1189    fn test_gen_bridge_plugin_impl_uses_already_qualified_super_trait_path() {
1190        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1191        let config = make_trait_bridge_config(Some("other_crate::Plugin"), None);
1192        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1193        let generator = MockBridgeGenerator;
1194        let result = gen_bridge_plugin_impl(&spec, &generator).unwrap();
1195        assert!(
1196            result.contains("impl other_crate::Plugin for PyOcrBackendBridge"),
1197            "wrong super-trait path in:\n{result}"
1198        );
1199    }
1200
1201    #[test]
1202    fn test_gen_bridge_plugin_impl_contains_name_fn() {
1203        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1204        let config = make_trait_bridge_config(Some("Plugin"), None);
1205        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1206        let generator = MockBridgeGenerator;
1207        let result = gen_bridge_plugin_impl(&spec, &generator).unwrap();
1208        assert!(
1209            result.contains("fn name(") && result.contains("cached_name"),
1210            "missing name() using cached_name in:\n{result}"
1211        );
1212    }
1213
1214    #[test]
1215    fn test_gen_bridge_plugin_impl_contains_version_fn() {
1216        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1217        let config = make_trait_bridge_config(Some("Plugin"), None);
1218        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1219        let generator = MockBridgeGenerator;
1220        let result = gen_bridge_plugin_impl(&spec, &generator).unwrap();
1221        assert!(result.contains("fn version("), "missing version() in:\n{result}");
1222    }
1223
1224    #[test]
1225    fn test_gen_bridge_plugin_impl_contains_initialize_fn() {
1226        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1227        let config = make_trait_bridge_config(Some("Plugin"), None);
1228        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1229        let generator = MockBridgeGenerator;
1230        let result = gen_bridge_plugin_impl(&spec, &generator).unwrap();
1231        assert!(result.contains("fn initialize("), "missing initialize() in:\n{result}");
1232    }
1233
1234    #[test]
1235    fn test_gen_bridge_plugin_impl_contains_shutdown_fn() {
1236        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1237        let config = make_trait_bridge_config(Some("Plugin"), None);
1238        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1239        let generator = MockBridgeGenerator;
1240        let result = gen_bridge_plugin_impl(&spec, &generator).unwrap();
1241        assert!(result.contains("fn shutdown("), "missing shutdown() in:\n{result}");
1242    }
1243
1244    // ---------------------------------------------------------------------------
1245    // gen_bridge_trait_impl
1246    // ---------------------------------------------------------------------------
1247
1248    #[test]
1249    fn test_gen_bridge_trait_impl_includes_impl_header() {
1250        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1251        let config = make_trait_bridge_config(None, None);
1252        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1253        let generator = MockBridgeGenerator;
1254        let result = gen_bridge_trait_impl(&spec, &generator);
1255        assert!(
1256            result.contains("impl mylib::OcrBackend for PyOcrBackendBridge"),
1257            "missing impl header in:\n{result}"
1258        );
1259    }
1260
1261    #[test]
1262    fn test_gen_bridge_trait_impl_includes_method_signatures() {
1263        let methods = vec![make_method(
1264            "process",
1265            vec![],
1266            TypeRef::String,
1267            false,
1268            false,
1269            None,
1270            None,
1271        )];
1272        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1273        let config = make_trait_bridge_config(None, None);
1274        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1275        let generator = MockBridgeGenerator;
1276        let result = gen_bridge_trait_impl(&spec, &generator);
1277        assert!(result.contains("fn process("), "missing method signature in:\n{result}");
1278    }
1279
1280    #[test]
1281    fn test_gen_bridge_trait_impl_includes_method_body_from_generator() {
1282        let methods = vec![make_method(
1283            "process",
1284            vec![],
1285            TypeRef::String,
1286            false,
1287            false,
1288            None,
1289            None,
1290        )];
1291        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1292        let config = make_trait_bridge_config(None, None);
1293        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1294        let generator = MockBridgeGenerator;
1295        let result = gen_bridge_trait_impl(&spec, &generator);
1296        assert!(
1297            result.contains("// sync body for process"),
1298            "missing sync method body in:\n{result}"
1299        );
1300    }
1301
1302    #[test]
1303    fn test_gen_bridge_trait_impl_async_method_uses_async_body() {
1304        let methods = vec![make_method(
1305            "process_async",
1306            vec![],
1307            TypeRef::String,
1308            true,
1309            false,
1310            None,
1311            None,
1312        )];
1313        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1314        let config = make_trait_bridge_config(None, None);
1315        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1316        let generator = MockBridgeGenerator;
1317        let result = gen_bridge_trait_impl(&spec, &generator);
1318        assert!(
1319            result.contains("// async body for process_async"),
1320            "missing async method body in:\n{result}"
1321        );
1322        assert!(
1323            result.contains("async fn process_async("),
1324            "missing async keyword in method signature in:\n{result}"
1325        );
1326    }
1327
1328    #[test]
1329    fn test_gen_bridge_trait_impl_filters_trait_source_methods() {
1330        // Methods with trait_source set come from super-traits and should be excluded
1331        let methods = vec![
1332            make_method("own_method", vec![], TypeRef::String, false, false, None, None),
1333            make_method(
1334                "inherited_method",
1335                vec![],
1336                TypeRef::String,
1337                false,
1338                false,
1339                Some("other_crate::OtherTrait"),
1340                None,
1341            ),
1342        ];
1343        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1344        let config = make_trait_bridge_config(None, None);
1345        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1346        let generator = MockBridgeGenerator;
1347        let result = gen_bridge_trait_impl(&spec, &generator);
1348        assert!(
1349            result.contains("fn own_method("),
1350            "own method should be present in:\n{result}"
1351        );
1352        assert!(
1353            !result.contains("fn inherited_method("),
1354            "inherited method should be filtered out in:\n{result}"
1355        );
1356    }
1357
1358    #[test]
1359    fn test_gen_bridge_trait_impl_method_with_params() {
1360        let params = vec![
1361            make_param("input", TypeRef::String, true),
1362            make_param("count", TypeRef::Primitive(PrimitiveType::U32), false),
1363        ];
1364        let methods = vec![make_method(
1365            "process",
1366            params,
1367            TypeRef::String,
1368            false,
1369            false,
1370            None,
1371            None,
1372        )];
1373        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1374        let config = make_trait_bridge_config(None, None);
1375        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1376        let generator = MockBridgeGenerator;
1377        let result = gen_bridge_trait_impl(&spec, &generator);
1378        assert!(result.contains("input: &str"), "missing &str param in:\n{result}");
1379        assert!(result.contains("count: u32"), "missing u32 param in:\n{result}");
1380    }
1381
1382    #[test]
1383    fn test_gen_bridge_trait_impl_return_type_with_error() {
1384        let methods = vec![make_method(
1385            "process",
1386            vec![],
1387            TypeRef::String,
1388            false,
1389            false,
1390            None,
1391            Some("MyError"),
1392        )];
1393        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1394        let config = make_trait_bridge_config(None, None);
1395        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1396        let generator = MockBridgeGenerator;
1397        let result = gen_bridge_trait_impl(&spec, &generator);
1398        assert!(
1399            result.contains("-> std::result::Result<String, mylib::MyError>"),
1400            "missing std::result::Result return type in:\n{result}"
1401        );
1402    }
1403
1404    // ---------------------------------------------------------------------------
1405    // gen_bridge_registration_fn
1406    // ---------------------------------------------------------------------------
1407
1408    #[test]
1409    fn test_gen_bridge_registration_fn_returns_none_without_register_fn() {
1410        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1411        let config = make_trait_bridge_config(None, None);
1412        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1413        let generator = MockBridgeGenerator;
1414        assert!(gen_bridge_registration_fn(&spec, &generator).is_none());
1415    }
1416
1417    #[test]
1418    fn test_gen_bridge_registration_fn_returns_some_with_register_fn() {
1419        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1420        let config = make_trait_bridge_config(None, Some("register_ocr_backend"));
1421        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1422        let generator = MockBridgeGenerator;
1423        let result = gen_bridge_registration_fn(&spec, &generator);
1424        assert!(result.is_some());
1425        let code = result.unwrap();
1426        assert!(
1427            code.contains("register_ocr_backend"),
1428            "missing register fn name in:\n{code}"
1429        );
1430    }
1431
1432    // ---------------------------------------------------------------------------
1433    // gen_bridge_all
1434    // ---------------------------------------------------------------------------
1435
1436    #[test]
1437    fn test_gen_bridge_all_includes_imports() {
1438        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1439        let config = make_trait_bridge_config(None, None);
1440        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1441        let generator = MockBridgeGenerator;
1442        let output = gen_bridge_all(&spec, &generator);
1443        assert!(output.imports.contains(&"pyo3::prelude::*".to_string()));
1444        assert!(output.imports.contains(&"pyo3::types::PyString".to_string()));
1445    }
1446
1447    #[test]
1448    fn test_gen_bridge_all_includes_wrapper_struct() {
1449        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1450        let config = make_trait_bridge_config(None, None);
1451        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1452        let generator = MockBridgeGenerator;
1453        let output = gen_bridge_all(&spec, &generator);
1454        assert!(
1455            output.code.contains("pub struct PyOcrBackendBridge"),
1456            "missing struct in:\n{}",
1457            output.code
1458        );
1459    }
1460
1461    #[test]
1462    fn test_gen_bridge_all_includes_constructor() {
1463        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1464        let config = make_trait_bridge_config(None, None);
1465        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1466        let generator = MockBridgeGenerator;
1467        let output = gen_bridge_all(&spec, &generator);
1468        assert!(
1469            output.code.contains("pub fn new("),
1470            "missing constructor in:\n{}",
1471            output.code
1472        );
1473    }
1474
1475    #[test]
1476    fn test_gen_bridge_all_includes_trait_impl() {
1477        let methods = vec![make_method(
1478            "process",
1479            vec![],
1480            TypeRef::String,
1481            false,
1482            false,
1483            None,
1484            None,
1485        )];
1486        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1487        let config = make_trait_bridge_config(None, None);
1488        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1489        let generator = MockBridgeGenerator;
1490        let output = gen_bridge_all(&spec, &generator);
1491        assert!(
1492            output.code.contains("impl mylib::OcrBackend for PyOcrBackendBridge"),
1493            "missing trait impl in:\n{}",
1494            output.code
1495        );
1496    }
1497
1498    #[test]
1499    fn test_gen_bridge_all_includes_plugin_impl_when_super_trait_set() {
1500        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1501        let config = make_trait_bridge_config(Some("Plugin"), None);
1502        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1503        let generator = MockBridgeGenerator;
1504        let output = gen_bridge_all(&spec, &generator);
1505        assert!(
1506            output.code.contains("impl mylib::Plugin for PyOcrBackendBridge"),
1507            "missing plugin impl in:\n{}",
1508            output.code
1509        );
1510    }
1511
1512    #[test]
1513    fn test_gen_bridge_all_no_plugin_impl_when_no_super_trait() {
1514        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1515        let config = make_trait_bridge_config(None, None);
1516        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1517        let generator = MockBridgeGenerator;
1518        let output = gen_bridge_all(&spec, &generator);
1519        assert!(
1520            !output.code.contains("fn name(") || !output.code.contains("cached_name"),
1521            "unexpected plugin impl present without super_trait"
1522        );
1523    }
1524
1525    #[test]
1526    fn test_gen_bridge_all_includes_registration_fn_when_configured() {
1527        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1528        let config = make_trait_bridge_config(None, Some("register_ocr_backend"));
1529        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1530        let generator = MockBridgeGenerator;
1531        let output = gen_bridge_all(&spec, &generator);
1532        assert!(
1533            output.code.contains("register_ocr_backend"),
1534            "missing registration fn in:\n{}",
1535            output.code
1536        );
1537    }
1538
1539    #[test]
1540    fn test_gen_bridge_all_no_registration_fn_when_absent() {
1541        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1542        let config = make_trait_bridge_config(None, None);
1543        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1544        let generator = MockBridgeGenerator;
1545        let output = gen_bridge_all(&spec, &generator);
1546        assert!(
1547            !output.code.contains("register_ocr_backend"),
1548            "unexpected registration fn present:\n{}",
1549            output.code
1550        );
1551    }
1552
1553    #[test]
1554    fn test_gen_bridge_all_ordering_struct_before_trait_impl() {
1555        let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1556        let config = make_trait_bridge_config(None, None);
1557        let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1558        let generator = MockBridgeGenerator;
1559        let output = gen_bridge_all(&spec, &generator);
1560        let struct_pos = output.code.find("pub struct PyOcrBackendBridge").unwrap();
1561        let impl_pos = output
1562            .code
1563            .find("impl mylib::OcrBackend for PyOcrBackendBridge")
1564            .unwrap();
1565        assert!(struct_pos < impl_pos, "struct should appear before trait impl");
1566    }
1567}