trait_ffi/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use quote::{format_ident, quote};
4use syn::{ItemImpl, ItemTrait, parse_macro_input, spanned::Spanned};
5
6macro_rules! bail {
7    ($i:expr, $msg:expr) => {
8        return syn::parse::Error::new($i, $msg).to_compile_error().into();
9    };
10}
11
12fn get_crate_name() -> String {
13    std::env::var("CARGO_PKG_NAME").unwrap_or_else(|_| "unknown".to_string())
14}
15
16fn parse_def_extern_trait_args(args: TokenStream) -> Result<String, String> {
17    if args.is_empty() {
18        return Ok("rust".to_string()); // 默认使用 Rust ABI
19    }
20
21    let args_str = args.to_string();
22    let mut abi = None;
23
24    // 简单解析 abi="value" 形式
25    let parts: Vec<&str> = args_str.split(',').collect();
26
27    for part in parts {
28        let part = part.trim();
29        if part.starts_with("abi") {
30            if let Some(start) = part.find('"') {
31                if let Some(end) = part.rfind('"') {
32                    if start < end {
33                        abi = Some(part[start + 1..end].to_string());
34                    }
35                }
36            }
37        }
38    }
39
40    let abi = abi.unwrap_or_else(|| "rust".to_string());
41
42    if abi != "c" && abi != "rust" {
43        return Err("Invalid abi parameter. Supported values: \"c\", \"rust\"".to_string());
44    }
45
46    Ok(abi)
47}
48
49#[proc_macro_attribute]
50pub fn def_extern_trait(args: TokenStream, input: TokenStream) -> TokenStream {
51    let abi = match parse_def_extern_trait_args(args) {
52        Ok(abi) => abi,
53        Err(error_msg) => {
54            bail!(Span::call_site(), error_msg);
55        }
56    };
57
58    let input = parse_macro_input!(input as ItemTrait);
59    let vis = input.vis.clone();
60    let mod_name = format_ident!("{}", input.ident.to_string().to_lowercase());
61    let crate_name_str = get_crate_name();
62    let prefix = make_prefix(&crate_name_str);
63
64    let mut fn_list = vec![];
65
66    for item in &input.items {
67        if let syn::TraitItem::Fn(func) = item {
68            let fn_name = func.sig.ident.clone();
69            let extern_fn_name = format_ident!("{}{}", prefix, func.sig.ident);
70            let attrs = &func.attrs;
71            let inputs = &func.sig.inputs;
72            let output = &func.sig.output;
73
74            // 生成参数名和类型
75            let mut param_names = vec![];
76            let mut param_types = vec![];
77
78            for input in inputs {
79                if let syn::FnArg::Typed(pat_type) = input {
80                    param_names.push(&pat_type.pat);
81                    param_types.push(&pat_type.ty);
82                }
83            }
84
85            let extern_abi = if abi == "rust" { "Rust" } else { "C" };
86
87            fn_list.push(quote! {
88                #(#attrs)*
89                pub fn #fn_name(#inputs) #output {
90                    unsafe extern #extern_abi {
91                        fn #extern_fn_name(#inputs) #output;
92                    }
93                    unsafe{ #extern_fn_name(#(#param_names),*) }
94                }
95            });
96        } else {
97            bail!(
98                item.span(),
99                "Only function items are allowed in extern traits"
100            );
101        }
102    }
103
104    let crate_name = format_ident!("{}", crate_name_str.replace("-", "_"));
105
106    let warn_fn_name = format_ident!(
107        "Trait_{}_in_crate_{}_need_impl",
108        input.ident,
109        crate_name_str.replace("-", "_")
110    );
111
112    let generated_macro = quote! {
113        #[macro_export]
114        macro_rules! impl_trait {
115            (impl $trait:ident for $type:ty { $($body:tt)* }) => {
116                #[#crate_name::impl_extern_trait(name = #crate_name_str, abi = #abi)]
117                impl $trait for $type {
118                    $($body)*
119                }
120
121                #[allow(snake_case)]
122                #[unsafe(no_mangle)]
123                extern "C" fn #warn_fn_name() { }
124            };
125        }
126    };
127
128    quote! {
129        pub use trait_ffi::impl_extern_trait;
130
131        #input
132
133        #vis mod #mod_name {
134            use super::*;
135            pub fn ____checker_do_not_use(){
136                unsafe extern "C" {
137                    fn #warn_fn_name();
138                }
139                unsafe { #warn_fn_name() };
140            }
141            #(#fn_list)*
142        }
143
144        #generated_macro
145    }
146    .into()
147}
148
149fn make_prefix(name: &str) -> String {
150    format!("__{}_", name.to_lowercase().replace("-", "_"))
151}
152
153fn parse_extern_trait_args(args: TokenStream) -> Result<(String, String), String> {
154    if args.is_empty() {
155        return Err(
156            "Missing parameters. Usage: #[impl_extern_trait(name=\"crate_name\", abi=\"c\")]"
157                .to_string(),
158        );
159    }
160
161    let args_str = args.to_string();
162    let mut name = None;
163    let mut abi = None;
164
165    // 简单解析 name="value", abi="value" 形式
166    let parts: Vec<&str> = args_str.split(',').collect();
167
168    for part in parts {
169        let part = part.trim();
170        if part.starts_with("name") {
171            if let Some(start) = part.find('"') {
172                if let Some(end) = part.rfind('"') {
173                    if start < end {
174                        name = Some(part[start + 1..end].to_string());
175                    }
176                }
177            }
178        } else if part.starts_with("abi") {
179            if let Some(start) = part.find('"') {
180                if let Some(end) = part.rfind('"') {
181                    if start < end {
182                        abi = Some(part[start + 1..end].to_string());
183                    }
184                }
185            }
186        }
187    }
188
189    let name = name.ok_or_else(|| {
190        "Missing name parameter. Usage: #[impl_extern_trait(name=\"crate_name\", abi=\"c\")]"
191            .to_string()
192    })?;
193    let abi = abi.unwrap_or_else(|| "c".to_string());
194
195    if abi != "c" && abi != "rust" {
196        return Err("Invalid abi parameter. Supported values: \"c\", \"rust\"".to_string());
197    }
198
199    Ok((name, abi))
200}
201
202#[proc_macro_attribute]
203pub fn impl_extern_trait(args: TokenStream, input: TokenStream) -> TokenStream {
204    let (crate_name_str, abi) = match parse_extern_trait_args(args) {
205        Ok((name, abi)) => (name, abi),
206        Err(error_msg) => {
207            bail!(Span::call_site(), error_msg);
208        }
209    };
210    let input = parse_macro_input!(input as ItemImpl);
211    let mut extern_fn_list = vec![];
212
213    let prefix = make_prefix(&crate_name_str);
214
215    let struct_name = input.self_ty.clone();
216    let trait_name = input.clone().trait_.unwrap().1;
217
218    for item in &input.items {
219        if let syn::ImplItem::Fn(func) = item {
220            let fn_name_raw = &func.sig.ident;
221            let fn_name = format_ident!("{prefix}{fn_name_raw}");
222            let inputs = &func.sig.inputs;
223            let output = &func.sig.output;
224
225            let extern_abi = if abi == "rust" { "Rust" } else { "C" };
226
227            let mut param_names = vec![];
228            let mut param_types = vec![];
229
230            for input in inputs {
231                if let syn::FnArg::Typed(pat_type) = input {
232                    param_names.push(&pat_type.pat);
233                    param_types.push(&pat_type.ty);
234                }
235            }
236
237            extern_fn_list.push(quote! {
238                #[unsafe(no_mangle)]
239                pub extern #extern_abi fn #fn_name(#inputs) #output {
240                    <#struct_name as #trait_name>::#fn_name_raw(#(#param_names),*)
241                }
242            });
243        }
244    }
245
246    quote! {
247        #input
248        #(#extern_fn_list)*
249    }
250    .into()
251}