use super::{csharp_file_header, is_bridge_param, pinvoke_param_type, pinvoke_return_type};
use alef_codegen::naming::to_csharp_name;
use alef_core::config::TraitBridgeConfig;
use alef_core::ir::{ApiSurface, FunctionDef, MethodDef, TypeRef};
use heck::{ToLowerCamelCase, ToPascalCase, ToSnakeCase};
use std::collections::HashSet;
#[allow(clippy::too_many_arguments)]
pub(super) fn gen_native_methods(
api: &ApiSurface,
namespace: &str,
lib_name: &str,
prefix: &str,
bridge_param_names: &HashSet<String>,
bridge_type_aliases: &HashSet<String>,
has_visitor_callbacks: bool,
trait_bridges: &[TraitBridgeConfig],
streaming_methods: &HashSet<String>,
exclude_functions: &HashSet<String>,
) -> String {
let mut out = csharp_file_header();
out.push_str("using System;\n");
out.push_str("using System.Runtime.InteropServices;\n\n");
out.push_str(&format!("namespace {};\n\n", namespace));
out.push_str("internal static partial class NativeMethods\n{\n");
out.push_str(&format!(" private const string LibName = \"{}\";\n\n", lib_name));
let mut emitted: HashSet<String> = HashSet::new();
let enum_names: HashSet<String> = api.enums.iter().map(|e| e.name.clone()).collect();
let mut opaque_param_types: HashSet<String> = HashSet::new();
let mut opaque_return_types: HashSet<String> = HashSet::new();
for func in api.functions.iter().filter(|f| !exclude_functions.contains(&f.name)) {
for param in &func.params {
if let TypeRef::Named(name) = ¶m.ty {
opaque_param_types.insert(name.clone());
}
}
if let TypeRef::Named(name) = &func.return_type {
if !enum_names.contains(name) {
opaque_return_types.insert(name.clone());
}
}
}
for typ in api.types.iter().filter(|typ| !typ.is_trait) {
for method in &typ.methods {
if streaming_methods.contains(&method.name) {
continue;
}
for param in &method.params {
if let TypeRef::Named(name) = ¶m.ty {
opaque_param_types.insert(name.clone());
}
}
if let TypeRef::Named(name) = &method.return_type {
if !enum_names.contains(name) {
opaque_return_types.insert(name.clone());
}
}
}
}
let true_opaque_types: HashSet<String> = api
.types
.iter()
.filter(|t| t.is_opaque)
.map(|t| t.name.clone())
.collect();
let mut sorted_true_opaque_types: Vec<&String> = true_opaque_types.iter().collect();
sorted_true_opaque_types.sort();
for type_name in sorted_true_opaque_types {
let snake = type_name.to_snake_case();
let free_entry = format!("{prefix}_{snake}_free");
let free_cs = format!("{}Free", type_name.to_pascal_case());
if emitted.insert(free_entry.clone()) {
out.push_str(&format!(
" [DllImport(LibName, CallingConvention = CallingConvention.Cdecl, EntryPoint = \"{free_entry}\")]\n"
));
out.push_str(&format!(" internal static extern void {free_cs}(IntPtr ptr);\n\n"));
}
}
let mut sorted_param_types: Vec<&String> = opaque_param_types.iter().collect();
sorted_param_types.sort();
for type_name in sorted_param_types {
let snake = type_name.to_snake_case();
if !true_opaque_types.contains(type_name) {
let from_json_entry = format!("{prefix}_{snake}_from_json");
let from_json_cs = format!("{}FromJson", type_name.to_pascal_case());
if emitted.insert(from_json_entry.clone()) {
out.push_str(&format!(
" [DllImport(LibName, CallingConvention = CallingConvention.Cdecl, EntryPoint = \"{from_json_entry}\")]\n"
));
out.push_str(&format!(
" internal static extern IntPtr {from_json_cs}([MarshalAs(UnmanagedType.LPUTF8Str)] string json);\n\n"
));
}
}
let free_entry = format!("{prefix}_{snake}_free");
let free_cs = format!("{}Free", type_name.to_pascal_case());
if emitted.insert(free_entry.clone()) {
out.push_str(&format!(
" [DllImport(LibName, CallingConvention = CallingConvention.Cdecl, EntryPoint = \"{free_entry}\")]\n"
));
out.push_str(&format!(" internal static extern void {free_cs}(IntPtr ptr);\n\n"));
}
}
let mut sorted_return_types: Vec<&String> = opaque_return_types.iter().collect();
sorted_return_types.sort();
for type_name in sorted_return_types {
let snake = type_name.to_snake_case();
if !true_opaque_types.contains(type_name) {
let to_json_entry = format!("{prefix}_{snake}_to_json");
let to_json_cs = format!("{}ToJson", type_name.to_pascal_case());
if emitted.insert(to_json_entry.clone()) {
out.push_str(&format!(
" [DllImport(LibName, CallingConvention = CallingConvention.Cdecl, EntryPoint = \"{to_json_entry}\")]\n"
));
out.push_str(&format!(
" internal static extern IntPtr {to_json_cs}(IntPtr ptr);\n\n"
));
}
}
let free_entry = format!("{prefix}_{snake}_free");
let free_cs = format!("{}Free", type_name.to_pascal_case());
if emitted.insert(free_entry.clone()) {
out.push_str(&format!(
" [DllImport(LibName, CallingConvention = CallingConvention.Cdecl, EntryPoint = \"{free_entry}\")]\n"
));
out.push_str(&format!(" internal static extern void {free_cs}(IntPtr ptr);\n\n"));
}
}
for func in api.functions.iter().filter(|f| !exclude_functions.contains(&f.name)) {
let c_func_name = format!("{}_{}", prefix, func.name.to_lowercase());
if emitted.insert(c_func_name.clone()) {
out.push_str(&gen_pinvoke_for_func(
&c_func_name,
func,
bridge_param_names,
bridge_type_aliases,
));
}
}
for typ in api.types.iter().filter(|typ| !typ.is_trait) {
let type_snake = typ.name.to_snake_case();
for method in &typ.methods {
if streaming_methods.contains(&method.name) {
continue;
}
let c_method_name = format!("{}_{}_{}", prefix, type_snake, method.name.to_lowercase());
let cs_method_name = format!("{}{}", typ.name.to_pascal_case(), to_csharp_name(&method.name));
if emitted.insert(c_method_name.clone()) {
out.push_str(&gen_pinvoke_for_method(&c_method_name, &cs_method_name, method));
}
}
}
out.push_str(&format!(
" [DllImport(LibName, CallingConvention = CallingConvention.Cdecl, EntryPoint = \"{prefix}_last_error_code\")]\n"
));
out.push_str(" internal static extern int LastErrorCode();\n\n");
out.push_str(&format!(
" [DllImport(LibName, CallingConvention = CallingConvention.Cdecl, EntryPoint = \"{prefix}_last_error_context\")]\n"
));
out.push_str(" internal static extern IntPtr LastErrorContext();\n\n");
out.push_str(&format!(
" [DllImport(LibName, CallingConvention = CallingConvention.Cdecl, EntryPoint = \"{prefix}_free_string\")]\n"
));
out.push_str(" internal static extern void FreeString(IntPtr ptr);\n");
if has_visitor_callbacks {
out.push('\n');
let visitor_bridge = trait_bridges
.iter()
.find(|b| b.bind_via == alef_core::config::BridgeBinding::OptionsField);
if let Some(bridge) = visitor_bridge {
out.push_str(&crate::gen_visitor::gen_native_methods_visitor(
namespace,
lib_name,
prefix,
&bridge.trait_name,
bridge.options_field.as_deref().unwrap_or("visitor"),
));
}
}
if !trait_bridges.is_empty() {
let trait_defs: Vec<_> = api.types.iter().filter(|t| t.is_trait).collect();
let bridges: Vec<_> = trait_bridges
.iter()
.filter_map(|config| {
let trait_name = config.trait_name.clone();
trait_defs
.iter()
.find(|t| t.name == trait_name)
.map(|trait_def| (trait_name, config, *trait_def))
})
.collect();
if !bridges.is_empty() {
out.push('\n');
out.push_str(&crate::trait_bridge::gen_native_methods_trait_bridges(
namespace, prefix, &bridges,
));
}
}
out.push_str("}\n");
out
}
pub(super) fn gen_pinvoke_for_func(
c_name: &str,
func: &FunctionDef,
bridge_param_names: &HashSet<String>,
bridge_type_aliases: &HashSet<String>,
) -> String {
let cs_name = to_csharp_name(&func.name);
let mut out =
format!(" [DllImport(LibName, CallingConvention = CallingConvention.Cdecl, EntryPoint = \"{c_name}\")]\n");
out.push_str(" internal static extern ");
out.push_str(pinvoke_return_type(&func.return_type));
out.push_str(&format!(" {}(", cs_name));
let visible_params: Vec<_> = func
.params
.iter()
.filter(|p| !is_bridge_param(p, bridge_param_names, bridge_type_aliases))
.collect();
if visible_params.is_empty() {
out.push_str(");\n\n");
} else {
out.push('\n');
for (i, param) in visible_params.iter().enumerate() {
out.push_str(" ");
let pinvoke_ty = pinvoke_param_type(¶m.ty);
if pinvoke_ty == "string" {
out.push_str("[MarshalAs(UnmanagedType.LPUTF8Str)] ");
}
let param_name = param.name.to_lower_camel_case();
out.push_str(&format!("{pinvoke_ty} {param_name}"));
if i < visible_params.len() - 1 {
out.push(',');
}
out.push('\n');
}
out.push_str(" );\n\n");
}
out
}
pub(super) fn gen_pinvoke_for_method(c_name: &str, cs_name: &str, method: &MethodDef) -> String {
let mut out =
format!(" [DllImport(LibName, CallingConvention = CallingConvention.Cdecl, EntryPoint = \"{c_name}\")]\n");
out.push_str(" internal static extern ");
out.push_str(pinvoke_return_type(&method.return_type));
out.push_str(&format!(" {}(", cs_name));
let has_receiver = !method.is_static && method.receiver.is_some();
if !has_receiver && method.params.is_empty() {
out.push_str(");\n\n");
} else {
out.push('\n');
let total = if has_receiver {
method.params.len() + 1
} else {
method.params.len()
};
let mut idx = 0usize;
if has_receiver {
out.push_str(" IntPtr handle");
if total > 1 {
out.push(',');
}
out.push('\n');
idx += 1;
}
for param in method.params.iter() {
out.push_str(" ");
let pinvoke_ty = pinvoke_param_type(¶m.ty);
if pinvoke_ty == "string" {
out.push_str("[MarshalAs(UnmanagedType.LPUTF8Str)] ");
}
let param_name = param.name.to_lower_camel_case();
out.push_str(&format!("{pinvoke_ty} {param_name}"));
if idx < total - 1 {
out.push(',');
}
out.push('\n');
idx += 1;
}
out.push_str(" );\n\n");
}
out
}