use crate::backends::go::type_map::go_type;
use crate::core::config::TraitBridgeConfig;
use crate::core::hash::{self, CommentStyle};
use crate::core::ir::{ApiSurface, MethodDef, TypeDef, TypeRef};
use heck::ToPascalCase;
use std::collections::HashSet;
fn substitute_excluded_types(ty: &TypeRef, excluded: &HashSet<&str>) -> TypeRef {
match ty {
TypeRef::Named(name) if excluded.contains(name.as_str()) => TypeRef::Json,
TypeRef::Optional(inner) => TypeRef::Optional(Box::new(substitute_excluded_types(inner, excluded))),
TypeRef::Vec(inner) => TypeRef::Vec(Box::new(substitute_excluded_types(inner, excluded))),
TypeRef::Map(k, v) => TypeRef::Map(
Box::new(substitute_excluded_types(k, excluded)),
Box::new(substitute_excluded_types(v, excluded)),
),
other => other.clone(),
}
}
fn method_with_excluded_substituted(method: &MethodDef, excluded: &HashSet<&str>) -> MethodDef {
let mut m = method.clone();
for p in &mut m.params {
p.ty = substitute_excluded_types(&p.ty, excluded);
}
m.return_type = substitute_excluded_types(&m.return_type, excluded);
m
}
#[allow(clippy::too_many_arguments)]
pub fn gen_trait_bridges_file(
api: &ApiSurface,
config: &crate::core::config::ResolvedCrateConfig,
pkg_name: &str,
ffi_prefix: &str,
ffi_header: &str,
ffi_crate_dir: &str,
to_root: &str,
crate_name: &str,
) -> String {
let mut out = String::with_capacity(16_384);
let excluded_named_types: HashSet<&str> = api
.excluded_type_paths
.keys()
.map(|s| s.as_str())
.chain(api.types.iter().filter(|t| t.binding_excluded).map(|t| t.name.as_str()))
.collect();
out.push_str(&hash::header(CommentStyle::DoubleSlash));
out.push_str(&crate::backends::go::template_env::render(
"package_and_cgo.jinja",
minijinja::context! {
pkg_name => pkg_name,
to_root => to_root,
ffi_crate_dir => ffi_crate_dir,
ffi_header => ffi_header,
},
));
out.push('\n');
for bridge_cfg in &config.trait_bridges {
if let Some(trait_def) = api.types.iter().find(|t| t.name == bridge_cfg.trait_name) {
let pascal = bridge_cfg.trait_name.to_pascal_case();
for method in trait_def
.methods
.iter()
.filter(|m| !bridge_cfg.ffi_skip_methods.contains(&m.name))
{
let export_name = format!("go{}{}", &pascal, method.name.to_pascal_case());
let method_substituted = method_with_excluded_substituted(method, &excluded_named_types);
let c_sig = c_trampoline_signature(&export_name, &method_substituted);
out.push_str(&crate::backends::go::template_env::render(
"extern_trampoline_decl.jinja",
minijinja::context! {
export_name => export_name,
c_sig => c_sig,
},
));
}
out.push_str(&crate::backends::go::template_env::render(
"plugin_trampoline_decl.jinja",
minijinja::context! {
pascal => pascal.clone(),
method => "Name",
},
));
out.push_str(&crate::backends::go::template_env::render(
"plugin_trampoline_decl.jinja",
minijinja::context! {
pascal => pascal.clone(),
method => "Version",
},
));
out.push_str(&crate::backends::go::template_env::render(
"plugin_trampoline_decl.jinja",
minijinja::context! {
pascal => pascal.clone(),
method => "Initialize",
},
));
out.push_str(&crate::backends::go::template_env::render(
"plugin_trampoline_decl.jinja",
minijinja::context! {
pascal => pascal.clone(),
method => "Shutdown",
},
));
out.push_str(&crate::backends::go::template_env::render(
"plugin_free_user_data_extern.jinja",
minijinja::context! {
pascal => &pascal,
},
));
}
}
out.push_str("*/\n");
out.push_str("import \"C\"\n");
out.push('\n');
out.push_str("import (\n");
out.push_str("\t\"encoding/base64\"\n");
out.push_str("\t\"encoding/json\"\n");
out.push_str("\t\"fmt\"\n");
out.push_str("\t\"runtime/cgo\"\n");
out.push_str("\t\"unsafe\"\n");
out.push_str(")\n");
out.push('\n');
for bridge_cfg in &config.trait_bridges {
if bridge_cfg.exclude_languages.iter().any(|lang| lang == "go") {
continue;
}
if let Some(trait_def) = api.types.iter().find(|t| t.name == bridge_cfg.trait_name) {
gen_trait_bridge(
&mut out,
trait_def,
bridge_cfg,
ffi_prefix,
crate_name,
&excluded_named_types,
);
out.push('\n');
}
}
out
}
fn gen_trait_bridge(
out: &mut String,
trait_def: &TypeDef,
bridge_cfg: &TraitBridgeConfig,
ffi_prefix: &str,
crate_name: &str,
excluded_named_types: &HashSet<&str>,
) {
let trait_name = &trait_def.name;
let trait_snake = heck::AsSnakeCase(trait_name).to_string();
let trait_pascal = trait_name.to_pascal_case();
let crate_normalized = crate_name.replace('-', "_");
let crate_upper = crate_normalized.to_uppercase();
let crate_pascal = crate_normalized.to_pascal_case();
let c_vtable_struct = format!("{}{}{}{}", crate_upper, crate_pascal, trait_pascal, "VTable");
out.push_str(&crate::backends::go::template_env::render(
"trait_interface_header.jinja",
minijinja::context! {
name => trait_name,
},
));
out.push_str(&crate::backends::go::template_env::render(
"plugin_method_signature.jinja",
minijinja::context! {
doc => "Name returns the plugin name.",
method => "Name",
return_type => "string",
},
));
out.push_str(&crate::backends::go::template_env::render(
"plugin_method_signature.jinja",
minijinja::context! {
doc => "Version returns the plugin version.",
method => "Version",
return_type => "string",
},
));
out.push_str(&crate::backends::go::template_env::render(
"plugin_method_signature.jinja",
minijinja::context! {
doc => "Initialize is called when the plugin is loaded.",
method => "Initialize",
return_type => "error",
},
));
out.push_str(&crate::backends::go::template_env::render(
"plugin_method_signature.jinja",
minijinja::context! {
doc => "Shutdown is called when the plugin is unloaded.",
method => "Shutdown",
return_type => "error",
},
));
for method in trait_def
.methods
.iter()
.filter(|m| !bridge_cfg.ffi_skip_methods.contains(&m.name))
{
let method_substituted = method_with_excluded_substituted(method, excluded_named_types);
gen_interface_method(out, &method_substituted);
}
out.push_str("}\n");
out.push('\n');
for method in trait_def
.methods
.iter()
.filter(|m| !bridge_cfg.ffi_skip_methods.contains(&m.name))
{
let export_name = format!("go{}{}", &trait_pascal, method.name.to_pascal_case());
out.push_str(&crate::backends::go::template_env::render(
"export_marker.jinja",
minijinja::context! {
name => &export_name,
},
));
out.push('\n');
let method_substituted = method_with_excluded_substituted(method, excluded_named_types);
gen_trampoline(out, trait_name, &trait_pascal, &method_substituted);
}
gen_plugin_trampolines(out, trait_name, &trait_pascal);
out.push_str(&crate::backends::go::template_env::render(
"register_function_header.jinja",
minijinja::context! {
name => trait_name,
},
));
out.push_str(&crate::backends::go::template_env::render(
"vtable_struct_init.jinja",
minijinja::context! {
c_vtable_struct => &c_vtable_struct,
},
));
for method in trait_def
.methods
.iter()
.filter(|m| !bridge_cfg.ffi_skip_methods.contains(&m.name))
{
let export_name = format!("go{}{}", &trait_pascal, method.name.to_pascal_case());
out.push_str(&crate::backends::go::template_env::render(
"register_vtable_method_field.jinja",
minijinja::context! {
method_name => &method.name,
export_name => export_name,
},
));
}
out.push_str(&crate::backends::go::template_env::render(
"plugin_trampoline_lifecycle.jinja",
minijinja::context! {
field => "name_fn",
pascal => &trait_pascal,
method => "Name",
},
));
out.push_str(&crate::backends::go::template_env::render(
"plugin_trampoline_lifecycle.jinja",
minijinja::context! {
field => "version_fn",
pascal => &trait_pascal,
method => "Version",
},
));
out.push_str(&crate::backends::go::template_env::render(
"plugin_trampoline_lifecycle.jinja",
minijinja::context! {
field => "initialize_fn",
pascal => &trait_pascal,
method => "Initialize",
},
));
out.push_str(&crate::backends::go::template_env::render(
"plugin_trampoline_lifecycle.jinja",
minijinja::context! {
field => "shutdown_fn",
pascal => &trait_pascal,
method => "Shutdown",
},
));
out.push_str(&crate::backends::go::template_env::render(
"vtable_free_user_data_field.jinja",
minijinja::context! {
pascal => &trait_pascal,
},
));
out.push_str("\t}\n");
out.push('\n');
out.push_str(&crate::backends::go::template_env::render(
"register_c_call.jinja",
minijinja::context! {
c_function => format!("{}_register_{}", ffi_prefix, trait_snake),
trait_name => trait_name,
},
));
out.push_str("}\n");
out.push('\n');
out.push_str(&crate::backends::go::template_env::render(
"unregister_function_header.jinja",
minijinja::context! {
name => trait_name,
},
));
out.push_str(&crate::backends::go::template_env::render(
"unregister_c_call.jinja",
minijinja::context! {
c_function => format!("{}_unregister_{}", ffi_prefix, trait_snake),
trait_name => trait_name,
},
));
out.push_str("}\n");
let unregister_block = gen_unregistration_fn(bridge_cfg, ffi_prefix, trait_name);
if !unregister_block.is_empty() {
out.push('\n');
out.push_str(&unregister_block);
}
let clear_block = gen_clear_fn(bridge_cfg, ffi_prefix, trait_name);
if !clear_block.is_empty() {
out.push('\n');
out.push_str(&clear_block);
}
}
fn gen_unregistration_fn(bridge_cfg: &TraitBridgeConfig, ffi_prefix: &str, trait_name: &str) -> String {
let Some(fn_name) = bridge_cfg.unregister_fn.as_deref() else {
return String::new();
};
let trait_snake = heck::AsSnakeCase(trait_name).to_string();
let standard_pascal_name = format!("Unregister{}", trait_name);
let standard_snake_name = heck::AsSnakeCase(&standard_pascal_name).to_string();
if fn_name == standard_snake_name {
return String::new();
}
let c_function = format!("{}_unregister_{}", ffi_prefix, trait_snake);
let go_fn_name = fn_name.to_pascal_case();
let mut out = String::new();
out.push_str(&crate::backends::go::template_env::render(
"unregister_fn_header.jinja",
minijinja::context! {
fn_name => &go_fn_name,
trait_name => trait_name,
},
));
out.push_str(&crate::backends::go::template_env::render(
"unregister_c_call.jinja",
minijinja::context! {
c_function => c_function,
trait_name => trait_name,
},
));
out.push_str("}\n");
out
}
fn gen_clear_fn(bridge_cfg: &TraitBridgeConfig, ffi_prefix: &str, trait_name: &str) -> String {
let Some(fn_name) = bridge_cfg.clear_fn.as_deref() else {
return String::new();
};
let trait_snake = heck::AsSnakeCase(trait_name).to_string();
let c_function = format!("{}_clear_{}", ffi_prefix, trait_snake);
let go_fn_name = fn_name.to_pascal_case();
let mut out = String::new();
out.push_str(&crate::backends::go::template_env::render(
"clear_function_header.jinja",
minijinja::context! {
fn_name => &go_fn_name,
name => trait_name,
},
));
out.push_str(&crate::backends::go::template_env::render(
"clear_c_call.jinja",
minijinja::context! {
c_function => c_function,
trait_name => trait_name,
},
));
out.push_str("}\n");
out
}
fn gen_interface_method(out: &mut String, method: &MethodDef) {
let mut params = Vec::new();
for p in &method.params {
let go_type = rust_to_go_type(&p.ty);
params.push(format!("{} {}", p.name, go_type));
}
let return_type = if method.error_type.is_some() {
match &method.return_type {
TypeRef::Unit => "error".to_string(),
_ => {
let ret = rust_to_go_type(&method.return_type);
format!("({}, error)", ret)
}
}
} else {
rust_to_go_type(&method.return_type)
};
let params_str = params.join(", ");
out.push_str(&crate::backends::go::template_env::render(
"trait_interface_method.jinja",
minijinja::context! {
doc => &method.name,
method_name => method.name.to_pascal_case(),
params => params_str,
return_type => return_type,
},
));
out.push('\n');
}
fn gen_trampoline(out: &mut String, trait_name: &str, trait_pascal: &str, method: &MethodDef) {
let export_name = format!("go{}{}", trait_pascal, method.name.to_pascal_case());
let mut params = vec!["userData unsafe.Pointer".to_string()];
for p in &method.params {
let c_type = rust_to_c_type(&p.ty);
params.push(format!("{} {}", p.name, c_type));
}
if !matches!(method.return_type, TypeRef::Unit) {
params.push("outResult **C.char".to_string());
}
params.push("outError **C.char".to_string());
out.push_str(&crate::backends::go::template_env::render(
"trampoline_signature.jinja",
minijinja::context! {
name => export_name,
params => params,
},
));
out.push('\n');
out.push_str("\thandle := cgo.Handle(uintptr(unsafe.Pointer(userData)))\n");
out.push_str(&crate::backends::go::template_env::render(
"handle_type_assertion.jinja",
minijinja::context! {
type_name => trait_name,
},
));
out.push('\n');
out.push_str("\tif !ok {\n");
out.push_str("\t\treturn 1 // error: invalid handle\n");
out.push_str("\t}\n");
out.push('\n');
for p in &method.params {
gen_param_conversion(out, p);
}
let mut call_args = Vec::new();
for p in &method.params {
call_args.push(format!("go{}", capitalize(&p.name)));
}
out.push_str("\t// Call the method\n");
if method.error_type.is_some() {
match &method.return_type {
TypeRef::Unit => {
out.push_str(&crate::backends::go::template_env::render(
"impl_method_call_err.jinja",
minijinja::context! {
method => method.name.to_pascal_case(),
args => call_args.join(", "),
},
));
out.push('\n');
}
_ => {
out.push_str(&crate::backends::go::template_env::render(
"impl_method_call_result_err.jinja",
minijinja::context! {
method => method.name.to_pascal_case(),
args => call_args.join(", "),
},
));
out.push('\n');
}
}
out.push_str("\tif err != nil {\n");
out.push_str("\t\tcErr := C.CString(err.Error())\n");
out.push_str("\t\t*outError = cErr\n");
out.push_str("\t\treturn 1\n");
out.push_str("\t}\n");
if !matches!(&method.return_type, TypeRef::Unit) {
out.push_str("\tjsonBytes, _ := json.Marshal(result)\n");
out.push_str("\tcResult := C.CString(string(jsonBytes))\n");
out.push_str("\t*outResult = cResult\n");
}
} else {
out.push_str(&crate::backends::go::template_env::render(
"impl_method_call_result.jinja",
minijinja::context! {
method => method.name.to_pascal_case(),
args => call_args.join(", "),
},
));
out.push('\n');
if !matches!(&method.return_type, TypeRef::Unit) {
out.push_str("\tjsonBytes, _ := json.Marshal(result)\n");
out.push_str("\tcResult := C.CString(string(jsonBytes))\n");
out.push_str("\t*outResult = cResult\n");
}
}
out.push_str("\treturn 0 // success\n");
out.push_str("}\n");
out.push('\n');
}
fn gen_plugin_trampolines(out: &mut String, trait_name: &str, trait_pascal: &str) {
out.push_str(&crate::backends::go::template_env::render(
"export_marker.jinja",
minijinja::context! {
name => format!("go{trait_pascal}Name"),
},
));
out.push('\n');
out.push_str(&crate::backends::go::template_env::render(
"plugin_method_trampoline_header.jinja",
minijinja::context! {
pascal => &trait_pascal,
method => "Name",
params => "userData unsafe.Pointer, outResult **C.char, outError **C.char",
},
));
out.push('\n');
out.push_str("\thandle := cgo.Handle(uintptr(unsafe.Pointer(userData)))\n");
out.push_str(&crate::backends::go::template_env::render(
"handle_type_assertion.jinja",
minijinja::context! {
type_name => trait_name,
},
));
out.push('\n');
out.push_str("\tif !ok {\n");
out.push_str("\t\treturn 1\n");
out.push_str("\t}\n");
out.push_str("\tname := impl.Name()\n");
out.push_str("\tcName := C.CString(name)\n");
out.push_str("\t*outResult = cName\n");
out.push_str("\treturn 0\n");
out.push_str("}\n");
out.push('\n');
out.push_str(&crate::backends::go::template_env::render(
"export_marker.jinja",
minijinja::context! {
name => format!("go{trait_pascal}Version"),
},
));
out.push('\n');
out.push_str(&crate::backends::go::template_env::render(
"plugin_method_trampoline_header.jinja",
minijinja::context! {
pascal => &trait_pascal,
method => "Version",
params => "userData unsafe.Pointer, outResult **C.char, outError **C.char",
},
));
out.push('\n');
out.push_str("\thandle := cgo.Handle(uintptr(unsafe.Pointer(userData)))\n");
out.push_str(&crate::backends::go::template_env::render(
"handle_type_assertion.jinja",
minijinja::context! {
type_name => trait_name,
},
));
out.push('\n');
out.push_str("\tif !ok {\n");
out.push_str("\t\treturn 1\n");
out.push_str("\t}\n");
out.push_str("\tversion := impl.Version()\n");
out.push_str("\tcVersion := C.CString(version)\n");
out.push_str("\t*outResult = cVersion\n");
out.push_str("\treturn 0\n");
out.push_str("}\n");
out.push('\n');
out.push_str(&crate::backends::go::template_env::render(
"export_marker.jinja",
minijinja::context! {
name => format!("go{trait_pascal}Initialize"),
},
));
out.push('\n');
out.push_str(&crate::backends::go::template_env::render(
"plugin_method_trampoline_header.jinja",
minijinja::context! {
pascal => &trait_pascal,
method => "Initialize",
params => "userData unsafe.Pointer, outError **C.char",
},
));
out.push('\n');
out.push_str("\thandle := cgo.Handle(uintptr(unsafe.Pointer(userData)))\n");
out.push_str(&crate::backends::go::template_env::render(
"handle_type_assertion.jinja",
minijinja::context! {
type_name => trait_name,
},
));
out.push('\n');
out.push_str("\tif !ok {\n");
out.push_str("\t\treturn 1\n");
out.push_str("\t}\n");
out.push_str("\terr := impl.Initialize()\n");
out.push_str("\tif err != nil {\n");
out.push_str("\t\tcErr := C.CString(err.Error())\n");
out.push_str("\t\t*outError = cErr\n");
out.push_str("\t\treturn 1\n");
out.push_str("\t}\n");
out.push_str("\treturn 0\n");
out.push_str("}\n");
out.push('\n');
out.push_str(&crate::backends::go::template_env::render(
"export_marker.jinja",
minijinja::context! {
name => format!("go{trait_pascal}Shutdown"),
},
));
out.push('\n');
out.push_str(&crate::backends::go::template_env::render(
"plugin_method_trampoline_header.jinja",
minijinja::context! {
pascal => &trait_pascal,
method => "Shutdown",
params => "userData unsafe.Pointer, outError **C.char",
},
));
out.push('\n');
out.push_str("\thandle := cgo.Handle(uintptr(unsafe.Pointer(userData)))\n");
out.push_str(&crate::backends::go::template_env::render(
"handle_type_assertion.jinja",
minijinja::context! {
type_name => trait_name,
},
));
out.push('\n');
out.push_str("\tif !ok {\n");
out.push_str("\t\treturn 1\n");
out.push_str("\t}\n");
out.push_str("\terr := impl.Shutdown()\n");
out.push_str("\tif err != nil {\n");
out.push_str("\t\tcErr := C.CString(err.Error())\n");
out.push_str("\t\t*outError = cErr\n");
out.push_str("\t\treturn 1\n");
out.push_str("\t}\n");
out.push_str("\treturn 0\n");
out.push_str("}\n");
out.push('\n');
out.push_str(&crate::backends::go::template_env::render(
"export_marker.jinja",
minijinja::context! {
name => format!("go{trait_pascal}FreeUserData"),
},
));
out.push('\n');
out.push_str(&crate::backends::go::template_env::render(
"plugin_free_user_data_func.jinja",
minijinja::context! {
pascal => &trait_pascal,
},
));
out.push('\n');
out.push_str("\tcgo.Handle(uintptr(unsafe.Pointer(userData))).Delete()\n");
out.push_str("}\n");
out.push('\n');
}
fn c_trampoline_signature(_export_name: &str, method: &MethodDef) -> String {
let mut params = vec!["void* user_data".to_string()];
for p in &method.params {
let cty = rust_to_plain_c_type(&p.ty);
params.push(format!("{} {}", cty, p.name));
}
if !matches!(method.return_type, TypeRef::Unit) {
params.push("char** out_result".to_string());
}
params.push("char** out_error".to_string());
params.join(", ")
}
fn rust_to_plain_c_type(ty: &TypeRef) -> String {
match ty {
TypeRef::Primitive(p) => {
use crate::core::ir::PrimitiveType::*;
match p {
Bool => "int32_t",
U8 => "uint8_t",
U16 => "uint16_t",
U32 => "uint32_t",
U64 => "uint64_t",
I8 => "int8_t",
I16 => "int16_t",
I32 => "int32_t",
I64 => "int64_t",
F32 => "float",
F64 => "double",
Usize => "size_t",
Isize => "intptr_t",
}
.to_string()
}
TypeRef::String | TypeRef::Char | TypeRef::Path => "char*".to_string(),
TypeRef::Bytes => "uint8_t*".to_string(),
TypeRef::Optional(_) | TypeRef::Vec(_) | TypeRef::Map(_, _) | TypeRef::Named(_) => "char*".to_string(),
TypeRef::Unit => "void".to_string(),
TypeRef::Duration => "uint64_t".to_string(),
_ => "char*".to_string(),
}
}
fn rust_to_go_type(ty: &TypeRef) -> String {
go_type(ty).into_owned()
}
fn rust_to_c_type(ty: &TypeRef) -> String {
match ty {
TypeRef::Primitive(p) => {
use crate::core::ir::PrimitiveType::*;
match p {
Bool => "C.int32_t",
U8 => "C.uint8_t",
U16 => "C.uint16_t",
U32 => "C.uint32_t",
U64 => "C.uint64_t",
I8 => "C.int8_t",
I16 => "C.int16_t",
I32 => "C.int32_t",
I64 => "C.int64_t",
F32 => "C.float",
F64 => "C.double",
Usize => "C.size_t",
Isize => "C.intptr_t",
}
.to_string()
}
TypeRef::String | TypeRef::Char | TypeRef::Path => "*C.char".to_string(),
TypeRef::Bytes => "*C.uint8_t".to_string(),
TypeRef::Optional(_) => "*C.char".to_string(), TypeRef::Vec(_) => "*C.char".to_string(), TypeRef::Map(_, _) => "*C.char".to_string(), TypeRef::Unit => "C.void".to_string(),
TypeRef::Duration => "C.uint64_t".to_string(),
TypeRef::Named(_) => "*C.char".to_string(), _ => "*C.char".to_string(),
}
}
fn gen_param_conversion(out: &mut String, param: &crate::core::ir::ParamDef) {
let var_name = format!("go{}", capitalize(¶m.name));
match ¶m.ty {
TypeRef::String | TypeRef::Char | TypeRef::Path => {
out.push_str(&crate::backends::go::template_env::render(
"go_string_cast.jinja",
minijinja::context! {
name => capitalize(¶m.name),
param => param.name.as_str(),
},
));
out.push('\n');
}
TypeRef::Bytes => {
out.push_str(&crate::backends::go::template_env::render(
"var_bytes_decl.jinja",
minijinja::context! {
var_name => &var_name,
},
));
out.push_str(&crate::backends::go::template_env::render(
"if_nil_check.jinja",
minijinja::context! {
param => param.name.as_str(),
},
));
out.push_str("\t\tvar b64str string\n");
out.push_str(&crate::backends::go::template_env::render(
"json_unmarshal_unsafe.jinja",
minijinja::context! {
param => param.name.as_str(),
},
));
out.push('\n');
out.push_str("\t\tif decoded, err := base64.StdEncoding.DecodeString(b64str); err == nil {\n");
out.push_str(&crate::backends::go::template_env::render(
"var_assign.jinja",
minijinja::context! {
var => &var_name,
expr => "decoded",
},
));
out.push_str("\t\t}\n");
out.push_str("\t}\n");
out.push('\n');
}
TypeRef::Vec(_) => {
let go_type = rust_to_go_type(¶m.ty);
out.push_str(&crate::backends::go::template_env::render(
"var_type_decl.jinja",
minijinja::context! {
var_name => &var_name,
type_name => &go_type,
},
));
out.push_str(&crate::backends::go::template_env::render(
"if_nil_check.jinja",
minijinja::context! {
param => param.name.as_str(),
},
));
out.push_str(&crate::backends::go::template_env::render(
"json_unmarshal_simple.jinja",
minijinja::context! {
param => param.name.as_str(),
var_name => &var_name,
},
));
out.push('\n');
out.push_str("\t}\n");
out.push('\n');
}
TypeRef::Named(_) => {
let go_type = rust_to_go_type(¶m.ty);
out.push_str(&crate::backends::go::template_env::render(
"var_type_decl.jinja",
minijinja::context! {
var_name => &var_name,
type_name => &go_type,
},
));
out.push_str(&crate::backends::go::template_env::render(
"if_nil_check.jinja",
minijinja::context! {
param => param.name.as_str(),
},
));
out.push_str(&crate::backends::go::template_env::render(
"json_unmarshal_simple.jinja",
minijinja::context! {
param => param.name.as_str(),
var_name => &var_name,
},
));
out.push('\n');
out.push_str("\t}\n");
out.push('\n');
}
TypeRef::Map(_, _) => {
let go_type = rust_to_go_type(¶m.ty);
out.push_str(&crate::backends::go::template_env::render(
"var_type_decl.jinja",
minijinja::context! {
var_name => &var_name,
type_name => &go_type,
},
));
out.push_str(&crate::backends::go::template_env::render(
"if_nil_check.jinja",
minijinja::context! {
param => param.name.as_str(),
},
));
out.push_str("\t\tvar rawData interface{}\n");
out.push_str(&crate::backends::go::template_env::render(
"json_unmarshal_rawdata.jinja",
minijinja::context! {
param => param.name.as_str(),
},
));
out.push('\n');
out.push_str("\t\tif m, ok := rawData.(map[string]interface{}); ok {\n");
out.push_str(&crate::backends::go::template_env::render(
"var_assign_m.jinja",
minijinja::context! {
var => &var_name,
},
));
out.push('\n');
out.push_str("\t\t}\n");
out.push_str("\t}\n");
out.push('\n');
}
TypeRef::Optional(_) => {
let go_type = rust_to_go_type(¶m.ty);
out.push_str(&crate::backends::go::template_env::render(
"var_type_decl.jinja",
minijinja::context! {
var_name => &var_name,
type_name => &go_type,
},
));
out.push_str(&crate::backends::go::template_env::render(
"if_nil_check.jinja",
minijinja::context! {
param => param.name.as_str(),
},
));
out.push_str("\t\tvar rawData interface{}\n");
out.push_str(&crate::backends::go::template_env::render(
"json_unmarshal_rawdata.jinja",
minijinja::context! {
param => param.name.as_str(),
},
));
out.push('\n');
out.push_str("\t\tif m, ok := rawData.(map[string]interface{}); ok {\n");
out.push_str(&crate::backends::go::template_env::render(
"var_assign_m.jinja",
minijinja::context! {
var => &var_name,
},
));
out.push('\n');
out.push_str("\t\t}\n");
out.push_str("\t}\n");
out.push('\n');
}
TypeRef::Json => {
out.push_str(&format!("\tvar {var_name} json.RawMessage\n"));
out.push_str(&crate::backends::go::template_env::render(
"if_nil_check.jinja",
minijinja::context! {
param => param.name.as_str(),
},
));
out.push_str(&format!(
"\t\t{var_name} = json.RawMessage(C.GoString({}))\n",
param.name
));
out.push_str("\t}\n");
out.push('\n');
}
TypeRef::Primitive(p) => {
use crate::core::ir::PrimitiveType::*;
let cast = match p {
Bool => format!("{} != 0", param.name),
_ => {
let go_type = match p {
U8 => "uint8",
U16 => "uint16",
U32 => "uint32",
U64 => "uint64",
I8 => "int8",
I16 => "int16",
I32 => "int32",
I64 => "int64",
F32 => "float32",
F64 => "float64",
Usize => "uint",
Isize => "int",
_ => "",
};
format!("{}({})", go_type, param.name)
}
};
out.push_str(&crate::backends::go::template_env::render(
"var_assign_cast.jinja",
minijinja::context! {
var_name => &var_name,
cast => &cast,
},
));
out.push('\n');
out.push('\n');
}
_ => {
out.push_str(&crate::backends::go::template_env::render(
"var_assign_cast.jinja",
minijinja::context! {
var_name => &var_name,
cast => param.name.as_str(),
},
));
out.push('\n');
out.push('\n');
}
}
}
fn capitalize(s: &str) -> String {
let mut chars = s.chars();
match chars.next() {
None => String::new(),
Some(c) => c.to_uppercase().collect::<String>() + chars.as_str(),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_vtable_struct_name_derivation() {
let crate_name = "kreuzberg";
let crate_upper = crate_name.to_uppercase();
let crate_pascal = crate_name.to_pascal_case();
let trait_name = "OcrBackend";
let trait_pascal = trait_name.to_pascal_case();
let c_vtable_struct = format!("{}{}{}{}", crate_upper, crate_pascal, trait_pascal, "VTable");
assert_eq!(c_vtable_struct, "KREUZBERGKreuzbergOcrBackendVTable");
}
#[test]
fn test_register_function_name_format() {
let ffi_prefix = "kreuzberg";
let trait_name = "OcrBackend";
let trait_snake = heck::AsSnakeCase(trait_name).to_string();
let register_fn = format!("{}_register_{}", ffi_prefix, trait_snake);
assert_eq!(register_fn, "kreuzberg_register_ocr_backend");
}
#[test]
fn test_unregister_function_name_format() {
let ffi_prefix = "kreuzberg";
let trait_name = "PostProcessor";
let trait_snake = heck::AsSnakeCase(trait_name).to_string();
let unregister_fn = format!("{}_unregister_{}", ffi_prefix, trait_snake);
assert_eq!(unregister_fn, "kreuzberg_unregister_post_processor");
}
#[test]
fn test_vtable_struct_name_multiple_traits() {
let test_cases = vec![
("kreuzberg", "OcrBackend", "KREUZBERGKreuzbergOcrBackendVTable"),
("kreuzberg", "PostProcessor", "KREUZBERGKreuzbergPostProcessorVTable"),
("kreuzberg", "Validator", "KREUZBERGKreuzbergValidatorVTable"),
(
"kreuzberg",
"EmbeddingBackend",
"KREUZBERGKreuzbergEmbeddingBackendVTable",
),
];
for (crate_name, trait_name, expected_struct) in test_cases {
let crate_upper = crate_name.to_uppercase();
let crate_pascal = crate_name.to_pascal_case();
let trait_pascal = trait_name.to_pascal_case();
let c_vtable_struct = format!("{}{}{}{}", crate_upper, crate_pascal, trait_pascal, "VTable");
assert_eq!(
c_vtable_struct, expected_struct,
"Mismatch for crate={}, trait={}",
crate_name, trait_name
);
}
}
#[test]
fn gen_unregistration_fn_returns_empty_when_none() {
let cfg = crate::core::config::TraitBridgeConfig {
trait_name: "OcrBackend".to_string(),
unregister_fn: None,
clear_fn: None,
..Default::default()
};
let result = gen_unregistration_fn(&cfg, "kreuzberg", "OcrBackend");
assert!(result.is_empty(), "expected empty output when unregister_fn is None");
}
#[test]
fn gen_unregistration_fn_emits_wrapper_when_set() {
let cfg = crate::core::config::TraitBridgeConfig {
trait_name: "OcrBackend".to_string(),
unregister_fn: Some("remove_ocr_backend".to_string()),
clear_fn: None,
..Default::default()
};
let result = gen_unregistration_fn(&cfg, "kreuzberg", "OcrBackend");
assert!(
!result.is_empty(),
"expected non-empty output when unregister_fn is set"
);
assert!(
result.contains("func RemoveOcrBackend(name string) error"),
"generated function signature not found in:\n{result}"
);
assert!(
result.contains("C.kreuzberg_unregister_ocr_backend"),
"C call not found in:\n{result}"
);
}
#[test]
fn gen_clear_fn_returns_empty_when_none() {
let cfg = crate::core::config::TraitBridgeConfig {
trait_name: "OcrBackend".to_string(),
unregister_fn: None,
clear_fn: None,
..Default::default()
};
let result = gen_clear_fn(&cfg, "kreuzberg", "OcrBackend");
assert!(result.is_empty(), "expected empty output when clear_fn is None");
}
#[test]
fn gen_clear_fn_emits_wrapper_when_set() {
let cfg = crate::core::config::TraitBridgeConfig {
trait_name: "OcrBackend".to_string(),
unregister_fn: None,
clear_fn: Some("clear_ocr_backends".to_string()),
..Default::default()
};
let result = gen_clear_fn(&cfg, "kreuzberg", "OcrBackend");
assert!(!result.is_empty(), "expected non-empty output when clear_fn is set");
assert!(
result.contains("func ClearOcrBackends() error"),
"generated function signature not found in:\n{result}"
);
assert!(
result.contains("C.kreuzberg_clear_ocr_backend"),
"C call not found in:\n{result}"
);
}
#[test]
fn substitute_excluded_types_replaces_excluded_named_with_json() {
let mut excluded = HashSet::new();
excluded.insert("InternalDocument");
let result = substitute_excluded_types(&TypeRef::Named("InternalDocument".to_string()), &excluded);
assert!(matches!(result, TypeRef::Json), "expected Json, got {:?}", result);
}
#[test]
fn substitute_excluded_types_leaves_non_excluded_named_intact() {
let excluded: HashSet<&str> = HashSet::new();
let result = substitute_excluded_types(&TypeRef::Named("ExtractionConfig".to_string()), &excluded);
match result {
TypeRef::Named(ref n) => assert_eq!(n, "ExtractionConfig"),
other => panic!("expected Named, got {:?}", other),
}
}
#[test]
fn substitute_excluded_types_recurses_into_optional_vec_map() {
let mut excluded = HashSet::new();
excluded.insert("X");
excluded.insert("Y");
excluded.insert("Z");
let opt = TypeRef::Optional(Box::new(TypeRef::Named("X".to_string())));
match substitute_excluded_types(&opt, &excluded) {
TypeRef::Optional(inner) => assert!(matches!(*inner, TypeRef::Json)),
other => panic!("expected Optional<Json>, got {:?}", other),
}
let v = TypeRef::Vec(Box::new(TypeRef::Named("Y".to_string())));
match substitute_excluded_types(&v, &excluded) {
TypeRef::Vec(inner) => assert!(matches!(*inner, TypeRef::Json)),
other => panic!("expected Vec<Json>, got {:?}", other),
}
let m = TypeRef::Map(Box::new(TypeRef::String), Box::new(TypeRef::Named("Z".to_string())));
match substitute_excluded_types(&m, &excluded) {
TypeRef::Map(k, v) => {
assert!(matches!(*k, TypeRef::String));
assert!(matches!(*v, TypeRef::Json));
}
other => panic!("expected Map<String, Json>, got {:?}", other),
}
}
#[test]
fn substitute_excluded_types_passes_through_primitives_and_other_atoms() {
let excluded: HashSet<&str> = HashSet::new();
assert!(matches!(
substitute_excluded_types(&TypeRef::String, &excluded),
TypeRef::String
));
assert!(matches!(
substitute_excluded_types(&TypeRef::Bytes, &excluded),
TypeRef::Bytes
));
assert!(matches!(
substitute_excluded_types(&TypeRef::Unit, &excluded),
TypeRef::Unit
));
}
}