trait_ffi/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use convert_case::Casing;
4use proc_macro::TokenStream;
5use proc_macro2::Span;
6use quote::{format_ident, quote};
7use syn::{Ident, ItemImpl, ItemTrait, parse_macro_input, spanned::Spanned};
8
9macro_rules! bail {
10    ($i:expr, $msg:expr) => {
11        return syn::parse::Error::new($i, $msg).to_compile_error().into();
12    };
13}
14
15fn get_crate_name() -> String {
16    std::env::var("CARGO_PKG_NAME").unwrap_or_else(|_| "unknown".to_string())
17}
18
19fn get_crate_version() -> String {
20    std::env::var("CARGO_PKG_VERSION").unwrap_or_else(|_| "0.1.0".to_string())
21}
22
23fn prefix_version() -> String {
24    let version = lenient_semver::parse(&get_crate_version()).unwrap();
25    let major = version.major;
26    let minor = version.minor;
27    if major == 0 {
28        format!("0_{minor}")
29    } else {
30        major.to_string()
31    }
32}
33
34fn extern_fn_name(crate_name: &str, fn_name: &Ident) -> Ident {
35    let crate_name = crate_name.to_lowercase().replace("-", "_");
36    // let version = prefix_version();
37
38    format_ident!("__{crate_name}_{fn_name}")
39}
40
41fn parse_def_extern_trait_args(
42    args: TokenStream,
43) -> Result<(String, bool, Option<String>), String> {
44    if args.is_empty() {
45        return Ok(("rust".to_string(), false, None)); // 默认使用 Rust ABI,默认生成 impl_trait! 宏,无自定义模块路径
46    }
47
48    let args_str = args.to_string();
49    let mut abi = None;
50    let mut not_def_impl = false;
51    let mut mod_path = None;
52
53    // 简单解析 abi="value"、not_def_impl 和 mod_path="value" 形式
54    let parts: Vec<&str> = args_str.split(',').collect();
55
56    for part in parts {
57        let part = part.trim();
58        if part.starts_with("abi")
59            && let Some(start) = part.find('"')
60            && let Some(end) = part.rfind('"')
61            && start < end
62        {
63            abi = Some(part[start + 1..end].to_string());
64        } else if part.starts_with("mod_path")
65            && let Some(start) = part.find('"')
66            && let Some(end) = part.rfind('"')
67            && start < end
68        {
69            mod_path = Some(part[start + 1..end].to_string());
70        } else if part == "not_def_impl" {
71            not_def_impl = true;
72        }
73    }
74
75    let abi = abi.unwrap_or_else(|| "rust".to_string());
76
77    if abi != "c" && abi != "rust" {
78        return Err("Invalid abi parameter. Supported values: \"c\", \"rust\"".to_string());
79    }
80
81    Ok((abi, not_def_impl, mod_path))
82}
83
84/// Defines an extern trait that can be called across FFI boundaries.
85///
86/// This macro converts a regular Rust trait into a trait that can be called through FFI.
87/// It generates:
88/// 1. The original trait definition
89/// 2. A module containing wrapper functions that call external implementations
90/// 3. Optionally, a helper macro `impl_trait!` for implementing the trait (unless `not_def_impl` is specified)
91/// 4. A checker function to ensure the trait is properly implemented
92///
93/// # Arguments
94/// - `abi`: Optional parameter specifying ABI type ("c" or "rust"), defaults to "rust"
95/// - `not_def_impl`: Optional parameter to skip generating the `impl_trait!` macro
96///
97/// # Example
98/// ```rust
99/// #[def_extern_trait(abi = "c")]
100/// trait Calculator {
101///     fn add(&self, a: i32, b: i32) -> i32;
102///     fn multiply(&self, a: i32, b: i32) -> i32;
103/// }
104///
105/// // Skip generating impl_trait! macro
106/// #[def_extern_trait(abi = "c", not_def_impl)]
107/// trait Calculator2 {
108///     fn add(&self, a: i32, b: i32) -> i32;
109/// }
110/// ```
111///
112/// This will generate a `calculator` module containing functions that can call external implementations.
113#[proc_macro_attribute]
114pub fn def_extern_trait(args: TokenStream, input: TokenStream) -> TokenStream {
115    let (abi, not_def_impl, _mod_path) = match parse_def_extern_trait_args(args) {
116        Ok((abi, not_def_impl, mod_path)) => (abi, not_def_impl, mod_path),
117        Err(error_msg) => {
118            bail!(Span::call_site(), error_msg);
119        }
120    };
121
122    let input = parse_macro_input!(input as ItemTrait);
123    let vis = input.vis.clone();
124    let mod_name = format_ident!(
125        "{}",
126        input.ident.to_string().to_case(convert_case::Case::Snake)
127    );
128    let crate_name_str = get_crate_name();
129
130    let mut fn_list = vec![];
131    let crate_name = format_ident!("{}", crate_name_str.replace("-", "_"));
132    let mut crate_path_tokens = quote! { #crate_name };
133    if let Some(mod_path) = _mod_path {
134        // 解析 mod_path 并生成路径tokens
135        let path_segments: Vec<&str> = mod_path.split("::").collect();
136        let path_idents: Vec<proc_macro2::Ident> = path_segments
137            .iter()
138            .map(|segment| format_ident!("{}", segment))
139            .collect();
140        crate_path_tokens = quote! { #crate_name::#(#path_idents)::* };
141    }
142
143    let crate_name_version = format!("{}_{}", crate_name_str, prefix_version());
144
145    for item in &input.items {
146        if let syn::TraitItem::Fn(func) = item {
147            let fn_name = func.sig.ident.clone();
148            let extern_fn_name = extern_fn_name(&crate_name_version, &fn_name);
149
150            let attrs = &func.attrs;
151            let inputs = &func.sig.inputs;
152            let output = &func.sig.output;
153            let generics = &func.sig.generics;
154
155            let mut param_names = vec![];
156            let mut param_types = vec![];
157
158            for input in inputs {
159                if let syn::FnArg::Typed(pat_type) = input {
160                    param_names.push(&pat_type.pat);
161                    param_types.push(&pat_type.ty);
162                }
163            }
164
165            let extern_abi = if abi == "rust" { "Rust" } else { "C" };
166
167            fn_list.push(quote! {
168                #(#attrs)*
169                pub fn #fn_name #generics (#inputs) #output {
170                    unsafe extern #extern_abi {
171                        fn #extern_fn_name #generics (#inputs) #output;
172                    }
173                    unsafe{ #extern_fn_name(#(#param_names),*) }
174                }
175            });
176        } else {
177            bail!(
178                item.span(),
179                "Only function items are allowed in extern traits"
180            );
181        }
182    }
183
184    let warn_fn_name = format_ident!(
185        "Trait_{}_in_crate_{}_{}_need_impl",
186        input.ident,
187        crate_name_str.replace("-", "_"),
188        prefix_version()
189    );
190
191    let generated_macro = if not_def_impl {
192        quote! {}
193    } else {
194        quote! {
195            pub use trait_ffi::impl_extern_trait;
196
197            #[macro_export]
198            macro_rules! impl_trait {
199                (impl $trait:ident for $type:ty { $($body:tt)* }) => {
200                    #[#crate_path_tokens::impl_extern_trait(name = #crate_name_version, abi = #abi)]
201                    impl $trait for $type {
202                        $($body)*
203                    }
204
205                    #[allow(snake_case)]
206                    #[unsafe(no_mangle)]
207                    extern "C" fn #warn_fn_name() { }
208                };
209            }
210        }
211    };
212
213    quote! {
214        #input
215
216        #vis mod #mod_name {
217            use super::*;
218            /// `trait-ffi` generated.
219            pub fn ____checker_do_not_use(){
220                unsafe extern "C" {
221                    fn #warn_fn_name();
222                }
223                unsafe { #warn_fn_name() };
224            }
225            #(#fn_list)*
226        }
227
228        #generated_macro
229    }
230    .into()
231}
232
233fn parse_extern_trait_args(args: TokenStream) -> Result<(String, String), String> {
234    if args.is_empty() {
235        return Err(
236            "Missing parameters. Usage: #[impl_extern_trait(name=\"crate_name\", abi=\"c\")]"
237                .to_string(),
238        );
239    }
240
241    let args_str = args.to_string();
242    let mut name = None;
243    let mut abi = None;
244
245    let parts: Vec<&str> = args_str.split(',').collect();
246
247    for part in parts {
248        let part = part.trim();
249        if part.starts_with("name") {
250            if let Some(start) = part.find('"')
251                && let Some(end) = part.rfind('"')
252                && start < end
253            {
254                name = Some(part[start + 1..end].to_string());
255            }
256        } else if part.starts_with("abi")
257            && let Some(start) = part.find('"')
258            && let Some(end) = part.rfind('"')
259            && start < end
260        {
261            abi = Some(part[start + 1..end].to_string());
262        }
263    }
264
265    let name = name.ok_or_else(|| {
266        "Missing name parameter. Usage: #[impl_extern_trait(name=\"crate_name\", abi=\"c\")]"
267            .to_string()
268    })?;
269    let abi = abi.unwrap_or_else(|| "c".to_string());
270
271    if abi != "c" && abi != "rust" {
272        return Err("Invalid abi parameter. Supported values: \"c\", \"rust\"".to_string());
273    }
274
275    Ok((name, abi))
276}
277
278/// Implements an extern trait for a type and generates corresponding C function exports.
279///
280/// This macro takes a trait implementation and generates extern "C" functions that can be
281/// called from other languages. Each method in the trait implementation gets a corresponding
282/// extern function with a mangled name based on the crate name and version.
283///
284/// # Arguments
285/// - `name`: The name of the crate that defines the extern trait
286/// - `abi`: The ABI to use for the extern functions ("c" or "rust"), defaults to "c"
287///
288/// # Example
289/// ```rust
290/// struct Calculator;
291///
292/// #[impl_extern_trait(name = "calculator_crate", abi = "c")]
293/// impl MyTrait for Calculator {
294///     fn add(&self, a: i32, b: i32) -> i32 {
295///         a + b
296///     }
297/// }
298/// ```
299///
300/// This will generate extern "C" functions that can be called from other languages.
301#[proc_macro_attribute]
302pub fn impl_extern_trait(args: TokenStream, input: TokenStream) -> TokenStream {
303    let (crate_name_str, abi) = match parse_extern_trait_args(args) {
304        Ok((name, abi)) => (name, abi),
305        Err(error_msg) => {
306            bail!(Span::call_site(), error_msg);
307        }
308    };
309    let input = parse_macro_input!(input as ItemImpl);
310    let mut extern_fn_list = vec![];
311
312    let struct_name = input.self_ty.clone();
313    let trait_name = input.clone().trait_.unwrap().1;
314
315    for item in &input.items {
316        if let syn::ImplItem::Fn(func) = item {
317            let fn_name_raw = &func.sig.ident;
318            let fn_name = extern_fn_name(&crate_name_str, fn_name_raw);
319
320            let inputs = &func.sig.inputs;
321            let output = &func.sig.output;
322            let generics = &func.sig.generics;
323
324            let extern_abi = if abi == "rust" { "Rust" } else { "C" };
325
326            let mut param_names = vec![];
327            let mut param_types = vec![];
328
329            for input in inputs {
330                if let syn::FnArg::Typed(pat_type) = input {
331                    param_names.push(&pat_type.pat);
332                    param_types.push(&pat_type.ty);
333                }
334            }
335
336            extern_fn_list.push(quote! {
337                /// `trait-ffi` generated extern function.
338                #[unsafe(no_mangle)]
339                pub extern #extern_abi fn #fn_name #generics (#inputs) #output {
340                    <#struct_name as #trait_name>::#fn_name_raw(#(#param_names),*)
341                }
342            });
343        }
344    }
345
346    quote! {
347        #input
348        #(#extern_fn_list)*
349    }
350    .into()
351}