use alef_core::config::TraitBridgeConfig;
use alef_core::ir::{MethodDef, TypeDef, TypeRef};
use heck::ToSnakeCase;
fn vtable_param_type(ty: &TypeRef) -> &'static str {
match ty {
TypeRef::Primitive(p) => {
use alef_core::ir::PrimitiveType::*;
match p {
Bool => "i32",
U8 => "u8",
U16 => "u16",
U32 => "u32",
U64 => "u64",
I8 => "i8",
I16 => "i16",
I32 => "i32",
I64 => "i64",
F32 => "f32",
F64 => "f64",
Usize => "usize",
Isize => "isize",
}
}
TypeRef::Unit => "void",
TypeRef::Duration => "i64",
_ => "[*c]const u8",
}
}
fn vtable_return_type(method: &MethodDef) -> String {
if method.error_type.is_some() {
"i32".to_string()
} else {
vtable_param_type(&method.return_type).to_string()
}
}
fn trait_snake(trait_name: &str) -> String {
trait_name.to_snake_case()
}
fn vtable_c_params(method: &MethodDef) -> Vec<(String, String)> {
let mut params = vec![("ud".to_string(), "?*anyopaque".to_string())];
for p in &method.params {
if matches!(p.ty, TypeRef::Bytes) {
params.push((format!("{}_ptr", p.name), "[*c]const u8".to_string()));
params.push((format!("{}_len", p.name), "usize".to_string()));
} else {
params.push((p.name.clone(), vtable_param_type(&p.ty).to_string()));
}
}
if method.error_type.is_some() {
if !matches!(method.return_type, TypeRef::Unit) {
params.push(("out_result".to_string(), "?*?[*c]u8".to_string()));
}
params.push(("out_error".to_string(), "?*?[*c]u8".to_string()));
} else if !matches!(method.return_type, TypeRef::Unit) {
params.push(("out_result".to_string(), "?*?[*c]u8".to_string()));
}
params
}
pub fn emit_make_vtable(trait_name: &str, has_super_trait: bool, trait_def: &TypeDef, out: &mut String) {
let snake = trait_snake(trait_name);
out.push_str(&crate::template_env::render(
"vtable_header_doc.jinja",
minijinja::context! {
trait_name => trait_name,
snake => &snake,
},
));
out.push_str(&crate::template_env::render(
"vtable_impl_method.jinja",
minijinja::context! {
snake => &snake,
trait_name => trait_name,
},
));
out.push_str(&crate::template_env::render(
"vtable_make_fn_header.jinja",
minijinja::context! {
trait_name => trait_name,
},
));
if has_super_trait {
out.push_str(&crate::template_env::render(
"vtable_field_name_fn.jinja",
minijinja::context! {},
));
out.push_str(&crate::template_env::render(
"vtable_field_version_fn.jinja",
minijinja::context! {},
));
out.push_str(&crate::template_env::render(
"vtable_field_initialize_fn.jinja",
minijinja::context! {},
));
out.push_str(&crate::template_env::render(
"vtable_field_shutdown_fn.jinja",
minijinja::context! {},
));
}
for method in &trait_def.methods {
let method_snake = method.name.to_snake_case();
let c_params = vtable_c_params(method);
let ret = vtable_return_type(method);
let params_str = c_params
.iter()
.map(|(name, ty)| format!("{name}: {ty}"))
.collect::<Vec<_>>()
.join(", ");
out.push_str(&crate::template_env::render(
"vtable_instance_field.jinja",
minijinja::context! {
method_snake => &method_snake,
params_str => ¶ms_str,
ret => &ret,
},
));
out.push_str(" const self: *T = @ptrCast(@alignCast(ud));\n");
let mut call_args: Vec<String> = Vec::new();
for p in &method.params {
if matches!(p.ty, TypeRef::Bytes) {
out.push_str(&crate::template_env::render(
"thunk_bytes_slice.jinja",
minijinja::context! {
slice_name => format!("{}_slice", p.name),
ptr_name => format!("{}_ptr", p.name),
len_name => format!("{}_len", p.name),
},
));
call_args.push(format!("{}_slice", p.name));
} else {
call_args.push(p.name.clone());
}
}
let args_str = call_args.join(", ");
let ok_binding = if method.params.iter().any(|p| p.name == "value") {
"ok_value"
} else {
"value"
};
if method.error_type.is_some() {
let has_result_out = !matches!(method.return_type, TypeRef::Unit);
out.push_str(&crate::template_env::render(
"thunk_fn_signature.jinja",
minijinja::context! {
method_snake => &method_snake,
args_str => &args_str,
ok_binding => &ok_binding,
},
));
let mut success_path_diverges = false;
if has_result_out {
match &method.return_type {
TypeRef::Primitive(_) | TypeRef::Unit => {
out.push_str(&crate::template_env::render(
"thunk_result_assign.jinja",
minijinja::context! {
ok_binding => &ok_binding,
},
));
}
_ => {
out.push_str(&crate::template_env::render(
"thunk_if_fallible.jinja",
minijinja::context! {
ok_binding => &ok_binding,
},
));
success_path_diverges = true;
}
}
} else {
out.push_str(&crate::template_env::render(
"thunk_if_ok_result.jinja",
minijinja::context! {
ok_binding => &ok_binding,
},
));
}
if !success_path_diverges {
out.push_str(" return 0;\n");
}
out.push_str(" } else |err| {\n");
out.push_str(" _ = err;\n");
out.push_str(" if (out_error) |ptr| ptr.* = null; // caller checks error code\n");
out.push_str(" return 1;\n");
out.push_str(" }\n");
} else {
if !matches!(method.return_type, TypeRef::Unit) {
out.push_str(" _ = out_result;\n");
}
match &method.return_type {
TypeRef::Unit => {
out.push_str(&crate::template_env::render(
"thunk_if_error.jinja",
minijinja::context! {
method_snake => &method_snake,
args_str => &args_str,
},
));
}
TypeRef::Primitive(_) => {
out.push_str(&crate::template_env::render(
"thunk_infallible_return.jinja",
minijinja::context! {
method_snake => &method_snake,
args_str => &args_str,
},
));
}
_ => {
out.push_str(&crate::template_env::render(
"thunk_infallible_return.jinja",
minijinja::context! {
method_snake => &method_snake,
args_str => &args_str,
},
));
}
}
}
out.push_str(" }\n");
out.push_str(" }.thunk,\n");
out.push('\n');
}
out.push_str(&crate::template_env::render(
"vtable_free_user_data.jinja",
minijinja::context! {},
));
out.push_str(" };\n");
out.push_str("}\n");
}
pub fn emit_trait_bridge(prefix: &str, bridge_cfg: &TraitBridgeConfig, trait_def: &TypeDef, out: &mut String) {
let trait_name = &trait_def.name;
let snake = trait_snake(trait_name);
let has_super_trait = bridge_cfg.super_trait.is_some();
out.push_str(&crate::template_env::render(
"trait_vtable_header.jinja",
minijinja::context! {
trait_name => trait_name,
snake => &snake,
},
));
out.push_str(&crate::template_env::render(
"trait_struct_header.jinja",
minijinja::context! {
trait_name => trait_name,
},
));
if has_super_trait {
out.push_str(" /// Return the plugin name into `out_name` (heap-allocated, caller frees).\n");
out.push_str(
" name_fn: ?*const fn (user_data: ?*anyopaque, out_name: ?*?[*c]u8) callconv(.C) void = null,\n",
);
out.push('\n');
out.push_str(" /// Return the plugin version into `out_version` (heap-allocated, caller frees).\n");
out.push_str(
" version_fn: ?*const fn (user_data: ?*anyopaque, out_version: ?*?[*c]u8) callconv(.C) void = null,\n",
);
out.push('\n');
out.push_str(" /// Initialise the plugin; return 0 on success, non-zero on error.\n");
out.push_str(
" initialize_fn: ?*const fn (user_data: ?*anyopaque, out_error: ?*?[*c]u8) callconv(.C) i32 = null,\n",
);
out.push('\n');
out.push_str(" /// Shut down the plugin; return 0 on success, non-zero on error.\n");
out.push_str(
" shutdown_fn: ?*const fn (user_data: ?*anyopaque, out_error: ?*?[*c]u8) callconv(.C) i32 = null,\n",
);
out.push('\n');
}
for method in &trait_def.methods {
if !method.doc.is_empty() {
for line in method.doc.lines() {
out.push_str(&crate::template_env::render(
"trait_method_doc.jinja",
minijinja::context! {
line => line,
},
));
}
}
let ret = vtable_return_type(method);
let method_snake = method.name.to_snake_case();
let mut params = vec!["user_data: ?*anyopaque".to_string()];
for p in &method.params {
let ty = vtable_param_type(&p.ty);
if matches!(p.ty, TypeRef::Bytes) {
params.push(format!("{}_ptr: [*c]const u8", p.name));
params.push(format!("{}_len: usize", p.name));
} else {
params.push(format!("{}: {ty}", p.name));
}
}
if method.error_type.is_some() {
if !matches!(method.return_type, TypeRef::Unit) {
params.push("out_result: ?*?[*c]u8".to_string());
}
params.push("out_error: ?*?[*c]u8".to_string());
} else if !matches!(method.return_type, TypeRef::Unit) {
params.push("out_result: ?*?[*c]u8".to_string());
}
let params_str = params.join(", ");
out.push_str(&crate::template_env::render(
"trait_method_signature.jinja",
minijinja::context! {
method_snake => &method_snake,
params_str => ¶ms_str,
ret => &ret,
},
));
}
out.push_str(" /// Called by the Rust runtime when the bridge is dropped.\n");
out.push_str(" /// Use this to release any Zig-side state held via `user_data`.\n");
out.push_str(" free_user_data: ?*const fn (user_data: ?*anyopaque) callconv(.C) void = null,\n");
out.push_str("};\n");
out.push('\n');
let c_register = format!("c.{prefix}_register_{snake}");
let c_unregister = format!("c.{prefix}_unregister_{snake}");
out.push_str(&crate::template_env::render(
"register_fn_doc1.jinja",
minijinja::context! {
trait_name => trait_name,
snake => &snake,
},
));
out.push_str(&crate::template_env::render(
"register_fn_signature.jinja",
minijinja::context! {
snake => &snake,
trait_name => trait_name,
},
));
out.push_str(&crate::template_env::render(
"register_fn_body.jinja",
minijinja::context! {
c_register => &c_register,
},
));
out.push_str("}\n");
out.push('\n');
out.push_str(&crate::template_env::render(
"unregister_fn_doc.jinja",
minijinja::context! {
trait_name => trait_name,
},
));
out.push_str(&crate::template_env::render(
"unregister_fn_signature.jinja",
minijinja::context! {
snake => &snake,
},
));
out.push_str(&crate::template_env::render(
"unregister_fn_body.jinja",
minijinja::context! {
c_unregister => &c_unregister,
},
));
out.push_str("}\n");
out.push('\n');
emit_make_vtable(trait_name, has_super_trait, trait_def, out);
}
#[cfg(test)]
mod tests {
use super::*;
use alef_core::ir::{FieldDef, MethodDef, ParamDef, PrimitiveType, ReceiverKind, TypeRef};
fn make_trait_def(name: &str, methods: Vec<MethodDef>) -> TypeDef {
TypeDef {
name: name.to_string(),
rust_path: format!("demo::{name}"),
original_rust_path: String::new(),
fields: Vec::<FieldDef>::new(),
methods,
is_opaque: true,
is_clone: false,
is_copy: false,
is_trait: true,
has_default: false,
has_stripped_cfg_fields: false,
is_return_type: false,
serde_rename_all: None,
has_serde: false,
super_traits: vec![],
doc: String::new(),
cfg: None,
}
}
fn make_method(name: &str, params: Vec<ParamDef>, return_type: TypeRef, error_type: Option<&str>) -> MethodDef {
MethodDef {
name: name.to_string(),
params,
return_type,
is_async: false,
is_static: false,
error_type: error_type.map(|s| s.to_string()),
doc: String::new(),
receiver: Some(ReceiverKind::Ref),
sanitized: false,
trait_source: None,
returns_ref: false,
returns_cow: false,
return_newtype_wrapper: None,
has_default_impl: false,
}
}
fn make_param(name: &str, ty: TypeRef) -> ParamDef {
ParamDef {
name: name.to_string(),
ty,
optional: false,
default: None,
sanitized: false,
typed_default: None,
is_ref: false,
is_mut: false,
newtype_wrapper: None,
original_type: None,
}
}
fn make_bridge_cfg(trait_name: &str, super_trait: Option<&str>) -> TraitBridgeConfig {
TraitBridgeConfig {
trait_name: trait_name.to_string(),
super_trait: super_trait.map(|s| s.to_string()),
registry_getter: None,
register_fn: None,
unregister_fn: None,
clear_fn: None,
type_alias: None,
param_name: None,
register_extra_args: None,
exclude_languages: vec![],
bind_via: alef_core::config::BridgeBinding::FunctionParam,
options_type: None,
options_field: None,
}
}
#[test]
fn single_method_trait_emits_vtable_and_register() {
let trait_def = make_trait_def(
"Validator",
vec![make_method(
"validate",
vec![make_param("input", TypeRef::String)],
TypeRef::Primitive(PrimitiveType::Bool),
None,
)],
);
let bridge_cfg = make_bridge_cfg("Validator", None);
let mut out = String::new();
emit_trait_bridge("demo", &bridge_cfg, &trait_def, &mut out);
assert!(
out.contains("pub const IValidator = extern struct {"),
"missing vtable struct: {out}"
);
assert!(out.contains("validate:"), "missing validate slot: {out}");
assert!(out.contains("user_data: ?*anyopaque"), "missing user_data: {out}");
assert!(out.contains("callconv(.C)"), "missing callconv: {out}");
assert!(out.contains("free_user_data:"), "missing free_user_data: {out}");
assert!(out.contains("pub fn register_validator("), "missing register fn: {out}");
assert!(out.contains("c.demo_register_validator("), "wrong C symbol: {out}");
assert!(
out.contains("pub fn unregister_validator("),
"missing unregister fn: {out}"
);
assert!(
out.contains("c.demo_unregister_validator("),
"wrong unregister C symbol: {out}"
);
assert!(
!out.contains("name_fn:"),
"should not emit name_fn without super_trait: {out}"
);
}
#[test]
fn multi_method_trait_with_super_trait_emits_lifecycle_slots() {
let trait_def = make_trait_def(
"OcrBackend",
vec![
make_method(
"process_image",
vec![
make_param("image_bytes", TypeRef::Bytes),
make_param("config", TypeRef::String),
],
TypeRef::String,
Some("OcrError"),
),
make_method(
"supports_language",
vec![make_param("lang", TypeRef::String)],
TypeRef::Primitive(PrimitiveType::Bool),
None,
),
],
);
let bridge_cfg = make_bridge_cfg("OcrBackend", Some("kreuzberg::plugins::Plugin"));
let mut out = String::new();
emit_trait_bridge("kreuzberg", &bridge_cfg, &trait_def, &mut out);
assert!(
out.contains("pub const IOcrBackend = extern struct {"),
"missing vtable: {out}"
);
assert!(out.contains("name_fn:"), "missing name_fn: {out}");
assert!(out.contains("version_fn:"), "missing version_fn: {out}");
assert!(out.contains("initialize_fn:"), "missing initialize_fn: {out}");
assert!(out.contains("shutdown_fn:"), "missing shutdown_fn: {out}");
assert!(out.contains("process_image:"), "missing process_image slot: {out}");
assert!(
out.contains("supports_language:"),
"missing supports_language slot: {out}"
);
assert!(out.contains("image_bytes_ptr:"), "missing bytes ptr expansion: {out}");
assert!(out.contains("image_bytes_len:"), "missing bytes len expansion: {out}");
assert!(
out.contains("out_error:"),
"missing out_error for fallible method: {out}"
);
assert!(
out.contains("c.kreuzberg_register_ocr_backend("),
"wrong register symbol: {out}"
);
assert!(
out.contains("c.kreuzberg_unregister_ocr_backend("),
"wrong unregister symbol: {out}"
);
assert!(
out.contains("pub fn register_ocr_backend("),
"missing register_ocr_backend fn: {out}"
);
}
#[test]
fn make_vtable_emits_comptime_function_and_thunk() {
let trait_def = make_trait_def(
"Validator",
vec![make_method(
"validate",
vec![make_param("input", TypeRef::String)],
TypeRef::Primitive(PrimitiveType::Bool),
None,
)],
);
let bridge_cfg = make_bridge_cfg("Validator", None);
let mut out = String::new();
emit_trait_bridge("demo", &bridge_cfg, &trait_def, &mut out);
assert!(
out.contains("pub fn make_validator_vtable(comptime T: type, instance: *T)"),
"missing make_validator_vtable: {out}"
);
assert!(out.contains("IValidator{"), "missing vtable literal: {out}");
assert!(out.contains("@ptrCast(@alignCast(ud))"), "missing @ptrCast cast: {out}");
assert!(out.contains("callconv(.C)"), "missing callconv(.C) in thunk: {out}");
assert!(out.contains(".validate ="), "missing .validate thunk field: {out}");
assert!(
out.contains(".free_user_data ="),
"missing .free_user_data thunk: {out}"
);
assert!(
!out.contains(".name_fn ="),
"must not emit .name_fn without super_trait: {out}"
);
}
#[test]
fn make_vtable_with_super_trait_emits_lifecycle_stubs() {
let trait_def = make_trait_def("OcrBackend", vec![]);
let bridge_cfg = make_bridge_cfg("OcrBackend", Some("kreuzberg::Plugin"));
let mut out = String::new();
emit_trait_bridge("kreuzberg", &bridge_cfg, &trait_def, &mut out);
assert!(
out.contains("pub fn make_ocr_backend_vtable(comptime T: type, instance: *T)"),
"missing make_ocr_backend_vtable: {out}"
);
assert!(out.contains(".name_fn ="), "missing .name_fn stub: {out}");
assert!(out.contains(".version_fn ="), "missing .version_fn stub: {out}");
assert!(out.contains(".initialize_fn ="), "missing .initialize_fn stub: {out}");
assert!(out.contains(".shutdown_fn ="), "missing .shutdown_fn stub: {out}");
}
#[test]
fn make_vtable_bytes_param_reconstructs_slice_in_thunk() {
let trait_def = make_trait_def(
"Processor",
vec![make_method(
"process",
vec![make_param("data", TypeRef::Bytes)],
TypeRef::Unit,
None,
)],
);
let bridge_cfg = make_bridge_cfg("Processor", None);
let mut out = String::new();
emit_trait_bridge("demo", &bridge_cfg, &trait_def, &mut out);
assert!(out.contains("data_ptr: [*c]const u8"), "missing data_ptr param: {out}");
assert!(out.contains("data_len: usize"), "missing data_len param: {out}");
assert!(
out.contains("data_ptr[0..data_len]"),
"thunk must reconstruct slice from ptr+len: {out}"
);
assert!(
out.contains("self.process(data_slice)"),
"thunk must call self.process: {out}"
);
}
#[test]
fn make_vtable_fallible_method_returns_i32_error_code() {
let trait_def = make_trait_def(
"Parser",
vec![make_method("parse", vec![], TypeRef::Unit, Some("ParseError"))],
);
let bridge_cfg = make_bridge_cfg("Parser", None);
let mut out = String::new();
emit_trait_bridge("demo", &bridge_cfg, &trait_def, &mut out);
assert!(
out.contains("callconv(.C) i32"),
"fallible thunk must return i32: {out}"
);
assert!(out.contains("return 0;"), "must return 0 on success: {out}");
assert!(out.contains("return 1;"), "must return 1 on error: {out}");
assert!(out.contains("out_error"), "must write to out_error: {out}");
}
#[test]
fn make_vtable_primitive_return_passes_through() {
let trait_def = make_trait_def(
"Counter",
vec![make_method(
"count",
vec![],
TypeRef::Primitive(PrimitiveType::I32),
None,
)],
);
let bridge_cfg = make_bridge_cfg("demo", None);
let mut out = String::new();
emit_trait_bridge("demo", &bridge_cfg, &trait_def, &mut out);
assert!(
out.contains("return self.count()"),
"primitive return must be forwarded directly: {out}"
);
}
}