use heck::{ToPascalCase, ToSnakeCase};
use std::fmt::Write;
pub fn gen_trait_bridge(
trait_name: &str,
prefix: &str,
trait_methods: &[(&str, &str)], has_super_trait: bool,
) -> String {
let trait_pascal = trait_name.to_pascal_case();
let trait_snake = trait_name.to_snake_case();
let prefix_upper = prefix.to_uppercase();
let mut out = String::with_capacity(4096);
writeln!(out, "/**").ok();
writeln!(out, " * Bridge trait for {} plugin system.", trait_pascal).ok();
writeln!(out, " *").ok();
writeln!(
out,
" * Implementations provide methods that are called via upcall stubs"
)
.ok();
writeln!(out, " * into the C vtable during registration.").ok();
writeln!(out, " */").ok();
writeln!(out, "public interface {} {{", trait_pascal).ok();
writeln!(out).ok();
if has_super_trait {
writeln!(out, " /** Return the plugin name. */").ok();
writeln!(out, " String name();").ok();
writeln!(out).ok();
writeln!(out, " /** Return the plugin version. */").ok();
writeln!(out, " String version();").ok();
writeln!(out).ok();
writeln!(out, " /** Initialize the plugin. */").ok();
writeln!(out, " void initialize() throws Exception;").ok();
writeln!(out).ok();
writeln!(out, " /** Shut down the plugin. */").ok();
writeln!(out, " void shutdown() throws Exception;").ok();
writeln!(out).ok();
}
for (method_name, return_type) in trait_methods {
writeln!(out, " /** Trait method: {}. */", method_name).ok();
writeln!(out, " {} {}();", return_type, method_name).ok();
writeln!(out).ok();
}
writeln!(out, "}}").ok();
writeln!(out).ok();
writeln!(out, "/**").ok();
writeln!(
out,
" * Allocates Panama FFM upcall stubs for a {} trait implementation",
trait_pascal
)
.ok();
writeln!(out, " * and assembles the C vtable in native memory.").ok();
writeln!(out, " */").ok();
writeln!(out, "final class {}Bridge implements AutoCloseable {{", trait_pascal).ok();
writeln!(out).ok();
writeln!(out, " private static final Linker LINKER = Linker.nativeLinker();").ok();
writeln!(
out,
" private static final MethodHandles.Lookup LOOKUP = MethodHandles.lookup();"
)
.ok();
writeln!(out).ok();
let num_methods = trait_methods.len();
let num_super_slots = if has_super_trait { 4usize } else { 0usize };
let num_vtable_fields = num_super_slots + num_methods + 1; writeln!(
out,
" // C vtable: {} fields ({} plugin methods + {} trait methods + free_user_data)",
num_vtable_fields, num_super_slots, num_methods
)
.ok();
writeln!(
out,
" private static final long VTABLE_SIZE = (long) ValueLayout.ADDRESS.byteSize() * {}L;",
num_vtable_fields
)
.ok();
writeln!(out).ok();
writeln!(out, " private final Arena arena;").ok();
writeln!(out, " private final MemorySegment vtable;").ok();
writeln!(out, " private final {} impl;", trait_pascal).ok();
writeln!(out).ok();
writeln!(out, " {}Bridge(final {} impl) {{", trait_pascal, trait_pascal).ok();
writeln!(out, " this.impl = impl;").ok();
writeln!(out, " this.arena = Arena.ofConfined();").ok();
writeln!(out, " this.vtable = arena.allocate(VTABLE_SIZE);").ok();
writeln!(out).ok();
writeln!(out, " try {{").ok();
writeln!(out, " long offset = 0L;").ok();
writeln!(out).ok();
if has_super_trait {
writeln!(
out,
" var stubName = LINKER.upcallStub(LOOKUP.bind(this, \"handleName\","
)
.ok();
writeln!(out, " MethodType.methodType(MemorySegment.class)),").ok();
writeln!(
out,
" FunctionDescriptor.of(ValueLayout.ADDRESS, ValueLayout.ADDRESS),"
)
.ok();
writeln!(out, " arena);").ok();
writeln!(out, " vtable.set(ValueLayout.ADDRESS, offset, stubName);").ok();
writeln!(out, " offset += ValueLayout.ADDRESS.byteSize();").ok();
writeln!(out).ok();
writeln!(
out,
" var stubVersion = LINKER.upcallStub(LOOKUP.bind(this, \"handleVersion\","
)
.ok();
writeln!(out, " MethodType.methodType(MemorySegment.class)),").ok();
writeln!(
out,
" FunctionDescriptor.of(ValueLayout.ADDRESS, ValueLayout.ADDRESS),"
)
.ok();
writeln!(out, " arena);").ok();
writeln!(out, " vtable.set(ValueLayout.ADDRESS, offset, stubVersion);").ok();
writeln!(out, " offset += ValueLayout.ADDRESS.byteSize();").ok();
writeln!(out).ok();
writeln!(
out,
" var stubInitialize = LINKER.upcallStub(LOOKUP.bind(this, \"handleInitialize\","
)
.ok();
writeln!(out, " MethodType.methodType(int.class)),").ok();
writeln!(
out,
" FunctionDescriptor.of(ValueLayout.JAVA_INT, ValueLayout.ADDRESS),"
)
.ok();
writeln!(out, " arena);").ok();
writeln!(
out,
" vtable.set(ValueLayout.ADDRESS, offset, stubInitialize);"
)
.ok();
writeln!(out, " offset += ValueLayout.ADDRESS.byteSize();").ok();
writeln!(out).ok();
writeln!(
out,
" var stubShutdown = LINKER.upcallStub(LOOKUP.bind(this, \"handleShutdown\","
)
.ok();
writeln!(out, " MethodType.methodType(int.class)),").ok();
writeln!(
out,
" FunctionDescriptor.of(ValueLayout.JAVA_INT, ValueLayout.ADDRESS),"
)
.ok();
writeln!(out, " arena);").ok();
writeln!(
out,
" vtable.set(ValueLayout.ADDRESS, offset, stubShutdown);"
)
.ok();
writeln!(out, " offset += ValueLayout.ADDRESS.byteSize();").ok();
writeln!(out).ok();
}
for (method_name, _) in trait_methods {
let handle_name = format!("handle{}", method_name.to_pascal_case());
writeln!(
out,
" var stub{} = LINKER.upcallStub(LOOKUP.bind(this, \"{}\",",
method_name.to_pascal_case(),
handle_name
)
.ok();
writeln!(out, " MethodType.methodType(MemorySegment.class)),").ok();
writeln!(
out,
" FunctionDescriptor.of(ValueLayout.ADDRESS, ValueLayout.ADDRESS),"
)
.ok();
writeln!(out, " arena);").ok();
writeln!(
out,
" vtable.set(ValueLayout.ADDRESS, offset, stub{});",
method_name.to_pascal_case()
)
.ok();
writeln!(out, " offset += ValueLayout.ADDRESS.byteSize();").ok();
writeln!(out).ok();
}
writeln!(
out,
" vtable.set(ValueLayout.ADDRESS, offset, MemorySegment.NULL);"
)
.ok();
writeln!(out).ok();
writeln!(out, " }} catch (ReflectiveOperationException e) {{").ok();
writeln!(out, " arena.close();").ok();
writeln!(
out,
" throw new RuntimeException(\"Failed to create trait bridge stubs\", e);"
)
.ok();
writeln!(out, " }}").ok();
writeln!(out, " }}").ok();
writeln!(out).ok();
writeln!(out, " MemorySegment vtableSegment() {{").ok();
writeln!(out, " return vtable;").ok();
writeln!(out, " }}").ok();
writeln!(out).ok();
writeln!(
out,
" // --- Upcall handlers (return MemorySegment pointing to allocated strings) ---"
)
.ok();
writeln!(out).ok();
if has_super_trait {
writeln!(out, " private MemorySegment handleName() {{").ok();
writeln!(out, " try {{").ok();
writeln!(out, " String name = impl.name();").ok();
writeln!(out, " return arena.allocateFrom(name);").ok();
writeln!(out, " }} catch (Throwable e) {{").ok();
writeln!(out, " return MemorySegment.NULL;").ok();
writeln!(out, " }}").ok();
writeln!(out, " }}").ok();
writeln!(out).ok();
writeln!(out, " private MemorySegment handleVersion() {{").ok();
writeln!(out, " try {{").ok();
writeln!(out, " String version = impl.version();").ok();
writeln!(out, " return arena.allocateFrom(version);").ok();
writeln!(out, " }} catch (Throwable e) {{").ok();
writeln!(out, " return MemorySegment.NULL;").ok();
writeln!(out, " }}").ok();
writeln!(out, " }}").ok();
writeln!(out).ok();
writeln!(out, " private int handleInitialize() {{").ok();
writeln!(out, " try {{").ok();
writeln!(out, " impl.initialize();").ok();
writeln!(out, " return 0;").ok();
writeln!(out, " }} catch (Throwable e) {{").ok();
writeln!(out, " return 1;").ok();
writeln!(out, " }}").ok();
writeln!(out, " }}").ok();
writeln!(out).ok();
writeln!(out, " private int handleShutdown() {{").ok();
writeln!(out, " try {{").ok();
writeln!(out, " impl.shutdown();").ok();
writeln!(out, " return 0;").ok();
writeln!(out, " }} catch (Throwable e) {{").ok();
writeln!(out, " return 1;").ok();
writeln!(out, " }}").ok();
writeln!(out, " }}").ok();
writeln!(out).ok();
}
for (method_name, return_type) in trait_methods {
writeln!(
out,
" private MemorySegment handle{}() {{",
method_name.to_pascal_case()
)
.ok();
writeln!(out, " try {{").ok();
if *return_type == "void" || *return_type == "Void" {
writeln!(out, " impl.{}();", method_name).ok();
writeln!(out, " return MemorySegment.NULL;").ok();
} else {
writeln!(out, " {} result = impl.{}();", return_type, method_name).ok();
writeln!(out, " String json = new com.fasterxml.jackson.databind.ObjectMapper().writeValueAsString(result);").ok();
writeln!(out, " return arena.allocateFrom(json);").ok();
}
writeln!(out, " }} catch (Throwable e) {{").ok();
writeln!(out, " return MemorySegment.NULL;").ok();
writeln!(out, " }}").ok();
writeln!(out, " }}").ok();
writeln!(out).ok();
}
writeln!(out, " @Override").ok();
writeln!(out, " public void close() {{").ok();
writeln!(out, " arena.close();").ok();
writeln!(out, " }}").ok();
writeln!(out, "}}").ok();
writeln!(out).ok();
writeln!(
out,
"/** Registry of live {} bridges — keeps upcall stubs and arenas alive. */",
trait_pascal
)
.ok();
writeln!(
out,
"private static final java.util.concurrent.ConcurrentHashMap<String, {}Bridge> {}_BRIDGES",
trait_pascal,
trait_snake.to_uppercase()
)
.ok();
writeln!(out, " = new java.util.concurrent.ConcurrentHashMap<>();").ok();
writeln!(out).ok();
writeln!(
out,
"/** Register a {} implementation via Panama FFM upcall stubs. */",
trait_pascal
)
.ok();
writeln!(
out,
"public static void register{}(final {} impl) throws Exception {{",
trait_pascal, trait_pascal
)
.ok();
writeln!(out, " var bridge = new {}Bridge(impl);", trait_pascal).ok();
writeln!(out, " try {{").ok();
writeln!(out, " try (var nameArena = Arena.ofConfined()) {{").ok();
writeln!(out, " var nameCs = nameArena.allocateFrom(impl.name());").ok();
writeln!(out, " var outErrArena = Arena.ofConfined();").ok();
writeln!(
out,
" MemorySegment outErr = outErrArena.allocate(ValueLayout.ADDRESS);"
)
.ok();
writeln!(
out,
" int rc = (int) NativeLib.{}_REGISTER_{}.invoke(nameCs, bridge.vtableSegment(), MemorySegment.NULL, outErr);",
prefix_upper,
trait_snake.to_uppercase()
)
.ok();
writeln!(out, " if (rc != 0) {{").ok();
writeln!(
out,
" MemorySegment errPtr = outErr.get(ValueLayout.ADDRESS, 0);"
)
.ok();
writeln!(
out,
" String msg = errPtr.equals(MemorySegment.NULL) ? \"registration failed (rc=\" + rc + \")\" : errPtr.reinterpret(Long.MAX_VALUE).getString(0);"
)
.ok();
writeln!(
out,
" throw new RuntimeException(\"register{}: \" + msg);",
trait_pascal
)
.ok();
writeln!(out, " }}").ok();
writeln!(out, " }}").ok();
writeln!(out, " }} catch (Throwable t) {{").ok();
writeln!(out, " bridge.close();").ok();
writeln!(out, " throw t;").ok();
writeln!(out, " }}").ok();
writeln!(
out,
" {}_BRIDGES.put(impl.name(), bridge);",
trait_snake.to_uppercase()
)
.ok();
writeln!(out, "}}").ok();
writeln!(out).ok();
writeln!(out, "/** Unregister a {} implementation. */", trait_pascal).ok();
writeln!(
out,
"public static void unregister{}(String name) throws Exception {{",
trait_pascal
)
.ok();
writeln!(out, " try (var nameArena = Arena.ofConfined()) {{").ok();
writeln!(out, " var nameCs = nameArena.allocateFrom(name);").ok();
writeln!(out, " var outErrArena = Arena.ofConfined();").ok();
writeln!(
out,
" MemorySegment outErr = outErrArena.allocate(ValueLayout.ADDRESS);"
)
.ok();
writeln!(
out,
" int rc = (int) NativeLib.{}_UNREGISTER_{}.invoke(nameCs, outErr);",
prefix_upper,
trait_snake.to_uppercase()
)
.ok();
writeln!(out, " if (rc != 0) {{").ok();
writeln!(
out,
" MemorySegment errPtr = outErr.get(ValueLayout.ADDRESS, 0);"
)
.ok();
writeln!(
out,
" String msg = errPtr.equals(MemorySegment.NULL) ? \"unregistration failed (rc=\" + rc + \")\" : errPtr.reinterpret(Long.MAX_VALUE).getString(0);"
)
.ok();
writeln!(
out,
" throw new RuntimeException(\"unregister{}: \" + msg);",
trait_pascal
)
.ok();
writeln!(out, " }}").ok();
writeln!(out, " }}").ok();
writeln!(
out,
" {}Bridge old = {}_BRIDGES.remove(name);",
trait_pascal,
trait_snake.to_uppercase()
)
.ok();
writeln!(out, " if (old != null) {{ old.close(); }}").ok();
writeln!(out, "}}").ok();
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gen_trait_bridge_basic() {
let code = gen_trait_bridge("MyPlugin", "mylib", &[("doWork", "String"), ("getStatus", "int")], true);
assert!(code.contains("public interface MyPlugin"));
assert!(code.contains("String name()"));
assert!(code.contains("String version()"));
assert!(code.contains("void initialize()"));
assert!(code.contains("void shutdown()"));
assert!(code.contains("doWork"));
assert!(code.contains("getStatus"));
assert!(code.contains("MyPluginBridge"));
assert!(code.contains("registerMyPlugin"));
assert!(code.contains("unregisterMyPlugin"));
}
#[test]
fn test_gen_trait_bridge_vtable_stubs() {
let code = gen_trait_bridge("Handler", "lib", &[], true);
assert!(code.contains("LINKER.upcallStub"));
assert!(code.contains("handleName"));
assert!(code.contains("handleVersion"));
assert!(code.contains("handleInitialize"));
assert!(code.contains("handleShutdown"));
}
#[test]
fn test_gen_trait_bridge_lifecycle_methods() {
let code = gen_trait_bridge("Processor", "pfx", &[("process", "Object")], true);
assert!(code.contains("String name()"));
assert!(code.contains("String version()"));
assert!(code.contains("void initialize()"));
assert!(code.contains("void shutdown()"));
}
#[test]
fn test_gen_trait_bridge_no_super_trait_omits_lifecycle() {
let code = gen_trait_bridge("Transformer", "lib", &[("transform", "String")], false);
assert!(
!code.contains("String name()"),
"no lifecycle methods without super_trait"
);
assert!(
!code.contains("handleName"),
"no lifecycle handlers without super_trait"
);
assert!(code.contains("transform"), "trait method must still be emitted");
}
}