use crate::backends::swift::naming::bridge_protocol_name;
use crate::core::config::TraitBridgeConfig;
use crate::core::ir::{TypeDef, TypeRef};
use heck::{ToLowerCamelCase, ToSnakeCase};
use std::collections::HashSet;
pub fn gen_trait_bridge_files(
bridges: &[(String, &TraitBridgeConfig, &TypeDef)],
exclude_types: &HashSet<String>,
) -> Vec<(String, String)> {
let mut files = Vec::new();
for (trait_name, bridge_cfg, trait_def) in bridges {
if bridge_cfg.exclude_languages.iter().any(|lang| lang == "swift") {
continue;
}
if !matches!(bridge_cfg.bind_via, crate::core::config::BridgeBinding::FunctionParam) {
continue;
}
let content = gen_single_trait_bridge_file(trait_name, bridge_cfg, trait_def, exclude_types);
let protocol = bridge_protocol_name(trait_name);
let filename = format!("{protocol}.swift");
files.push((filename, content));
}
files
}
fn gen_single_trait_bridge_file(
trait_name: &str,
_bridge_cfg: &TraitBridgeConfig,
trait_def: &TypeDef,
exclude_types: &HashSet<String>,
) -> String {
let mut out = String::new();
out.push_str("// Generated by alef. Do not edit by hand.\n");
out.push_str("// swift-format-ignore-file\n");
out.push_str("// This file contains generated FFI glue for trait bridge registration.\n\n");
out.push_str("import Foundation\n");
out.push_str("import RustBridge\n\n");
let protocol = bridge_protocol_name(trait_name);
out.push_str(&format!(
"/// Protocol for outbound `{trait_name}` implementations.\n\
/// Conform your Swift class or struct to this protocol to implement\n\
/// a Rust trait from the host side.\n\
public protocol {protocol}: AnyObject {{\n"
));
for method in &trait_def.methods {
if method.has_default_impl {
continue;
}
let method_camel = method.name.to_lower_camel_case();
let params_sig = swift_method_params_native(&method.params, exclude_types);
let return_type = swift_return_type(&method.return_type, exclude_types);
let throws = if method.error_type.is_some() { " throws" } else { "" };
let async_kw = if method.is_async { " async" } else { "" };
out.push_str(&format!(
" func {method_camel}({params_sig}){async_kw}{throws} -> {return_type}\n"
));
}
out.push_str("}\n\n");
out.push_str(&format!(
"/// Internal adapter wrapping a `{protocol}` conformer.\n\
/// Marshals Swift types and trait calls to/from the C boundary.\n\
/// Excluded/internal types are serialised to/from JSON strings.\n\
final class Swift{trait_name}Adapter {{\n\
\x20 private let bridge: any {protocol}\n\n"
));
out.push_str(&format!(
" init(bridge: any {protocol}) {{\n\
\x20\x20\x20\x20self.bridge = bridge\n\
\x20 }}\n\n"
));
for method in &trait_def.methods {
if method.has_default_impl {
continue;
}
let method_camel = method.name.to_lower_camel_case();
let params_sig = swift_method_params(&method.params, exclude_types);
let return_type = if method.error_type.is_some() {
"String".to_string()
} else {
swift_return_type(&method.return_type, exclude_types)
};
let async_kw = if method.is_async { " async" } else { "" };
let throws_kw = if method.error_type.is_some() { " throws" } else { "" };
out.push_str(&format!(
" func {method_camel}Call({params_sig}){async_kw}{throws_kw} -> {return_type} {{\n"
));
let call_args = build_adapter_call_args(method);
let call_args_str = call_args.join(", ");
if method.error_type.is_some() {
let try_await = if method.is_async { "try await " } else { "try " };
out.push_str(&format!(
" do {{\n\
\x20\x20\x20\x20let result = {try_await}self.bridge.{method_camel}({call_args_str})\n"
));
if matches!(method.return_type, TypeRef::Unit) {
out.push_str(
" return marshal_ok_result(Empty())\n\
\x20\x20\x20\x20} catch {\n\
\x20\x20\x20\x20\x20\x20\x20\x20return marshal_error_result(error)\n\
\x20\x20\x20\x20}\n",
);
} else {
match &method.return_type {
TypeRef::String => {
out.push_str(
" return marshal_ok_result(result)\n\
\x20\x20\x20\x20} catch {\n\
\x20\x20\x20\x20\x20\x20\x20\x20return marshal_error_result(error)\n\
\x20\x20\x20\x20}\n",
);
}
TypeRef::Unit => {
out.push_str(
" return marshal_ok_result(Empty())\n\
\x20\x20\x20\x20} catch {\n\
\x20\x20\x20\x20\x20\x20\x20\x20return marshal_error_result(error)\n\
\x20\x20\x20\x20}\n",
);
}
TypeRef::Primitive(_) | TypeRef::Bytes | TypeRef::Char => {
out.push_str(
" return marshal_ok_result(result)\n\
\x20\x20\x20\x20} catch {\n\
\x20\x20\x20\x20\x20\x20\x20\x20return marshal_error_result(error)\n\
\x20\x20\x20\x20}\n",
);
}
TypeRef::Named(name) if exclude_types.contains(name) => {
out.push_str(
" let encodedData = try marshal_encode_excluded(result)\n\
\x20\x20\x20\x20if let jsonString = String(data: encodedData, encoding: .utf8) {\n\
\x20\x20\x20\x20\x20\x20\x20\x20return \"{\\\"ok\\\": \\(jsonString)}\"\n\
\x20\x20\x20\x20}\n\
\x20\x20\x20\x20return \"{\\\"ok\\\": null}\"\n\
\x20\x20\x20\x20} catch {\n\
\x20\x20\x20\x20\x20\x20\x20\x20return marshal_error_result(error)\n\
\x20\x20\x20\x20}\n",
);
}
_ => {
out.push_str(
" return marshal_ok_result(try JSONEncoder().encode(result))\n\
\x20\x20\x20\x20} catch {\n\
\x20\x20\x20\x20\x20\x20\x20\x20return marshal_error_result(error)\n\
\x20\x20\x20\x20}\n",
);
}
}
}
} else if method.is_async {
out.push_str(&format!(
" let result = await self.bridge.{method_camel}({call_args_str})\n"
));
out.push_str(" return result\n");
} else {
out.push_str(&format!(
" let result = self.bridge.{method_camel}({call_args_str})\n"
));
out.push_str(" return result\n");
}
out.push_str(" }\n\n");
}
out.push_str("}\n\n");
out.push_str("// MARK: - Marshalling helpers\n\n");
out.push_str(
"private struct Empty: Codable {}\n\n\
private func marshal_ok_result<T: Encodable>(_ value: T) -> String {\n\
\x20\x20\x20\x20let encoder = JSONEncoder()\n\
\x20\x20\x20\x20if let data = try? encoder.encode(value),\n\
\x20\x20\x20\x20 let jsonString = String(data: data, encoding: .utf8) {\n\
\x20\x20\x20\x20\x20\x20\x20\x20return \"{\\\"ok\\\": \\(jsonString)}\"\n\
\x20\x20\x20\x20}\n\
\x20\x20\x20\x20return \"{\\\"ok\\\": null}\"\n\
}\n\n\
private func marshal_encode_excluded<T: Encodable>(_ value: T) throws -> Data {\n\
\x20\x20\x20\x20let encoder = JSONEncoder()\n\
\x20\x20\x20\x20return try encoder.encode(value)\n\
}\n\n\
private func marshal_error_result(_ error: any Error) -> String {\n\
\x20\x20\x20\x20let errorString = String(describing: error)\n\
\x20\x20\x20\x20let encoder = JSONEncoder()\n\
\x20\x20\x20\x20if let data = try? encoder.encode(errorString),\n\
\x20\x20\x20\x20 let jsonString = String(data: data, encoding: .utf8) {\n\
\x20\x20\x20\x20\x20\x20\x20\x20return \"{\\\"err\\\": \\(jsonString)}\"\n\
\x20\x20\x20\x20}\n\
\x20\x20\x20\x20return \"{\\\"err\\\": \\\"unknown error\\\"}\"\n\
}\n\n",
);
out
}
fn swift_method_params_native(params: &[crate::core::ir::ParamDef], exclude_types: &HashSet<String>) -> String {
if params.is_empty() {
return String::new();
}
params
.iter()
.map(|p| {
let name = p.name.to_snake_case();
let ty = swift_type_name_native(&p.ty, exclude_types);
format!("{}: {}", name, ty)
})
.collect::<Vec<_>>()
.join(", ")
}
fn swift_method_params(params: &[crate::core::ir::ParamDef], exclude_types: &HashSet<String>) -> String {
if params.is_empty() {
return String::new();
}
params
.iter()
.map(|p| {
let name = p.name.to_snake_case();
let ty = swift_type_name(&p.ty, exclude_types);
format!("{}: {}", name, ty)
})
.collect::<Vec<_>>()
.join(", ")
}
#[allow(clippy::only_used_in_recursion)]
fn swift_type_name_native(ty: &TypeRef, exclude_types: &HashSet<String>) -> String {
match ty {
TypeRef::Primitive(p) => match p {
crate::core::ir::PrimitiveType::Bool => "Bool".to_string(),
crate::core::ir::PrimitiveType::I8 => "Int8".to_string(),
crate::core::ir::PrimitiveType::I16 => "Int16".to_string(),
crate::core::ir::PrimitiveType::I32 => "Int32".to_string(),
crate::core::ir::PrimitiveType::I64 => "Int64".to_string(),
crate::core::ir::PrimitiveType::U8 => "UInt8".to_string(),
crate::core::ir::PrimitiveType::U16 => "UInt16".to_string(),
crate::core::ir::PrimitiveType::U32 => "UInt32".to_string(),
crate::core::ir::PrimitiveType::U64 => "UInt64".to_string(),
crate::core::ir::PrimitiveType::Usize => "Int".to_string(),
crate::core::ir::PrimitiveType::Isize => "Int".to_string(),
crate::core::ir::PrimitiveType::F32 => "Float".to_string(),
crate::core::ir::PrimitiveType::F64 => "Double".to_string(),
},
TypeRef::String => "String".to_string(),
TypeRef::Bytes => "Data".to_string(),
TypeRef::Path => "URL".to_string(),
TypeRef::Char => "Character".to_string(),
TypeRef::Named(name) => {
name.clone()
}
TypeRef::Vec(inner) => format!("[{}]", swift_type_name_native(inner, exclude_types)),
TypeRef::Map(k, v) => format!(
"[{}: {}]",
swift_type_name_native(k, exclude_types),
swift_type_name_native(v, exclude_types)
),
TypeRef::Optional(inner) => format!("{}?", swift_type_name_native(inner, exclude_types)),
TypeRef::Unit => "Void".to_string(),
TypeRef::Json => "String".to_string(),
TypeRef::Duration => "TimeInterval".to_string(),
}
}
fn swift_type_name(ty: &TypeRef, exclude_types: &HashSet<String>) -> String {
match ty {
TypeRef::Primitive(p) => match p {
crate::core::ir::PrimitiveType::Bool => "Bool".to_string(),
crate::core::ir::PrimitiveType::I8 => "Int8".to_string(),
crate::core::ir::PrimitiveType::I16 => "Int16".to_string(),
crate::core::ir::PrimitiveType::I32 => "Int32".to_string(),
crate::core::ir::PrimitiveType::I64 => "Int64".to_string(),
crate::core::ir::PrimitiveType::U8 => "UInt8".to_string(),
crate::core::ir::PrimitiveType::U16 => "UInt16".to_string(),
crate::core::ir::PrimitiveType::U32 => "UInt32".to_string(),
crate::core::ir::PrimitiveType::U64 => "UInt64".to_string(),
crate::core::ir::PrimitiveType::Usize => "Int".to_string(), crate::core::ir::PrimitiveType::Isize => "Int".to_string(), crate::core::ir::PrimitiveType::F32 => "Float".to_string(),
crate::core::ir::PrimitiveType::F64 => "Double".to_string(),
},
TypeRef::String => "String".to_string(),
TypeRef::Bytes => "Data".to_string(),
TypeRef::Path => "URL".to_string(),
TypeRef::Char => "Character".to_string(),
TypeRef::Named(name) => {
if exclude_types.contains(name) {
"String".to_string() } else {
name.clone()
}
}
TypeRef::Vec(inner) => format!("[{}]", swift_type_name(inner, exclude_types)),
TypeRef::Map(k, v) => format!(
"[{}: {}]",
swift_type_name(k, exclude_types),
swift_type_name(v, exclude_types)
),
TypeRef::Optional(inner) => format!("{}?", swift_type_name(inner, exclude_types)),
TypeRef::Unit => "Void".to_string(),
TypeRef::Json => "String".to_string(), TypeRef::Duration => "TimeInterval".to_string(), }
}
fn swift_return_type(ty: &TypeRef, exclude_types: &HashSet<String>) -> String {
swift_type_name(ty, exclude_types)
}
fn build_adapter_call_args(method: &crate::core::ir::MethodDef) -> Vec<String> {
method
.params
.iter()
.map(|p| {
let name = p.name.to_snake_case();
format!("{}: {}", name, name)
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::config::BridgeBinding;
fn make_trait_def(name: &str) -> TypeDef {
TypeDef {
name: name.to_string(),
rust_path: format!("testcrate::{}", name),
original_rust_path: String::new(),
fields: vec![],
methods: vec![],
is_opaque: false,
is_clone: true,
is_copy: false,
is_trait: true,
has_default: false,
has_stripped_cfg_fields: false,
is_return_type: false,
serde_rename_all: None,
has_serde: false,
super_traits: vec![],
doc: String::new(),
cfg: None,
binding_excluded: false,
binding_exclusion_reason: None,
}
}
fn make_bridge_cfg(trait_name: &str) -> TraitBridgeConfig {
TraitBridgeConfig {
trait_name: trait_name.to_string(),
param_name: None,
type_alias: None,
exclude_languages: vec![],
super_trait: None,
registry_getter: None,
register_fn: Some(format!("register{}", trait_name)),
unregister_fn: None,
clear_fn: None,
register_extra_args: None,
bind_via: BridgeBinding::FunctionParam,
options_type: None,
options_field: None,
context_type: None,
result_type: None,
ffi_skip_methods: Vec::new(),
}
}
#[test]
fn test_trait_bridge_protocol_generated() {
let trait_def = make_trait_def("OcrBackend");
let bridge_cfg = make_bridge_cfg("OcrBackend");
let bridges = vec![("OcrBackend".to_string(), &bridge_cfg, &trait_def)];
let exclude_types = HashSet::new();
let files = gen_trait_bridge_files(&bridges, &exclude_types);
assert_eq!(files.len(), 1);
assert_eq!(files[0].0, "SwiftOcrBackendBridge.swift");
assert!(files[0].1.contains("protocol SwiftOcrBackendBridge"));
}
#[test]
fn test_trait_bridge_excludes_swift_language() {
let trait_def = make_trait_def("OcrBackend");
let mut bridge_cfg = make_bridge_cfg("OcrBackend");
bridge_cfg.exclude_languages = vec!["swift".to_string()];
let bridges = vec![("OcrBackend".to_string(), &bridge_cfg, &trait_def)];
let exclude_types = HashSet::new();
let files = gen_trait_bridge_files(&bridges, &exclude_types);
assert!(files.is_empty());
}
#[test]
fn test_trait_bridge_skips_non_function_param() {
let trait_def = make_trait_def("OcrBackend");
let mut bridge_cfg = make_bridge_cfg("OcrBackend");
bridge_cfg.bind_via = BridgeBinding::OptionsField;
let bridges = vec![("OcrBackend".to_string(), &bridge_cfg, &trait_def)];
let exclude_types = HashSet::new();
let files = gen_trait_bridge_files(&bridges, &exclude_types);
assert!(files.is_empty());
}
#[test]
fn test_swift_type_mapping() {
use crate::core::ir::PrimitiveType;
let exclude_types = HashSet::new();
assert_eq!(swift_type_name(&TypeRef::String, &exclude_types), "String");
assert_eq!(swift_type_name(&TypeRef::Bytes, &exclude_types), "Data");
assert_eq!(swift_type_name(&TypeRef::Unit, &exclude_types), "Void");
assert_eq!(
swift_type_name(&TypeRef::Primitive(PrimitiveType::I32), &exclude_types),
"Int32"
);
assert_eq!(swift_type_name(&TypeRef::Duration, &exclude_types), "TimeInterval");
}
#[test]
fn test_swift_marshals_excluded_types_as_json() {
let mut exclude_types = HashSet::new();
exclude_types.insert("InternalDocument".to_string());
assert_eq!(
swift_type_name(&TypeRef::Named("InternalDocument".to_string()), &exclude_types),
"String",
"Excluded types should be marshalled as JSON strings"
);
assert_eq!(
swift_type_name(&TypeRef::Named("ExtractionResult".to_string()), &exclude_types),
"ExtractionResult",
"Non-excluded types should keep their original names"
);
}
#[test]
fn test_no_register_fn_in_trait_bridge_file() {
let trait_def = make_trait_def("DocumentExtractor");
let bridge_cfg = make_bridge_cfg("DocumentExtractor");
let bridges = vec![("DocumentExtractor".to_string(), &bridge_cfg, &trait_def)];
let exclude_types = HashSet::new();
let files = gen_trait_bridge_files(&bridges, &exclude_types);
assert_eq!(files.len(), 1);
let content = &files[0].1;
assert!(
content.contains("protocol SwiftDocumentExtractorBridge"),
"protocol must be emitted, got:\n{content}"
);
assert!(
content.contains("final class SwiftDocumentExtractorAdapter"),
"adapter class must be emitted, got:\n{content}"
);
assert!(
!content.contains("public func registerDocumentExtractor("),
"register function must NOT be emitted in the bridge file (would use wrong Adapter type), got:\n{content}"
);
}
#[test]
fn test_excluded_type_in_method_becomes_string() {
use crate::core::ir::{MethodDef, ParamDef, ReceiverKind};
let mut trait_def = make_trait_def("DocumentExtractor");
trait_def.methods.push(MethodDef {
name: "extract_bytes".to_string(),
params: vec![ParamDef {
name: "content".to_string(),
ty: TypeRef::Bytes,
optional: false,
default: None,
sanitized: false,
typed_default: None,
is_ref: false,
is_mut: false,
newtype_wrapper: None,
original_type: None,
map_is_ahash: false,
map_key_is_cow: false,
}],
return_type: TypeRef::Named("InternalDocument".to_string()),
is_async: false,
is_static: false,
error_type: Some("Error".to_string()),
doc: String::new(),
receiver: Some(ReceiverKind::Ref),
sanitized: false,
trait_source: None,
returns_ref: false,
returns_cow: false,
return_newtype_wrapper: None,
has_default_impl: false,
binding_excluded: false,
binding_exclusion_reason: None,
});
let bridge_cfg = make_bridge_cfg("DocumentExtractor");
let bridges = vec![("DocumentExtractor".to_string(), &bridge_cfg, &trait_def)];
let mut exclude_types = HashSet::new();
exclude_types.insert("InternalDocument".to_string());
let files = gen_trait_bridge_files(&bridges, &exclude_types);
assert_eq!(files.len(), 1);
let content = &files[0].1;
assert!(
content.contains("func extractBytes(content: Data) throws -> String"),
"protocol method must marshal excluded type to String, got:\n{content}"
);
assert!(
content.contains("func extractBytesCall(content: Data) throws -> String"),
"adapter method must marshal to String, got:\n{content}"
);
assert!(
content.contains("marshal_encode_excluded"),
"marshal_encode_excluded helper must be present, got:\n{content}"
);
}
}