use crate::backends::swift::naming::bridge_protocol_name;
use crate::core::config::TraitBridgeConfig;
use crate::core::ir::{TypeDef, TypeRef};
use heck::{ToLowerCamelCase, ToSnakeCase};
use std::collections::HashSet;
pub fn gen_trait_bridge_files(
bridges: &[(String, &TraitBridgeConfig, &TypeDef)],
exclude_types: &HashSet<String>,
) -> Vec<(String, String)> {
let mut files = Vec::new();
for (trait_name, bridge_cfg, trait_def) in bridges {
if bridge_cfg.exclude_languages.iter().any(|lang| lang == "swift") {
continue;
}
if !matches!(bridge_cfg.bind_via, crate::core::config::BridgeBinding::FunctionParam) {
continue;
}
let content = gen_single_trait_bridge_file(trait_name, bridge_cfg, trait_def, exclude_types);
let protocol = bridge_protocol_name(trait_name);
let filename = format!("{protocol}.swift");
files.push((filename, content));
}
files
}
fn gen_single_trait_bridge_file(
trait_name: &str,
bridge_cfg: &TraitBridgeConfig,
trait_def: &TypeDef,
exclude_types: &HashSet<String>,
) -> String {
let mut out = String::new();
out.push_str("// Generated by alef. Do not edit by hand.\n");
out.push_str("// swift-format-ignore-file\n");
out.push_str("// This file contains generated FFI glue for trait bridge registration.\n\n");
out.push_str("import Foundation\n");
out.push_str("import RustBridge\n\n");
let protocol = bridge_protocol_name(trait_name);
out.push_str(&format!(
"/// Protocol for outbound `{trait_name}` implementations.\n\
/// Conform your Swift class or struct to this protocol to implement\n\
/// a Rust trait from the host side.\n\
public protocol {protocol}: AnyObject {{\n"
));
for method in &trait_def.methods {
if method.has_default_impl {
continue;
}
let method_camel = method.name.to_lower_camel_case();
let params_sig = swift_method_params(&method.params, exclude_types);
let return_type = swift_return_type(&method.return_type, exclude_types);
let throws = if method.error_type.is_some() { " throws" } else { "" };
let async_kw = if method.is_async { " async" } else { "" };
out.push_str(&format!(
" func {method_camel}({params_sig}){async_kw}{throws} -> {return_type}\n"
));
}
out.push_str("}\n\n");
out.push_str(&format!(
"/// Internal adapter wrapping a `{protocol}` conformer.\n\
/// Marshals Swift types and trait calls to/from the C boundary.\n\
/// Excluded/internal types are serialised to/from JSON strings.\n\
final class Swift{trait_name}Adapter {{\n\
\x20 private let bridge: any {protocol}\n\n"
));
out.push_str(&format!(
" init(bridge: any {protocol}) {{\n\
\x20\x20\x20\x20self.bridge = bridge\n\
\x20 }}\n\n"
));
for method in &trait_def.methods {
if method.has_default_impl {
continue;
}
let method_camel = method.name.to_lower_camel_case();
let params_sig = swift_method_params(&method.params, exclude_types);
let return_type = swift_return_type(&method.return_type, exclude_types);
out.push_str(&format!(
" func {method_camel}Call({params_sig}) -> {return_type} {{\n"
));
let (call_args, call_expr) = build_adapter_call_expr(method, exclude_types);
let call_args_str = call_args.join(", ");
if method.error_type.is_some() {
out.push_str(&format!(
" do {{\n\
\x20\x20\x20\x20\x20\x20\x20\x20let result = try self.bridge.{method_camel}({call_args_str})\n"
));
out.push_str(&format!(
" return marshal_ok_result({call_expr})\n\
\x20\x20\x20\x20}} catch {{\n\
\x20\x20\x20\x20\x20\x20\x20\x20return marshal_error_result(error)\n\
\x20\x20\x20\x20}}\n"
));
} else if method.is_async {
out.push_str(&format!(
" let result = await self.bridge.{method_camel}({call_args_str})\n"
));
out.push_str(&format!(
" return {call_expr}\n"
));
} else {
out.push_str(&format!(
" let result = self.bridge.{method_camel}({call_args_str})\n"
));
out.push_str(&format!(
" return {call_expr}\n"
));
}
out.push_str(" }\n\n");
}
out.push_str("}\n\n");
out.push_str("// MARK: - Marshalling helpers\n\n");
out.push_str(
"private func marshal_ok_result<T: Encodable>(_ value: T) -> String {\n\
\x20\x20\x20\x20let encoder = JSONEncoder()\n\
\x20\x20\x20\x20if let data = try? encoder.encode(value),\n\
\x20\x20\x20\x20 let jsonString = String(data: data, encoding: .utf8) {\n\
\x20\x20\x20\x20\x20\x20\x20\x20return \"{\\\"ok\\\": \\(jsonString)}\"\n\
\x20\x20\x20\x20}\n\
\x20\x20\x20\x20return \"{\\\"ok\\\": null}\"\n\
}\n\n\
private func marshal_error_result(_ error: any Error) -> String {\n\
\x20\x20\x20\x20let errorString = String(describing: error)\n\
\x20\x20\x20\x20let encoder = JSONEncoder()\n\
\x20\x20\x20\x20if let data = try? encoder.encode(errorString),\n\
\x20\x20\x20\x20 let jsonString = String(data: data, encoding: .utf8) {\n\
\x20\x20\x20\x20\x20\x20\x20\x20return \"{\\\"err\\\": \\(jsonString)}\"\n\
\x20\x20\x20\x20}\n\
\x20\x20\x20\x20return \"{\\\"err\\\": \\\"unknown error\\\"}\"\n\
}\n\n"
);
if let Some(register_fn) = bridge_cfg.register_fn.as_deref() {
let camel = register_fn.to_lower_camel_case();
out.push_str(&format!(
"/// Register an outbound `{trait_name}` plugin.\n\
/// Pass an instance conforming to `{protocol}`.\n\
public func {camel}(_ bridge: any {protocol}) throws {{\n\
\x20 let adapter = Swift{trait_name}Adapter(bridge: bridge)\n\
\x20 // Call into Rust to register the adapter\n\
\x20 try RustBridge.{camel}(adapter)\n\
}}\n"
));
}
out
}
fn swift_method_params(params: &[crate::core::ir::ParamDef], exclude_types: &HashSet<String>) -> String {
if params.is_empty() {
return String::new();
}
params
.iter()
.map(|p| {
let name = p.name.to_snake_case();
let ty = swift_type_name(&p.ty, exclude_types);
format!("{}: {}", name, ty)
})
.collect::<Vec<_>>()
.join(", ")
}
fn swift_type_name(ty: &TypeRef, exclude_types: &HashSet<String>) -> String {
match ty {
TypeRef::Primitive(p) => match p {
crate::core::ir::PrimitiveType::Bool => "Bool".to_string(),
crate::core::ir::PrimitiveType::I8 => "Int8".to_string(),
crate::core::ir::PrimitiveType::I16 => "Int16".to_string(),
crate::core::ir::PrimitiveType::I32 => "Int32".to_string(),
crate::core::ir::PrimitiveType::I64 => "Int64".to_string(),
crate::core::ir::PrimitiveType::U8 => "UInt8".to_string(),
crate::core::ir::PrimitiveType::U16 => "UInt16".to_string(),
crate::core::ir::PrimitiveType::U32 => "UInt32".to_string(),
crate::core::ir::PrimitiveType::U64 => "UInt64".to_string(),
crate::core::ir::PrimitiveType::Usize => "Int".to_string(), crate::core::ir::PrimitiveType::Isize => "Int".to_string(), crate::core::ir::PrimitiveType::F32 => "Float".to_string(),
crate::core::ir::PrimitiveType::F64 => "Double".to_string(),
},
TypeRef::String => "String".to_string(),
TypeRef::Bytes => "Data".to_string(),
TypeRef::Path => "URL".to_string(),
TypeRef::Char => "Character".to_string(),
TypeRef::Named(name) => {
if exclude_types.contains(name) {
"String".to_string() } else {
name.clone()
}
}
TypeRef::Vec(inner) => format!("[{}]", swift_type_name(inner, exclude_types)),
TypeRef::Map(k, v) => format!("[{}: {}]", swift_type_name(k, exclude_types), swift_type_name(v, exclude_types)),
TypeRef::Optional(inner) => format!("{}?", swift_type_name(inner, exclude_types)),
TypeRef::Unit => "Void".to_string(),
TypeRef::Json => "String".to_string(), TypeRef::Duration => "TimeInterval".to_string(), }
}
fn swift_return_type(ty: &TypeRef, exclude_types: &HashSet<String>) -> String {
swift_type_name(ty, exclude_types)
}
fn build_adapter_call_expr(
method: &crate::core::ir::MethodDef,
exclude_types: &HashSet<String>,
) -> (Vec<String>, String) {
let call_args: Vec<String> = method
.params
.iter()
.map(|p| p.name.to_snake_case())
.collect();
let return_expr = match &method.return_type {
TypeRef::Named(name) if exclude_types.contains(name) => {
"try JSONEncoder().encode(result)...".to_string() }
TypeRef::String | TypeRef::Bytes | TypeRef::Primitive(_) | TypeRef::Unit => {
"result".to_string()
}
_ => "result".to_string(), };
(call_args, return_expr)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::config::BridgeBinding;
fn make_trait_def(name: &str) -> TypeDef {
TypeDef {
name: name.to_string(),
rust_path: format!("testcrate::{}", name),
original_rust_path: String::new(),
fields: vec![],
methods: vec![],
is_opaque: false,
is_clone: true,
is_copy: false,
is_trait: true,
has_default: false,
has_stripped_cfg_fields: false,
is_return_type: false,
serde_rename_all: None,
has_serde: false,
super_traits: vec![],
doc: String::new(),
cfg: None,
binding_excluded: false,
binding_exclusion_reason: None,
}
}
fn make_bridge_cfg(trait_name: &str) -> TraitBridgeConfig {
TraitBridgeConfig {
trait_name: trait_name.to_string(),
param_name: None,
type_alias: None,
exclude_languages: vec![],
super_trait: None,
registry_getter: None,
register_fn: Some(format!("register{}", trait_name)),
unregister_fn: None,
clear_fn: None,
register_extra_args: None,
bind_via: BridgeBinding::FunctionParam,
options_type: None,
options_field: None,
context_type: None,
result_type: None,
ffi_skip_methods: Vec::new(),
}
}
#[test]
fn test_trait_bridge_protocol_generated() {
let trait_def = make_trait_def("OcrBackend");
let bridge_cfg = make_bridge_cfg("OcrBackend");
let bridges = vec![("OcrBackend".to_string(), &bridge_cfg, &trait_def)];
let exclude_types = HashSet::new();
let files = gen_trait_bridge_files(&bridges, &exclude_types);
assert_eq!(files.len(), 1);
assert_eq!(files[0].0, "SwiftOcrBackendBridge.swift");
assert!(files[0].1.contains("protocol SwiftOcrBackendBridge"));
}
#[test]
fn test_trait_bridge_excludes_swift_language() {
let trait_def = make_trait_def("OcrBackend");
let mut bridge_cfg = make_bridge_cfg("OcrBackend");
bridge_cfg.exclude_languages = vec!["swift".to_string()];
let bridges = vec![("OcrBackend".to_string(), &bridge_cfg, &trait_def)];
let exclude_types = HashSet::new();
let files = gen_trait_bridge_files(&bridges, &exclude_types);
assert!(files.is_empty());
}
#[test]
fn test_trait_bridge_skips_non_function_param() {
let trait_def = make_trait_def("OcrBackend");
let mut bridge_cfg = make_bridge_cfg("OcrBackend");
bridge_cfg.bind_via = BridgeBinding::OptionsField;
let bridges = vec![("OcrBackend".to_string(), &bridge_cfg, &trait_def)];
let exclude_types = HashSet::new();
let files = gen_trait_bridge_files(&bridges, &exclude_types);
assert!(files.is_empty());
}
#[test]
fn test_swift_type_mapping() {
use crate::core::ir::PrimitiveType;
let exclude_types = HashSet::new();
assert_eq!(swift_type_name(&TypeRef::String, &exclude_types), "String");
assert_eq!(swift_type_name(&TypeRef::Bytes, &exclude_types), "Data");
assert_eq!(swift_type_name(&TypeRef::Unit, &exclude_types), "Void");
assert_eq!(swift_type_name(&TypeRef::Primitive(PrimitiveType::I32), &exclude_types), "Int32");
assert_eq!(swift_type_name(&TypeRef::Duration, &exclude_types), "TimeInterval");
}
#[test]
fn test_swift_marshals_excluded_types_as_json() {
let mut exclude_types = HashSet::new();
exclude_types.insert("InternalDocument".to_string());
assert_eq!(
swift_type_name(&TypeRef::Named("InternalDocument".to_string()), &exclude_types),
"String",
"Excluded types should be marshalled as JSON strings"
);
assert_eq!(
swift_type_name(&TypeRef::Named("ExtractionResult".to_string()), &exclude_types),
"ExtractionResult",
"Non-excluded types should keep their original names"
);
}
}