Skip to main content

alef_backend_zig/
trait_bridge.rs

1//! Zig trait-bridge code generation.
2//!
3//! Emits one Zig extern struct (vtable) and one registration wrapper function
4//! per configured `[[trait_bridges]]` entry.  The Zig consumer fills in the
5//! struct with `callconv(.C)` function pointers and calls `register_*`.
6//!
7//! # C symbol convention
8//!
9//! The generated `register_{trait_snake}` shim calls
10//! `c.{prefix}_register_{trait_snake}` — the symbol exposed by the
11//! `kreuzberg-ffi` C layer (pattern: `{crate_prefix}_register_{trait_snake}`).
12//! If the actual symbol differs, override the generated call site.
13//!
14//! # `TraitBridgeGenerator` implementation
15//!
16//! [`ZigTraitBridgeGenerator`] implements the shared [`TraitBridgeGenerator`]
17//! trait so that the shared codegen driver can invoke the Zig-specific
18//! `gen_unregistration_fn` and `gen_clear_fn` overrides.  The other required
19//! methods are stubs — Zig code is produced through the standalone
20//! [`emit_trait_bridge`] free function, not the shared driver.
21
22use alef_codegen::generators::trait_bridge::{TraitBridgeGenerator, TraitBridgeSpec};
23use alef_core::config::{BridgeBinding, TraitBridgeConfig};
24use alef_core::ir::{MethodDef, TypeDef, TypeRef};
25use heck::ToSnakeCase;
26use std::fmt::Write as _;
27
28/// Zig type string to use for a vtable slot parameter or return type.
29///
30/// All string/complex types collapse to `[*c]const u8` (C string pointer) since
31/// the vtable slots use the raw C ABI — not the Zig-friendly wrapper layer.
32fn vtable_param_type(ty: &TypeRef) -> &'static str {
33    match ty {
34        TypeRef::Primitive(p) => {
35            use alef_core::ir::PrimitiveType::*;
36            match p {
37                Bool => "i32",
38                U8 => "u8",
39                U16 => "u16",
40                U32 => "u32",
41                U64 => "u64",
42                I8 => "i8",
43                I16 => "i16",
44                I32 => "i32",
45                I64 => "i64",
46                F32 => "f32",
47                F64 => "f64",
48                Usize => "usize",
49                Isize => "isize",
50            }
51        }
52        TypeRef::Unit => "void",
53        TypeRef::Duration => "i64",
54        // All string/path/complex types become C string pointers at the C ABI boundary.
55        _ => "[*c]const u8",
56    }
57}
58
59/// Zig return type for a vtable slot.
60///
61/// Fallible methods always return `i32` (0 = success, non-zero = error).
62/// Unit infallible methods return `void`.  Other infallible returns use the
63/// primitive mapping.
64fn vtable_return_type(method: &MethodDef) -> String {
65    if method.error_type.is_some() {
66        "i32".to_string()
67    } else {
68        vtable_param_type(&method.return_type).to_string()
69    }
70}
71
72/// Build a snake_case trait name from a PascalCase trait name.
73///
74/// Uses `heck::ToSnakeCase`, matching the pattern used by Go/C# backends.
75fn trait_snake(trait_name: &str) -> String {
76    trait_name.to_snake_case()
77}
78
79/// Emit a Zig param name for the C-ABI slot, expanding `Bytes` to ptr+len.
80///
81/// Returns a list of `(c_param_name, c_param_type)` pairs.
82fn vtable_c_params(method: &MethodDef) -> Vec<(String, String)> {
83    let mut params = vec![("ud".to_string(), "?*anyopaque".to_string())];
84    for p in &method.params {
85        if matches!(p.ty, TypeRef::Bytes) {
86            params.push((format!("{}_ptr", p.name), "[*c]const u8".to_string()));
87            params.push((format!("{}_len", p.name), "usize".to_string()));
88        } else {
89            params.push((p.name.clone(), vtable_param_type(&p.ty).to_string()));
90        }
91    }
92    if method.error_type.is_some() {
93        if !matches!(method.return_type, TypeRef::Unit) {
94            params.push(("out_result".to_string(), "?*?[*c]u8".to_string()));
95        }
96        params.push(("out_error".to_string(), "?*?[*c]u8".to_string()));
97    } else if !matches!(method.return_type, TypeRef::Unit) {
98        params.push(("out_result".to_string(), "?*?[*c]u8".to_string()));
99    }
100    params
101}
102
103/// Emit a `make_{trait_snake}_vtable(comptime T: type, instance: *T) I{Trait}` helper.
104///
105/// The helper builds `callconv(.C)` thunks for every vtable slot so the consumer
106/// only needs to write plain Zig methods on their type.
107///
108/// # Limitations
109///
110/// - Methods returning non-unit values through `out_result` use `unreachable` for
111///   the conversion path when the type cannot be expressed as a direct C primitive
112///   (complex types are documented as requiring manual implementation).
113/// - Lifecycle slots (`name_fn`, `version_fn`, `initialize_fn`, `shutdown_fn`) are
114///   emitted with `unreachable` bodies as stubs — the consumer overrides the
115///   relevant field in the returned vtable if needed.
116pub fn emit_make_vtable(trait_name: &str, has_super_trait: bool, trait_def: &TypeDef, out: &mut String) {
117    let snake = trait_snake(trait_name);
118
119    out.push_str(&crate::template_env::render(
120        "vtable_header_doc.jinja",
121        minijinja::context! {
122            trait_name => trait_name,
123            snake => &snake,
124        },
125    ));
126    out.push_str(&crate::template_env::render(
127        "vtable_impl_method.jinja",
128        minijinja::context! {
129            snake => &snake,
130            trait_name => trait_name,
131        },
132    ));
133    out.push_str(&crate::template_env::render(
134        "vtable_make_fn_header.jinja",
135        minijinja::context! {
136            trait_name => trait_name,
137        },
138    ));
139
140    // Lifecycle stubs when super_trait is present
141    if has_super_trait {
142        out.push_str(&crate::template_env::render(
143            "vtable_field_name_fn.jinja",
144            minijinja::context! {},
145        ));
146        out.push_str(&crate::template_env::render(
147            "vtable_field_version_fn.jinja",
148            minijinja::context! {},
149        ));
150        out.push_str(&crate::template_env::render(
151            "vtable_field_initialize_fn.jinja",
152            minijinja::context! {},
153        ));
154        out.push_str(&crate::template_env::render(
155            "vtable_field_shutdown_fn.jinja",
156            minijinja::context! {},
157        ));
158    }
159
160    // Per-method thunks
161    for method in &trait_def.methods {
162        let method_snake = method.name.to_snake_case();
163        let c_params = vtable_c_params(method);
164        let ret = vtable_return_type(method);
165
166        // Build the thunk parameter list string
167        let params_str = c_params
168            .iter()
169            .map(|(name, ty)| format!("{name}: {ty}"))
170            .collect::<Vec<_>>()
171            .join(", ");
172
173        out.push_str(&crate::template_env::render(
174            "vtable_instance_field.jinja",
175            minijinja::context! {
176                method_snake => &method_snake,
177                params_str => &params_str,
178                ret => &ret,
179            },
180        ));
181
182        // Cast user_data to *T
183        out.push_str("                const self: *T = @ptrCast(@alignCast(ud));\n");
184
185        // Reconstruct Bytes slices and build forwarding arg list
186        let mut call_args: Vec<String> = Vec::new();
187        for p in &method.params {
188            if matches!(p.ty, TypeRef::Bytes) {
189                out.push_str(&crate::template_env::render(
190                    "thunk_bytes_slice.jinja",
191                    minijinja::context! {
192                        slice_name => format!("{}_slice", p.name),
193                        ptr_name => format!("{}_ptr", p.name),
194                        len_name => format!("{}_len", p.name),
195                    },
196                ));
197                call_args.push(format!("{}_slice", p.name));
198            } else {
199                call_args.push(p.name.clone());
200            }
201        }
202
203        let args_str = call_args.join(", ");
204
205        // Pick a capture name for the success branch that won't collide with method
206        // params. Methods can have a param literally called `result`; using that as
207        // the unwrap binding shadows the outer scope (zig 0.16+ flags this).
208        let ok_binding = if method.params.iter().any(|p| p.name == "value") {
209            "ok_value"
210        } else {
211            "value"
212        };
213
214        if method.error_type.is_some() {
215            // Fallible method: call returns error union, write out_result/out_error
216            let has_result_out = !matches!(method.return_type, TypeRef::Unit);
217            out.push_str(&crate::template_env::render(
218                "thunk_fn_signature.jinja",
219                minijinja::context! {
220                    method_snake => &method_snake,
221                    args_str => &args_str,
222                    ok_binding => &ok_binding,
223                },
224            ));
225            // Write result via out_result pointer — for complex types this is unreachable.
226            // `unreachable` diverges, so any code after it (including `return 0;`) would
227            // be flagged "unreachable code" by zig 0.16+; only emit the trailing return
228            // when the success path actually flows through.
229            let mut success_path_diverges = false;
230            if has_result_out {
231                match &method.return_type {
232                    TypeRef::Primitive(_) | TypeRef::Unit => {
233                        out.push_str(&crate::template_env::render(
234                            "thunk_result_assign.jinja",
235                            minijinja::context! {
236                                ok_binding => &ok_binding,
237                            },
238                        ));
239                    }
240                    _ => {
241                        // String/Bytes/complex: cannot safely convert without allocator context
242                        out.push_str(&crate::template_env::render(
243                            "thunk_if_fallible.jinja",
244                            minijinja::context! {
245                                ok_binding => &ok_binding,
246                            },
247                        ));
248                        success_path_diverges = true;
249                    }
250                }
251            } else {
252                // Unit return on success — discard the captured Void to silence unused-variable.
253                out.push_str(&crate::template_env::render(
254                    "thunk_if_ok_result.jinja",
255                    minijinja::context! {
256                        ok_binding => &ok_binding,
257                    },
258                ));
259            }
260            if !success_path_diverges {
261                out.push_str("                    return 0;\n");
262            }
263            out.push_str("                } else |err| {\n");
264            out.push_str("                    _ = err;\n");
265            out.push_str("                    if (out_error) |ptr| ptr.* = null; // caller checks error code\n");
266            out.push_str("                    return 1;\n");
267            out.push_str("                }\n");
268        } else {
269            // Infallible non-Unit methods get an `out_result` param "for uniformity"
270            // (see vtable_c_params), but the body returns the value directly via the
271            // function return type — so the param is unused. Discard it so zig 0.16+
272            // doesn't flag "unused function parameter".
273            if !matches!(method.return_type, TypeRef::Unit) {
274                out.push_str("                _ = out_result;\n");
275            }
276            match &method.return_type {
277                TypeRef::Unit => {
278                    out.push_str(&crate::template_env::render(
279                        "thunk_if_error.jinja",
280                        minijinja::context! {
281                            method_snake => &method_snake,
282                            args_str => &args_str,
283                        },
284                    ));
285                }
286                TypeRef::Primitive(_) => {
287                    out.push_str(&crate::template_env::render(
288                        "thunk_infallible_return.jinja",
289                        minijinja::context! {
290                            method_snake => &method_snake,
291                            args_str => &args_str,
292                        },
293                    ));
294                }
295                _ => {
296                    // Non-unit infallible non-primitive: pass through (e.g., [*c]const u8)
297                    out.push_str(&crate::template_env::render(
298                        "thunk_infallible_return.jinja",
299                        minijinja::context! {
300                            method_snake => &method_snake,
301                            args_str => &args_str,
302                        },
303                    ));
304                }
305            }
306        }
307
308        out.push_str("            }\n");
309        out.push_str("        }.thunk,\n");
310        out.push('\n');
311    }
312
313    // free_user_data stub — does nothing by default; caller overrides if needed
314    out.push_str(&crate::template_env::render(
315        "vtable_free_user_data.jinja",
316        minijinja::context! {},
317    ));
318
319    out.push_str("    };\n");
320    out.push_str("}\n");
321}
322
323/// Emit the vtable extern struct and registration shim for a single trait bridge.
324///
325/// `prefix` is the C FFI prefix (e.g., `"kreuzberg"`).
326/// `bridge_cfg` is the trait bridge configuration entry.
327/// `trait_def` is the IR type definition for the trait (must have `is_trait = true`).
328/// `out` is the output buffer to append to.
329pub fn emit_trait_bridge(prefix: &str, bridge_cfg: &TraitBridgeConfig, trait_def: &TypeDef, out: &mut String) {
330    let trait_name = &trait_def.name;
331    let snake = trait_snake(trait_name);
332    let has_super_trait = bridge_cfg.super_trait.is_some();
333
334    // -------------------------------------------------------------------------
335    // Vtable struct: I{Trait}
336    // -------------------------------------------------------------------------
337    out.push_str(&crate::template_env::render(
338        "trait_vtable_header.jinja",
339        minijinja::context! {
340            trait_name => trait_name,
341            snake => &snake,
342        },
343    ));
344    out.push_str(&crate::template_env::render(
345        "trait_struct_header.jinja",
346        minijinja::context! {
347            trait_name => trait_name,
348        },
349    ));
350
351    // Plugin lifecycle slots — always present when a super_trait is configured.
352    if has_super_trait {
353        out.push_str("    /// Return the plugin name into `out_name` (heap-allocated, caller frees).\n");
354        out.push_str(
355            "    name_fn: ?*const fn (user_data: ?*anyopaque, out_name: ?*?[*c]u8) callconv(.C) void = null,\n",
356        );
357        out.push('\n');
358
359        out.push_str("    /// Return the plugin version into `out_version` (heap-allocated, caller frees).\n");
360        out.push_str(
361            "    version_fn: ?*const fn (user_data: ?*anyopaque, out_version: ?*?[*c]u8) callconv(.C) void = null,\n",
362        );
363        out.push('\n');
364
365        out.push_str("    /// Initialise the plugin; return 0 on success, non-zero on error.\n");
366        out.push_str(
367            "    initialize_fn: ?*const fn (user_data: ?*anyopaque, out_error: ?*?[*c]u8) callconv(.C) i32 = null,\n",
368        );
369        out.push('\n');
370
371        out.push_str("    /// Shut down the plugin; return 0 on success, non-zero on error.\n");
372        out.push_str(
373            "    shutdown_fn: ?*const fn (user_data: ?*anyopaque, out_error: ?*?[*c]u8) callconv(.C) i32 = null,\n",
374        );
375        out.push('\n');
376    }
377
378    // Trait method slots
379    for method in &trait_def.methods {
380        if !method.doc.is_empty() {
381            out.push_str(&crate::template_env::render(
382                "trait_method_doc_lines.jinja",
383                minijinja::context! {
384                    method_doc_lines => method.doc.lines().collect::<Vec<_>>(),
385                },
386            ));
387        }
388
389        let ret = vtable_return_type(method);
390        let method_snake = method.name.to_snake_case();
391
392        // Build the parameter list: user_data first, then method params.
393        let mut params = vec!["user_data: ?*anyopaque".to_string()];
394        for p in &method.params {
395            let ty = vtable_param_type(&p.ty);
396            // Bytes expand to two args (ptr + len)
397            if matches!(p.ty, TypeRef::Bytes) {
398                params.push(format!("{}_ptr: [*c]const u8", p.name));
399                params.push(format!("{}_len: usize", p.name));
400            } else {
401                params.push(format!("{}: {ty}", p.name));
402            }
403        }
404
405        // Fallible methods get out-result and out-error pointers.
406        if method.error_type.is_some() {
407            if !matches!(method.return_type, TypeRef::Unit) {
408                params.push("out_result: ?*?[*c]u8".to_string());
409            }
410            params.push("out_error: ?*?[*c]u8".to_string());
411        } else if !matches!(method.return_type, TypeRef::Unit) {
412            // Infallible non-void: return via out_result too for uniformity
413            params.push("out_result: ?*?[*c]u8".to_string());
414        }
415
416        let params_str = params.join(", ");
417        out.push_str(&crate::template_env::render(
418            "trait_method_signature.jinja",
419            minijinja::context! {
420                method_snake => &method_snake,
421                params_str => &params_str,
422                ret => &ret,
423            },
424        ));
425    }
426
427    // free_user_data — always last; called by Rust Drop to release the Zig-side handle.
428    out.push_str("    /// Called by the Rust runtime when the bridge is dropped.\n");
429    out.push_str("    /// Use this to release any Zig-side state held via `user_data`.\n");
430    out.push_str("    free_user_data: ?*const fn (user_data: ?*anyopaque) callconv(.C) void = null,\n");
431
432    out.push_str("};\n");
433    out.push('\n');
434
435    // -------------------------------------------------------------------------
436    // Registration / unregistration shims (function-param binding only).
437    //
438    // When `bind_via = "options_field"` the bridge is wired to a field on a
439    // configured options struct (e.g. `ConversionOptions.visitor`); there is
440    // no `{prefix}_register_{trait}` / `{prefix}_unregister_{trait}` C
441    // symbol to call. Emitting the shims unconditionally would produce code
442    // that fails to link. Options-field bridges instead consume the C
443    // vtable directly via a small `..._handle_from_vtable` helper (see
444    // below).
445    // -------------------------------------------------------------------------
446    if matches!(bridge_cfg.bind_via, BridgeBinding::FunctionParam) {
447        let c_register = format!("c.{prefix}_register_{snake}");
448        let c_unregister = format!("c.{prefix}_unregister_{snake}");
449
450        out.push_str(&crate::template_env::render(
451            "register_fn_doc1.jinja",
452            minijinja::context! {
453                trait_name => trait_name,
454                snake => &snake,
455            },
456        ));
457        out.push_str(&crate::template_env::render(
458            "register_fn_signature.jinja",
459            minijinja::context! {
460                snake => &snake,
461                trait_name => trait_name,
462            },
463        ));
464        out.push_str(&crate::template_env::render(
465            "register_fn_body.jinja",
466            minijinja::context! {
467                c_register => &c_register,
468            },
469        ));
470        out.push_str("}\n");
471        out.push('\n');
472
473        out.push_str(&crate::template_env::render(
474            "unregister_fn_doc.jinja",
475            minijinja::context! {
476                trait_name => trait_name,
477            },
478        ));
479        out.push_str(&crate::template_env::render(
480            "unregister_fn_signature.jinja",
481            minijinja::context! {
482                snake => &snake,
483            },
484        ));
485        out.push_str(&crate::template_env::render(
486            "unregister_fn_body.jinja",
487            minijinja::context! {
488                c_unregister => &c_unregister,
489            },
490        ));
491        out.push_str("}\n");
492        out.push('\n');
493    } else {
494        // Options-field binding: emit a vtable -> handle helper that wraps the
495        // C callbacks struct into the trait-object handle expected by the
496        // generated `ConversionOptionsBuilder.{field}` setter. The upstream
497        // FFI must export `{prefix}_{trait_snake}_handle_from_callbacks` with
498        // the standard `extern "C" fn(*const T) -> *mut Handle` shape.
499        let ctor_fn = format!("c.{prefix}_{snake}_handle_from_callbacks");
500        let handle_type = bridge_cfg.type_alias.as_deref().unwrap_or("VisitorHandle").to_string();
501        let _ = writeln!(
502            out,
503            "/// Wrap a `I{trait_name}` vtable into a `{handle_type}` suitable for the"
504        );
505        let _ = writeln!(
506            out,
507            "/// generated options-field setter (e.g. `ConversionOptionsBuilder.visitor`)."
508        );
509        let _ = writeln!(
510            out,
511            "/// The returned handle owns the vtable's function pointers and must be"
512        );
513        let _ = writeln!(
514            out,
515            "/// released with the matching `{prefix}_visitor_handle_free` once the"
516        );
517        let _ = writeln!(out, "/// containing options object is no longer needed.");
518        let _ = writeln!(
519            out,
520            "pub fn {snake}_handle_from_vtable(callbacks: c.HTMHtmVisitorCallbacks) ?{handle_type} {{"
521        );
522        let _ = writeln!(out, "    var _cb = callbacks;");
523        let _ = writeln!(out, "    return @ptrCast({ctor_fn}(&_cb));");
524        let _ = writeln!(out, "}}");
525        let _ = writeln!(out);
526    }
527
528    // -------------------------------------------------------------------------
529    // Comptime vtable builder: make_{trait_snake}_vtable
530    // -------------------------------------------------------------------------
531    emit_make_vtable(trait_name, has_super_trait, trait_def, out);
532}
533
534// ---------------------------------------------------------------------------
535// TraitBridgeGenerator implementation for the Zig backend
536// ---------------------------------------------------------------------------
537
538/// Zig-specific [`TraitBridgeGenerator`] implementation.
539///
540/// Carries the FFI symbol prefix (e.g., `"kreuzberg"`) used when deriving the
541/// C symbol for `unregister_*` and `clear_*` wrappers.
542///
543/// The required trait methods that produce *Rust* source (`gen_sync_method_body`,
544/// `gen_async_method_body`, `gen_constructor`, `gen_registration_fn`) return
545/// empty strings because Zig bridge code is produced by the standalone
546/// [`emit_trait_bridge`] free function, not the shared driver.
547pub struct ZigTraitBridgeGenerator {
548    /// FFI symbol prefix (e.g., `"kreuzberg"`).
549    pub prefix: String,
550}
551
552impl ZigTraitBridgeGenerator {
553    /// Construct a new generator for the given FFI symbol prefix.
554    pub fn new(prefix: impl Into<String>) -> Self {
555        Self { prefix: prefix.into() }
556    }
557}
558
559impl TraitBridgeGenerator for ZigTraitBridgeGenerator {
560    // ------------------------------------------------------------------
561    // Stub methods — Zig bridge code is emitted by `emit_trait_bridge`.
562    // ------------------------------------------------------------------
563
564    fn foreign_object_type(&self) -> &str {
565        ""
566    }
567
568    fn bridge_imports(&self) -> Vec<String> {
569        Vec::new()
570    }
571
572    fn gen_sync_method_body(&self, _method: &MethodDef, _spec: &TraitBridgeSpec) -> String {
573        String::new()
574    }
575
576    fn gen_async_method_body(&self, _method: &MethodDef, _spec: &TraitBridgeSpec) -> String {
577        String::new()
578    }
579
580    fn gen_constructor(&self, _spec: &TraitBridgeSpec) -> String {
581        String::new()
582    }
583
584    fn gen_registration_fn(&self, _spec: &TraitBridgeSpec) -> String {
585        String::new()
586    }
587
588    // ------------------------------------------------------------------
589    // Zig-specific overrides
590    // ------------------------------------------------------------------
591
592    /// Emit a Zig wrapper that calls `c.{prefix}_{unregister_fn}(name, out_error)`.
593    ///
594    /// Returns an empty string when `spec.bridge_config.unregister_fn` is `None`.
595    fn gen_unregistration_fn(&self, spec: &TraitBridgeSpec) -> String {
596        let Some(unregister_fn) = spec.bridge_config.unregister_fn.as_deref() else {
597            return String::new();
598        };
599        let c_unregister = format!("c.{}_{}", self.prefix, unregister_fn);
600
601        let mut out = String::new();
602        out.push_str(&crate::template_env::render(
603            "unregister_fn_doc.jinja",
604            minijinja::context! {
605                trait_name => spec.trait_def.name.as_str(),
606            },
607        ));
608        // Emit the signature directly: the configured `unregister_fn` is the
609        // complete Zig function name, not just the trait-snake suffix.
610        out.push_str(&crate::template_env::render(
611            "unregister_fn_configured_signature.jinja",
612            minijinja::context! {
613                unregister_fn => unregister_fn,
614            },
615        ));
616        out.push_str(&crate::template_env::render(
617            "unregister_fn_body.jinja",
618            minijinja::context! {
619                c_unregister => &c_unregister,
620            },
621        ));
622        out.push_str("}\n");
623        out
624    }
625
626    /// Emit a Zig wrapper that calls `c.{prefix}_{clear_fn}(out_error)`.
627    ///
628    /// Returns an empty string when `spec.bridge_config.clear_fn` is `None`.
629    fn gen_clear_fn(&self, spec: &TraitBridgeSpec) -> String {
630        let Some(clear_fn) = spec.bridge_config.clear_fn.as_deref() else {
631            return String::new();
632        };
633        let c_clear = format!("c.{}_{}", self.prefix, clear_fn);
634
635        let mut out = String::new();
636        out.push_str(&crate::template_env::render(
637            "clear_fn_doc.jinja",
638            minijinja::context! {
639                trait_name => spec.trait_def.name.as_str(),
640            },
641        ));
642        out.push_str(&crate::template_env::render(
643            "clear_fn_signature.jinja",
644            minijinja::context! {
645                clear_fn => clear_fn,
646            },
647        ));
648        out.push_str(&crate::template_env::render(
649            "clear_fn_body.jinja",
650            minijinja::context! {
651                c_clear => &c_clear,
652            },
653        ));
654        out.push_str("}\n");
655        out
656    }
657}
658
659#[cfg(test)]
660mod tests {
661    use super::*;
662    use alef_core::ir::{FieldDef, MethodDef, ParamDef, PrimitiveType, ReceiverKind, TypeRef};
663
664    fn make_trait_def(name: &str, methods: Vec<MethodDef>) -> TypeDef {
665        TypeDef {
666            name: name.to_string(),
667            rust_path: format!("demo::{name}"),
668            original_rust_path: String::new(),
669            fields: Vec::<FieldDef>::new(),
670            methods,
671            is_opaque: true,
672            is_clone: false,
673            is_copy: false,
674            is_trait: true,
675            has_default: false,
676            has_stripped_cfg_fields: false,
677            is_return_type: false,
678            serde_rename_all: None,
679            has_serde: false,
680            super_traits: vec![],
681            doc: String::new(),
682            cfg: None,
683            binding_excluded: false,
684            binding_exclusion_reason: None,
685        }
686    }
687
688    fn make_method(name: &str, params: Vec<ParamDef>, return_type: TypeRef, error_type: Option<&str>) -> MethodDef {
689        MethodDef {
690            name: name.to_string(),
691            params,
692            return_type,
693            is_async: false,
694            is_static: false,
695            error_type: error_type.map(|s| s.to_string()),
696            doc: String::new(),
697            receiver: Some(ReceiverKind::Ref),
698            sanitized: false,
699            trait_source: None,
700            returns_ref: false,
701            returns_cow: false,
702            return_newtype_wrapper: None,
703            has_default_impl: false,
704            binding_excluded: false,
705            binding_exclusion_reason: None,
706        }
707    }
708
709    fn make_param(name: &str, ty: TypeRef) -> ParamDef {
710        ParamDef {
711            name: name.to_string(),
712            ty,
713            optional: false,
714            default: None,
715            sanitized: false,
716            typed_default: None,
717            is_ref: false,
718            is_mut: false,
719            newtype_wrapper: None,
720            original_type: None,
721        }
722    }
723
724    fn make_bridge_cfg(trait_name: &str, super_trait: Option<&str>) -> TraitBridgeConfig {
725        TraitBridgeConfig {
726            trait_name: trait_name.to_string(),
727            super_trait: super_trait.map(|s| s.to_string()),
728            registry_getter: None,
729            register_fn: None,
730
731            unregister_fn: None,
732
733            clear_fn: None,
734            type_alias: None,
735            param_name: None,
736            register_extra_args: None,
737            exclude_languages: vec![],
738            bind_via: alef_core::config::BridgeBinding::FunctionParam,
739            options_type: None,
740            options_field: None,
741            context_type: None,
742            result_type: None,
743            ffi_skip_methods: Vec::new(),
744        }
745    }
746
747    #[test]
748    fn single_method_trait_emits_vtable_and_register() {
749        let trait_def = make_trait_def(
750            "Validator",
751            vec![make_method(
752                "validate",
753                vec![make_param("input", TypeRef::String)],
754                TypeRef::Primitive(PrimitiveType::Bool),
755                None,
756            )],
757        );
758        let bridge_cfg = make_bridge_cfg("Validator", None);
759
760        let mut out = String::new();
761        emit_trait_bridge("demo", &bridge_cfg, &trait_def, &mut out);
762
763        // Vtable struct
764        assert!(
765            out.contains("pub const IValidator = extern struct {"),
766            "missing vtable struct: {out}"
767        );
768        // Method slot present
769        assert!(out.contains("validate:"), "missing validate slot: {out}");
770        // user_data first arg
771        assert!(out.contains("user_data: ?*anyopaque"), "missing user_data: {out}");
772        // callconv(.C) present
773        assert!(out.contains("callconv(.C)"), "missing callconv: {out}");
774        // free_user_data slot
775        assert!(out.contains("free_user_data:"), "missing free_user_data: {out}");
776        // Registration shim
777        assert!(out.contains("pub fn register_validator("), "missing register fn: {out}");
778        assert!(out.contains("c.demo_register_validator("), "wrong C symbol: {out}");
779        // Unregistration shim
780        assert!(
781            out.contains("pub fn unregister_validator("),
782            "missing unregister fn: {out}"
783        );
784        assert!(
785            out.contains("c.demo_unregister_validator("),
786            "wrong unregister C symbol: {out}"
787        );
788        // No plugin lifecycle when no super_trait
789        assert!(
790            !out.contains("name_fn:"),
791            "should not emit name_fn without super_trait: {out}"
792        );
793    }
794
795    #[test]
796    fn multi_method_trait_with_super_trait_emits_lifecycle_slots() {
797        let trait_def = make_trait_def(
798            "OcrBackend",
799            vec![
800                make_method(
801                    "process_image",
802                    vec![
803                        make_param("image_bytes", TypeRef::Bytes),
804                        make_param("config", TypeRef::String),
805                    ],
806                    TypeRef::String,
807                    Some("OcrError"),
808                ),
809                make_method(
810                    "supports_language",
811                    vec![make_param("lang", TypeRef::String)],
812                    TypeRef::Primitive(PrimitiveType::Bool),
813                    None,
814                ),
815            ],
816        );
817        let bridge_cfg = make_bridge_cfg("OcrBackend", Some("kreuzberg::plugins::Plugin"));
818
819        let mut out = String::new();
820        emit_trait_bridge("kreuzberg", &bridge_cfg, &trait_def, &mut out);
821
822        // Struct name
823        assert!(
824            out.contains("pub const IOcrBackend = extern struct {"),
825            "missing vtable: {out}"
826        );
827        // Plugin lifecycle slots emitted
828        assert!(out.contains("name_fn:"), "missing name_fn: {out}");
829        assert!(out.contains("version_fn:"), "missing version_fn: {out}");
830        assert!(out.contains("initialize_fn:"), "missing initialize_fn: {out}");
831        assert!(out.contains("shutdown_fn:"), "missing shutdown_fn: {out}");
832        // Trait method slots
833        assert!(out.contains("process_image:"), "missing process_image slot: {out}");
834        assert!(
835            out.contains("supports_language:"),
836            "missing supports_language slot: {out}"
837        );
838        // Bytes param expands to ptr + len
839        assert!(out.contains("image_bytes_ptr:"), "missing bytes ptr expansion: {out}");
840        assert!(out.contains("image_bytes_len:"), "missing bytes len expansion: {out}");
841        // Fallible method gets out_error
842        assert!(
843            out.contains("out_error:"),
844            "missing out_error for fallible method: {out}"
845        );
846        // C symbols use kreuzberg prefix
847        assert!(
848            out.contains("c.kreuzberg_register_ocr_backend("),
849            "wrong register symbol: {out}"
850        );
851        assert!(
852            out.contains("c.kreuzberg_unregister_ocr_backend("),
853            "wrong unregister symbol: {out}"
854        );
855        // Registration shim signature
856        assert!(
857            out.contains("pub fn register_ocr_backend("),
858            "missing register_ocr_backend fn: {out}"
859        );
860    }
861
862    // -----------------------------------------------------------------
863    // make_*_vtable tests
864    // -----------------------------------------------------------------
865
866    #[test]
867    fn make_vtable_emits_comptime_function_and_thunk() {
868        let trait_def = make_trait_def(
869            "Validator",
870            vec![make_method(
871                "validate",
872                vec![make_param("input", TypeRef::String)],
873                TypeRef::Primitive(PrimitiveType::Bool),
874                None,
875            )],
876        );
877        let bridge_cfg = make_bridge_cfg("Validator", None);
878
879        let mut out = String::new();
880        emit_trait_bridge("demo", &bridge_cfg, &trait_def, &mut out);
881
882        // Helper function declaration
883        assert!(
884            out.contains("pub fn make_validator_vtable(comptime T: type, instance: *T)"),
885            "missing make_validator_vtable: {out}"
886        );
887        // Returns the vtable type
888        assert!(out.contains("IValidator{"), "missing vtable literal: {out}");
889        // Thunk casts user_data
890        assert!(out.contains("@ptrCast(@alignCast(ud))"), "missing @ptrCast cast: {out}");
891        // callconv(.C) in thunk
892        assert!(out.contains("callconv(.C)"), "missing callconv(.C) in thunk: {out}");
893        // validate thunk field
894        assert!(out.contains(".validate ="), "missing .validate thunk field: {out}");
895        // free_user_data thunk
896        assert!(
897            out.contains(".free_user_data ="),
898            "missing .free_user_data thunk: {out}"
899        );
900        // No lifecycle stubs without super_trait
901        assert!(
902            !out.contains(".name_fn ="),
903            "must not emit .name_fn without super_trait: {out}"
904        );
905    }
906
907    #[test]
908    fn make_vtable_with_super_trait_emits_lifecycle_stubs() {
909        let trait_def = make_trait_def("OcrBackend", vec![]);
910        let bridge_cfg = make_bridge_cfg("OcrBackend", Some("kreuzberg::Plugin"));
911
912        let mut out = String::new();
913        emit_trait_bridge("kreuzberg", &bridge_cfg, &trait_def, &mut out);
914
915        assert!(
916            out.contains("pub fn make_ocr_backend_vtable(comptime T: type, instance: *T)"),
917            "missing make_ocr_backend_vtable: {out}"
918        );
919        assert!(out.contains(".name_fn ="), "missing .name_fn stub: {out}");
920        assert!(out.contains(".version_fn ="), "missing .version_fn stub: {out}");
921        assert!(out.contains(".initialize_fn ="), "missing .initialize_fn stub: {out}");
922        assert!(out.contains(".shutdown_fn ="), "missing .shutdown_fn stub: {out}");
923    }
924
925    #[test]
926    fn make_vtable_bytes_param_reconstructs_slice_in_thunk() {
927        let trait_def = make_trait_def(
928            "Processor",
929            vec![make_method(
930                "process",
931                vec![make_param("data", TypeRef::Bytes)],
932                TypeRef::Unit,
933                None,
934            )],
935        );
936        let bridge_cfg = make_bridge_cfg("Processor", None);
937
938        let mut out = String::new();
939        emit_trait_bridge("demo", &bridge_cfg, &trait_def, &mut out);
940
941        // Thunk receives ptr+len params
942        assert!(out.contains("data_ptr: [*c]const u8"), "missing data_ptr param: {out}");
943        assert!(out.contains("data_len: usize"), "missing data_len param: {out}");
944        // Thunk reconstructs slice
945        assert!(
946            out.contains("data_ptr[0..data_len]"),
947            "thunk must reconstruct slice from ptr+len: {out}"
948        );
949        // Thunk calls self.process with the slice
950        assert!(
951            out.contains("self.process(data_slice)"),
952            "thunk must call self.process: {out}"
953        );
954    }
955
956    #[test]
957    fn make_vtable_fallible_method_returns_i32_error_code() {
958        let trait_def = make_trait_def(
959            "Parser",
960            vec![make_method("parse", vec![], TypeRef::Unit, Some("ParseError"))],
961        );
962        let bridge_cfg = make_bridge_cfg("Parser", None);
963
964        let mut out = String::new();
965        emit_trait_bridge("demo", &bridge_cfg, &trait_def, &mut out);
966
967        // Thunk returns i32 (fallible → i32 return)
968        assert!(
969            out.contains("callconv(.C) i32"),
970            "fallible thunk must return i32: {out}"
971        );
972        // Returns 0 on success
973        assert!(out.contains("return 0;"), "must return 0 on success: {out}");
974        // Returns 1 on error
975        assert!(out.contains("return 1;"), "must return 1 on error: {out}");
976        // Error branch writes to out_error
977        assert!(out.contains("out_error"), "must write to out_error: {out}");
978    }
979
980    #[test]
981    fn make_vtable_primitive_return_passes_through() {
982        let trait_def = make_trait_def(
983            "Counter",
984            vec![make_method(
985                "count",
986                vec![],
987                TypeRef::Primitive(PrimitiveType::I32),
988                None,
989            )],
990        );
991        let bridge_cfg = make_bridge_cfg("demo", None);
992
993        let mut out = String::new();
994        emit_trait_bridge("demo", &bridge_cfg, &trait_def, &mut out);
995
996        // Infallible primitive method: thunk returns the value directly
997        assert!(
998            out.contains("return self.count()"),
999            "primitive return must be forwarded directly: {out}"
1000        );
1001    }
1002
1003    // -----------------------------------------------------------------
1004    // ZigTraitBridgeGenerator tests
1005    // -----------------------------------------------------------------
1006
1007    fn make_spec<'a>(trait_def: &'a TypeDef, bridge_cfg: &'a TraitBridgeConfig) -> TraitBridgeSpec<'a> {
1008        use alef_codegen::generators::trait_bridge::TraitBridgeSpec;
1009        use std::collections::HashMap;
1010        TraitBridgeSpec {
1011            trait_def,
1012            bridge_config: bridge_cfg,
1013            core_import: "kreuzberg",
1014            wrapper_prefix: "Zig",
1015            type_paths: HashMap::new(),
1016            error_type: "KreuzbergError".to_string(),
1017            error_constructor: "KreuzbergError::msg({msg})".to_string(),
1018        }
1019    }
1020
1021    #[test]
1022    fn gen_unregistration_fn_emits_wrapper_when_configured() {
1023        let trait_def = make_trait_def("OcrBackend", vec![]);
1024        let mut bridge_cfg = make_bridge_cfg("OcrBackend", None);
1025        bridge_cfg.unregister_fn = Some("unregister_ocr_backend".to_string());
1026
1027        let generator = ZigTraitBridgeGenerator::new("kreuzberg");
1028        let spec = make_spec(&trait_def, &bridge_cfg);
1029        let out = generator.gen_unregistration_fn(&spec);
1030
1031        assert!(!out.is_empty(), "expected non-empty output when unregister_fn is set");
1032        assert!(
1033            out.contains("pub fn unregister_ocr_backend("),
1034            "wrong function name: {out}"
1035        );
1036        assert!(
1037            out.contains("c.kreuzberg_unregister_ocr_backend("),
1038            "wrong C symbol: {out}"
1039        );
1040        assert!(
1041            out.contains("out_error: ?*?[*c]u8") || out.contains("out_error"),
1042            "missing out_error param: {out}"
1043        );
1044        assert!(out.contains("return "), "missing return statement: {out}");
1045        assert!(out.ends_with("}\n"), "missing closing brace: {out}");
1046    }
1047
1048    #[test]
1049    fn gen_unregistration_fn_returns_empty_when_not_configured() {
1050        let trait_def = make_trait_def("OcrBackend", vec![]);
1051        let bridge_cfg = make_bridge_cfg("OcrBackend", None); // unregister_fn is None
1052
1053        let generator = ZigTraitBridgeGenerator::new("kreuzberg");
1054        let spec = make_spec(&trait_def, &bridge_cfg);
1055        let out = generator.gen_unregistration_fn(&spec);
1056
1057        assert!(
1058            out.is_empty(),
1059            "expected empty output when unregister_fn is None, got: {out}"
1060        );
1061    }
1062
1063    #[test]
1064    fn gen_clear_fn_emits_wrapper_when_configured() {
1065        let trait_def = make_trait_def("OcrBackend", vec![]);
1066        let mut bridge_cfg = make_bridge_cfg("OcrBackend", None);
1067        bridge_cfg.clear_fn = Some("clear_ocr_backends".to_string());
1068
1069        let generator = ZigTraitBridgeGenerator::new("kreuzberg");
1070        let spec = make_spec(&trait_def, &bridge_cfg);
1071        let out = generator.gen_clear_fn(&spec);
1072
1073        assert!(!out.is_empty(), "expected non-empty output when clear_fn is set");
1074        assert!(out.contains("pub fn clear_ocr_backends("), "wrong function name: {out}");
1075        assert!(out.contains("c.kreuzberg_clear_ocr_backends("), "wrong C symbol: {out}");
1076        assert!(
1077            out.contains("out_error: ?*?[*c]u8") || out.contains("out_error"),
1078            "missing out_error param: {out}"
1079        );
1080        assert!(out.contains("return "), "missing return statement: {out}");
1081        assert!(out.ends_with("}\n"), "missing closing brace: {out}");
1082    }
1083
1084    #[test]
1085    fn gen_clear_fn_returns_empty_when_not_configured() {
1086        let trait_def = make_trait_def("OcrBackend", vec![]);
1087        let bridge_cfg = make_bridge_cfg("OcrBackend", None); // clear_fn is None
1088
1089        let generator = ZigTraitBridgeGenerator::new("kreuzberg");
1090        let spec = make_spec(&trait_def, &bridge_cfg);
1091        let out = generator.gen_clear_fn(&spec);
1092
1093        assert!(
1094            out.is_empty(),
1095            "expected empty output when clear_fn is None, got: {out}"
1096        );
1097    }
1098
1099    #[test]
1100    fn gen_unregistration_fn_uses_snake_case_function_name_verbatim() {
1101        // The configured `unregister_fn` name is used as-is (not re-derived from the trait).
1102        let trait_def = make_trait_def("DocumentExtractor", vec![]);
1103        let mut bridge_cfg = make_bridge_cfg("DocumentExtractor", None);
1104        bridge_cfg.unregister_fn = Some("unregister_extractor".to_string());
1105
1106        let generator = ZigTraitBridgeGenerator::new("demo");
1107        let spec = make_spec(&trait_def, &bridge_cfg);
1108        let out = generator.gen_unregistration_fn(&spec);
1109
1110        assert!(
1111            out.contains("pub fn unregister_extractor("),
1112            "must use configured fn name verbatim: {out}"
1113        );
1114        assert!(
1115            out.contains("c.demo_unregister_extractor("),
1116            "must use configured fn name in C symbol: {out}"
1117        );
1118    }
1119
1120    #[test]
1121    fn gen_clear_fn_uses_configured_fn_name_verbatim() {
1122        let trait_def = make_trait_def("DocumentExtractor", vec![]);
1123        let mut bridge_cfg = make_bridge_cfg("DocumentExtractor", None);
1124        bridge_cfg.clear_fn = Some("clear_all_extractors".to_string());
1125
1126        let generator = ZigTraitBridgeGenerator::new("demo");
1127        let spec = make_spec(&trait_def, &bridge_cfg);
1128        let out = generator.gen_clear_fn(&spec);
1129
1130        assert!(
1131            out.contains("pub fn clear_all_extractors("),
1132            "must use configured fn name verbatim: {out}"
1133        );
1134        assert!(
1135            out.contains("c.demo_clear_all_extractors("),
1136            "must use configured fn name in C symbol: {out}"
1137        );
1138    }
1139}