Skip to main content

alef_backend_zig/
trait_bridge.rs

1//! Zig trait-bridge code generation.
2//!
3//! Emits one Zig extern struct (vtable) and one registration wrapper function
4//! per configured `[[trait_bridges]]` entry.  The Zig consumer fills in the
5//! struct with `callconv(.C)` function pointers and calls `register_*`.
6//!
7//! # C symbol convention
8//!
9//! The generated `register_{trait_snake}` shim calls
10//! `c.{prefix}_register_{trait_snake}` — the symbol exposed by the
11//! `kreuzberg-ffi` C layer (pattern: `{crate_prefix}_register_{trait_snake}`).
12//! If the actual symbol differs, override the generated call site.
13//!
14//! # `TraitBridgeGenerator` implementation
15//!
16//! [`ZigTraitBridgeGenerator`] implements the shared [`TraitBridgeGenerator`]
17//! trait so that the shared codegen driver can invoke the Zig-specific
18//! `gen_unregistration_fn` and `gen_clear_fn` overrides.  The other required
19//! methods are stubs — Zig code is produced through the standalone
20//! [`emit_trait_bridge`] free function, not the shared driver.
21
22use alef_codegen::generators::trait_bridge::{TraitBridgeGenerator, TraitBridgeSpec};
23use alef_core::config::{BridgeBinding, TraitBridgeConfig};
24use alef_core::ir::{MethodDef, TypeDef, TypeRef};
25use heck::ToSnakeCase;
26use std::fmt::Write as _;
27
28/// Zig type string to use for a vtable slot parameter or return type.
29///
30/// All string/complex types collapse to `[*c]const u8` (C string pointer) since
31/// the vtable slots use the raw C ABI — not the Zig-friendly wrapper layer.
32fn vtable_param_type(ty: &TypeRef) -> &'static str {
33    match ty {
34        TypeRef::Primitive(p) => {
35            use alef_core::ir::PrimitiveType::*;
36            match p {
37                Bool => "i32",
38                U8 => "u8",
39                U16 => "u16",
40                U32 => "u32",
41                U64 => "u64",
42                I8 => "i8",
43                I16 => "i16",
44                I32 => "i32",
45                I64 => "i64",
46                F32 => "f32",
47                F64 => "f64",
48                Usize => "usize",
49                Isize => "isize",
50            }
51        }
52        TypeRef::Unit => "void",
53        TypeRef::Duration => "i64",
54        // All string/path/complex types become C string pointers at the C ABI boundary.
55        _ => "[*c]const u8",
56    }
57}
58
59/// Zig return type for a vtable slot.
60///
61/// Fallible methods always return `i32` (0 = success, non-zero = error).
62/// Unit infallible methods return `void`.  Other infallible returns use the
63/// primitive mapping.
64fn vtable_return_type(method: &MethodDef) -> String {
65    if method.error_type.is_some() {
66        "i32".to_string()
67    } else {
68        vtable_param_type(&method.return_type).to_string()
69    }
70}
71
72/// Build a snake_case trait name from a PascalCase trait name.
73///
74/// Uses `heck::ToSnakeCase`, matching the pattern used by Go/C# backends.
75fn trait_snake(trait_name: &str) -> String {
76    trait_name.to_snake_case()
77}
78
79/// Emit a Zig param name for the C-ABI slot, expanding `Bytes` to ptr+len.
80///
81/// Returns a list of `(c_param_name, c_param_type)` pairs.
82fn vtable_c_params(method: &MethodDef) -> Vec<(String, String)> {
83    let mut params = vec![("ud".to_string(), "?*anyopaque".to_string())];
84    for p in &method.params {
85        if matches!(p.ty, TypeRef::Bytes) {
86            params.push((format!("{}_ptr", p.name), "[*c]const u8".to_string()));
87            params.push((format!("{}_len", p.name), "usize".to_string()));
88        } else {
89            params.push((p.name.clone(), vtable_param_type(&p.ty).to_string()));
90        }
91    }
92    if method.error_type.is_some() {
93        if !matches!(method.return_type, TypeRef::Unit) {
94            params.push(("out_result".to_string(), "?*?[*c]u8".to_string()));
95        }
96        params.push(("out_error".to_string(), "?*?[*c]u8".to_string()));
97    } else if !matches!(method.return_type, TypeRef::Unit) {
98        params.push(("out_result".to_string(), "?*?[*c]u8".to_string()));
99    }
100    params
101}
102
103/// Emit a `make_{trait_snake}_vtable(comptime T: type, instance: *T) I{Trait}` helper.
104///
105/// The helper builds `callconv(.C)` thunks for every vtable slot so the consumer
106/// only needs to write plain Zig methods on their type.
107///
108/// # Limitations
109///
110/// - Methods returning non-unit values through `out_result` use `unreachable` for
111///   the conversion path when the type cannot be expressed as a direct C primitive
112///   (complex types are documented as requiring manual implementation).
113/// - Lifecycle slots (`name_fn`, `version_fn`, `initialize_fn`, `shutdown_fn`) are
114///   emitted with `unreachable` bodies as stubs — the consumer overrides the
115///   relevant field in the returned vtable if needed.
116pub fn emit_make_vtable(trait_name: &str, has_super_trait: bool, trait_def: &TypeDef, out: &mut String) {
117    let snake = trait_snake(trait_name);
118
119    out.push_str(&crate::template_env::render(
120        "vtable_header_doc.jinja",
121        minijinja::context! {
122            trait_name => trait_name,
123            snake => &snake,
124        },
125    ));
126    out.push_str(&crate::template_env::render(
127        "vtable_impl_method.jinja",
128        minijinja::context! {
129            snake => &snake,
130            trait_name => trait_name,
131        },
132    ));
133    out.push_str(&crate::template_env::render(
134        "vtable_make_fn_header.jinja",
135        minijinja::context! {
136            trait_name => trait_name,
137        },
138    ));
139
140    // Lifecycle stubs when super_trait is present
141    if has_super_trait {
142        out.push_str(&crate::template_env::render(
143            "vtable_field_name_fn.jinja",
144            minijinja::context! {},
145        ));
146        out.push_str(&crate::template_env::render(
147            "vtable_field_version_fn.jinja",
148            minijinja::context! {},
149        ));
150        out.push_str(&crate::template_env::render(
151            "vtable_field_initialize_fn.jinja",
152            minijinja::context! {},
153        ));
154        out.push_str(&crate::template_env::render(
155            "vtable_field_shutdown_fn.jinja",
156            minijinja::context! {},
157        ));
158    }
159
160    // Per-method thunks
161    for method in &trait_def.methods {
162        let method_snake = method.name.to_snake_case();
163        let c_params = vtable_c_params(method);
164        let ret = vtable_return_type(method);
165
166        // Build the thunk parameter list string
167        let params_str = c_params
168            .iter()
169            .map(|(name, ty)| format!("{name}: {ty}"))
170            .collect::<Vec<_>>()
171            .join(", ");
172
173        out.push_str(&crate::template_env::render(
174            "vtable_instance_field.jinja",
175            minijinja::context! {
176                method_snake => &method_snake,
177                params_str => &params_str,
178                ret => &ret,
179            },
180        ));
181
182        // Cast user_data to *T
183        out.push_str("                const self: *T = @ptrCast(@alignCast(ud));\n");
184
185        // Reconstruct Bytes slices and build forwarding arg list
186        let mut call_args: Vec<String> = Vec::new();
187        for p in &method.params {
188            if matches!(p.ty, TypeRef::Bytes) {
189                out.push_str(&crate::template_env::render(
190                    "thunk_bytes_slice.jinja",
191                    minijinja::context! {
192                        slice_name => format!("{}_slice", p.name),
193                        ptr_name => format!("{}_ptr", p.name),
194                        len_name => format!("{}_len", p.name),
195                    },
196                ));
197                call_args.push(format!("{}_slice", p.name));
198            } else {
199                call_args.push(p.name.clone());
200            }
201        }
202
203        let args_str = call_args.join(", ");
204
205        // Pick a capture name for the success branch that won't collide with method
206        // params. Methods can have a param literally called `result`; using that as
207        // the unwrap binding shadows the outer scope (zig 0.16+ flags this).
208        let ok_binding = if method.params.iter().any(|p| p.name == "value") {
209            "ok_value"
210        } else {
211            "value"
212        };
213
214        if method.error_type.is_some() {
215            // Fallible method: call returns error union, write out_result/out_error
216            let has_result_out = !matches!(method.return_type, TypeRef::Unit);
217            out.push_str(&crate::template_env::render(
218                "thunk_fn_signature.jinja",
219                minijinja::context! {
220                    method_snake => &method_snake,
221                    args_str => &args_str,
222                    ok_binding => &ok_binding,
223                },
224            ));
225            // Write result via out_result pointer — for complex types this is unreachable.
226            // `unreachable` diverges, so any code after it (including `return 0;`) would
227            // be flagged "unreachable code" by zig 0.16+; only emit the trailing return
228            // when the success path actually flows through.
229            let mut success_path_diverges = false;
230            if has_result_out {
231                match &method.return_type {
232                    TypeRef::Primitive(_) | TypeRef::Unit => {
233                        out.push_str(&crate::template_env::render(
234                            "thunk_result_assign.jinja",
235                            minijinja::context! {
236                                ok_binding => &ok_binding,
237                            },
238                        ));
239                    }
240                    _ => {
241                        // String/Bytes/complex: cannot safely convert without allocator context
242                        out.push_str(&crate::template_env::render(
243                            "thunk_if_fallible.jinja",
244                            minijinja::context! {
245                                ok_binding => &ok_binding,
246                            },
247                        ));
248                        success_path_diverges = true;
249                    }
250                }
251            } else {
252                // Unit return on success — discard the captured Void to silence unused-variable.
253                out.push_str(&crate::template_env::render(
254                    "thunk_if_ok_result.jinja",
255                    minijinja::context! {
256                        ok_binding => &ok_binding,
257                    },
258                ));
259            }
260            if !success_path_diverges {
261                out.push_str("                    return 0;\n");
262            }
263            out.push_str("                } else |err| {\n");
264            out.push_str("                    _ = err;\n");
265            out.push_str("                    if (out_error) |ptr| ptr.* = null; // caller checks error code\n");
266            out.push_str("                    return 1;\n");
267            out.push_str("                }\n");
268        } else {
269            // Infallible non-Unit methods get an `out_result` param "for uniformity"
270            // (see vtable_c_params), but the body returns the value directly via the
271            // function return type — so the param is unused. Discard it so zig 0.16+
272            // doesn't flag "unused function parameter".
273            if !matches!(method.return_type, TypeRef::Unit) {
274                out.push_str("                _ = out_result;\n");
275            }
276            match &method.return_type {
277                TypeRef::Unit => {
278                    out.push_str(&crate::template_env::render(
279                        "thunk_if_error.jinja",
280                        minijinja::context! {
281                            method_snake => &method_snake,
282                            args_str => &args_str,
283                        },
284                    ));
285                }
286                TypeRef::Primitive(_) => {
287                    out.push_str(&crate::template_env::render(
288                        "thunk_infallible_return.jinja",
289                        minijinja::context! {
290                            method_snake => &method_snake,
291                            args_str => &args_str,
292                        },
293                    ));
294                }
295                _ => {
296                    // Non-unit infallible non-primitive: pass through (e.g., [*c]const u8)
297                    out.push_str(&crate::template_env::render(
298                        "thunk_infallible_return.jinja",
299                        minijinja::context! {
300                            method_snake => &method_snake,
301                            args_str => &args_str,
302                        },
303                    ));
304                }
305            }
306        }
307
308        out.push_str("            }\n");
309        out.push_str("        }.thunk,\n");
310        out.push('\n');
311    }
312
313    // free_user_data stub — does nothing by default; caller overrides if needed
314    out.push_str(&crate::template_env::render(
315        "vtable_free_user_data.jinja",
316        minijinja::context! {},
317    ));
318
319    out.push_str("    };\n");
320    out.push_str("}\n");
321}
322
323/// Emit the vtable extern struct and registration shim for a single trait bridge.
324///
325/// `prefix` is the C FFI prefix (e.g., `"kreuzberg"`).
326/// `bridge_cfg` is the trait bridge configuration entry.
327/// `trait_def` is the IR type definition for the trait (must have `is_trait = true`).
328/// `out` is the output buffer to append to.
329pub fn emit_trait_bridge(prefix: &str, bridge_cfg: &TraitBridgeConfig, trait_def: &TypeDef, out: &mut String) {
330    let trait_name = &trait_def.name;
331    let snake = trait_snake(trait_name);
332    let has_super_trait = bridge_cfg.super_trait.is_some();
333
334    // -------------------------------------------------------------------------
335    // Vtable struct: I{Trait}
336    // -------------------------------------------------------------------------
337    out.push_str(&crate::template_env::render(
338        "trait_vtable_header.jinja",
339        minijinja::context! {
340            trait_name => trait_name,
341            snake => &snake,
342        },
343    ));
344    out.push_str(&crate::template_env::render(
345        "trait_struct_header.jinja",
346        minijinja::context! {
347            trait_name => trait_name,
348        },
349    ));
350
351    // Plugin lifecycle slots — always present when a super_trait is configured.
352    if has_super_trait {
353        out.push_str("    /// Return the plugin name into `out_name` (heap-allocated, caller frees).\n");
354        out.push_str(
355            "    name_fn: ?*const fn (user_data: ?*anyopaque, out_name: ?*?[*c]u8) callconv(.C) void = null,\n",
356        );
357        out.push('\n');
358
359        out.push_str("    /// Return the plugin version into `out_version` (heap-allocated, caller frees).\n");
360        out.push_str(
361            "    version_fn: ?*const fn (user_data: ?*anyopaque, out_version: ?*?[*c]u8) callconv(.C) void = null,\n",
362        );
363        out.push('\n');
364
365        out.push_str("    /// Initialise the plugin; return 0 on success, non-zero on error.\n");
366        out.push_str(
367            "    initialize_fn: ?*const fn (user_data: ?*anyopaque, out_error: ?*?[*c]u8) callconv(.C) i32 = null,\n",
368        );
369        out.push('\n');
370
371        out.push_str("    /// Shut down the plugin; return 0 on success, non-zero on error.\n");
372        out.push_str(
373            "    shutdown_fn: ?*const fn (user_data: ?*anyopaque, out_error: ?*?[*c]u8) callconv(.C) i32 = null,\n",
374        );
375        out.push('\n');
376    }
377
378    // Trait method slots
379    for method in &trait_def.methods {
380        if !method.doc.is_empty() {
381            out.push_str(&crate::template_env::render(
382                "trait_method_doc_lines.jinja",
383                minijinja::context! {
384                    method_doc_lines => method.doc.lines().collect::<Vec<_>>(),
385                },
386            ));
387        }
388
389        let ret = vtable_return_type(method);
390        let method_snake = method.name.to_snake_case();
391
392        // Build the parameter list: user_data first, then method params.
393        let mut params = vec!["user_data: ?*anyopaque".to_string()];
394        for p in &method.params {
395            let ty = vtable_param_type(&p.ty);
396            // Bytes expand to two args (ptr + len)
397            if matches!(p.ty, TypeRef::Bytes) {
398                params.push(format!("{}_ptr: [*c]const u8", p.name));
399                params.push(format!("{}_len: usize", p.name));
400            } else {
401                params.push(format!("{}: {ty}", p.name));
402            }
403        }
404
405        // Fallible methods get out-result and out-error pointers.
406        if method.error_type.is_some() {
407            if !matches!(method.return_type, TypeRef::Unit) {
408                params.push("out_result: ?*?[*c]u8".to_string());
409            }
410            params.push("out_error: ?*?[*c]u8".to_string());
411        } else if !matches!(method.return_type, TypeRef::Unit) {
412            // Infallible non-void: return via out_result too for uniformity
413            params.push("out_result: ?*?[*c]u8".to_string());
414        }
415
416        let params_str = params.join(", ");
417        out.push_str(&crate::template_env::render(
418            "trait_method_signature.jinja",
419            minijinja::context! {
420                method_snake => &method_snake,
421                params_str => &params_str,
422                ret => &ret,
423            },
424        ));
425    }
426
427    // free_user_data — always last; called by Rust Drop to release the Zig-side handle.
428    out.push_str("    /// Called by the Rust runtime when the bridge is dropped.\n");
429    out.push_str("    /// Use this to release any Zig-side state held via `user_data`.\n");
430    out.push_str("    free_user_data: ?*const fn (user_data: ?*anyopaque) callconv(.C) void = null,\n");
431
432    out.push_str("};\n");
433    out.push('\n');
434
435    // -------------------------------------------------------------------------
436    // Registration / unregistration shims (function-param binding only).
437    //
438    // When `bind_via = "options_field"` the bridge is wired to a field on a
439    // configured options struct (e.g. `ConversionOptions.visitor`); there is
440    // no `{prefix}_register_{trait}` / `{prefix}_unregister_{trait}` C
441    // symbol to call. Emitting the shims unconditionally would produce code
442    // that fails to link. Options-field bridges instead consume the C
443    // vtable directly via a small `..._handle_from_vtable` helper (see
444    // below).
445    // -------------------------------------------------------------------------
446    if matches!(bridge_cfg.bind_via, BridgeBinding::FunctionParam) {
447        let c_register = format!("c.{prefix}_register_{snake}");
448        let c_unregister = format!("c.{prefix}_unregister_{snake}");
449
450        out.push_str(&crate::template_env::render(
451            "register_fn_doc1.jinja",
452            minijinja::context! {
453                trait_name => trait_name,
454                snake => &snake,
455            },
456        ));
457        out.push_str(&crate::template_env::render(
458            "register_fn_signature.jinja",
459            minijinja::context! {
460                snake => &snake,
461                trait_name => trait_name,
462            },
463        ));
464        out.push_str(&crate::template_env::render(
465            "register_fn_body.jinja",
466            minijinja::context! {
467                c_register => &c_register,
468            },
469        ));
470        out.push_str("}\n");
471        out.push('\n');
472
473        out.push_str(&crate::template_env::render(
474            "unregister_fn_doc.jinja",
475            minijinja::context! {
476                trait_name => trait_name,
477            },
478        ));
479        out.push_str(&crate::template_env::render(
480            "unregister_fn_signature.jinja",
481            minijinja::context! {
482                snake => &snake,
483            },
484        ));
485        out.push_str(&crate::template_env::render(
486            "unregister_fn_body.jinja",
487            minijinja::context! {
488                c_unregister => &c_unregister,
489            },
490        ));
491        out.push_str("}\n");
492        out.push('\n');
493
494        // ---------------------------------------------------------------
495        // Clear wrapper (registry-wide reset).
496        //
497        // The Zig wrapper is named after `bridge_cfg.clear_fn` verbatim
498        // (e.g. `clear_ocr_backends` — pluralised by convention to signal
499        // multi-removal). The underlying C FFI symbol follows the singular
500        // trait-snake naming used elsewhere in `kreuzberg-ffi`
501        // (`kreuzberg_clear_ocr_backend`), so derive `c_clear` from
502        // `trait_snake` rather than from `clear_fn`.
503        // ---------------------------------------------------------------
504        if let Some(clear_fn) = bridge_cfg.clear_fn.as_deref() {
505            let c_clear = format!("c.{prefix}_clear_{snake}");
506
507            out.push_str(&crate::template_env::render(
508                "clear_fn_doc.jinja",
509                minijinja::context! {
510                    trait_name => trait_name,
511                },
512            ));
513            out.push_str(&crate::template_env::render(
514                "clear_fn_signature.jinja",
515                minijinja::context! {
516                    clear_fn => clear_fn,
517                },
518            ));
519            out.push_str(&crate::template_env::render(
520                "clear_fn_body.jinja",
521                minijinja::context! {
522                    c_clear => &c_clear,
523                },
524            ));
525            out.push_str("}\n");
526            out.push('\n');
527        }
528    } else {
529        // Options-field binding: emit a vtable -> handle helper that wraps the
530        // C callbacks struct into the trait-object handle expected by the
531        // generated `ConversionOptionsBuilder.{field}` setter. The upstream
532        // FFI must export `{prefix}_{trait_snake}_handle_from_callbacks` with
533        // the standard `extern "C" fn(*const T) -> *mut Handle` shape.
534        let ctor_fn = format!("c.{prefix}_{snake}_handle_from_callbacks");
535        let handle_type = bridge_cfg.type_alias.as_deref().unwrap_or("VisitorHandle").to_string();
536        let _ = writeln!(
537            out,
538            "/// Wrap a `I{trait_name}` vtable into a `{handle_type}` suitable for the"
539        );
540        let _ = writeln!(
541            out,
542            "/// generated options-field setter (e.g. `ConversionOptionsBuilder.visitor`)."
543        );
544        let _ = writeln!(
545            out,
546            "/// The returned handle owns the vtable's function pointers and must be"
547        );
548        let _ = writeln!(
549            out,
550            "/// released with the matching `{prefix}_visitor_handle_free` once the"
551        );
552        let _ = writeln!(out, "/// containing options object is no longer needed.");
553        let _ = writeln!(
554            out,
555            "pub fn {snake}_handle_from_vtable(callbacks: c.HTMHtmVisitorCallbacks) ?{handle_type} {{"
556        );
557        let _ = writeln!(out, "    var _cb = callbacks;");
558        let _ = writeln!(out, "    return @ptrCast({ctor_fn}(&_cb));");
559        let _ = writeln!(out, "}}");
560        let _ = writeln!(out);
561    }
562
563    // -------------------------------------------------------------------------
564    // Comptime vtable builder: make_{trait_snake}_vtable
565    // -------------------------------------------------------------------------
566    emit_make_vtable(trait_name, has_super_trait, trait_def, out);
567}
568
569// ---------------------------------------------------------------------------
570// TraitBridgeGenerator implementation for the Zig backend
571// ---------------------------------------------------------------------------
572
573/// Zig-specific [`TraitBridgeGenerator`] implementation.
574///
575/// Carries the FFI symbol prefix (e.g., `"kreuzberg"`) used when deriving the
576/// C symbol for `unregister_*` and `clear_*` wrappers.
577///
578/// The required trait methods that produce *Rust* source (`gen_sync_method_body`,
579/// `gen_async_method_body`, `gen_constructor`, `gen_registration_fn`) return
580/// empty strings because Zig bridge code is produced by the standalone
581/// [`emit_trait_bridge`] free function, not the shared driver.
582pub struct ZigTraitBridgeGenerator {
583    /// FFI symbol prefix (e.g., `"kreuzberg"`).
584    pub prefix: String,
585}
586
587impl ZigTraitBridgeGenerator {
588    /// Construct a new generator for the given FFI symbol prefix.
589    pub fn new(prefix: impl Into<String>) -> Self {
590        Self { prefix: prefix.into() }
591    }
592}
593
594impl TraitBridgeGenerator for ZigTraitBridgeGenerator {
595    // ------------------------------------------------------------------
596    // Stub methods — Zig bridge code is emitted by `emit_trait_bridge`.
597    // ------------------------------------------------------------------
598
599    fn foreign_object_type(&self) -> &str {
600        ""
601    }
602
603    fn bridge_imports(&self) -> Vec<String> {
604        Vec::new()
605    }
606
607    fn gen_sync_method_body(&self, _method: &MethodDef, _spec: &TraitBridgeSpec) -> String {
608        String::new()
609    }
610
611    fn gen_async_method_body(&self, _method: &MethodDef, _spec: &TraitBridgeSpec) -> String {
612        String::new()
613    }
614
615    fn gen_constructor(&self, _spec: &TraitBridgeSpec) -> String {
616        String::new()
617    }
618
619    fn gen_registration_fn(&self, _spec: &TraitBridgeSpec) -> String {
620        String::new()
621    }
622
623    // ------------------------------------------------------------------
624    // Zig-specific overrides
625    // ------------------------------------------------------------------
626
627    /// Emit a Zig wrapper that calls `c.{prefix}_{unregister_fn}(name, out_error)`.
628    ///
629    /// Returns an empty string when `spec.bridge_config.unregister_fn` is `None`.
630    fn gen_unregistration_fn(&self, spec: &TraitBridgeSpec) -> String {
631        let Some(unregister_fn) = spec.bridge_config.unregister_fn.as_deref() else {
632            return String::new();
633        };
634        let c_unregister = format!("c.{}_{}", self.prefix, unregister_fn);
635
636        let mut out = String::new();
637        out.push_str(&crate::template_env::render(
638            "unregister_fn_doc.jinja",
639            minijinja::context! {
640                trait_name => spec.trait_def.name.as_str(),
641            },
642        ));
643        // Emit the signature directly: the configured `unregister_fn` is the
644        // complete Zig function name, not just the trait-snake suffix.
645        out.push_str(&crate::template_env::render(
646            "unregister_fn_configured_signature.jinja",
647            minijinja::context! {
648                unregister_fn => unregister_fn,
649            },
650        ));
651        out.push_str(&crate::template_env::render(
652            "unregister_fn_body.jinja",
653            minijinja::context! {
654                c_unregister => &c_unregister,
655            },
656        ));
657        out.push_str("}\n");
658        out
659    }
660
661    /// Emit a Zig wrapper that calls `c.{prefix}_{clear_fn}(out_error)`.
662    ///
663    /// Returns an empty string when `spec.bridge_config.clear_fn` is `None`.
664    fn gen_clear_fn(&self, spec: &TraitBridgeSpec) -> String {
665        let Some(clear_fn) = spec.bridge_config.clear_fn.as_deref() else {
666            return String::new();
667        };
668        let c_clear = format!("c.{}_{}", self.prefix, clear_fn);
669
670        let mut out = String::new();
671        out.push_str(&crate::template_env::render(
672            "clear_fn_doc.jinja",
673            minijinja::context! {
674                trait_name => spec.trait_def.name.as_str(),
675            },
676        ));
677        out.push_str(&crate::template_env::render(
678            "clear_fn_signature.jinja",
679            minijinja::context! {
680                clear_fn => clear_fn,
681            },
682        ));
683        out.push_str(&crate::template_env::render(
684            "clear_fn_body.jinja",
685            minijinja::context! {
686                c_clear => &c_clear,
687            },
688        ));
689        out.push_str("}\n");
690        out
691    }
692}
693
694#[cfg(test)]
695mod tests {
696    use super::*;
697    use alef_core::ir::{FieldDef, MethodDef, ParamDef, PrimitiveType, ReceiverKind, TypeRef};
698
699    fn make_trait_def(name: &str, methods: Vec<MethodDef>) -> TypeDef {
700        TypeDef {
701            name: name.to_string(),
702            rust_path: format!("demo::{name}"),
703            original_rust_path: String::new(),
704            fields: Vec::<FieldDef>::new(),
705            methods,
706            is_opaque: true,
707            is_clone: false,
708            is_copy: false,
709            is_trait: true,
710            has_default: false,
711            has_stripped_cfg_fields: false,
712            is_return_type: false,
713            serde_rename_all: None,
714            has_serde: false,
715            super_traits: vec![],
716            doc: String::new(),
717            cfg: None,
718            binding_excluded: false,
719            binding_exclusion_reason: None,
720        }
721    }
722
723    fn make_method(name: &str, params: Vec<ParamDef>, return_type: TypeRef, error_type: Option<&str>) -> MethodDef {
724        MethodDef {
725            name: name.to_string(),
726            params,
727            return_type,
728            is_async: false,
729            is_static: false,
730            error_type: error_type.map(|s| s.to_string()),
731            doc: String::new(),
732            receiver: Some(ReceiverKind::Ref),
733            sanitized: false,
734            trait_source: None,
735            returns_ref: false,
736            returns_cow: false,
737            return_newtype_wrapper: None,
738            has_default_impl: false,
739            binding_excluded: false,
740            binding_exclusion_reason: None,
741        }
742    }
743
744    fn make_param(name: &str, ty: TypeRef) -> ParamDef {
745        ParamDef {
746            name: name.to_string(),
747            ty,
748            optional: false,
749            default: None,
750            sanitized: false,
751            typed_default: None,
752            is_ref: false,
753            is_mut: false,
754            newtype_wrapper: None,
755            original_type: None,
756        }
757    }
758
759    fn make_bridge_cfg(trait_name: &str, super_trait: Option<&str>) -> TraitBridgeConfig {
760        TraitBridgeConfig {
761            trait_name: trait_name.to_string(),
762            super_trait: super_trait.map(|s| s.to_string()),
763            registry_getter: None,
764            register_fn: None,
765
766            unregister_fn: None,
767
768            clear_fn: None,
769            type_alias: None,
770            param_name: None,
771            register_extra_args: None,
772            exclude_languages: vec![],
773            bind_via: alef_core::config::BridgeBinding::FunctionParam,
774            options_type: None,
775            options_field: None,
776            context_type: None,
777            result_type: None,
778            ffi_skip_methods: Vec::new(),
779        }
780    }
781
782    #[test]
783    fn single_method_trait_emits_vtable_and_register() {
784        let trait_def = make_trait_def(
785            "Validator",
786            vec![make_method(
787                "validate",
788                vec![make_param("input", TypeRef::String)],
789                TypeRef::Primitive(PrimitiveType::Bool),
790                None,
791            )],
792        );
793        let bridge_cfg = make_bridge_cfg("Validator", None);
794
795        let mut out = String::new();
796        emit_trait_bridge("demo", &bridge_cfg, &trait_def, &mut out);
797
798        // Vtable struct
799        assert!(
800            out.contains("pub const IValidator = extern struct {"),
801            "missing vtable struct: {out}"
802        );
803        // Method slot present
804        assert!(out.contains("validate:"), "missing validate slot: {out}");
805        // user_data first arg
806        assert!(out.contains("user_data: ?*anyopaque"), "missing user_data: {out}");
807        // callconv(.C) present
808        assert!(out.contains("callconv(.C)"), "missing callconv: {out}");
809        // free_user_data slot
810        assert!(out.contains("free_user_data:"), "missing free_user_data: {out}");
811        // Registration shim
812        assert!(out.contains("pub fn register_validator("), "missing register fn: {out}");
813        assert!(out.contains("c.demo_register_validator("), "wrong C symbol: {out}");
814        // Unregistration shim
815        assert!(
816            out.contains("pub fn unregister_validator("),
817            "missing unregister fn: {out}"
818        );
819        assert!(
820            out.contains("c.demo_unregister_validator("),
821            "wrong unregister C symbol: {out}"
822        );
823        // No plugin lifecycle when no super_trait
824        assert!(
825            !out.contains("name_fn:"),
826            "should not emit name_fn without super_trait: {out}"
827        );
828    }
829
830    #[test]
831    fn emit_trait_bridge_emits_clear_fn_when_configured() {
832        let trait_def = make_trait_def(
833            "OcrBackend",
834            vec![make_method(
835                "process",
836                vec![make_param("input", TypeRef::String)],
837                TypeRef::String,
838                Some("OcrError"),
839            )],
840        );
841        let mut bridge_cfg = make_bridge_cfg("OcrBackend", Some("kreuzberg::plugins::Plugin"));
842        bridge_cfg.clear_fn = Some("clear_ocr_backends".to_string());
843
844        let mut out = String::new();
845        emit_trait_bridge("kreuzberg", &bridge_cfg, &trait_def, &mut out);
846
847        assert!(
848            out.contains("pub fn clear_ocr_backends(out_error: ?*?[*c]u8) i32"),
849            "missing clear_ocr_backends signature: {out}"
850        );
851        // C symbol uses the singular trait-snake suffix to match kreuzberg-ffi naming.
852        assert!(
853            out.contains("c.kreuzberg_clear_ocr_backend(out_error)"),
854            "wrong C symbol target for clear wrapper: {out}"
855        );
856        // Doc comment present.
857        assert!(
858            out.contains("/// Remove ALL registered `OcrBackend` plugins"),
859            "missing clear doc comment: {out}"
860        );
861    }
862
863    #[test]
864    fn emit_trait_bridge_omits_clear_fn_when_not_configured() {
865        let trait_def = make_trait_def(
866            "OcrBackend",
867            vec![make_method(
868                "process",
869                vec![make_param("input", TypeRef::String)],
870                TypeRef::String,
871                Some("OcrError"),
872            )],
873        );
874        let bridge_cfg = make_bridge_cfg("OcrBackend", Some("kreuzberg::plugins::Plugin"));
875        // clear_fn left as None.
876
877        let mut out = String::new();
878        emit_trait_bridge("kreuzberg", &bridge_cfg, &trait_def, &mut out);
879
880        assert!(
881            !out.contains("pub fn clear_"),
882            "should not emit any clear_* fn when clear_fn is None: {out}"
883        );
884    }
885
886    #[test]
887    fn multi_method_trait_with_super_trait_emits_lifecycle_slots() {
888        let trait_def = make_trait_def(
889            "OcrBackend",
890            vec![
891                make_method(
892                    "process_image",
893                    vec![
894                        make_param("image_bytes", TypeRef::Bytes),
895                        make_param("config", TypeRef::String),
896                    ],
897                    TypeRef::String,
898                    Some("OcrError"),
899                ),
900                make_method(
901                    "supports_language",
902                    vec![make_param("lang", TypeRef::String)],
903                    TypeRef::Primitive(PrimitiveType::Bool),
904                    None,
905                ),
906            ],
907        );
908        let bridge_cfg = make_bridge_cfg("OcrBackend", Some("kreuzberg::plugins::Plugin"));
909
910        let mut out = String::new();
911        emit_trait_bridge("kreuzberg", &bridge_cfg, &trait_def, &mut out);
912
913        // Struct name
914        assert!(
915            out.contains("pub const IOcrBackend = extern struct {"),
916            "missing vtable: {out}"
917        );
918        // Plugin lifecycle slots emitted
919        assert!(out.contains("name_fn:"), "missing name_fn: {out}");
920        assert!(out.contains("version_fn:"), "missing version_fn: {out}");
921        assert!(out.contains("initialize_fn:"), "missing initialize_fn: {out}");
922        assert!(out.contains("shutdown_fn:"), "missing shutdown_fn: {out}");
923        // Trait method slots
924        assert!(out.contains("process_image:"), "missing process_image slot: {out}");
925        assert!(
926            out.contains("supports_language:"),
927            "missing supports_language slot: {out}"
928        );
929        // Bytes param expands to ptr + len
930        assert!(out.contains("image_bytes_ptr:"), "missing bytes ptr expansion: {out}");
931        assert!(out.contains("image_bytes_len:"), "missing bytes len expansion: {out}");
932        // Fallible method gets out_error
933        assert!(
934            out.contains("out_error:"),
935            "missing out_error for fallible method: {out}"
936        );
937        // C symbols use kreuzberg prefix
938        assert!(
939            out.contains("c.kreuzberg_register_ocr_backend("),
940            "wrong register symbol: {out}"
941        );
942        assert!(
943            out.contains("c.kreuzberg_unregister_ocr_backend("),
944            "wrong unregister symbol: {out}"
945        );
946        // Registration shim signature
947        assert!(
948            out.contains("pub fn register_ocr_backend("),
949            "missing register_ocr_backend fn: {out}"
950        );
951    }
952
953    // -----------------------------------------------------------------
954    // make_*_vtable tests
955    // -----------------------------------------------------------------
956
957    #[test]
958    fn make_vtable_emits_comptime_function_and_thunk() {
959        let trait_def = make_trait_def(
960            "Validator",
961            vec![make_method(
962                "validate",
963                vec![make_param("input", TypeRef::String)],
964                TypeRef::Primitive(PrimitiveType::Bool),
965                None,
966            )],
967        );
968        let bridge_cfg = make_bridge_cfg("Validator", None);
969
970        let mut out = String::new();
971        emit_trait_bridge("demo", &bridge_cfg, &trait_def, &mut out);
972
973        // Helper function declaration
974        assert!(
975            out.contains("pub fn make_validator_vtable(comptime T: type, instance: *T)"),
976            "missing make_validator_vtable: {out}"
977        );
978        // Returns the vtable type
979        assert!(out.contains("IValidator{"), "missing vtable literal: {out}");
980        // Thunk casts user_data
981        assert!(out.contains("@ptrCast(@alignCast(ud))"), "missing @ptrCast cast: {out}");
982        // callconv(.C) in thunk
983        assert!(out.contains("callconv(.C)"), "missing callconv(.C) in thunk: {out}");
984        // validate thunk field
985        assert!(out.contains(".validate ="), "missing .validate thunk field: {out}");
986        // free_user_data thunk
987        assert!(
988            out.contains(".free_user_data ="),
989            "missing .free_user_data thunk: {out}"
990        );
991        // No lifecycle stubs without super_trait
992        assert!(
993            !out.contains(".name_fn ="),
994            "must not emit .name_fn without super_trait: {out}"
995        );
996    }
997
998    #[test]
999    fn make_vtable_with_super_trait_emits_lifecycle_stubs() {
1000        let trait_def = make_trait_def("OcrBackend", vec![]);
1001        let bridge_cfg = make_bridge_cfg("OcrBackend", Some("kreuzberg::Plugin"));
1002
1003        let mut out = String::new();
1004        emit_trait_bridge("kreuzberg", &bridge_cfg, &trait_def, &mut out);
1005
1006        assert!(
1007            out.contains("pub fn make_ocr_backend_vtable(comptime T: type, instance: *T)"),
1008            "missing make_ocr_backend_vtable: {out}"
1009        );
1010        assert!(out.contains(".name_fn ="), "missing .name_fn stub: {out}");
1011        assert!(out.contains(".version_fn ="), "missing .version_fn stub: {out}");
1012        assert!(out.contains(".initialize_fn ="), "missing .initialize_fn stub: {out}");
1013        assert!(out.contains(".shutdown_fn ="), "missing .shutdown_fn stub: {out}");
1014    }
1015
1016    #[test]
1017    fn make_vtable_bytes_param_reconstructs_slice_in_thunk() {
1018        let trait_def = make_trait_def(
1019            "Processor",
1020            vec![make_method(
1021                "process",
1022                vec![make_param("data", TypeRef::Bytes)],
1023                TypeRef::Unit,
1024                None,
1025            )],
1026        );
1027        let bridge_cfg = make_bridge_cfg("Processor", None);
1028
1029        let mut out = String::new();
1030        emit_trait_bridge("demo", &bridge_cfg, &trait_def, &mut out);
1031
1032        // Thunk receives ptr+len params
1033        assert!(out.contains("data_ptr: [*c]const u8"), "missing data_ptr param: {out}");
1034        assert!(out.contains("data_len: usize"), "missing data_len param: {out}");
1035        // Thunk reconstructs slice
1036        assert!(
1037            out.contains("data_ptr[0..data_len]"),
1038            "thunk must reconstruct slice from ptr+len: {out}"
1039        );
1040        // Thunk calls self.process with the slice
1041        assert!(
1042            out.contains("self.process(data_slice)"),
1043            "thunk must call self.process: {out}"
1044        );
1045    }
1046
1047    #[test]
1048    fn make_vtable_fallible_method_returns_i32_error_code() {
1049        let trait_def = make_trait_def(
1050            "Parser",
1051            vec![make_method("parse", vec![], TypeRef::Unit, Some("ParseError"))],
1052        );
1053        let bridge_cfg = make_bridge_cfg("Parser", None);
1054
1055        let mut out = String::new();
1056        emit_trait_bridge("demo", &bridge_cfg, &trait_def, &mut out);
1057
1058        // Thunk returns i32 (fallible → i32 return)
1059        assert!(
1060            out.contains("callconv(.C) i32"),
1061            "fallible thunk must return i32: {out}"
1062        );
1063        // Returns 0 on success
1064        assert!(out.contains("return 0;"), "must return 0 on success: {out}");
1065        // Returns 1 on error
1066        assert!(out.contains("return 1;"), "must return 1 on error: {out}");
1067        // Error branch writes to out_error
1068        assert!(out.contains("out_error"), "must write to out_error: {out}");
1069    }
1070
1071    #[test]
1072    fn make_vtable_primitive_return_passes_through() {
1073        let trait_def = make_trait_def(
1074            "Counter",
1075            vec![make_method(
1076                "count",
1077                vec![],
1078                TypeRef::Primitive(PrimitiveType::I32),
1079                None,
1080            )],
1081        );
1082        let bridge_cfg = make_bridge_cfg("demo", None);
1083
1084        let mut out = String::new();
1085        emit_trait_bridge("demo", &bridge_cfg, &trait_def, &mut out);
1086
1087        // Infallible primitive method: thunk returns the value directly
1088        assert!(
1089            out.contains("return self.count()"),
1090            "primitive return must be forwarded directly: {out}"
1091        );
1092    }
1093
1094    // -----------------------------------------------------------------
1095    // ZigTraitBridgeGenerator tests
1096    // -----------------------------------------------------------------
1097
1098    fn make_spec<'a>(trait_def: &'a TypeDef, bridge_cfg: &'a TraitBridgeConfig) -> TraitBridgeSpec<'a> {
1099        use alef_codegen::generators::trait_bridge::TraitBridgeSpec;
1100        use std::collections::HashMap;
1101        TraitBridgeSpec {
1102            trait_def,
1103            bridge_config: bridge_cfg,
1104            core_import: "kreuzberg",
1105            wrapper_prefix: "Zig",
1106            type_paths: HashMap::new(),
1107            error_type: "KreuzbergError".to_string(),
1108            error_constructor: "KreuzbergError::msg({msg})".to_string(),
1109        }
1110    }
1111
1112    #[test]
1113    fn gen_unregistration_fn_emits_wrapper_when_configured() {
1114        let trait_def = make_trait_def("OcrBackend", vec![]);
1115        let mut bridge_cfg = make_bridge_cfg("OcrBackend", None);
1116        bridge_cfg.unregister_fn = Some("unregister_ocr_backend".to_string());
1117
1118        let generator = ZigTraitBridgeGenerator::new("kreuzberg");
1119        let spec = make_spec(&trait_def, &bridge_cfg);
1120        let out = generator.gen_unregistration_fn(&spec);
1121
1122        assert!(!out.is_empty(), "expected non-empty output when unregister_fn is set");
1123        assert!(
1124            out.contains("pub fn unregister_ocr_backend("),
1125            "wrong function name: {out}"
1126        );
1127        assert!(
1128            out.contains("c.kreuzberg_unregister_ocr_backend("),
1129            "wrong C symbol: {out}"
1130        );
1131        assert!(
1132            out.contains("out_error: ?*?[*c]u8") || out.contains("out_error"),
1133            "missing out_error param: {out}"
1134        );
1135        assert!(out.contains("return "), "missing return statement: {out}");
1136        assert!(out.ends_with("}\n"), "missing closing brace: {out}");
1137    }
1138
1139    #[test]
1140    fn gen_unregistration_fn_returns_empty_when_not_configured() {
1141        let trait_def = make_trait_def("OcrBackend", vec![]);
1142        let bridge_cfg = make_bridge_cfg("OcrBackend", None); // unregister_fn is None
1143
1144        let generator = ZigTraitBridgeGenerator::new("kreuzberg");
1145        let spec = make_spec(&trait_def, &bridge_cfg);
1146        let out = generator.gen_unregistration_fn(&spec);
1147
1148        assert!(
1149            out.is_empty(),
1150            "expected empty output when unregister_fn is None, got: {out}"
1151        );
1152    }
1153
1154    #[test]
1155    fn gen_clear_fn_emits_wrapper_when_configured() {
1156        let trait_def = make_trait_def("OcrBackend", vec![]);
1157        let mut bridge_cfg = make_bridge_cfg("OcrBackend", None);
1158        bridge_cfg.clear_fn = Some("clear_ocr_backends".to_string());
1159
1160        let generator = ZigTraitBridgeGenerator::new("kreuzberg");
1161        let spec = make_spec(&trait_def, &bridge_cfg);
1162        let out = generator.gen_clear_fn(&spec);
1163
1164        assert!(!out.is_empty(), "expected non-empty output when clear_fn is set");
1165        assert!(out.contains("pub fn clear_ocr_backends("), "wrong function name: {out}");
1166        assert!(out.contains("c.kreuzberg_clear_ocr_backends("), "wrong C symbol: {out}");
1167        assert!(
1168            out.contains("out_error: ?*?[*c]u8") || out.contains("out_error"),
1169            "missing out_error param: {out}"
1170        );
1171        assert!(out.contains("return "), "missing return statement: {out}");
1172        assert!(out.ends_with("}\n"), "missing closing brace: {out}");
1173    }
1174
1175    #[test]
1176    fn gen_clear_fn_returns_empty_when_not_configured() {
1177        let trait_def = make_trait_def("OcrBackend", vec![]);
1178        let bridge_cfg = make_bridge_cfg("OcrBackend", None); // clear_fn is None
1179
1180        let generator = ZigTraitBridgeGenerator::new("kreuzberg");
1181        let spec = make_spec(&trait_def, &bridge_cfg);
1182        let out = generator.gen_clear_fn(&spec);
1183
1184        assert!(
1185            out.is_empty(),
1186            "expected empty output when clear_fn is None, got: {out}"
1187        );
1188    }
1189
1190    #[test]
1191    fn gen_unregistration_fn_uses_snake_case_function_name_verbatim() {
1192        // The configured `unregister_fn` name is used as-is (not re-derived from the trait).
1193        let trait_def = make_trait_def("DocumentExtractor", vec![]);
1194        let mut bridge_cfg = make_bridge_cfg("DocumentExtractor", None);
1195        bridge_cfg.unregister_fn = Some("unregister_extractor".to_string());
1196
1197        let generator = ZigTraitBridgeGenerator::new("demo");
1198        let spec = make_spec(&trait_def, &bridge_cfg);
1199        let out = generator.gen_unregistration_fn(&spec);
1200
1201        assert!(
1202            out.contains("pub fn unregister_extractor("),
1203            "must use configured fn name verbatim: {out}"
1204        );
1205        assert!(
1206            out.contains("c.demo_unregister_extractor("),
1207            "must use configured fn name in C symbol: {out}"
1208        );
1209    }
1210
1211    #[test]
1212    fn gen_clear_fn_uses_configured_fn_name_verbatim() {
1213        let trait_def = make_trait_def("DocumentExtractor", vec![]);
1214        let mut bridge_cfg = make_bridge_cfg("DocumentExtractor", None);
1215        bridge_cfg.clear_fn = Some("clear_all_extractors".to_string());
1216
1217        let generator = ZigTraitBridgeGenerator::new("demo");
1218        let spec = make_spec(&trait_def, &bridge_cfg);
1219        let out = generator.gen_clear_fn(&spec);
1220
1221        assert!(
1222            out.contains("pub fn clear_all_extractors("),
1223            "must use configured fn name verbatim: {out}"
1224        );
1225        assert!(
1226            out.contains("c.demo_clear_all_extractors("),
1227            "must use configured fn name in C symbol: {out}"
1228        );
1229    }
1230}